Scaling Generalization: Automating Flexible AI with Meta-Learning and NAS
Explore how adaptable neural networks handle few-shot learning
By Kuriko IWAI

Table of Contents
IntroductionWhat is Meta-LearningIntroduction
AI models are highly specialized in specific tasks but struggles in generalization when facing with slightly different types of problems.
Meta-learning combined with Neural Architecture Search (NAS) can tackle this challenge by finding the optimal architecture and training the model to be a fast learner.
In this article, I’ll explore the practical steps of this synergy, demonstrating how to build truly general-purpose AI that can quickly adapts to new challenges.
What is Meta-Learning
Meta-learning, or “learning to learn,” is one of the primary strategies for Few-Shot Learning (FSL) where machine learning models learn and generalize from a very small number of labeled examples.
Instead of training a model to be an expert at a single task, meta-learning trains a model to become a fast learner, measuring model performance by its ability to adapt to a new, unseen task with minimal examples.
The meta-learner model is exposed to a wide distribution of related tasks and learns a strategy for quickly acquiring new skills with minimal data.
For example, a model trained to recognize cats, dogs, and birds performs a classification task of other animals like elephant or lion that it has never seen before.

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure A. Image of meta-learning in an animal classification task (Created by Kuriko IWAI)
Without meta-learning, we need to retrain the model from scratch using on many images on elephants and lions, which is a critical bottleneck for building AI that can operate in the real, unpredictable world.
◼ Major Meta-Learning Methods
The primal approaches of meta-learning are categorized into metric-based or optimization-based.
▫ Metric-Based Methods
Metric-based methods learn a “metric” to compare new data points to the limited examples available.
In this approach, the model classifies a new example by finding its nearest neighbors in an embedding space using networks like:
Prototypical Networks: Learns an embedding space where a prototype (e.g., mean value) is computed for each class. New examples are classified based on their distance to these prototypes.
Matching Networks: Uses an attention mechanism to compare a new sample with all the examples in the support set to make a prediction.
Siamese Networks: Consists of two identical neural networks with shared weights, trained to determine if two inputs belong to the same class by learning a similarity function.
For example, in case of Prototypical Networks, the algorithm recognizes three new animals — a dog, a cat, and a bird — using only one picture of each.
The process flows:
Learning to map each of the three images into a feature space,
Creating a prototype vector for each by averaging the feature vectors of the given image,
Mapping a new, unseen image to the feature space,
Measuring distance from a new image’s feature vector to each of the three prototypes created in Step 2, and
Classifying the new image as a class (animal) whose prototype is the closest.
In this case, the “metric” is the distance function used to measure closeness in the feature space.
This method is best when the core of the few-shot problem is learning a similarity among different classes or examples.
▫ Optimization-Based Methods
Optimization-based methods, on the other hand, train a model that can quickly adapt its parameters to a new task with just a few gradient updates.
Major approaches include:
Model-Agnostic Meta-Learning (MAML): MAML finds an optimal set of initial parameters for a model. This initialization allows the model to be fine-tuned to a new task efficiently with a small number of gradient steps.
Reptile: A simpler and more computationally efficient alternative to MAML that works by iteratively training on a random batch of tasks and moving the model’s weights toward the trained weights for each task.
For example, in case of using MAML to run the same dog, cat, and bird classification task, the model acquires the ability to quickly learn new animal categories instead of being trained on huge datasets of animals.
This method is best when the few-shot problem requires a highly flexible and rapid learner. Particularly when:
The model needs to adapt to a wide variety of tasks because the optimization-based methods can train the model to find a good starting point for any task, and
Speed of adaptation is critical because the method allows the model to achieve high performance on a new task with a minimal number of examples and gradient updates.
◼ Meta-Learning Works for “Related” Tasks
Meta-learning works when the underlying tasks have similar learning dynamics, sharing a common structure that allows the model to learn a good learning strategy.
In the previous example of the animal classification, tasks of classifying dogs/cats or birds/dogs are all related because the common structure lies in the fundamental features of the images that the model must analyze.
Particularly, both tasks involve:
Identifying and distinguishing between different animal species,
Recognizing common features like fur, feathers, eyes, and noses, and
Understanding the concept of different classes and how to assign an image to one of them.
This shared structure enables the model to learn a meta-strategy—a higher-level approach to solving classification problems—rather than just memorizing features for a single task.
▫ Task Heterogeneity (Domain Mismatch)
On the other hand, when significant domain shift happens between the tasks, meta-learning would not work.
For example, a meta-learning model trained on the animal image classification would inaccurately classify human images because the visual features in the human images are different from those of animals.
The model is biased due to a prior training over the animal images, so would try to apply animal-centric features to the human image presented.
When tasks are related, they aren't just similar on a superficial level. Instead, they possess a fundamental connection that makes knowledge transferable between them.
This is the core principle behind meta-learning.
How Meta-Learning Leverages Neural Architecture Search (NAS)
Neural Architecture Search (NAS) is a method to automate the process of designing neural network architecture by formulating the process as an optimization problem.
The ultimate goal of the NAS is to find the optimal architecture that can minimize validation losses with the minimal computational efforts.
However, NAS itself has limitations like:
Computational cost: This is the biggest hurdle for NAS. Searching for an optimal architecture can take hundreds or even thousands of GPU days.
Lack of generalization: A model found by NAS is highly optimized for a specific task and dataset. So, when we need to solve a slightly different task or use a new dataset, we often have to run the entire, costly NAS process again from scratch.
The combination of meta-learning and NAS is powerful in leveraging advantages of both, while overcoming the limitations of NAS.
Now, let us take a look at its step-by-step approach.
◼ Step 1. The Blueprint — Defining the Search Space
First, define the search space for the NAS algorithm.
The search space includes:
Different types of convolutional layers (e.g., standard, separable).
Various activation functions (e.g., ReLU, Sigmoid).
Options for pooling layers or skip connections.
The number of layers or channels.
Then, the NAS algorithm will explore different combinations of these components to create unique architectures.
◼ Step 2. The Experiment — Running Meta-Training Loop
This is the core of the process, where the NAS algorithm and the meta-learner work together.
First, the NAS algorithm (e.g., a reinforcement learning agent or an evolutionary algorithm) proposes a candidate neural network architecture from the search space.
Let us call this candidate Model_A.
Then, the meta-learning algorithm (e.g., MAML) takes Model_A and evaluates its ability to learn. This involves a series of few-shot classification tasks like:
Task 1: Classify dogs vs. cats. Model_A is given a few images of each and performs a few gradient updates to adapt its weights.
Task 2: Classify birds vs. dogs. Model_A is given a new small dataset and again performs a few quick updates.
The process continues over a variety of different two-class tasks.
The meta-learner's job is to optimize Model_A so that it can adapt to these tasks as quickly and accurately as possible.
After this meta-training, Model_A is evaluated on a held-out set of new few-shot tasks.
The key here is not just to test its performance, but to measure how quickly and accurately it can adapt.
The resulting performance score (e.g., the average accuracy across all test tasks) is a measure of how "meta-learnable" that specific architecture is.
◼ Step 3. The Feedback — Updating the NAS Algorithm
The NAS algorithm receives the performance score of Model_A.
It uses this feedback to refine its strategy for proposing new architectures:
If Model_A performed well, the NAS algorithm learns that architectures with similar properties (e.g., more skip connections or a certain type of convolutional layer) are likely to be good choices.
If Model_A performed poorly, the algorithm learns to avoid similar architectures in the future.
◼ Step 4. Repeat and Finalize
Step 2 and Step 3 are repeated many times; the NAS algorithm iteratively proposes new architectures, and the meta-learner evaluates their learning abilities.
Eventually, the search stops, and the best-performing architecture found throughout the process is selected.
This final model is the ultimate outcome—an architecture that is not only effective but is also specifically designed to be an optimal few-shot learner.
Simulation
Now, let us take a look at these steps in a practical coding implementation on the animal classification task:
Main task: Image classification: cat, dog, and bird
Few-Shot learning: 5 shots per each class
Dataset: Extremely limited. We only have 10 images for each animal.
Base model: Convolutional Neural Network (CNN)
NAS algorithm: Reinforcement Learning (RL)
Meta-learning method: Model-Agnostic Meta-Learning (MAML)
◼ The Base Model
First, I defined the BaseModel class whose architecture is subject to search by the NAS algorithm.
It takes the BaseModelCell class as a single cell in the network and processes the architecture_list list that the NAS will provide:
1import torch.nn as nn
2
3
4# a single block in the neural network.
5class BaseModelCell(nn.Module):
6 def __init__(self, architecture_list, in_channels, out_channels):
7 super(BaseModelCell, self).__init__()
8 self.ops = nn.ModuleList()
9
10 # create an instance of each chosen operation
11 for op_index in architecture_list:
12 op_fn = OPERATIONS[op_index].op
13 self.ops.append(op_fn(in_channels, out_channels))
14
15 # add a final convolutional layer to match output size if needed.
16 self.final_conv = nn.Conv2d(in_channels, out_channels, 1)
17
18 def forward(self, x):
19 h = x
20 for op in self.ops:
21 h = op(h)
22 # apply the final convolution for consistent output channels.
23 return self.final_conv(x) + h
24
25
26# primary model to be optimized (cnn)
27class BaseModel(nn.Module):
28 def __init__(self, architecture, input_channels, num_classes):
29 super(BaseModel, self).__init__()
30 self.architecture = architecture
31
32 # initial number of channels
33 channels = 16
34 self.initial_conv = nn.Conv2d(input_channels, channels, 3, padding=1)
35
36 # construct cells in the network
37 self.cells = nn.ModuleList()
38 for cell_arch in self.architecture:
39 self.cells.append(BaseModelCell(cell_arch, channels, channels))
40
41 # add a simple classifier head
42 self.global_pool = nn.AdaptiveAvgPool2d(1)
43 self.classifier = nn.Linear(channels, num_classes)
44
45 # forward pass
46 def forward(self, x):
47 h = self.initial_conv(x)
48 for cell in self.cells: h = cell(h)
49 h = self.global_pool(h)
50 h = h.view(h.size(0), -1)
51 return self.classifier(h)
52
◼ Step 1. The Blueprint - Defining Search Space
The first step is to define the search space, flexible enough yet computationally realizable, for the NAS algorithm.
The search space OPERATIONS contains various convolutional and pooling operations:
1import torch.nn as nn
2from collections import namedtuple
3
4Op = namedtuple('Op', ['name', 'op'])
5
6OPERATIONS = [
7 Op('conv3x3', lambda c_in, c_out: nn.Conv2d(c_in, c_out, 3, padding=1)),
8 Op('conv5x5', lambda c_in, c_out: nn.Conv2d(c_in, c_out, 5, padding=2)),
9 Op('maxpool3x3', lambda c_in, c_out: nn.MaxPool2d(3, stride=1, padding=1)),
10 Op('avgpool3x3', lambda c_in, c_out: nn.AvgPool2d(3, stride=1, padding=1)),
11 Op('identity', lambda c_in, c_out: nn.Identity()),
12]
13
◼ Step 2. The Experiment - Running Meta-Training Loop
First, I defined the Controller class, an RNN(LSTM)-based controller that learns to generate a good network architecture.
Because an architecture can be thought of as a sequence of operations like conv3x3 or maxpool3x3, the LSTM outputs a sequence of choices that represent a neural network's architecture.
1import torch
2import torch.nn as nn
3from torch.distributions import Categorical
4
5
6class Controller(nn.Module):
7 def __init__(self, search_space_size, hidden_size, num_layers):
8 super(Controller, self).__init__()
9 self.search_space_size = search_space_size
10 self.hidden_size = hidden_size
11 self.num_layers = num_layers
12
13 # embedding for the search space operations
14 self.op_embedding = nn.Embedding(self.search_space_size, self.hidden_size)
15
16 # ltsm to generate the architecture
17 self.lstm = nn.LSTM(
18 input_size=self.hidden_size,
19 hidden_size=self.hidden_size,
20 num_layers=self.num_layers
21 )
22
23 # linear layer to output the logits for the next operation choice
24 self.logits = nn.Linear(self.hidden_size, self.search_space_size)
25
26
27 # forward pass (taking x: the embedding of the previous choice as an argument)
28 def forward(self, x, hidden):
29 x = self.op_embedding(x)
30 x = x.unsqueeze(0) # add a sequence dimension
31 output, hidden = self.lstm(x, hidden)
32 logits = self.logits(output.squeeze(0))
33 return logits, hidden
34
35 # samples a new architecture using the controller
36 def sample_architecture(self, max_cells, num_ops_per_cell):
37 # initialize hidden state
38 hidden = (
39 torch.zeros(self.num_layers, 1, self.hidden_size).to(CONFIG['device']),
40 torch.zeros(self.num_layers, 1, self.hidden_size).to(CONFIG['device'])
41 )
42
43 # start with a placeholder input token (e.g., index 0)
44 input_token = torch.tensor([0]).to(CONFIG['device'])
45
46 log_probs = []
47 architecture = []
48
49 # sample architecture cell by cell
50 for _ in range(max_cells):
51 cell_arch = []
52 for _ in range(num_ops_per_cell):
53 logits, hidden = self.forward(input_token, hidden)
54
55 # het the probability distribution and sample from it
56 op_dist = Categorical(logits=logits)
57 op_choice = op_dist.sample()
58
59 # store the log probability for reward calculation
60 log_probs.append(op_dist.log_prob(op_choice))
61
62 # update input for the next step and store the choice
63 input_token = op_choice
64 cell_arch.append(op_choice.item())
65 architecture.append(cell_arch)
66
67 return architecture, torch.stack(log_probs)
68
69# define the NAS controller
70controller = Controller(
71 search_space_size=len(OPERATIONS),
72 hidden_size=CONFIG['hidden_size'],
73 num_layers=CONFIG['num_layers']
74).to(CONFIG['device'])
75
Then, I defined the MAMLTrainer class that handles the MAML training loop.
It trains the child model on support sets, evaluates on query sets, and uses the performance as a reward to update the controller.
1import torch
2import copy
3import torch.nn.functional as F
4import numpy as np
5
6
7class MAMLTrainer:
8 def __init__(self, controller, meta_dataset):
9 self.controller = controller
10 self.meta_dataset = meta_dataset
11 self.controller_optimizer = torch.optim.Adam(
12 self.controller.parameters(), lr=CONFIG['learning_rate_controller']
13 )
14
15 def train_step(self, architecture, log_probs):
16 # sample a batch of few-shot tasks
17 task_data = self.meta_dataset.get_tasks(
18 num_tasks=4,
19 ways=3, # number of classes in the task
20 shots=5,
21 query_size=15
22 )
23
24 rewards = []
25 losses = []
26
27 # outer loop: iterate through tasks in the batch
28 for task in task_data:
29 # update the base model with the new architecture
30 model = BaseModel(
31 architecture,
32 CONFIG['input_channels'],
33 CONFIG['ways']
34 ).to(CONFIG['device'])
35
36 # Use an Adam optimizer for the base model for inner loop
37 child_optimizer = torch.optim.Adam(
38 model.parameters(), lr=CONFIG['learning_rate_child_inner']
39 )
40
41 # inner loop (MAML). create support and query sets
42 support_set_x, support_set_y = task['support']
43 query_set_x, query_set_y = task['query']
44
45 support_set_x = support_set_x.to(CONFIG['device'])
46 support_set_y = support_set_y.to(CONFIG['device'])
47 query_set_x = query_set_x.to(CONFIG['device'])
48 query_set_y = query_set_y.to(CONFIG['device'])
49
50 # clone the model's initial weights for the inner loop
51 inner_loop_model = copy.deepcopy(model)
52 inner_loop_optimizer = torch.optim.Adam(
53 inner_loop_model.parameters(), lr=CONFIG['learning_rate_child_inner']
54 )
55
56 # train on the support set to get adapted weights
57 support_preds = inner_loop_model(support_set_x)
58 support_loss = F.cross_entropy(support_preds, support_set_y)
59
60 inner_loop_optimizer.zero_grad()
61 support_loss.backward()
62 inner_loop_optimizer.step()
63
64 # outer loop - evaluation and reward computation
65 with torch.no_grad():
66 query_preds = inner_loop_model(query_set_x)
67 query_loss = F.cross_entropy(query_preds, query_set_y)
68
69 # calculate accuracy as the reward for the controller
70 _, predicted = torch.max(query_preds.data, 1)
71 accuracy = (predicted == query_set_y).sum().item() / query_set_y.size(0)
72
73 rewards.append(accuracy)
74 losses.append(query_loss.item())
75
76 avg_reward = np.mean(rewards)
77 avg_loss = np.mean(losses)
78
79 # compute the policy gradient loss for the controller based on this batch's performance
80 reward = torch.tensor(avg_reward).to(CONFIG['device'])
81 loss_controller = -torch.mean(log_probs * reward) # use average
82
83 # update the controller (Step 3)
84 self.controller_optimizer.zero_grad()
85 loss_controller.backward()
86 self.controller_optimizer.step()
87
88 return avg_reward, avg_loss
89
90
91# defines MAML trainer using synthetic dataset
92meta_dataset = SimulatedFewShotDataset(class_labels=['cat', 'dog', 'bird'])
93trainer = MAMLTrainer(controller, meta_dataset)
94
Lastly, samples a randomly selected architecture using the NAS controller, and evaluate the architecture using the train_step class method in the MAMLTrainer class:
1# sample one architecture for the entire training batch
2current_architecture, log_probs = controller.sample_architecture(
3 max_cells=2,
4 num_ops_per_cell=4
5)
6
7# evaluate the architecture and use its performance to update the controller.
8avg_reward, avg_loss = trainer.train_step(current_architecture, log_probs)
9
max_cells and num_ops_per_cell can be any integer of our choice.
◼ Step 3. The Feedback - Updating the NAS Algorithm
The last part of the train_step function in the MAMLTrainer class updates the NAS algorithm based on the learning from Step 2:
1 # update the controller (Step 3)
2 self.controller_optimizer.zero_grad()
3 loss_controller.backward()
4 self.controller_optimizer.step()
5
The ave_reward computed in the previous step works as the feedback to the NAS algorithm as we use reinforcement learning.
◼ Step 4. Repeat and Finalize
Step 2 and Step 3 are iterated through epochs:
1best_avg_reward = 0.0
2best_architecture = None
3
4# stores history of rewards and losses
5top_k_architectures_and_rewards = []
6rewards_history = []
7losses_history = []
8
9epochs = 500
10for epoch in range(epochs):
11 current_architecture, log_probs = controller.sample_architecture(
12 max_cells=2, num_ops_per_cell=4
13 )
14
15 avg_reward, avg_loss = trainer.train_step(current_architecture, log_probs)
16
17 rewards_history.append(avg_reward)
18 losses_history.append(avg_loss)
19
20 # select top performing architectures
21 if len(top_k_architectures_and_rewards) < 5:
22 # if the list isn't full, just add the new architecture
23 top_k_architectures_and_rewards.append((avg_reward, current_architecture))
24 else:
25 # if the list is full, check if the current reward is better than the worst in the list
26 top_k_architectures_and_rewards.sort(key=lambda x: x[0], reverse=True)
27 if avg_reward > top_k_architectures_and_rewards[-1][0]:
28 top_k_architectures_and_rewards[-1] = (avg_reward, current_architecture)
29
◼ Results
After running 500 epochs with five shot per class, the final top 5 architectures are:
Rank 1: Reward = 0.3667
Cell 1: - identity - avgpool3x3 - identity - conv3x3
Cell 2: - conv3x3 - conv3x3 - avgpool3x3 - conv5x5
Rank 2: Reward = 0.3556
Cell 1: - conv5x5 - conv5x5 - conv3x3 - identity
Cell 2: - maxpool3x3 - conv5x5 - conv3x3 - maxpool3x3
Rank 3: Reward = 0.3556
Cell 1: - avgpool3x3 - avgpool3x3 - conv5x5 - avgpool3x3
Cell 2: - conv3x3 - conv3x3 - avgpool3x3 - conv5x5
Rank 4: Reward = 0.3556
Cell 1: - maxpool3x3 - identity - avgpool3x3 - identity
Cell 2: - maxpool3x3 - conv3x3 - avgpool3x3 - conv3x3
Rank 5: Reward = 0.3500
Cell 1: - avgpool3x3 - conv5x5 - conv5x5 - conv3x3
Cell 2: - conv3x3 - maxpool3x3 - conv3x3 - conv5x5

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure B. Leaning curve of the NAS controller with MAML (Left: Accuracy, Right: Loss) (Created by Kuriko IWAI)
◼ Applying the Meta-Learning Model to New, Unseen Tasks
Now, let us see how the model works for the animal classification task with elephant, lion, zebra, giraffe , instead of cat, dog, and bird.
I ran the same five shot learning for each class using the run_new_task function:
1import torch
2import torch.nn.functional as F
3import copy
4import numpy as np
5
6def run_new_task(
7 best_architecture,
8 epochs,
9 ways,
10 shots,
11 query_size,
12 num_tasks_per_batch,
13 new_task_classes
14 ):
15 # initiate a new few-shot dataset
16 new_meta_dataset = SimulatedFewShotDataset(class_labels=new_task_classes)
17
18 # initiate the base cnn model with the best architecture
19 model = BaseModel(
20 best_architecture=best_architecture,
21 input_channels=3,
22 num_classes=ways
23 ).to(CONFIG['device'])
24
25 # optimizer for meta-learner
26 child_optimizer = torch.optim.Adam(
27 model.parameters(), lr=0.001
28 )
29
30 losses_history = []
31
32 # training
33 for _ in range(epochs):
34 task_data = new_meta_dataset.get_tasks(
35 num_tasks_per_batch,
36 ways,
37 shots,
38 query_size
39 )
40
41 batch_losses = []
42 for task in task_data:
43 # create datasets
44 support_set_x, support_set_y = task['support']
45 query_set_x, query_set_y = task['query']
46
47 support_set_x = support_set_x.to(CONFIG['device'])
48 support_set_y = support_set_y.to(CONFIG['device'])
49 query_set_x = query_set_x.to(CONFIG['device'])
50 query_set_y = query_set_y.to(CONFIG['device'])
51
52 # clone the model for the inner loop
53 inner_loop_model = copy.deepcopy(model)
54 inner_loop_optimizer = torch.optim.Adam(
55 inner_loop_model.parameters(), lr=0.001
56 )
57
58 # inner loop: train on the support set
59 support_preds = inner_loop_model(support_set_x)
60 support_loss = F.cross_entropy(support_preds, support_set_y)
61
62 inner_loop_optimizer.zero_grad()
63 support_loss.backward()
64 inner_loop_optimizer.step()
65
66 # outer loop: evaluate on the query set and update the main model
67 query_preds = inner_loop_model(query_set_x)
68 query_loss = F.cross_entropy(query_preds, query_set_y)
69
70 batch_losses.append(query_loss.item())
71
72 avg_loss = np.mean(batch_losses)
73 losses_history.append(avg_loss)
74
75 # update the main model's weights using the average gradient from the batch
76 # This is a simplified outer loop update
77 model_loss = F.cross_entropy(model(query_set_x), query_set_y)
78 child_optimizer.zero_grad()
79 model_loss.backward()
80 child_optimizer.step()
81
82
83# execute the function
84new_task_epochs = 500
85new_task_ways = 4
86new_task_shots = 5
87new_task_query_size = 15
88new_task_classes = ['elephant', 'lion', 'zebra', 'giraffe']
89
90# reuse the best_architecture found in the NAS training loop
91top_k_architectures_and_rewards.sort(key=lambda x: x[0], reverse=True)
92best_architecture = top_k_architectures_and_rewards[0][1]
93
94run_new_task(
95 best_architecture=best_architecture,
96 epochs=new_task_epochs,
97 ways=new_task_ways,
98 shots=new_task_shots,
99 query_size=new_task_query_size,
100 num_tasks_per_batch=CONFIG['num_tasks_per_batch'],
101 new_task_classes=new_task_classes
102)
103
The point here is that we can use the best_architecture that the NAS controller found without rerunning the architectural design tuning.
◼ Results
The architecture found by the NAS algorithm achieved accuracy of 0.3611, similar to the initial training task of 0.3444.
This indicates that it can generalize the learning ability to a completely new class like elephants as far as it is provided with a few examples to adapt to the new task.
In summary, the cat, dog, and bird classes serve as a proxy for the general problem of few-shot image classification.
Best Accuracy: 0.3611

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure C. Loss history during the training epochs (Created by Kuriko IWAI)
◼ Optional Discussion - Using Reptile Instead of the MAML Method
Optionally, we can use Reptile instead of the MAML as a meta-learning algorithm.
Reptile performs a meta-update by adjusting model parameters toward the adapted parameters found after a training step.
1import torch
2import torch.nn.functional as F
3import copy
4import numpy as np
5
6
7class ReptileTrainer:
8 def __init__(self, controller, meta_dataset):
9 self.controller = controller
10 self.meta_dataset = meta_dataset
11 self.controller_optimizer = torch.optim.Adam(
12 self.controller.parameters(), lr=0.001
13 )
14
15 def train_step(self, architecture, log_probs):
16 # sample a batch of few-shot tasks
17 task_data = self.meta_dataset.get_tasks(
18 num_tasks=4,
19 ways=3,
20 shots=5,
21 query_size=15
22 )
23
24 rewards = []
25 losses = []
26
27 # outer loop: iterate through tasks in the batch
28 # this is the key difference from the MAML method where the base model is initialized every task evaluation.
29 model = BaseModel(
30 architecture=architecture,
31 input_channels=3,
32 num_classes=3
33 ).to(CONFIG['device'])
34
35 for task in task_data:
36 # inner loop (Reptile)
37 # create datasets
38 support_set_x, support_set_y = task['support']
39 query_set_x, query_set_y = task['query']
40
41 support_set_x = support_set_x.to(CONFIG['device'])
42 support_set_y = support_set_y.to(CONFIG['device'])
43 query_set_x = query_set_x.to(CONFIG['device'])
44 query_set_y = query_set_y.to(CONFIG['device'])
45
46 # store the original weights before adaptation
47 original_weights = copy.deepcopy(model.state_dict())
48
49 # inner loop update
50 child_optimizer = torch.optim.Adam(
51 model.parameters(), lr=0.001,
52 )
53
54 support_preds = model(support_set_x)
55 support_loss = F.cross_entropy(support_preds, support_set_y)
56 child_optimizer.zero_grad()
57 support_loss.backward()
58 child_optimizer.step()
59
60 # outer loop - ealuation and reward calculation ---
61 # evaluate the *adapted* model on the query set to get the reward for the controller
62 with torch.no_grad():
63 query_preds = model(query_set_x)
64 query_loss = F.cross_entropy(query_preds, query_set_y)
65
66 # calculate accuracy as the reward for the controller
67 _, predicted = torch.max(query_preds.data, 1)
68 accuracy = (predicted == query_set_y).sum().item() / query_set_y.size(0)
69
70 rewards.append(accuracy)
71 losses.append(query_loss.item())
72
73 # core of reptile update: move original model towards the adapted weights
74 adapted_weights = model.state_dict()
75 with torch.no_grad():
76 for key in original_weights.keys():
77 # compute the model parameters
78 new_param = original_weights[key] + 0.001 * (adapted_weights[key] - original_weights[key])
79 model.state_dict()[key].copy_(new_param)
80
81 avg_reward = np.mean(rewards)
82 avg_loss = np.mean(losses)
83
84 # compute rewards
85 reward = torch.tensor(avg_reward).to(CONFIG['device'])
86 loss_controller = -torch.mean(log_probs * reward)
87
88 # update the controller
89 self.controller_optimizer.zero_grad()
90 loss_controller.backward()
91 self.controller_optimizer.step()
92
93 return avg_reward, avg_loss
94
◼ Results
Reptile outperformed the MAML with 0.3389 accuracy, yet slightly struggles in handling the generalization task with the same elephant, lion, zebra, and giraffe classification.
Best Accuracy: 0.3389
Best Generalization Accuracy: 0.3722
Final Top 5 Architectures:
Rank 1: Reward = 0.3722
Cell 1: - identity - identity - conv5x5 - identity
Cell 2: - avgpool3x3 - identity - avgpool3x3 - maxpool3x3
Rank 2: Reward = 0.3722
Cell 1: - conv5x5 - conv3x3 - maxpool3x3 - avgpool3x3
Cell 2: - conv3x3 - avgpool3x3 - conv3x3 - conv5x5
Rank 3: Reward = 0.3611
Cell 1: - avgpool3x3 - identity - conv5x5 - avgpool3x3
Cell 2: - conv5x5 - avgpool3x3 - maxpool3x3 - conv3x3
Rank 4: Reward = 0.3556
Cell 1: - avgpool3x3 - conv3x3 - conv5x5 - avgpool3x3
Cell 2: - maxpool3x3 - conv3x3 - conv3x3 - avgpool3x3
Rank 5: Reward = 0.3500
Cell 1: - conv5x5 - identity - conv5x5 - conv5x5
Cell 2: - conv3x3 - conv5x5 - avgpool3x3 - avgpool3x3

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure D. Leaning curve of the NAS controller with Reptile (Left: Accuracy, Right: Loss) (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure E. Loss history during the training epochs (Created by Kuriko IWAI)
Wrapping Up
The combination of meta learning and the NAS algorithm is a competitive approach to build a versatile AI model.
In our experiment, we observed that optimization-based meta-learning methods like MAML and Reptile offer competitive generalization capabilities over the image classification task using completely different sets of animal images.
NAS also supported the process by finding the optimal architectural design for the base model (CNN), allowing the meta learner to reuse the architecture for generalization tasks.
This strong combination can offer variety of applications in addition to meta-learning:
Transfer Learning: The process of fine-tuning a pre-trained model for a new task can take hours or days. With a meta-architecture found by NAS, this fine-tuning can be completed in a fraction of the time, saving immense computational resources and development time.
Personalized Systems: Think of a recommendation engine that can instantly adapt to a new user's preferences after they interact with just a few items, or a language model that quickly adapts its tone to match a user's writing style.
Adaptable Robotics: A robot could learn a new manipulation task with only a few demonstrations, as its underlying neural architecture is already optimized for fast, robust learning.
As computational resources become more accessible and our understanding of meta-learning deepens, NAS for meta-learning will become more and more common.
Continue Your Learning
If you enjoyed this blog, these related entries will complete the picture:
A Comparative Guide to Hyperparameter Optimization Strategies
Optimizing LSTMs with Hyperband: A Comparative Guide to Bandit-Based Tuning
Automating Deep Learning: A Guide to Neural Architecture Search (NAS) Strategies
The Definitive Guide to Machine Learning Loss Functions: From Theory to Implementation
Related Books for Further Understanding
These books cover the wide range of theories and practices; from fundamentals to PhD level.

Linear Algebra Done Right

Foundations of Machine Learning, second edition (Adaptive Computation and Machine Learning series)

Designing Machine Learning Systems: An Iterative Process for Production-Ready Applications

Machine Learning Design Patterns: Solutions to Common Challenges in Data Preparation, Model Building, and MLOps
Share What You Learned
Kuriko IWAI, "Scaling Generalization: Automating Flexible AI with Meta-Learning and NAS" in Kernel Labs
https://kuriko-iwai.com/meta-learning-and-neural-architectural-search
Looking for Solutions?
- Deploying ML Systems 👉 Book a briefing session
- Hiring an ML Engineer 👉 Drop an email
- Learn by Doing 👉 Enroll AI Engineering Masterclass
Written by Kuriko IWAI. All images, unless otherwise noted, are by the author. All experimentations on this blog utilize synthetic or licensed data.


