Online Learning in Action — Building Real-Time Stock Forecasting on Lakehouse
Explore best practices for balancing model stability and adaptation in non-stationary price streams
By Kuriko IWAI

Table of Contents
IntroductionWhat is Online LearningWhat We’ll BuildIntroduction
Online learning is a powerful learning scenario for training machine learning models on continuous streams of data, allowing them to adapt to new information instantly without requiring full retraining.
However, its practical implementation can be challenging because maintaining model stability is difficult when the underlying data distribution is non-stationary, leading to potential catastrophic forgetting.
In this article, I'll demonstrate best practices for balancing real-time adaptation with model stability by building an online learning system that uses a simple baseline model to forecast stock prices.
What is Online Learning
Online learning is one of the learning scenarios in machine learning where the model is trained sequentially as new data arrives.
The below diagram illustrates how online learning (bottom) works, in comparison with traditional batch learning (top):

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure A. Comparison of online learning and batch learning (Created by Kuriko IWAI)
Batch learning loads and processes the entire dataset infrequently (e.g., weekly or monthly), making it suitable for tasks where model update latency isn't critical.
Online learning, on the other hand, processes streaming data as micro-batches (like N=1) in real-time, while immediately discarding the used data.
The model is continuously trained on the latest micro-batch within milliseconds, making it ideal for tasks that require immediate reflection of the latest data from the data streams like WebSocket, Kafka, or sensors.
This approach enable the model to:
Maintain high predictive accuracy over time, making the model more robust to concept drift where the underlying distribution of the input data shifts over time, and
Minimize computational overhead during training as it dose not need to load the entire dataset to train the model nor store the old data.
What We’ll Build
I’ll build a real-time stock prediction system where the baseline model (a simple multi-layered neural network) is continuously trained on live data streamed via WebSockets and forecasts the closing price for the next trading day.
The model is trained on both batch and online learning:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure B. Hybrid approach of batch and online learning (Created by Kuriko IWAI)
The system is engineered using the Medallion Lakehouse Architecture on AWS S3 to support both batch and real-time learning.
◼ Data Pipeline Architecture
Bronze Lakehouse stores raw, unstructured data directly.
Silver Layer checks data quality, structures data, and loads it into a Spark-processed Data Warehouse (DWH).
Gold Layer runs feature engineering using Spark and loads the data into the final DWH for consumption.
◼ Hybrid Learning Workflow
The system utilizes a hybrid approach, combining initial Batch Learning with continuous Online Learning:
Initial batch learning: The model is first trained using preprocessed historical data from the Gold layer.
Continuous adaptation via online learning: The model then trains on a real-time data stream from the WebSocket server.
Scheduled batch learning: Airflow manages a monthly scheduled batch re-training using the latest historical data.
The scheduled run ensures to mitigate the risk of catastrophic forgetting where the model overwrites parameters crucial for historical performance, ”forgetting” what it learned initially.
Workflow in Action
I’ll take the following steps to build the system:
Build an ETL pipeline on medallion lakehouse.
Run initial batch learning.
Create micro batches for online learning.
Configure a WebSocket server.
Run online learning.
Schedule the batch learning with Airflow.
◼ Step 1. Building an ETL Pipeline on Medallion Lakehouse
The first step is to define an ETL pipeline on the medallion lakehouse architecture.
I’ll define the run_lakehouse function where the raw data is extracted from the public API, stored in Bronze layer, and then transformed into a Delta Table in Silver and Gold layers:
src/batch.py
1import os
2import argparse
3import pandas as pd
4from pyspark.sql import SparkSession
5
6import src.data_handling as data_handling
7from src import TICKER
8
9
10def run_lakehouse(ticker: str = TICKER, should_local_save: bool = False) -> tuple[pd.DataFrame, SparkSession]:
11 # extract
12 stock_price_data = data_handling.extract(ticker=ticker)
13
14 # bronze layer
15 bronze_s3_path = data_handling.bronze.load(data=stock_price_data, ticker=ticker)
16
17 # start the spark session
18 spark = data_handling.config_and_start_spark_session()
19
20 # silver layer
21 bronze_delta_table = spark.read.json(bronze_s3_path, multiLine=True)
22 silver_df = data_handling.silver.transform(delta_table=bronze_delta_table, spark=spark)
23 silver_s3_path = data_handling.silver.load(df=silver_df, ticker=ticker)
24
25 # gold layer
26 silver_delta_table = data_handling.retrieve_delta_table(spark=spark, s3_path=silver_s3_path)
27 gold_df = data_handling.gold.transform(delta_table=silver_delta_table, spark=spark)
28 gold_s3_path = data_handling.gold.load(df=gold_df, ticker=ticker)
29
30 df = gold_df.toPandas()
31
32 return df, spark
33
34# execution
35if __name__ == '__main__':
36 # args
37 parser = argparse.ArgumentParser(description="run batch learning")
38 parser.add_argument('--ticker', type=str, default=TICKER, help="ticker")
39 args = parser.parse_args()
40
41 # execute etl
42 df, spark = run_lakehouse(ticker=args.ticker)
43 df.info()
44
45 # stop spark session
46 spark.stop()
47
As showed in Figure B, Silver and Gold layers define the schema to transform unstructured data in a JSON format into structured one.
Silver layer performs imputation, cleaning, and data type casting:
src/data_handling/silver.py
1import os
2import pandas as pd
3from deltalake import DeltaTable, write_deltalake
4from pyspark.sql.functions import col, expr, to_date, unix_millis
5from pyspark.sql.types import StructType, StructField, DateType, FloatType, IntegerType, LongType
6
7SILVER_SCHEMA = StructType([
8 StructField('dt', DateType(), False),
9 StructField('open', FloatType(), False),
10 StructField('high', FloatType(), False),
11 StructField('low', FloatType(), False),
12 StructField('close', FloatType(), False),
13 StructField('volume', IntegerType(), False),
14 StructField('timestamp_in_ms', LongType(), False)
15])
16
17def transform(delta_table, spark):
18 # get all the date-like column names
19 date_columns = delta_table.columns
20
21 # build the expression for the stack function
22 stack_expr = f'stack({len(date_columns)}, '
23 for date_col in date_columns:
24 stack_expr += f'"{date_col}", `{date_col}`, '
25
26 stack_expr = stack_expr.strip(', ') + ')'
27
28 # use stack to unpivot the data from wide to tall format
29 _silver_df = delta_table.select(expr(stack_expr).alias('dt_string', 'values'))
30
31 # process the unpivoted df to cast types and rename columns
32 silver_df = _silver_df.select(
33 to_date(col('dt_string'), 'yyyy-MM-dd').alias('dt'),
34 col('values').getItem('1. open').cast('float').alias('open'),
35 col('values').getItem('2. high').cast('float').alias('high'),
36 col('values').getItem('3. low').cast('float').alias('low'),
37 col('values').getItem('4. close').cast('float').alias('close'),
38 col('values').getItem('5. volume').cast('integer').alias('volume'),
39 unix_millis(col('dt_string').cast('timestamp')).alias('timestamp_in_ms'),
40 ).where(col('dt').isNotNull())
41
42 # finalize df
43 silver_df = spark.createDataFrame(silver_df.collect(), schema=SILVER_SCHEMA)
44 return silver_df
45
While Gold layer performs feature engineering:
src/data_handling/gold.py
1import os
2import pandas as pd
3from deltalake import DeltaTable, write_deltalake
4import pyspark.sql.functions as F
5from pyspark.sql.window import Window
6from pyspark.sql.types import StructType, StructField, DateType, FloatType, IntegerType, LongType
7
8GOLD_SCHEMA = StructType([
9 StructField('dt', DateType(), False),
10 StructField('open', FloatType(), False),
11 StructField('high', FloatType(), False),
12 StructField('low', FloatType(), False),
13 StructField('close', FloatType(), False),
14 StructField('volume', IntegerType(), False),
15 StructField('timestamp_in_ms', LongType(), False),
16 StructField('ave_open', FloatType(), False),
17 StructField('ave_high', FloatType(), False),
18 StructField('ave_low', FloatType(), False),
19 StructField('ave_close', FloatType(), False),
20 StructField('total_volume', IntegerType(), False),
21 StructField('30_day_ma_close', FloatType(), False),
22 StructField('year', IntegerType(), False),
23 StructField('month', IntegerType(), False),
24 StructField('date', IntegerType(), False),
25])
26
27def transform(delta_table, spark, should_filter: bool = False):
28 # moving ave. requires sorting by date and looking back 29 rows.
29 window_spec = Window.orderBy('dt').rowsBetween(-29, 0)
30 _gold_df = delta_table.withColumn('30_day_ma_close', F.avg(F.col('close')).over(window_spec))
31
32 # add temporal features
33 _gold_df = _gold_df.withColumn('year', F.year(F.col('dt')).cast(IntegerType()))
34 _gold_df = _gold_df.withColumn('month', F.month(F.col('dt')).cast(IntegerType()))
35 _gold_df = _gold_df.withColumn('date', F.dayofmonth(F.col('dt')).cast(IntegerType()))
36
37 # log transform close
38 _gold_df = _gold_df.withColumn('close', F.log1p(F.col('close')))
39
40 # select final columns
41 final_cols = [
42 'dt', 'open', 'high', 'low', 'close', 'volume', 'timestamp_in_ms',
43 F.col('open').alias('ave_open'), F.col('high').alias('ave_high'),
44 F.col('low').alias('ave_low'), F.col('close').alias('ave_close'),
45 F.col('volume').alias('total_volume'), '30_day_ma_close', 'year', 'month', 'date'
46 ]
47 gold_df = _gold_df.select(*final_cols)
48
49 # finalize
50 gold_df = gold_df.orderBy(F.col('dt').asc())
51 return gold_df
52
◼ Step 2. Running Initial Batch Learning
After transforming the raw data, I’ll first define the BaselineModel class where the baseline model is initialized, and then add the initial_batch_learning method to the class.
The initial batch learning is key to create a good starting point for online learning.
So, I'll incorporate early stopping logic where the training loop is stopped after observing little improvement of the MSE loss for consecutive 10 epochs.
The trained model, preprocessor, and optimizer are stored in the model storage to enable the system to load them during online learning.
src/model/baseline.py
1import os
2import math
3import joblib
4import numpy as np
5import pandas as pd
6import torch
7import torch.nn as nn
8import torch.optim as optim
9from torch.utils.data import DataLoader, TensorDataset
10from typing import List, Dict, Any
11from sklearn.model_selection import train_test_split
12from sklearn.impute import SimpleImputer
13from sklearn.preprocessing import StandardScaler
14from sklearn.compose import ColumnTransformer
15from sklearn.pipeline import Pipeline
16from category_encoders import BinaryEncoder
17
18import src.data_handling as data_handling
19import src.batch as batch
20from src import TICKER
21from src._utils import main_logger
22
23
24class BaselineModel:
25 def __init__(
26 self,
27 input_size: int = 10,
28 hidden_size: int = 64,
29 ticker: str = TICKER,
30
31 # batch learning metrics
32 batch_size: int = 32,
33 min_delta: float = 1e-3,
34 patience: int = 10,
35 lr: float = 1e-4,
36 ):
37 self.ticker = ticker
38
39 # model
40 self.input_size = input_size
41 self.hidden_size = hidden_size
42 self.model = nn.Sequential(nn.Linear(self.input_size, self.hidden_size),nn.ReLU(), nn.Linear(self.hidden_size, 1))
43
44 # spark session
45 self.spark = data_handling.config_and_start_spark_session()
46
47 # preprocess
48 self.num_cols = ['open', 'high', 'low', 'volume', 'ave_open', 'ave_high', 'ave_low', 'ave_close', 'total_volume', '30_day_ma_close', 'timestamp_in_ms']
49 self.cat_cols = ['dt', 'year', 'month', 'date']
50 self._setup_preprocessor()
51
52 # training
53 self.device_type = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
54 self.device = torch.device(self.device_type)
55
56 self.batch_size = batch_size
57 self.min_delta = min_delta
58 self.patience = patience
59
60 self.criterion = nn.MSELoss() # mse for logged closing price (target val)
61 self.lr = lr
62 self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
63 self.steps_trained = 0
64
65 # filepath
66 self.model_filepath = os.path.join('artifacts', 'base_model', self.ticker, 'model.pt')
67 self.preprocessor_filepath = os.path.join('artifacts', 'preprocessor', self.ticker, 'preprocessor.pkl')
68
69
70 def initial_batch_learning(self, num_epochs: int = 5000):
71 # create pandads df from the etl pipeline
72 df, _ = batch.run_lakehouse(ticker=self.ticker)
73
74 # create train, validation, and test dataset from the df
75 y = df['close']
76 X = df.copy().drop(columns='close', axis=1)
77
78 X_tv, X_test, y_tv, _ = train_test_split(X, y, test_size=2000, shuffle=False)
79 X_train, X_val, y_train, y_val = train_test_split(X_tv, y_tv, test_size=2000, shuffle=False)
80
81 # preprocess the input features
82 X_train = self.preprocessor.fit_transform(X_train)
83 X_val = self.preprocessor.transform(X_val)
84 X_test = self.preprocessor.transform(X_test)
85
86 # save trained preprocessor for later use on online learning
87 self._save_trained_preprocessor()
88
89 # create tensor data loader
90 train_data_loader = self._create_tensor_data_loader(X=X_train, y=y_train)
91 val_data_loader = self._create_tensor_data_loader(X=X_val, y=y_val)
92
93 # reconstruct the baseline model
94 self.input_size = X_train.shape[-1]
95 self.model = nn.Sequential(
96 nn.Linear(self.input_size, self.hidden_size),
97 nn.ReLU(),
98 nn.Linear(self.hidden_size, 1)
99 ).to(self.device)
100
101 # reconstruct the optimizer
102 self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
103
104 # start training with validation and early stopping
105 best_val_loss = float('inf')
106 epochs_no_improve = 0
107 for epoch in range(num_epochs):
108 main_logger.info(f'... start epoch {epoch + 1} ...')
109 self.model.train()
110 for batch_X, batch_y in train_data_loader:
111 batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)
112 self.optimizer.zero_grad()
113
114 try:
115 with torch.autocast(device_type=self.device_type):
116 outputs = self.model(batch_X)
117 loss = self.criterion(outputs, batch_y)
118 if torch.any(torch.isnan(outputs)) or torch.any(torch.isinf(outputs)):
119 main_logger.error('pytorch model returns nan or inf. break the training loop.')
120 break
121
122 if not math.isfinite(loss.item()):
123 main_logger.error('loss is nan or inf. break the training loop.')
124 break
125 loss.backward()
126 nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
127 self.optimizer.step()
128
129 except:
130 outputs = self.model(batch_X)
131 loss = self.criterion(outputs, batch_y)
132 loss.backward()
133 self.optimizer.step()
134
135 if (epoch + 1) % 10 == 0: main_logger.info(f"epoch [{epoch+1}/{num_epochs}], loss: {loss.item():.4f}")
136
137 # validate on a validation dataset (subset of the entire training dataset)
138 self.model.eval()
139 val_loss = 0.0
140
141 # switch the grad mode
142 with torch.inference_mode():
143 for batch_X_val, batch_y_val in val_data_loader:
144 batch_X_val, batch_y_val = batch_X_val.to(self.device), batch_y_val.to(self.device)
145 outputs_val = self.model(batch_X_val)
146 val_loss += self.criterion(outputs_val, batch_y_val).item()
147
148 val_loss /= len(val_data_loader)
149
150 # early stopping
151 if val_loss < best_val_loss - self.min_delta:
152 best_val_loss = val_loss
153 epochs_no_improve = 0
154 else:
155 epochs_no_improve += 1
156 if epochs_no_improve >= self.patience:
157 main_logger.info(f'early stopping at epoch {epoch + 1} as validation loss did not improve for {self.patience} epochs.')
158 break
159
160 # create checkpoint data and store it for later use during the online learning
161 checkpoint = dict(
162 state_dict=self.model.state_dict(),
163 input_dim=X_train.shape[1],
164 optimizer_name='adam',
165 optimizer_state_dict=self.optimizer.state_dict(),
166 batch_size=self.batch_size,
167 lr=self.lr,
168 )
169 self._save_model(checkpoint=checkpoint)
170
171# execute
172BaselineModel().initial_batch_learning()
173
◼ Step 3. Creating Micro Batches
Next, I’ll create a micro batch for online training.
This step has two key points.
▫ 1. Experience Replay (ER) for Stability
Experience Replay (ER) is a crucial technique in Deep Reinforcement Learning (DRL) where the system gives the agent a memory so that it can learn from past interaction as well as recent ones.
In this project, I’ll extend this ER concept to the micro batch by randomly sampling data from the replay buffer (memory) where both old and new data is stored.
This ensures stable model updates, while preventing catastrophic forgetting.
▫ 2. Recency-Biased Sampling for Concept Drift Hedge
Stock price data is typical non-stationary data whose statistical properties like the mean change over time, which causes concept drift.
When randomly sampling data from the replay buffer, the system selects the most recent data first with higher probability, instead of selecting all data at equal chance.
This ensures the model focuses on the newest trend without completely losing the stabilizing context of the older, general data.
src/websocket.py
1import random
2from collections import deque
3from typing import Dict, Any,List
4
5
6def create_micro_batch(er_buffer: deque, batch_size: int) -> List[Dict[str, Any]]:
7 """creates a micro batch for online learning with experience replay and recency-biased sampling"""
8
9 # not enough samples to create a micro batch. return all available data
10 if len(er_buffer) < batch_size: return list(er_buffer)
11
12 # recency bias - include the latest sample
13 latest_experience = er_buffer[-1]
14 batch = [latest_experience]
15
16 # sample remaining items from the past experiences in the er buffer. (using a simple random sample)
17 num_samples_to_add = batch_size - 1
18 past_experiences = list(er_buffer)[:-1]
19 sampled_experiences = random.sample(past_experiences, num_samples_to_add)
20
21 # add past experiences to the batch
22 batch.extend(sampled_experiences)
23
24 # sort samples in order of timestamp
25 sorted_batch = sorted(batch, key=lambda d: d['timestamp_in_ms'])
26 return sorted_batch
27
◼ Step 4. Configuring a WebSocket Server
Next, I’ll configure a WebSocket (WS) server to stream the micro batches.
A WS server is key to stream raw data from the data source, Yahoo Finance.
In this project, I’ll imitate a true real-time data stream by calling the YahooFinance API every 5 seconds via the server.
In the main function, the streaming_task is created by executing the streaming function where the system calls the I/O layer to fetch the latest raw data asynchronously from Yahoo Finance API, and then transform it to a micro batch using the create_micro_batch function defined in Step 3.
Then, the WS server starts to stream data until it is canceled when the streaming_task is completed or cancelled.
Note that providers like Polygon, Finnhub, and TwelveData offer paid websocket services ($30 - $200 per month), which would be an option for further refinement.
src/websocket.py
1import argparse
2import asyncio
3import json
4from collections import deque
5from typing import Dict, Any, Set
6import websockets
7
8# global set to hold all active websocker connections (clients)
9CLIENTS: Set[websockets.ClientConnection] = set()
10
11# fetch data from yfinance
12def fetch_data(ticker: str) -> Dict[str, Any]:
13 # ticker obj
14 stock = yf.Ticker(ticker)
15
16 # fetch the latest 1-minute historical data
17 data = stock.history(period="1d", interval="1m")
18
19 # get the latest row
20 latest_data = data.iloc[-1]
21 return {
22 'open': latest_data.get('Open', 0.0),
23 'high': latest_data.get('High', 0.0),
24 'low': latest_data.get('Low', 0.0),
25 'close': latest_data.get('Close', 0.0),
26 'volume': int(latest_data.get('Volume', 0.0)),
27 'dt': datetime.datetime.now(),
28 "timestamp_in_ms": int(time.time() * 1000)
29 }
30
31
32# poll the data source and broadcast the updates to the clients
33async def streaming(ticker: str, interval_seconds: int, er_buffer, batch_size: int):
34 count = 0
35 while True:
36 count += 1
37 # call the io layer to fetch the latest raw data asynchronously. using asyncio.to_thread() avoids blocking the main asynchronous event loop
38 new_data = await asyncio.to_thread(fetch_data, ticker)
39
40 if new_data:
41 # store the new data in the er buffer
42 er_buffer.append(new_data)
43 main_logger.info(f"... # {count} appended new data to the er buffer (buffer size: {len(er_buffer)}) ...")
44
45 if len(er_buffer) >= batch_size:
46 # create a micro batch
47 current_batch = create_micro_batch(er_buffer=er_buffer, batch_size=batch_size)
48 message = json.dumps(current_batch, cls=CustomJsonEncoder)
49
50 # broadcast the batch to the clients
51 if CLIENTS:
52 try:
53 websockets.broadcast(CLIENTS, message)
54 main_logger.info(f"... successfully broadcasted the current micro batch ({len(current_batch)} items) to {len(CLIENTS)} clients.")
55 except Exception as e: main_logger.error(f"... error during broadcast: {e}")
56 else:
57 main_logger.info("... current batch is ready, but no active clients to broadcast to ...")
58 else:
59 main_logger.info(f"... buffer filling up ({len(er_buffer)}/{batch_size}) ...")
60
61 # wait for the next polling interval
62 await asyncio.sleep(interval_seconds)
63
64
65# add or remove connnections (clients) to the websocket server
66async def handler(websocket: websockets.ClientConnection):
67 await register(websocket)
68 try: await websocket.wait_closed()
69 finally: await unregister(websocket)
70
71
72async def main(ticker: str, interval: int, port: int, host: str, er_buffer, batch_size: int):
73 # start the core data fetching/pushing task
74 streaming_task = asyncio.create_task(
75 streaming(ticker=ticker, interval_seconds=interval, er_buffer=er_buffer, batch_size=batch_size)
76 )
77 try:
78 async with websockets.serve(handler, host, port):
79 await asyncio.Future() # keeps the server running indefinitely
80 except Exception as e:
81 main_logger.error(f"... server failed to start: {e}")
82 finally:
83 # clean up the streaming task when the server stops
84 streaming_task.cancel()
85
86
87if __name__ == '__main__':
88 # handle args
89 POLLING_INTERVAL = 1
90 ER_MAX_BUFFER_SIZE = 300 # for er
91 MICRO_BATCH_SIZE = 4
92
93 parser = argparse.ArgumentParser(description='creating micro batch for online learning')
94 parser.add_argument('--ticker', type=str, default=TICKER, help=f"ticker. default = {TICKER}")
95 parser.add_argument('--interval', type=int, default=POLLING_INTERVAL, help=f"polling interval. default = {POLLING_INTERVAL}")
96 parser.add_argument('--port', type=int, default=PORT, help=f"port. default = {PORT}")
97 parser.add_argument('--host', type=str, default=HOST, help=f"host. default = {HOST}")
98 parser.add_argument('--er_max_buffer_size', type=int, default=ER_MAX_BUFFER_SIZE, help=f"max number of old samples to store. default = {ER_MAX_BUFFER_SIZE}")
99 parser.add_argument('--batch_size', type=int, default=MICRO_BATCH_SIZE, help=f"micro batch size for online learning. default = {MICRO_BATCH_SIZE}")
100 args = parser.parse_args()
101
102 ER_BUFFER: deque[Dict[str, Any]] = deque(maxlen=args.er_max_buffer_size)
103
104 # start running websocket server
105 try:
106 asyncio.run(
107 main(ticker=args.ticker, interval=args.interval, port=args.port, host=args.host, er_buffer=ER_BUFFER, batch_size=args.batch_size)
108 )
109 except KeyboardInterrupt: main_logger.info("\n--- server shutting down ---")
110 except Exception as e: main_logger.error(f"... an unexpected error occurred: {e}")
111
◼ Step 5. Running Online Learning
Next, I’ll configure a client to receive micro batches and perform online learning.
I’ll first add the online_learning method to the BaselineModel class created in Step 2.
During online learning, the trained model, optimizer, and preprocessor are first loaded from storage.
The online_learning method then performs feature engineering and preprocessing on the data in the micro-batch, and finally uses the transformed data to train the model.
src/model/baseline.py
1import os
2import math
3import joblib
4import numpy as np
5import pandas as pd
6import torch
7import torch.nn as nn
8import torch.optim as optim
9from torch.utils.data import DataLoader, TensorDataset
10from typing import List, Dict, Any
11from sklearn.model_selection import train_test_split
12from sklearn.impute import SimpleImputer
13from sklearn.preprocessing import StandardScaler
14from sklearn.compose import ColumnTransformer
15from sklearn.pipeline import Pipeline
16from category_encoders import BinaryEncoder
17
18
19class BaselineModel:
20 def __init__():
21 # same as Step 2
22
23 def initial_batch_learning(self, num_epochs: int = 5000):
24 # same as Step 2
25
26 def online_learning(self, current_batch: List[Dict[str, Any]]):
27 if not current_batch: return
28
29 # data transformation
30 df = self.spark.createDataFrame(current_batch, schema=data_handling.silver.SILVER_SCHEMA)
31 df = data_handling.gold.transform(delta_table=df, spark=self.spark)
32
33 # create training dataset
34 df_pandas = df.toPandas()
35 y = df_pandas['close']
36 X = df_pandas.copy().drop(columns='close', axis=1)
37
38 try:
39 self.preprocessor = self._load_trained_preprocessor()
40 X = self.preprocessor.transform(X)
41 except:
42 X = self.preprocessor.fit_transform(X)
43
44 # convert to pytorch tensors
45 X = torch.from_numpy(X).float()
46 X = X.to(self.device)
47 y = torch.from_numpy(y.to_numpy()).float().reshape(-1, 1)
48 y = y.to(self.device)
49
50 # load the model and optimizer
51 self._load_model_and_optimizer()
52
53 # start online learning
54 self.optimizer.zero_grad()
55 y_pred = self.model(X)
56 loss = self.criterion(y_pred, y)
57 loss.backward()
58 self.optimizer.step()
59 self.steps_trained += 1
60
61 # pred result
62 epsilon = 0
63 try: y_pred_actual = np.exp(y_pred + epsilon)
64 except: y_pred_actual = np.exp(y_pred.cpu().detach().numpy() + epsilon)
65
66 main_logger.info(f"... pytorch model trained (step {self.steps_trained}) ...\n - batch size {len(current_batch)}\n - mse loss {loss.item():.6f}\n - new closing price predicted: {np.mean(y_pred_actual):.2f}")
67
Then, I’ll configure the client_handler function where the client continuously receives micro-batches from the WS server and performs online learning.
The main function asynchronously invokes this client_handler function to ensure continuous online learning.
src/client.py
1import json
2import datetime
3import asyncio
4import websockets
5
6from src.model import BaselineModel
7from src import PORT, HOST
8from src._utils import main_logger
9
10
11async def client_handler(websocket_uri: str, model: BaselineModel, max_retries: int = 5, delay: int = 1):
12 for attempt in range(max_retries):
13 try:
14 async with websockets.connect(websocket_uri, ping_interval=5) as websocket:
15 # loop to retrieve current batch
16 async for message in websocket:
17 # deserialize the current batch from json string to list
18 current_batch = json.loads(message)
19
20 if current_batch and isinstance(current_batch, list):
21 # cast dtypes
22 for record in current_batch:
23 if isinstance(record['dt'], str):
24 record['dt'] = datetime.datetime.fromisoformat(record['dt'])
25
26 # start online learning
27 model.online_learning(current_batch=current_batch)
28 else:
29 main_logger.info(f"... received empty or invalid data: {message} ...")
30
31 except websockets.exceptions.ConnectionClosedOK:
32 main_logger.info("... connection closed by server ...")
33 break
34
35 except ConnectionRefusedError:
36 main_logger.error(f"... connection refused. retrying in {delay}s (attempt {attempt + 1}/{max_retries}) ...")
37
38 # exponential backoff for connection attempts
39 await asyncio.sleep(delay * (2 ** attempt))
40
41 except Exception as e:
42 main_logger.error(f"... unhandled connection error: {e}")
43 break
44
45 main_logger.info("... client shutting down ...")
46
47
48def main():
49 try:
50 websocket_uri = f"ws://{HOST}:{PORT}"
51 asyncio.run(client_handler(websocket_uri=websocket_uri, model=BaselineModel()))
52
53 except KeyboardInterrupt:
54 main_logger.info("\n--- client interrupted and shutting down ---")
55
56
57if __name__ == '__main__':
58 main()
59
◼ Test in Local
Now, let us run the WS server and client in local:
1$uv run src/websocket.py
2
3$uv run src/client.py --ticker <TICKER>
4
Replace <TICKER> with a ticker of your choice.
Here are the terminal output snippets:
12025-10-26 20:38:22,495 - root - INFO - ... base model loaded and reconstructed ...
22025-10-26 20:38:22,496 - root - INFO - ... pytorch model trained (step 347) ...
3 - batch size 4
4 - mse loss 0.105518
5 - new closing price predicted: 259.13
6
72025-10-26 20:38:23,435 - root - INFO - ... calculated gold layer features for 2025-10-26.
82025-10-26 20:38:23,537 - root - INFO - ... base model loaded and reconstructed ...
92025-10-26 20:38:23,539 - root - INFO - ... pytorch model trained (step 348) ...
10 - batch size 4
11 - mse loss 0.105518
12 - new closing price predicted: 259.13
13
◼ Step 6. Scheduling the Batch Learning with Airflow
The last step is to schedule the batch learning monthly with Airflow.
Airflow is an open-source platform used to programmatically author, schedule, and monitor workflows as Directed Acyclic Graphs (DAGs):
Directed: The workflow has a clear, one-way path from start to finish.
Acyclic: There are no loops, meaning a task cannot trigger an upstream task. This prevents infinite processing loops.
Graph: The structure is a collection of nodes (Tasks) and edges (Dependencies).
This process involves data quality check by the check_data_quality function and monthly model training with the full dataset:
src/dag/dag.py
1import datetime
2from airflow import DAG
3from airflow.providers.standard.operators.python import PythonOperator
4from pydeequ.checks import Check, CheckLevel
5from pydeequ.verification import VerificationSuite, VerificationResult
6
7from src.batch import run_lakehouse
8from src.data_handling.spark import config_and_start_spark_session
9
10# define the airflow dag
11default_args = {
12 'owner': 'kuriko iwai',
13 'start_date': datetime.datetime.now(),
14 'retries': 1,
15}
16
17
18def check_data_quality():
19 # initialize a sparksession to use deequ
20 spark = config_and_start_spark_session(session_name='deequ_quality_check')
21
22 # fetch delta table in the gold layer
23 gold_data_path = "s3a://ml-stockprice-pred/data/gold"
24 gold_delta_table = spark.read.format("delta").load(gold_data_path)
25
26 # use deequ to run quality check
27 check_result = (VerificationSuite(spark)
28 .onData(gold_delta_table)
29 .addCheck(
30 Check(spark_session=spark, level=CheckLevel.Warning, description="data quality check")
31 .hasSize(lambda s: s > 0) # check if the df is not empty
32 .hasCompleteness("close", lambda s: s == 1.0) # check for missing vals
33 .hasCompleteness("volume", lambda s: s == 1.0) # check for missing vals
34 .hasMin("open", lambda x: x > 0) # check if val is positive
35 .hasMax("volume", lambda x: x < 10000000000) # check for max val
36 )
37 .run()
38 )
39
40 # convert the deequ result to a spark df for analysis and logging
41 result_df = VerificationResult.checkResultsAsDataFrame(spark, check_result, pandas=False)
42 result_df.show(truncate=False) # type: ignore
43
44 # stop the Spark session
45 spark.stop()
46
47
48with DAG(
49 dag_id='pyspark_data_handling',
50 default_args=default_args,
51 description='A monthly batch learning',
52 schedule='@monthly',
53 catchup=False,
54) as dag:
55
56 # data quality check
57 data_quality_check_task = PythonOperator(
58 task_id='data_quality_check',
59 python_callable=check_data_quality,
60 trigger_rule='all_success',
61 depends_on_past=True, # ensures it only runs after the previous successful run
62 schedule='@monthly',
63 )
64
65 # batch learning
66 run_batch_learning = PythonOperator(
67 task_id='elt_lakehouse',
68 python_callable=BaselineModel().initial_batch_learning,
69 schedule='@monthly',
70 )
71
72 # set task dependencies
73 data_quality_check_task >> run_batch_learning # type: ignore
74
And that’s all for the workflow.
All code is available in my Github repo.
Wrapping Up
Online learning is powerful learning scenario for systems demanding continuous, real-time adaptation to non-stationary data.
Yet it presents fundamental challenges like catastrophic forgetting and concept drift.
In our demonstration, we successfully addressed these challenges by implementing a hybrid approach of combining initial batch training with recency-biased Experience Replay during online updates.
This enables the model to effectively balance adaptation to new stock data with the stability preserved from historical knowledge.
Moving forward, integrating more sophisticated techniques like drift detection algorithms or exploring more complex model architectures would further refine the system.
Related Books for Further Understanding
These books cover the wide range of theories and practices; from fundamentals to PhD level.

Linear Algebra Done Right

Foundations of Machine Learning, second edition (Adaptive Computation and Machine Learning series)

Designing Machine Learning Systems: An Iterative Process for Production-Ready Applications

Machine Learning Design Patterns: Solutions to Common Challenges in Data Preparation, Model Building, and MLOps
Share What You Learned
Kuriko IWAI, " Online Learning in Action — Building Real-Time Stock Forecasting on Lakehouse" in Kernel Labs
https://kuriko-iwai.com/online-learning-in-action
Looking for Solutions?
- Deploying ML Systems 👉 Book a briefing session
- Hiring an ML Engineer 👉 Drop an email
- Learn by Doing 👉 Enroll AI Engineering Masterclass
Written by Kuriko IWAI. All images, unless otherwise noted, are by the author. All experimentations on this blog utilize synthetic or licensed data.