Optimizing LSTMs with Hyperband: A Comparative Guide to Bandit-Based Tuning

A deep dive into mechanics and comparison with major hyperparameter tuning methods like Bayesian Optimization

Machine LearningDeep LearningData SciencePython

By Kuriko IWAI

Kuriko IWAI

Table of Contents

IntroductionWhat is Hyperband
How Hyperband Works
The Walkthrough Example - Support Vector Classifier
Simulation
The Data - Creating Training and Test Datasets
The Model - Defining the LSTM Network
The Search Space
Defining Validation Function
Running Hyperband
Results
Comparing with Other Tuning Methods
Consideration
Configuring Hyperband for Better Performance
Wrapping Up

Introduction

Hyperband is a powerful hyperparameter tuning method in machine learning, leveraging successive halving to efficiently allocate resources.

However, executing Hyperband requires careful consideration of its core mechanics and parameters to maximize its benefits.

In this article, I'll explore its core mechanics by tuning LSTM networks for stock price prediction, and compare the performance with other major methods:

  • Bayesian Optimization,

  • Random Search, and

  • Genetic Algorithms.

What is Hyperband

Hyperband is a bandit-based hyperparameter tuning algorithm:

Figure A. Types of hyperparameter tuning methods (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure A. Types of hyperparameter tuning methods (Created by Kuriko IWAI)

Among other tuning methods, Hyperband is characteristic in using a multi-armed bandit strategy combined with successive halving algorithm (SHA).

Multi-armed bandit refers to a probability theory that demonstrates the fundamental trade-off between:

  • Exploration: Explore a wide range of hyperparameter configurations, and

  • Exploitation: Exploit most promising configurations.

SHA handles resource-allocation strategies where a fixed budget (the number of epochs) is allocated to a randomly sampled hyperparameter configuration set from the search space.

In every epoch, SHA evaluates the performance of each set, discards the worst-performing one, and reallocate the remaining budget to the surviving configuration sets called survivors.

Hyperband takes this step further by implementing multi-armed bandits: running SHA with different initial budgets to balance the exploration and exploitation.

How Hyperband Works

The diagram below shows how Hyperband allocates more budget to the winner (Configuration #4):

Figure B. How Hyperband works (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure B. How Hyperband works (Created by Kuriko IWAI)

Hyperband starts the process with creating many hyperparameter configuration sets in Bracket 1 and allocating a random, small budget to each.

It then progressively increases the budget allocation to the survivors as moving on to next Bracket.

In Figure B, Bracket 2 has two survivors (Configuration #1 and #4) which gets more budgets, while Configurations #2 and #3 are discarded.

Lastly, Bracket 3 has the ultimate winner, Configuration #4, which gets the entire budget.

This approach effectively explores a wide range of configurations while quickly eliminating poor performers, taking a balance between exploration and exploitation.

Execution Steps

The process is broken down to the four major steps:

Step 1. Define the budget and halving factor

First, we’ll define:

  • Maximum resource budget (R): The total number of epochs a single model can be trained for, and

  • Halving factor (η): A predefined factor that determines how aggressively configurations are culled (Common values = 2, 3, or 4).

In Figure B, η = 4 and R = 3.0. At each step, the number of hyperparameter configurations is divided by the factor η, and the budget for the survivors is multiplied by η.

Step 2. Iterate through brackets

The algorithm runs a series of brackets each of which runs SHA with random starting budget, as showed in Bracket 1 in Figure B.

The maximum number of brackets created is determined by the maximum bracket index s_max:

smax=logη(R)s_{max} = \lfloor \log_{\eta}(R) \rfloor

where:

  • η: The halving factor by which the number of hyperparamter configurations is reduced, and

  • R: A maximum resource budget for all runs.

The algorithm iterates through total s_{max} brackets down to zero.

Step 3. Run SHA

For each bracket s, Hyperband determines total number of hyperparameter configurations to test such that:

ns=(smaxs+1)smax+1Rηsn_s = \lfloor \frac{(s_{max} - s + 1)}{s_{max} + 1} \cdot \frac{R}{\eta^s} \rfloor

where:

  • n_s​: The total number of hyperparameter configurations to test in the current bracket (index s),

  • R: The maximum resource budget,

  • η: The halving factor to determining how aggressively the algorithm prunes unpromising configurations,

  • s_{max}: The maximum number of brackets in the Hyperband algorithm, and

  • s: The current bracket index that ranges from s_{max}​ down to zero.

Then, it determines the initial budget (r_s​) for each bracket:

rs=ηsRsmaxηsr_s = \eta^s \cdot \frac{R}{s_{max} \cdot \eta^s}

where:

  • r_s​: The initial budget for a current bracket s,

  • R: The maximum resource budget,

  • η: The halving factor, and

  • s_{max}: The maximum number of brackets in the Hyperband algorithm.

To leverage the multi-armed bandit concept, hyperband intentionally samples the larger number of hyperparameter configurations for a bracket with a small initial budget, and smaller number of configurations for a bracket with a large initial budget (small r_s for brackets with large n_s, vice versa).

Then, it selects survivors - the top n_s​ / η sets of hyperparameters based on the performance and trains them on an additional budget for total r_s​⋅η epochs.

Step 4. Iteration and convergence

The process in Step 3 continues until only one hyperparamter configuration remains in the bracket or the maximum budget is reached.

Hyperband's efficiency comes from its ability to quickly discard poorly-performing configurations, freeing up resources to train more promising configurations for longer.

The Walkthrough Example - Support Vector Classifier

Now, let us see how it works on a simple model - a Support Vector Classifier (SVC) - by tuning its regularization parameter C and kernel coefficient gamma.

  • Model: Support Vector Classifier (SVC)

  • The Search Space:

    • C: [0.1, 1, 10, 100]

    • gamma: ['scale', 'auto', 0.1, 1, 10]

Step 1. Define a Budget and Halving Factor

Let us set a small maximum budget R = 81 for a simple model like SVC, and the halving factor η = 3.

Step 2. Iterate through Brackets

Hyperband will automatically calculate the different brackets of successive halving.

The maximum bracket index is:

smax=logη(R)=log3(81)=4s_{max} = \lfloor \log_{\eta}(R) \rfloor = log_3(81) = 4

This means that Hyperband will run brackets for s = 4, 3, 2, 1, 0.

Each bracket has a different starting number of hyperparameter configurations n_s and initial budget r_s. Let’s say:

  • Bracket 1 (s = 4): n_s​=1, r_s​=9

  • Bracket 2 (s = 3): n_s​=3, r_s​=3

  • Bracket 3 (s = 2): n_s​=9, r_s​=1

  • Bracket 4 (s = 1): n_s​=27, r_s​=1/3

  • Bracket 5 (s = 0): n_s​=81, r_s​=1/9

The total budget R = 81 is distributed across these brackets to find the best configuration efficiently.

Step 3. Run Successive Halving Algorithm (SHA)

Take Bracket 3 (s = 2) for an example.

  1. Initial Run:

    • Hyperband randomly samples 9 hyperparameter configurations.

    • Each is trained for a small initial budget of 1 epoch.

    • Records model performance.

    • The top 3 (9 / 3 = 3) best-performing configurations are kept (survivors), and the rest are discarded.

  2. Second Run:

    • The three survivors are now trained with a larger budget of 3 epochs (1 * 3 = 3).

    • Records model performance.

    • The top 1 (3 / 3 = 1) best-performing configuration is kept.

  3. Final Run:

    • The single remaining survivor is trained with the final budget of 9 epochs (3 * 3 = 9).

    • Records model performance.

Step 4. Find the Best

Hyperband runs Step 3 for all brackets to find the best performing configuration across all brackets.

Simulation

Now, let us see how Hyperband works on a more complex model - LSTM networks.

The model provides predictions for the closing stock price of the selected ticker, NVDA.

The Data - Creating Training and Test Datasets

I’ll fetch historical daily stock price data from the API endpoint provided by Alpha Vantage.

After transforming the data, I’ll load it into a Pandas DataFrame and preprocess the dataset to create training and test sets.

The training set will be used for model training and validation, while the test set will be held separate to prevent data leakage.

1import torch
2import pandas as pd
3from sklearn.model_selection import train_test_split
4from sklearn.preprocessing import StandardScaler, OneHotEncoder
5from sklearn.compose import ColumnTransformer
6
7
8# create target and input vals
9target_col = 'close'
10y = df.copy()[target_col].shift(-1) # avoid data leakage
11y = y.iloc[:-1] # drop the last row  (as y = nan)
12
13input_cols = [col for col in df.columns if col not in [target_col, 'dt']] # drop dt as year, month, date can capture sequence
14X = df.copy()[input_cols]
15X = X.iloc[:-1] # drop the last row
16
17# create trainning and test dataset (trianing will split into train and val for wfv)
18X_train, X_test, y_train, y_test = train_test_split(
19    X, y, test_size=800, shuffle=False, random_state=42
20)
21
22# preprocess
23cat_cols = ['year', 'month', 'date']
24num_cols = list(set(input_cols) - set(cat_cols))
25preprocessor = ColumnTransformer(
26    transformers=[
27        ('num', StandardScaler(), num_cols),
28        ('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols)
29    ]
30)
31
32X_train = preprocessor.fit_transform(X_train)
33X_test = preprocessor.transform(X_test)
34
35# convert the dense numpy arrays to pytorch tensors
36X_train = torch.from_numpy(X_train.toarray()).float()
37y_train = torch.from_numpy(y_train.values).float().unsqueeze(1)
38
39X_test = torch.from_numpy(X_test.toarray()).float()
40y_test = torch.from_numpy(y_test.values).float().unsqueeze(1)
41

The original data has 6,501 samples of historical stock price records of NVDA:

1<class 'pandas.core.frame.DataFrame'>
2RangeIndex: 6501 entries, 0 to 6500
3Data columns (total 15 columns):
4 #   Column           Non-Null Count  Dtype         
5---  ------           --------------  -----         
6 0   dt               6501 non-null   datetime64[ns]
7 1   open             6501 non-null   float32       
8 2   high             6501 non-null   float32       
9 3   low              6501 non-null   float32       
10 4   close            6501 non-null   float32       
11 5   volume           6501 non-null   int32         
12 6   ave_open         6501 non-null   float32       
13 7   ave_high         6501 non-null   float32       
14 8   ave_low          6501 non-null   float32       
15 9   ave_close        6501 non-null   float32       
16 10  total_volume     6501 non-null   int32         
17 11  30_day_ma_close  6501 non-null   float32       
18 12  year             6501 non-null   object        
19 13  month            6501 non-null   object        
20 14  date             6501 non-null   object        
21dtypes: datetime64[ns](1), float32(9), int32(2), object(3)
22memory usage: 482.6+ KB
23

The Model - Defining the LSTM Network

Next, I’ll define the LSTMModel class on PyTorch, baed on a many-to-one architecture.

1import torch
2import torch.nn as nn
3
4class LSTMModel(nn.Module):
5    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, dropout):
6        super(LSTMModel, self).__init__()
7        self.hidden_dim = hidden_dim
8        self.layer_dim = layer_dim
9        self.dropout = dropout
10        self.lstm = nn.LSTM(
11            input_dim, hidden_dim, layer_dim, batch_first=True, dropout=dropout
12        )
13        self.fc = nn.Linear(hidden_dim, output_dim)
14
15    def forward(self, x):
16        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).to(x.device)
17        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).to(x.device)
18        o_t, _ = self.lstm(x, (h0.detach(), c0.detach()))
19        o_final = self.fc(o_t[:, -1, :])
20        return o_final
21

The Search Space

Hyperband performs better with a broader search space. I’ll define the following search space:

1import random
2
3def search_space():
4    return {
5        'lr': 10**random.uniform(-6, -1),
6        'hidden_dim': random.choice([16, 32, 64, 128, 256]),
7        'layer_dim': random.choice([1, 2, 3, 4, 5]),
8        'dropout': random.uniform(0.1, 0.6),
9        'batch_size': random.choice([16, 32, 64, 128, 256])
10    }
11

Defining Validation Function

Next, I’ll define the train_and_val_wfv function that runs walk-forward validation for the time-series data:

1
2def train_and_val_wfv(hyperparams, budget, X, y, train_window, val_window):
3    total_val_loss = 0
4    all_loss_histories = []
5
6    num_folds = (X.size(0) - train_window - val_window) // val_window + 1
7
8    for i in range(num_folds):
9        train_start = i * val_window
10        train_end = train_start + train_window
11        val_start = train_end
12        val_end = val_start + val_window
13
14        # ensure not to go past the end of the dataset
15        if val_end > X.size(0):
16            break
17
18        # create folds
19        X_train_fold = X[train_start:train_end]
20        y_train_fold = y[train_start:train_end]
21        X_val_fold = X[val_start:val_end]
22        y_val_fold = y[val_start:val_end]
23
24        # train and validate on the current fold
25        fold_val_loss, fold_loss_history = train_and_val(
26            hyperparams=hyperparams,
27            budget=budget,
28            X_train=X_train_fold,
29            y_train=y_train_fold,
30            X_val=X_val_fold,
31            y_val=y_val_fold
32        )
33        total_val_loss += fold_val_loss
34        all_loss_histories.append(fold_loss_history)
35
36    # compute ave. loss
37    avg_val_loss = total_val_loss / num_folds
38    return avg_val_loss, all_loss_histories
39

Running Hyperband

Lastly, I’ll define the run_hyperband function.

This function takes four arguments:

  • the search space (search_space_fn),

  • the validation function (val_fn),

  • the total budget (R), and

  • the halving factor (eta).

In the provided code snippet, R is set to 100, eta is 3, and the training and validation windows for the Walk-Forward cross-validation are 3,000 and 500, respectively.

1
2def run_hyperband(search_space_fn, val_fn, R, eta):
3    s_max = int(log(R, eta))
4
5    overall_best_config = None
6    overall_best_loss = float('inf')
7    all_loss_histories = []
8
9    # outer loop: Iterate through all brackets
10    for s in range(s_max, -1, -1):
11        n = int(R / eta**s)
12        r = int(R / n)
13
14        main_logger.info(f'... running bracket s={s}: {n} configurations, initial budget={r} ...')
15
16        # geerate n random hyperparameter configurations
17        configs = [get_hparams_fn() for _ in range(n)]
18
19        # successive halving
20        for i in range(s + 1):
21            budget = r * (eta**i)
22            main_logger.info(f'... training {len(configs)} configurations for budget {budget} epochs ...')
23
24            evaluated_results = []
25            for config in configs:
26                loss, loss_history = train_val_fn(config, budget)
27                evaluated_results.append((config, loss, loss_history))
28
29            # record loss histories for plotting
30            all_loss_histories.append((evaluated_results, budget))
31
32            # sort and select top configurations
33            evaluated_results.sort(key=lambda x: x[1])
34
35            # keep track of the best configuration found so far
36            if evaluated_results and evaluated_results[0][1] < overall_best_loss:
37                overall_best_loss = evaluated_results[0][1]
38                overall_best_config = evaluated_results[0][0]
39
40            num_to_keep = floor(len(configs) / eta)
41            configs = [result[0] for result in evaluated_results[:num_to_keep]]
42
43            if not configs:
44                break
45
46    return overall_best_config, overall_best_loss, all_loss_histories, s_max
47
48
49# define budget, halving factor
50R = 100
51eta = 3
52
53# wfv setting
54train_window = 3000
55val_window = 500
56
57# run sha
58best_config, best_loss, all_loss_histories, s_max = run_hyperband(
59    search_space_fn=search_space, 
60    val_fn=lambda h, b: train_and_val_wfv(h, b, X_train, y_train, train_window=train_window, val_window=val_window),
61    R=R, 
62    eta=eta
63)
64

Results

Hyperband

Best Hyperparamter Configuration:

  • 'lr': 0.0001614172022855225

  • 'hidden_dim': 128

  • 'layer_dim': 3

  • 'dropout': 0.5825758700895215

  • 'batch_size': 16

Best Validation Loss (MSE):

0.0519

Loss History:

Solid lines in the graph below track average validation losses (MSEs) over training cycles, while the vertical dashed lines indicate where the Hyperband algorithm prunes underperforming models:

Figure C. Loss history (Hyperband) (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure C. Loss history (Hyperband) (Created by Kuriko IWAI)

The lines that stop early (mostly purple) were poor performers that were pruned because their loss was too high.

The few lines that continue to 100 epochs (mostly teal and blue) are the most successful configurations. Their loss drops quickly at the beginning and then stabilizes at a very low value, indicating excellent performance.

This is an efficient way to quickly eliminate bad configurations without training them for a long time.

Comparing with Other Tuning Methods

To compare different methods, I'll run 20 trials of:

  • Bayesian Optimization,

  • Random Search, and

  • Genetic Algorithms

using the same search space, model, and training / validation window.

Bayesian Optimization

Bayesian Optimization uses a probabilistic model like a Gaussian Process to model the validation errors and select the next best hyperparameter configuration to evaluate.

Best Hyperparamter Configuration:

  • 'lr': 0.00016768631941614767

  • 'hidden_dim': 256

  • 'layer_dim': 3

  • 'dropout': 0.3932769195043036

  • 'batch_size': 64

Best Validation Loss (MSE):

0.0428

Loss History:

Figure D. Loss history (Bayesian Optimization) (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure D. Loss history (Bayesian Optimization) (Created by Kuriko IWAI)

Random Search

Random Search samples a fixed number of configurations from the search space at random without using the results of past trials.

Best Hyperparamter Configuration:

  • 'lr': 0.0004941205117774383

  • 'hidden_dim': 128

  • 'layer_dim': 2

  • 'dropout': 0.3398469430820351

  • 'batch_size': 64

Best Validation Loss (MSE):

0.03620884600095451

Loss History:

Figure E. Loss history (Random Search) (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure E. Loss history (Random Search) (Created by Kuriko IWAI)

Genetic Algorithms (GA)

Inspired by biological evolution, Genetic Algorithms maintain a population of hyperparameter configurations and use concepts like mutation and crossover to generate new, potentially better ones.

Best Hyperparamter Configuration:

  • 'lr': 0.006441170552290832

  • 'hidden_dim': 128

  • 'layer_dim': 3

  • 'dropout': 0.2052570911345997

  • 'batch_size': 128

Best Validation Loss (MSE):

0.1321

Loss History:

Figure F. Loss history (Genetic Algorithms) (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure F. Loss history (Genetic Algorithms) (Created by Kuriko IWAI)

Consideration

Random Search (0.0362) and Bayesian Optimization (0.0428) slightly outperformed Hyperband (0.0519) in terms of final validation loss.

This indicates a trade-off between efficiency and finding the global optimum.

Hyperband's efficiency comes from aggressively pruning underperforming configurations early in the training process.

While this saves significant time, it risks accidentally eliminating a "late-blooming" configuration that might have eventually achieved the best result.

In this specific case, both Random Search and Bayesian Optimization were more successful.

Random Search allowed a high-performing configuration to reach its full potential by giving every model a full training budget.

Bayesian Optimization's intelligent, informed search was also more effective at finding the best hyperparameter set than Hyperband's early-stopping approach.

Configuring Hyperband for Better Performance

To improve Hyperband's performance, it is recommended to tune its parameters and combine it with other competitive tuning methods.

1. Tuning Hyperband's Key Parameters

  • Setting large R (total budget) allows more "late-blooming" models to prove their worth, reducing the chance of prematurely pruning a good configuration.

  • Setting small eta (halve values) allows more moderate pruning process by making more configurations advance to the next bracket (eta = 3 discards three configurations, while eta = 1 discards only one configuration).

2. Combining Hyperband with Bayesian Optimization

BOHB (Bayesian Optimization and HyperBand) is a hybrid approach that uses Hyperband's successive halving as a framework but replaces its random sampling with Bayesian Optimization's probabilistic model.

BOHB uses Bayesian Optimization to select the most promising candidates to feed into the Hyperband brackets.

This approach combines the best of both, offering Hyperband's fast results with the strong final performance of Bayesian Optimization.

Wrapping Up

Hyperband is a powerful and efficient hyperparameter optimization method that effectively balances the exploration of a wide search space with the exploitation of promising configurations.

Its ability to quickly prune poor performers makes it significantly faster and more scalable than traditional methods like Grid Search and Random Search.

While other methods like Bayesian Optimization can be sample-efficient, Hyperband is a robust option for a wide range of machine learning tasks, particularly when training is computationally expensive.

Continue Your Learning

If you enjoyed this blog, these related entries will complete the picture:

Related Books for Further Understanding

These books cover the wide range of theories and practices; from fundamentals to PhD level.

Linear Algebra Done Right

Linear Algebra Done Right

Foundations of Machine Learning, second edition (Adaptive Computation and Machine Learning series)

Foundations of Machine Learning, second edition (Adaptive Computation and Machine Learning series)

Designing Machine Learning Systems: An Iterative Process for Production-Ready Applications

Designing Machine Learning Systems: An Iterative Process for Production-Ready Applications

Machine Learning Design Patterns: Solutions to Common Challenges in Data Preparation, Model Building, and MLOps

Machine Learning Design Patterns: Solutions to Common Challenges in Data Preparation, Model Building, and MLOps

Share What You Learned

Kuriko IWAI, "Optimizing LSTMs with Hyperband: A Comparative Guide to Bandit-Based Tuning" in Kernel Labs

https://kuriko-iwai.com/hyperband

Looking for Solutions?

Written by Kuriko IWAI. All images, unless otherwise noted, are by the author. All experimentations on this blog utilize synthetic or licensed data.