Online Learning in Action — Building Real-Time Stock Forecasting on Lakehouse

Explore best practices for balancing model stability and adaptation in non-stationary price streams

Deep LearningData SciencePython

By Kuriko IWAI

Kuriko IWAI

Table of Contents

IntroductionWhat is Online LearningWhat We’ll Build
Data Pipeline Architecture
Hybrid Learning Workflow
Workflow in Action
Step 1. Building an ETL Pipeline on Medallion Lakehouse
Step 2. Running Initial Batch Learning
Step 3. Creating Micro Batches
Step 4. Configuring a WebSocket Server
Step 5. Running Online Learning
Test in Local
Step 6. Scheduling the Batch Learning with Airflow
Wrapping Up

Introduction

Online learning is a powerful learning scenario for training machine learning models on continuous streams of data, allowing them to adapt to new information instantly without requiring full retraining.

However, its practical implementation can be challenging because maintaining model stability is difficult when the underlying data distribution is non-stationary, leading to potential catastrophic forgetting.

In this article, I'll demonstrate best practices for balancing real-time adaptation with model stability by building an online learning system that uses a simple baseline model to forecast stock prices.

What is Online Learning

Online learning is one of the learning scenarios in machine learning where the model is trained sequentially as new data arrives.

The below diagram illustrates how online learning (bottom) works, in comparison with traditional batch learning (top):

Figure A. Comparison of online learning and batch learning (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure A. Comparison of online learning and batch learning (Created by Kuriko IWAI)

Batch learning loads and processes the entire dataset infrequently (e.g., weekly or monthly), making it suitable for tasks where model update latency isn't critical.

Online learning, on the other hand, processes streaming data as micro-batches (like N=1) in real-time, while immediately discarding the used data.

The model is continuously trained on the latest micro-batch within milliseconds, making it ideal for tasks that require immediate reflection of the latest data from the data streams like WebSocket, Kafka, or sensors.

This approach enable the model to:

  • Maintain high predictive accuracy over time, making the model more robust to concept drift where the underlying distribution of the input data shifts over time, and

  • Minimize computational overhead during training as it dose not need to load the entire dataset to train the model nor store the old data.

What We’ll Build

I’ll build a real-time stock prediction system where the baseline model (a simple multi-layered neural network) is continuously trained on live data streamed via WebSockets and forecasts the closing price for the next trading day.

The model is trained on both batch and online learning:

Figure B. Hybrid approach of batch and online learning (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure B. Hybrid approach of batch and online learning (Created by Kuriko IWAI)

The system is engineered using the Medallion Lakehouse Architecture on AWS S3 to support both batch and real-time learning.

Data Pipeline Architecture

  • Bronze Lakehouse stores raw, unstructured data directly.

  • Silver Layer checks data quality, structures data, and loads it into a Spark-processed Data Warehouse (DWH).

  • Gold Layer runs feature engineering using Spark and loads the data into the final DWH for consumption.

Hybrid Learning Workflow

The system utilizes a hybrid approach, combining initial Batch Learning with continuous Online Learning:

  1. Initial batch learning: The model is first trained using preprocessed historical data from the Gold layer.

  2. Continuous adaptation via online learning: The model then trains on a real-time data stream from the WebSocket server.

  3. Scheduled batch learning: Airflow manages a monthly scheduled batch re-training using the latest historical data.

The scheduled run ensures to mitigate the risk of catastrophic forgetting where the model overwrites parameters crucial for historical performance, ”forgetting” what it learned initially.

Workflow in Action

I’ll take the following steps to build the system:

  1. Build an ETL pipeline on medallion lakehouse.

  2. Run initial batch learning.

  3. Create micro batches for online learning.

  4. Configure a WebSocket server.

  5. Run online learning.

  6. Schedule the batch learning with Airflow.

Step 1. Building an ETL Pipeline on Medallion Lakehouse

The first step is to define an ETL pipeline on the medallion lakehouse architecture.

I’ll define the run_lakehouse function where the raw data is extracted from the public API, stored in Bronze layer, and then transformed into a Delta Table in Silver and Gold layers:

src/batch.py

1import os
2import argparse
3import pandas as pd
4from pyspark.sql import SparkSession
5
6import src.data_handling as data_handling
7from src import TICKER
8
9
10def run_lakehouse(ticker: str = TICKER, should_local_save: bool = False) -> tuple[pd.DataFrame, SparkSession]:
11    # extract
12    stock_price_data = data_handling.extract(ticker=ticker)
13
14    # bronze layer
15    bronze_s3_path = data_handling.bronze.load(data=stock_price_data, ticker=ticker)
16
17    # start the spark session
18    spark = data_handling.config_and_start_spark_session()
19
20    # silver layer
21    bronze_delta_table = spark.read.json(bronze_s3_path, multiLine=True)
22    silver_df = data_handling.silver.transform(delta_table=bronze_delta_table, spark=spark)
23    silver_s3_path = data_handling.silver.load(df=silver_df, ticker=ticker)
24
25    # gold layer
26    silver_delta_table = data_handling.retrieve_delta_table(spark=spark, s3_path=silver_s3_path)
27    gold_df = data_handling.gold.transform(delta_table=silver_delta_table, spark=spark)
28    gold_s3_path = data_handling.gold.load(df=gold_df, ticker=ticker)
29
30    df = gold_df.toPandas()
31
32    return df, spark
33
34# execution
35if __name__ == '__main__':
36    # args
37    parser = argparse.ArgumentParser(description="run batch learning")
38    parser.add_argument('--ticker', type=str, default=TICKER, help="ticker")
39    args = parser.parse_args()
40
41    # execute etl
42    df, spark = run_lakehouse(ticker=args.ticker)
43    df.info()
44
45    # stop spark session
46    spark.stop()
47

As showed in Figure B, Silver and Gold layers define the schema to transform unstructured data in a JSON format into structured one.

Silver layer performs imputation, cleaning, and data type casting:

src/data_handling/silver.py

1import os
2import pandas as pd
3from deltalake import DeltaTable, write_deltalake
4from pyspark.sql.functions import col, expr, to_date, unix_millis
5from pyspark.sql.types import StructType, StructField, DateType, FloatType, IntegerType, LongType
6
7SILVER_SCHEMA = StructType([
8    StructField('dt', DateType(), False),
9    StructField('open', FloatType(), False),
10    StructField('high', FloatType(), False),
11    StructField('low', FloatType(), False),
12    StructField('close', FloatType(), False),
13    StructField('volume', IntegerType(), False),
14    StructField('timestamp_in_ms', LongType(), False)
15])
16
17def transform(delta_table, spark):
18    # get all the date-like column names
19    date_columns = delta_table.columns
20
21    # build the expression for the stack function
22    stack_expr = f'stack({len(date_columns)}, '
23    for date_col in date_columns:
24        stack_expr += f'"{date_col}", `{date_col}`, '
25
26    stack_expr = stack_expr.strip(', ') + ')'
27
28    # use stack to unpivot the data from wide to tall format
29    _silver_df = delta_table.select(expr(stack_expr).alias('dt_string', 'values'))
30
31    # process the unpivoted df to cast types and rename columns
32    silver_df = _silver_df.select(
33        to_date(col('dt_string'), 'yyyy-MM-dd').alias('dt'),
34        col('values').getItem('1. open').cast('float').alias('open'),
35        col('values').getItem('2. high').cast('float').alias('high'),
36        col('values').getItem('3. low').cast('float').alias('low'),
37        col('values').getItem('4. close').cast('float').alias('close'),
38        col('values').getItem('5. volume').cast('integer').alias('volume'),
39        unix_millis(col('dt_string').cast('timestamp')).alias('timestamp_in_ms'),
40    ).where(col('dt').isNotNull())
41
42    # finalize df
43    silver_df = spark.createDataFrame(silver_df.collect(), schema=SILVER_SCHEMA)
44    return silver_df
45

While Gold layer performs feature engineering:

src/data_handling/gold.py

1import os
2import pandas as pd
3from deltalake import DeltaTable, write_deltalake
4import pyspark.sql.functions as F
5from pyspark.sql.window import Window
6from pyspark.sql.types import StructType, StructField, DateType, FloatType, IntegerType, LongType
7
8GOLD_SCHEMA = StructType([
9    StructField('dt', DateType(), False),
10    StructField('open', FloatType(), False),
11    StructField('high', FloatType(), False),
12    StructField('low', FloatType(), False),
13    StructField('close', FloatType(), False),
14    StructField('volume', IntegerType(), False),
15    StructField('timestamp_in_ms', LongType(), False),
16    StructField('ave_open', FloatType(), False),
17    StructField('ave_high', FloatType(), False),
18    StructField('ave_low', FloatType(), False),
19    StructField('ave_close', FloatType(), False),
20    StructField('total_volume', IntegerType(), False),
21    StructField('30_day_ma_close', FloatType(), False),
22    StructField('year', IntegerType(), False),
23    StructField('month', IntegerType(), False),
24    StructField('date', IntegerType(), False),
25])
26
27def transform(delta_table, spark, should_filter: bool = False):
28    # moving ave. requires sorting by date and looking back 29 rows.
29    window_spec = Window.orderBy('dt').rowsBetween(-29, 0)
30    _gold_df = delta_table.withColumn('30_day_ma_close', F.avg(F.col('close')).over(window_spec))
31
32    # add temporal features
33    _gold_df = _gold_df.withColumn('year', F.year(F.col('dt')).cast(IntegerType()))
34    _gold_df = _gold_df.withColumn('month', F.month(F.col('dt')).cast(IntegerType()))
35    _gold_df = _gold_df.withColumn('date', F.dayofmonth(F.col('dt')).cast(IntegerType()))
36
37    # log transform close
38    _gold_df = _gold_df.withColumn('close', F.log1p(F.col('close')))
39
40    # select final columns
41    final_cols = [
42        'dt', 'open', 'high', 'low', 'close', 'volume', 'timestamp_in_ms',
43        F.col('open').alias('ave_open'), F.col('high').alias('ave_high'),
44        F.col('low').alias('ave_low'), F.col('close').alias('ave_close'),
45        F.col('volume').alias('total_volume'), '30_day_ma_close', 'year', 'month', 'date'
46    ]
47    gold_df = _gold_df.select(*final_cols)
48
49    # finalize
50    gold_df = gold_df.orderBy(F.col('dt').asc())
51    return gold_df
52

Step 2. Running Initial Batch Learning

After transforming the raw data, I’ll first define the BaselineModel class where the baseline model is initialized, and then add the initial_batch_learning method to the class.

The initial batch learning is key to create a good starting point for online learning.

So, I'll incorporate early stopping logic where the training loop is stopped after observing little improvement of the MSE loss for consecutive 10 epochs.

The trained model, preprocessor, and optimizer are stored in the model storage to enable the system to load them during online learning.

src/model/baseline.py

1import os
2import math
3import joblib
4import numpy as np
5import pandas as pd
6import torch
7import torch.nn as nn
8import torch.optim as optim
9from torch.utils.data import DataLoader, TensorDataset
10from typing import List, Dict, Any
11from sklearn.model_selection import train_test_split
12from sklearn.impute import SimpleImputer
13from sklearn.preprocessing import StandardScaler
14from sklearn.compose import ColumnTransformer
15from sklearn.pipeline import Pipeline
16from category_encoders import BinaryEncoder
17
18import src.data_handling as data_handling
19import src.batch as batch
20from src import TICKER
21from src._utils import main_logger
22
23
24class BaselineModel:
25    def __init__(
26            self,
27            input_size: int = 10,
28            hidden_size: int = 64,
29            ticker: str = TICKER,
30
31            # batch learning metrics
32            batch_size: int = 32,
33            min_delta: float = 1e-3,
34            patience: int = 10,
35            lr: float = 1e-4,
36        ):
37        self.ticker = ticker
38
39        # model
40        self.input_size = input_size
41        self.hidden_size = hidden_size
42        self.model = nn.Sequential(nn.Linear(self.input_size, self.hidden_size),nn.ReLU(), nn.Linear(self.hidden_size, 1))
43
44        # spark session
45        self.spark = data_handling.config_and_start_spark_session()
46
47        # preprocess
48        self.num_cols = ['open', 'high', 'low', 'volume', 'ave_open', 'ave_high', 'ave_low', 'ave_close', 'total_volume', '30_day_ma_close', 'timestamp_in_ms']
49        self.cat_cols = ['dt', 'year', 'month', 'date']
50        self._setup_preprocessor()
51
52        # training
53        self.device_type = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
54        self.device = torch.device(self.device_type)
55
56        self.batch_size = batch_size
57        self.min_delta = min_delta
58        self.patience = patience
59
60        self.criterion = nn.MSELoss() # mse for logged closing price (target val)
61        self.lr = lr
62        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
63        self.steps_trained = 0
64
65        # filepath
66        self.model_filepath = os.path.join('artifacts', 'base_model', self.ticker, 'model.pt')
67        self.preprocessor_filepath = os.path.join('artifacts', 'preprocessor', self.ticker, 'preprocessor.pkl')
68
69
70    def initial_batch_learning(self, num_epochs: int = 5000):
71        # create pandads df from the etl pipeline
72        df, _ = batch.run_lakehouse(ticker=self.ticker)
73
74        # create train, validation, and test dataset from the df
75        y = df['close']
76        X = df.copy().drop(columns='close', axis=1)
77
78        X_tv, X_test, y_tv, _ = train_test_split(X, y, test_size=2000, shuffle=False)
79        X_train, X_val, y_train, y_val = train_test_split(X_tv, y_tv, test_size=2000, shuffle=False)
80
81        # preprocess the input features
82        X_train = self.preprocessor.fit_transform(X_train)
83        X_val = self.preprocessor.transform(X_val)
84        X_test = self.preprocessor.transform(X_test)
85
86        # save trained preprocessor for later use on online learning
87        self._save_trained_preprocessor()
88
89        # create tensor data loader
90        train_data_loader = self._create_tensor_data_loader(X=X_train, y=y_train)
91        val_data_loader = self._create_tensor_data_loader(X=X_val, y=y_val)
92
93        # reconstruct the baseline model
94        self.input_size = X_train.shape[-1]
95        self.model = nn.Sequential(
96            nn.Linear(self.input_size, self.hidden_size),
97            nn.ReLU(),
98            nn.Linear(self.hidden_size, 1)
99        ).to(self.device)
100
101        # reconstruct the optimizer
102        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
103
104        # start training with validation and early stopping
105        best_val_loss = float('inf')
106        epochs_no_improve = 0
107        for epoch in range(num_epochs):
108            main_logger.info(f'... start epoch {epoch + 1} ...')
109            self.model.train()
110            for batch_X, batch_y in train_data_loader:
111                batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)
112                self.optimizer.zero_grad()
113
114                try:
115                    with torch.autocast(device_type=self.device_type):
116                        outputs = self.model(batch_X)
117                        loss = self.criterion(outputs, batch_y)
118                        if torch.any(torch.isnan(outputs)) or torch.any(torch.isinf(outputs)):
119                            main_logger.error('pytorch model returns nan or inf. break the training loop.')
120                            break
121
122                        if not math.isfinite(loss.item()):
123                            main_logger.error('loss is nan or inf. break the training loop.')
124                            break
125                    loss.backward()  
126                    nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
127                    self.optimizer.step()
128
129                except:
130                    outputs = self.model(batch_X)
131                    loss = self.criterion(outputs, batch_y)
132                    loss.backward()
133                    self.optimizer.step()
134
135            if (epoch + 1) % 10 == 0: main_logger.info(f"epoch [{epoch+1}/{num_epochs}], loss: {loss.item():.4f}")
136
137            # validate on a validation dataset (subset of the entire training dataset)
138            self.model.eval()
139            val_loss = 0.0
140
141            # switch the grad mode
142            with torch.inference_mode():
143                for batch_X_val, batch_y_val in val_data_loader:
144                    batch_X_val, batch_y_val = batch_X_val.to(self.device), batch_y_val.to(self.device)
145                    outputs_val = self.model(batch_X_val)
146                    val_loss += self.criterion(outputs_val, batch_y_val).item()
147
148            val_loss /= len(val_data_loader)
149
150            # early stopping
151            if val_loss < best_val_loss - self.min_delta:
152                best_val_loss = val_loss
153                epochs_no_improve = 0
154            else:
155                epochs_no_improve += 1
156                if epochs_no_improve >= self.patience:
157                    main_logger.info(f'early stopping at epoch {epoch + 1} as validation loss did not improve for {self.patience} epochs.')
158                    break
159
160        # create checkpoint data and store it for later use during the online learning
161        checkpoint = dict(
162            state_dict=self.model.state_dict(),
163            input_dim=X_train.shape[1],
164            optimizer_name='adam',
165            optimizer_state_dict=self.optimizer.state_dict(),
166            batch_size=self.batch_size,
167            lr=self.lr,
168        )
169        self._save_model(checkpoint=checkpoint)
170
171# execute
172BaselineModel().initial_batch_learning()
173

Step 3. Creating Micro Batches

Next, I’ll create a micro batch for online training.

This step has two key points.

1. Experience Replay (ER) for Stability

Experience Replay (ER) is a crucial technique in Deep Reinforcement Learning (DRL) where the system gives the agent a memory so that it can learn from past interaction as well as recent ones.

In this project, I’ll extend this ER concept to the micro batch by randomly sampling data from the replay buffer (memory) where both old and new data is stored.

This ensures stable model updates, while preventing catastrophic forgetting.

2. Recency-Biased Sampling for Concept Drift Hedge

Stock price data is typical non-stationary data whose statistical properties like the mean change over time, which causes concept drift.

When randomly sampling data from the replay buffer, the system selects the most recent data first with higher probability, instead of selecting all data at equal chance.

This ensures the model focuses on the newest trend without completely losing the stabilizing context of the older, general data.

src/websocket.py

1import random
2from collections import deque
3from typing import Dict, Any,List
4
5
6def create_micro_batch(er_buffer: deque, batch_size: int) -> List[Dict[str, Any]]:
7    """creates a micro batch for online learning with experience replay and recency-biased sampling"""
8
9    # not enough samples to create a micro batch. return all available data
10    if len(er_buffer) < batch_size: return list(er_buffer)
11
12    # recency bias - include the latest sample
13    latest_experience = er_buffer[-1]
14    batch = [latest_experience]
15
16    # sample remaining items from the past experiences in the er buffer. (using a simple random sample)
17    num_samples_to_add = batch_size - 1
18    past_experiences = list(er_buffer)[:-1]
19    sampled_experiences = random.sample(past_experiences, num_samples_to_add)
20
21    # add past experiences to the batch
22    batch.extend(sampled_experiences)
23
24    # sort samples in order of timestamp
25    sorted_batch = sorted(batch, key=lambda d: d['timestamp_in_ms'])
26    return sorted_batch
27

Step 4. Configuring a WebSocket Server

Next, I’ll configure a WebSocket (WS) server to stream the micro batches.

A WS server is key to stream raw data from the data source, Yahoo Finance.

In this project, I’ll imitate a true real-time data stream by calling the YahooFinance API every 5 seconds via the server.

In the main function, the streaming_task is created by executing the streaming function where the system calls the I/O layer to fetch the latest raw data asynchronously from Yahoo Finance API, and then transform it to a micro batch using the create_micro_batch function defined in Step 3.

Then, the WS server starts to stream data until it is canceled when the streaming_task is completed or cancelled.

Note that providers like Polygon, Finnhub, and TwelveData offer paid websocket services ($30 - $200 per month), which would be an option for further refinement.

src/websocket.py

1import argparse
2import asyncio
3import json
4from collections import deque
5from typing import Dict, Any, Set
6import websockets
7
8# global set to hold all active websocker connections (clients)
9CLIENTS: Set[websockets.ClientConnection] = set()
10
11# fetch data from yfinance
12def fetch_data(ticker: str) -> Dict[str, Any]:
13    # ticker obj
14    stock = yf.Ticker(ticker)
15
16    # fetch the latest 1-minute historical data
17    data = stock.history(period="1d", interval="1m")
18
19    # get the latest row
20    latest_data = data.iloc[-1]
21    return {
22        'open': latest_data.get('Open', 0.0),
23        'high': latest_data.get('High', 0.0),
24        'low': latest_data.get('Low', 0.0),
25        'close': latest_data.get('Close', 0.0),
26        'volume': int(latest_data.get('Volume', 0.0)),
27        'dt': datetime.datetime.now(),
28        "timestamp_in_ms": int(time.time() * 1000)
29    }
30
31
32# poll the data source and broadcast the updates to the clients
33async def streaming(ticker: str, interval_seconds: int, er_buffer, batch_size: int):
34    count = 0
35    while True:
36        count += 1
37        # call the io layer to fetch the latest raw data asynchronously. using asyncio.to_thread() avoids blocking the main asynchronous event loop 
38        new_data = await asyncio.to_thread(fetch_data, ticker)
39
40        if new_data:
41            # store the new data in the er buffer
42            er_buffer.append(new_data)
43            main_logger.info(f"... # {count} appended new data to the er buffer (buffer size: {len(er_buffer)}) ...")
44
45            if len(er_buffer) >= batch_size:
46                # create a micro batch
47                current_batch = create_micro_batch(er_buffer=er_buffer, batch_size=batch_size)
48                message = json.dumps(current_batch, cls=CustomJsonEncoder)
49
50                # broadcast the batch to the clients
51                if CLIENTS:
52                    try:
53                        websockets.broadcast(CLIENTS, message)
54                        main_logger.info(f"... successfully broadcasted the current micro batch ({len(current_batch)} items) to {len(CLIENTS)} clients.")
55                    except Exception as e: main_logger.error(f"... error during broadcast: {e}")
56                else:
57                    main_logger.info("... current batch is ready, but no active clients to broadcast to ...")
58            else:
59                 main_logger.info(f"... buffer filling up ({len(er_buffer)}/{batch_size}) ...")
60
61        # wait for the next polling interval
62        await asyncio.sleep(interval_seconds)
63
64
65# add or remove connnections (clients) to the websocket server
66async def handler(websocket: websockets.ClientConnection):
67    await register(websocket)
68    try: await websocket.wait_closed()
69    finally: await unregister(websocket)
70
71
72async def main(ticker: str, interval: int, port: int, host: str, er_buffer, batch_size: int):
73    # start the core data fetching/pushing task
74    streaming_task = asyncio.create_task(
75        streaming(ticker=ticker, interval_seconds=interval, er_buffer=er_buffer, batch_size=batch_size)
76    )
77    try:
78        async with websockets.serve(handler, host, port):
79            await asyncio.Future() # keeps the server running indefinitely
80    except Exception as e:
81        main_logger.error(f"... server failed to start: {e}")
82    finally:
83         # clean up the streaming task when the server stops
84        streaming_task.cancel()
85
86
87if __name__ == '__main__':
88    # handle args
89    POLLING_INTERVAL = 1
90    ER_MAX_BUFFER_SIZE = 300  # for er
91    MICRO_BATCH_SIZE = 4
92
93    parser = argparse.ArgumentParser(description='creating micro batch for online learning')
94    parser.add_argument('--ticker', type=str, default=TICKER, help=f"ticker. default = {TICKER}")
95    parser.add_argument('--interval', type=int, default=POLLING_INTERVAL, help=f"polling interval. default = {POLLING_INTERVAL}")
96    parser.add_argument('--port', type=int, default=PORT, help=f"port. default = {PORT}")
97    parser.add_argument('--host', type=str, default=HOST, help=f"host. default = {HOST}")
98    parser.add_argument('--er_max_buffer_size', type=int, default=ER_MAX_BUFFER_SIZE, help=f"max number of old samples to store. default = {ER_MAX_BUFFER_SIZE}")
99    parser.add_argument('--batch_size', type=int, default=MICRO_BATCH_SIZE, help=f"micro batch size for online learning. default = {MICRO_BATCH_SIZE}")
100    args = parser.parse_args()
101
102    ER_BUFFER: deque[Dict[str, Any]] = deque(maxlen=args.er_max_buffer_size)
103
104    # start running websocket server
105    try:
106        asyncio.run(
107            main(ticker=args.ticker, interval=args.interval, port=args.port, host=args.host, er_buffer=ER_BUFFER, batch_size=args.batch_size)
108        )
109    except KeyboardInterrupt: main_logger.info("\n--- server shutting down ---")
110    except Exception as e: main_logger.error(f"... an unexpected error occurred: {e}")
111

Step 5. Running Online Learning

Next, I’ll configure a client to receive micro batches and perform online learning.

I’ll first add the online_learning method to the BaselineModel class created in Step 2.

During online learning, the trained model, optimizer, and preprocessor are first loaded from storage.

The online_learning method then performs feature engineering and preprocessing on the data in the micro-batch, and finally uses the transformed data to train the model.

src/model/baseline.py

1import os
2import math
3import joblib
4import numpy as np
5import pandas as pd
6import torch
7import torch.nn as nn
8import torch.optim as optim
9from torch.utils.data import DataLoader, TensorDataset
10from typing import List, Dict, Any
11from sklearn.model_selection import train_test_split
12from sklearn.impute import SimpleImputer
13from sklearn.preprocessing import StandardScaler
14from sklearn.compose import ColumnTransformer
15from sklearn.pipeline import Pipeline
16from category_encoders import BinaryEncoder
17
18
19class BaselineModel:
20    def __init__():
21        # same as Step 2
22
23    def initial_batch_learning(self, num_epochs: int = 5000):
24        # same as Step 2
25
26    def online_learning(self, current_batch: List[Dict[str, Any]]):     
27        if not current_batch: return
28
29        # data transformation
30        df = self.spark.createDataFrame(current_batch, schema=data_handling.silver.SILVER_SCHEMA)
31        df = data_handling.gold.transform(delta_table=df, spark=self.spark)
32
33        # create training dataset
34        df_pandas = df.toPandas()
35        y = df_pandas['close']
36        X = df_pandas.copy().drop(columns='close', axis=1)
37
38        try:
39            self.preprocessor = self._load_trained_preprocessor()
40            X = self.preprocessor.transform(X)
41        except:
42            X = self.preprocessor.fit_transform(X)
43
44        # convert to pytorch tensors
45        X = torch.from_numpy(X).float()
46        X = X.to(self.device)
47        y = torch.from_numpy(y.to_numpy()).float().reshape(-1, 1)
48        y = y.to(self.device)
49
50        # load the model and optimizer
51        self._load_model_and_optimizer() 
52
53        # start online learning
54        self.optimizer.zero_grad()
55        y_pred = self.model(X)
56        loss = self.criterion(y_pred, y)
57        loss.backward()
58        self.optimizer.step()
59        self.steps_trained += 1
60
61        # pred result
62        epsilon = 0
63        try: y_pred_actual = np.exp(y_pred + epsilon)
64        except: y_pred_actual = np.exp(y_pred.cpu().detach().numpy() + epsilon)
65
66        main_logger.info(f"... pytorch model trained (step {self.steps_trained}) ...\n - batch size {len(current_batch)}\n - mse loss {loss.item():.6f}\n - new closing price predicted: {np.mean(y_pred_actual):.2f}")
67

Then, I’ll configure the client_handler function where the client continuously receives micro-batches from the WS server and performs online learning.

The main function asynchronously invokes this client_handler function to ensure continuous online learning.

src/client.py

1import json
2import datetime
3import asyncio
4import websockets
5
6from src.model import BaselineModel
7from src import PORT, HOST
8from src._utils import main_logger
9
10
11async def client_handler(websocket_uri: str, model: BaselineModel, max_retries: int = 5, delay: int = 1):
12    for attempt in range(max_retries):
13        try:
14            async with websockets.connect(websocket_uri, ping_interval=5) as websocket:
15                # loop to retrieve current batch
16                async for message in websocket:
17                    # deserialize the current batch from json string to list
18                    current_batch = json.loads(message)
19
20                    if current_batch and isinstance(current_batch, list):
21                        # cast dtypes
22                        for record in current_batch:
23                            if isinstance(record['dt'], str):
24                                record['dt'] = datetime.datetime.fromisoformat(record['dt'])
25
26                        # start online learning
27                        model.online_learning(current_batch=current_batch)
28                    else:
29                        main_logger.info(f"... received empty or invalid data: {message} ...")
30
31        except websockets.exceptions.ConnectionClosedOK:
32            main_logger.info("... connection closed by server ...")
33            break
34
35        except ConnectionRefusedError:
36            main_logger.error(f"... connection refused. retrying in {delay}s (attempt {attempt + 1}/{max_retries}) ...")
37
38            # exponential backoff for connection attempts
39            await asyncio.sleep(delay * (2 ** attempt))
40
41        except Exception as e:
42            main_logger.error(f"... unhandled connection error: {e}")
43            break
44
45    main_logger.info("... client shutting down ...")
46
47
48def main():
49    try: 
50        websocket_uri = f"ws://{HOST}:{PORT}"
51        asyncio.run(client_handler(websocket_uri=websocket_uri, model=BaselineModel()))
52
53    except KeyboardInterrupt:
54        main_logger.info("\n--- client interrupted and shutting down ---")
55
56
57if __name__ == '__main__':
58    main()
59

Test in Local

Now, let us run the WS server and client in local:

1$uv run src/websocket.py
2
3$uv run src/client.py --ticker <TICKER>
4

Replace <TICKER> with a ticker of your choice.

Here are the terminal output snippets:

12025-10-26 20:38:22,495 - root - INFO - ... base model loaded and reconstructed ...
22025-10-26 20:38:22,496 - root - INFO - ... pytorch model trained (step 347) ...
3 - batch size 4
4 - mse loss 0.105518
5 - new closing price predicted: 259.13
6
72025-10-26 20:38:23,435 - root - INFO - ... calculated gold layer features for 2025-10-26.
82025-10-26 20:38:23,537 - root - INFO - ... base model loaded and reconstructed ...
92025-10-26 20:38:23,539 - root - INFO - ... pytorch model trained (step 348) ...
10 - batch size 4
11 - mse loss 0.105518
12 - new closing price predicted: 259.13
13

Step 6. Scheduling the Batch Learning with Airflow

The last step is to schedule the batch learning monthly with Airflow.

Airflow is an open-source platform used to programmatically author, schedule, and monitor workflows as Directed Acyclic Graphs (DAGs):

  • Directed: The workflow has a clear, one-way path from start to finish.

  • Acyclic: There are no loops, meaning a task cannot trigger an upstream task. This prevents infinite processing loops.

  • Graph: The structure is a collection of nodes (Tasks) and edges (Dependencies).

This process involves data quality check by the check_data_quality function and monthly model training with the full dataset:

src/dag/dag.py

1import datetime
2from airflow import DAG
3from airflow.providers.standard.operators.python import PythonOperator
4from pydeequ.checks import Check, CheckLevel
5from pydeequ.verification import VerificationSuite, VerificationResult
6
7from src.batch import run_lakehouse
8from src.data_handling.spark import config_and_start_spark_session
9
10# define the airflow dag
11default_args = {
12    'owner': 'kuriko iwai',
13    'start_date': datetime.datetime.now(),
14    'retries': 1,
15}
16
17
18def check_data_quality():
19    # initialize a sparksession to use deequ
20    spark = config_and_start_spark_session(session_name='deequ_quality_check')
21
22    # fetch delta table in the gold layer
23    gold_data_path = "s3a://ml-stockprice-pred/data/gold"
24    gold_delta_table = spark.read.format("delta").load(gold_data_path)
25
26    # use deequ to run quality check
27    check_result = (VerificationSuite(spark)
28        .onData(gold_delta_table)
29        .addCheck(
30            Check(spark_session=spark, level=CheckLevel.Warning, description="data quality check")
31            .hasSize(lambda s: s > 0)  # check if the df is not empty
32            .hasCompleteness("close", lambda s: s == 1.0)  # check for missing vals
33            .hasCompleteness("volume", lambda s: s == 1.0) # check for missing vals
34            .hasMin("open", lambda x: x > 0) # check if val is positive
35            .hasMax("volume", lambda x: x < 10000000000) # check for max val
36        )
37        .run()
38    )
39
40    # convert the deequ result to a spark df for analysis and logging
41    result_df = VerificationResult.checkResultsAsDataFrame(spark, check_result, pandas=False)
42    result_df.show(truncate=False) # type: ignore
43
44    # stop the Spark session
45    spark.stop()
46
47
48with DAG(
49    dag_id='pyspark_data_handling',
50    default_args=default_args,
51    description='A monthly batch learning',
52    schedule='@monthly',
53    catchup=False,
54) as dag:
55
56    # data quality check
57    data_quality_check_task = PythonOperator(
58        task_id='data_quality_check',
59        python_callable=check_data_quality,
60        trigger_rule='all_success',
61        depends_on_past=True, # ensures it only runs after the previous successful run
62        schedule='@monthly',
63    )
64
65    # batch learning
66    run_batch_learning = PythonOperator(
67        task_id='elt_lakehouse',
68        python_callable=BaselineModel().initial_batch_learning,
69        schedule='@monthly',
70    )
71
72    # set task dependencies
73    data_quality_check_task >> run_batch_learning  # type: ignore
74

And that’s all for the workflow.

All code is available in my Github repo.

Wrapping Up

Online learning is powerful learning scenario for systems demanding continuous, real-time adaptation to non-stationary data.

Yet it presents fundamental challenges like catastrophic forgetting and concept drift.

In our demonstration, we successfully addressed these challenges by implementing a hybrid approach of combining initial batch training with recency-biased Experience Replay during online updates.

This enables the model to effectively balance adaptation to new stock data with the stability preserved from historical knowledge.

Moving forward, integrating more sophisticated techniques like drift detection algorithms or exploring more complex model architectures would further refine the system.

Related Books for Further Understanding

These books cover the wide range of theories and practices; from fundamentals to PhD level.

Linear Algebra Done Right

Linear Algebra Done Right

Foundations of Machine Learning, second edition (Adaptive Computation and Machine Learning series)

Foundations of Machine Learning, second edition (Adaptive Computation and Machine Learning series)

Designing Machine Learning Systems: An Iterative Process for Production-Ready Applications

Designing Machine Learning Systems: An Iterative Process for Production-Ready Applications

Machine Learning Design Patterns: Solutions to Common Challenges in Data Preparation, Model Building, and MLOps

Machine Learning Design Patterns: Solutions to Common Challenges in Data Preparation, Model Building, and MLOps

Share What You Learned

Kuriko IWAI, " Online Learning in Action — Building Real-Time Stock Forecasting on Lakehouse" in Kernel Labs

https://kuriko-iwai.com/online-learning-in-action

Looking for Solutions?

Written by Kuriko IWAI. All images, unless otherwise noted, are by the author. All experimentations on this blog utilize synthetic or licensed data.