Implementing Attention Approximation: Transformer Efficiency & Trade-offsr
Self-attention mechanisms and attention approximation techniques in Transformer
By Kuriko IWAI

Table of Contents
IntroductionWhat is Self-AttentionHow Self-Attention Works - The Q-K-V MechanismIntroduction
The Transformer architecture, introduced in the “Attention Is All You Need” paper, has revolutionized Natural Language Processing (NLP).
Its core innovation, the self-attention mechanism, allows models to weigh the importance of different parts of the input sequence.
However, the standard self-attention mechanism suffers from its computational complexity which scales quadratically (O(N²)) as the length of the input sequence N grows, creating a bottleneck especially in tasks with long N such as document summarization or high-resolution image processing.
Attention approximation solve this challenge by reducing the complexity using various techniques.
In this article, I’ll demonstrate how self-attention works and implement major attention approximation techniques, analyzing efficiency and accuracy trade-offs.
What is Self-Attention
Self-attention is a mechanism that allows a model to weigh the importance of different elements like words in the input sequence when processing a single element.
The below diagram illustrates the architecture and computational process of a standard transformer model:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure A. Standard architecture and computation process of an MHA transformer model (Created by Kuriko IWAI based on the paper “Attention Is All You Need”)
Steps 3 to 5 in Figure A leverages the self-attention mechanism where the network decides the relevance of each element in the input sequence to the current element being processed.
This achieves two distinct advantages:
Enhances contextual understanding of the input regardless of the distance between the elements, and
Leverages parallelization to improve training speed of long-sequential tasks.
How Self-Attention Works - The Q-K-V Mechanism
Transformers use three vectors to handle the self-attention:
Query (Q)
Key (K), and
Value (V).
These vectors are created by multiplying the input (X) by three learnerable model parameters, weight matrices (W_Q, W_K, W_V):

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure B. How Q, K, V vectors are computed (Created by Kuriko IWAI)
This mechanism proceeds in four distinct steps:
◼ Step 1: Compute Raw Attention Scores
For a specific query token (i.e., the i-th word), its Query vector (Q_i) is multiplied by the Key vector (K_j) of every token from j=1 to N in the sequence to compute a raw attention score:
where:
S_{i, j}:
Q_i: The Query vector for the i-th token, and
K_j: The Key vector for the j-th token.
The raw attention score indicates how similar the query token i is to the key token j, and thus how much attention i should pay to j.
The set of raw attention scores of the entire sequence is generalized:
where:
S: A set of raw attention scores (NxN matrix),
Q: The Query vector, and
K^T: The transposed Key vector.
◼ Step 2: Compute Attention Weights
Then, the attention score matrix S is scaled to stabilize the gradients:
where:
S’: The scaled raw attention score matrix,
Q: The Query vector,
K: The Key vector,
d_k: The dimension of the Key vector, and
\sqrt{d_k}: The scaling factor.
Then, applies Softmax function to the scaled score to generates a set of attention weights A:
The attention weights are summed up to 1, ensuring all scores are normalized.
◼ Step 3: Compute the Context Vector
Each attention weight calculated in Step 2 is multiplied by the corresponding Value vector (V_j) and summed up to create the final context vector:
where:
Z_i: The Context vector for the i-th token,
N: Total number of the tokens in the sequence,
A_{i,j}: The attention weight for the i-th token and j-th token, and
V_j: The value vector for the j-th token.
This Context vector Z_i is a rich representation of the i-th token, containing information from all other tokens in the sequence, weighted by their calculated importance.
◼ Step 4: Output
Lastly, the resulting matrix Z (Z = AV), which contains all context vectors from Z_1 to Z_N, is passed forward through a linear layer and a residual connection for the next layer of the Transformer.
◼ Standard Components of Self-Attention
The core Q-K-V mechanism is implemented in two standard forms in the Transformer architecture:
The scaled dot-product attention (SDPA), and
Multi-head attention (MHA).
▫ 1. The Scaled Dot-Product Attention (SDPA)
Steps 1 and 2 use the scaled dot-production attention (SDPA), a mathematical formula where the Transformer applies the dot product between Q and K (Step 1), followed by scaling and normalizing the raw attention scores S (Step 2).
The scaling factor in Step 2 \sqrt{d_k} plays a critical role because large scores can push the Softmax function into regions with very small gradients (flat areas), hindering effective learning.
▫ 2. Multi-Head Attention (MHA)
Standard Transformers apply the multi-head attention (MHA) mechanism where the Q, K, V vectors are split into smaller chunks called heads.
The below diagram illustrates how MHA works when total heads H = 8:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure C. Multi-head attention with H=8 (Created by Kuriko IWAI)
It splits all tokens into eight heads, applies the Q-K-V mechanism to each head, and then concatenates the attention outputs Z’s.
Although optional, the process enhances contextual understanding as each head learns to attend to different aspects of the input like subject-verb or object-adjective relationships (hence, they output unique Z’s).
Mathematically, this process is generalized with a random head h:
where:
h: h-th head out of total H heads,
S^{(h)}: The score matrix for head h,
Q^{(h)}, K^{(h)}: The matrices containing all the query and key vectors for head h only, and
d_h: The dimension of the head vectors.
Then, the model computes the context vector for the head h:
where:
Z^{(h)}: The context vector for the head h,
V^{(h)}: The value vector for the head h, and
A^{(h)}: The attention score matrix for the head h, computed by applying Softmax function to the scaled score.
Lastly, the context vectors Z^{(h)} from all H heads are concatenated along the last dimension:
Then, the concatenated output is multiplied by a learned weight matrix of the output layer W^O to project the result back into the desired model dimension:
where:
O_{MHA}: The final output of the Multi-Head Attention block (NxD dimension), and
W^O: The learned output weight matrix (DxD dimension).
This step is critical for mixing the information learned by all heads.
◼ Computation Complexity of the Self-Attention Mechanism
With or without MHA, the complexity of computing the attention matrix is O(N²) because the core operation requires comparing every element in the Query vector with every element in the Key vector.
The below diagram illustrates the operation:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure D. Time complexity of a standard self-attention mechanism (Created by Kuriko IWAI)
When the length of the sequence is N, the computation results in an NxN matrix calculation, generating O(N²) complexity (based on the assumption that both d_v and d_k are much smaller than N).
What is Attention Approximation
Attention Approximation refers to techniques to reduce the computational demands of the standard attention mechanism while preserving its high performance.
The goal is to reduce the complexity to either linear (O(N)) or near-linear (O(N log N)) instead of O(N²).
◼ Types of Attention Approximation Techniques
Major attention approximation techniques fall into the following categories:
Low-Rank Approximation,
Kernel-Based Approximation, and
Sparse Attention.
Let us take a look.
▫ Low-Rank Approximation
Low-rank approximation represents the attention by a product of lower-dimensional matrices based on the assumption where the full attention matrix has high redundancy in its elements:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure E. Time complexity of low-rank approximation (Created by Kuriko IWAI)
The Key vector (K, orange box in Figure E) and Value vector (V, blue box in Figure E) have k x d_k and d_v x k dimensions respectively.
k is much smaller than the sequence length N because it represents a key essence from the entire sequence, reducing the time complexity to O(N).
Major techniques include:
Linformer: Projects the Key (K) and Value (V) matrices onto a smaller, fixed-size dimension (L « N) using learned linear projection matrices. Reduces the complexity to O(NL).
Nyströmformer: Uses the Nyström method where a matrix is approximated with a fixed number of landmark tokens (m « N) from the sequence. Reduces the complexity to O(Nm).
Low-Rank Factorization: Directly approximates the attention matrix A as the product of two lower-dimensional matrices U and V - such that S ≈ U V^T instead of S = Q K^T.
▫ Kernel-Based Approximation
Kernel-based methods omit the quadratic computation of the Softmax function by utilizing the kernel trick and Random Feature Maps (RFM), leading to O(N) complexity:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure F. Time complexity of kernel approximation vs standard self-attention (Created by Kuriko IWAI)
The method skips computing the attention matrix, which requires O(N²) complexity.
Instead, it applies the Value vector (blue box in Figure F) directly to the kernels (purple and orange boxes in Figure F), achieving O(N).
Major techniques include:
Performer (FAVOR+): Finds a positive random feature map ϕ (phi) such that the Softmax kernel exp(x) is approximated by the inner product of features: exp(Q K^T) ≈ ϕ(Q) ϕ(K)^T.
Kernelized Attention: Replaces the Softmax-based attention mechanism with a generalized kernel function k(Q_i, K_j) where the kernel k is an inner product of explicit feature maps (i.e., k(x, y) = ϕ(x)^T ϕ(y)).
▫ Sparse Attention
Sparse attention also avoids calculating the full NxN attention matrix A just like kernel approximation.
But its approach is unique by explicitly restricting the allowed connections between tokens using a mask M:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure G. Time complexity of sparse attention (Created by Kuriko IWAI)
Although the attention vector (pink box in Figure G) has NxN dimensions, the mask M (grey box in Figure G) omits computing only L elements in the attention vector.
Assuming that L is a linear function of N such that L = cN, sparse attention can reduce time complexity to O(N).
Major techniques include:
LongFormer: Combines two types of attention for efficiency, achieving O(N) complexity:
Local Attention: A fixed-size window around each token.
Global Attention: A few pre-selected tokens like the [CLS] token attend to all tokens, and all tokens attend back to them, ensuring global context flow.
Reformer: Uses Locality-Sensitive Hashing (LSH) to group similar queries Q and keys K and only computes attention within these small, relevant clusters. Reduces complexity to O(NlogN).
Sparse Transformer: Uses fixed strided and fixed-pattern sparsity (e.g., attention to every k-th token, or only the k previous tokens) to define the sparse attention matrix. Reduces the complexity to O(N \sqrt{N}) or O(NlogN).
Analysis in Action
In this section, I’ll demonstrate how to implement major approximation methods and analyze the trade-off between computational efficiency and model expressiveness.
◼ The Task & Data
The transformers are trained to perform a English-to-French translation task using the dataset:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure H. Sample dataset (source)
◼ The Transformers
I’ll instantiate the four Transformers with different attention layer:
Standard MHA (O(N²)): The baseline. Expects the maximum context capture.
Linformer (O(NL)): Represents a simple and efficient approach of learning a compressed representation with a fixed dimension (L).
Performer (O(N)): Rigorous approach avoiding the attention matrix through the kernel trick.
LongFormer (O(N)): Solves the local context limitation by ingeniously combining local windowed attention with global attention.
All transformers have the encoder-decoder architecture, and the same attention layer is applied to both encoder and decoder.
◼ Step 1. Computing K, V, and Q Vectors
The first step is to extract the raw data and compute the K, V, Q vectors:
1import pandas as pd
2import torch
3from typing import Tuple, List
4from transformers import AutoTokenizer
5
6
7CSV_PATH = filepath
8DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
10# pre-trained tokenizer
11TOKENIZER = AutoTokenizer.from_pretrained("t5-small", model_max_length=N)
12
13
14# extract the raw data and structure Q, K, V, Y_true tensors
15def extract_data(
16 batch_size: int,
17 seq_len: int,
18 d_model: int,
19 device: torch.device = DEVICE,
20 csv_path: str = CSV_PATH
21 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # q, k, v, y_true
22
23 chunk_iterator = pd.read_csv(
24 csv_path,
25 chunksize=batch_size * 2,
26 header=None, # ignore file header
27 names=['en', 'fr'],
28 usecols=['en', 'fr'],
29 on_bad_lines='skip'
30 )
31 first_chunk = next(chunk_iterator)
32
33 # tokenization and padding for input data (en)
34 raw_english_sentences: List[str] = first_chunk['en'].astype(str).tolist()[:batch_size]
35 tokenized_inputs = TOKENIZER(raw_english_sentences, padding='max_length', truncation=True, max_length=seq_len, return_tensors='pt')
36
37 # tokenization and padding for the target data (fr)
38 raw_french_sentences: List[str] = first_chunk['fr'].astype(str).tolist()[:batch_size]
39 tokenized_targets = TOKENIZER(raw_french_sentences, padding='max_length', truncation=True, max_length=seq_len, return_tensors='pt')
40
41 # total tokens
42 N_tokens = tokenized_inputs['input_ids'].shape[1]
43
44 # simulate Q, K, V vectors (used by the encoder)
45 Q = torch.randn(batch_size, N_tokens, d_model, device=device)
46 K = torch.randn(batch_size, N_tokens, d_model, device=device)
47 V = torch.randn(batch_size, N_tokens, d_model, device=device)
48
49 # true fr target token ids to compute loss
50 Y_true = tokenized_targets['input_ids'].to(device)
51
52 return Q, K, V, Y_true
53
54
55# extract train and test data with different sequence length
56N = 2048
57N_TEST = 4096
58D_MODEL = 512
59BATCH_SIZE = 4
60CSV_PATH = filepath
61
62Q_base, K_base, V_base, Y_true_base = extract_data(BATCH_SIZE, N, D_MODEL, DEVICE)
63Q_test, K_test, V_test, Y_true_test = extract_data(BATCH_SIZE, N_TEST, D_MODEL, DEVICE)
64
◼ Step 2. Defining Attention Layers
Next, I’ll define each attention layer using the PyTorch library:
1import torch
2import torch.nn as nn
3import torch.nn.functional as f
4
5# standard mha self-attention (o(n^2))
6class StandardMHA(nn.Module):
7 def __init__(self, d_k: int):
8 super().__init__()
9 self.scaling_factor = 1 / (d_k ** 0.5)
10
11 def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
12 S = torch.matmul(Q, K.transpose(-2, -1)) * self.scaling_factor # scaled score s
13 A = torch.softmax(S, dim=-1) # attention vector a
14 Z = torch.matmul(A, V) # context vector z - [b, n, d]
15 return Z
16
17
18# linformer attention (o(nl))
19class Linformer(nn.Module):
20 def __init__(self, N_max: int, L: int, d_k: int):
21 super().__init__()
22 # use n_max (which is n from global) to size the projection matrix
23 self.E = nn.Parameter(torch.randn(N_max, L, device=DEVICE), requires_grad=True) # projection vector (learned parameter unique to linformer)
24 self.scaling_factor = 1 / (d_k ** 0.5)
25
26 def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
27 N_curr = K.shape[1]
28
29 # fix: allow the projection matrix to handle sequences up to its max size (n)
30 if N_curr > self.E.shape[0]:
31 raise RuntimeError(f"... linformer: sequence length {N_curr} exceeds fixed projection size {self.E.shape[0]}. failed to generalize ...")
32
33 # limited elements l (use the necessary part of the learned projection)
34 L_proj = self.E[:N_curr, :] # use l_proj to avoid name conflict with l in self.e
35
36 # compress key and value vectors using l
37 K_compressed = torch.matmul(K.transpose(-2, -1), L_proj).transpose(-2, -1)
38 V_compressed = torch.matmul(V.transpose(-2, -1), L_proj).transpose(-2, -1)
39
40 # compute context vector z
41 S = torch.matmul(Q, K_compressed.transpose(-2, -1)) * self.scaling_factor
42 A = torch.softmax(S, dim=-1)
43 Z = torch.matmul(A, V_compressed)
44 return Z
45
46
47# performer attention (o(n))
48class Performer(nn.Module):
49 def __init__(self, d_model: int, m_features: int):
50 super().__init__()
51 self.m = m_features
52 self.d_model = d_model
53 # random features normalized by 1/sqrt(m)
54 self.random_features = nn.Parameter(torch.randn(d_model, m_features, device=DEVICE) / (m_features ** 0.5), requires_grad=False) # fixed random features
55
56 def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
57 # simplified rfa approximation: exp(q' * k') * v / exp(q' * sum(k'))
58
59 # apply elu + 1 to the feature map for positive approximation of exp(x)
60 Q_prime = f.elu(torch.matmul(Q, self.random_features)) + 1
61 K_prime = f.elu(torch.matmul(K, self.random_features)) + 1
62
63 # term 1: k_prime.t @ v -> [b, m, d_model] - numerator: q_prime @ term 1 -> [b, n, d_model]
64 numerator_term_1 = torch.matmul(K_prime.transpose(-2, -1), V)
65 numerator = torch.matmul(Q_prime, numerator_term_1)
66
67 # denominator term 1: sum(k_prime) -> [b, 1, m]
68 denominator_term_1 = torch.sum(K_prime, dim=1).unsqueeze(1)
69 denominator = torch.matmul(Q_prime, denominator_term_1.transpose(-2, -1))
70
71 epsilon = 1e-6 # increased epsilon slightly for stability
72 Z = numerator / (denominator + epsilon)
73 return Z
74
75
76# longformer (o(nw))
77class Longformer(nn.Module):
78 def __init__(self, window_size: int, d_k: int):
79 super().__init__()
80 self.w = window_size
81 self.scaling_factor = 1 / (d_k ** 0.5)
82 self.neg_inf = -1e9
83
84 def _create_local_mask(self, N_curr: int, device: torch.device) -> torch.Tensor:
85 indices = torch.arange(N_curr, device=device).unsqueeze(-1)
86 diff_matrix = torch.abs(indices - indices.transpose(0, 1))
87 mask = (diff_matrix <= self.w)
88 return mask.bool()
89
90 def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
91 B, N_curr, _ = Q.shape
92
93 S = torch.matmul(Q, K.transpose(-2, -1)) * self.scaling_factor # attention scores
94
95 # create and apply a mask m
96 local_mask = self._create_local_mask(N_curr, Q.device)
97 M = local_mask.unsqueeze(0).expand(B, -1, -1)
98 masked_scores = S.masked_fill(~M, self.neg_inf)
99
100 # compute z
101 A = torch.softmax(masked_scores, dim=-1)
102 Z = torch.matmul(A, V)
103 return Z
104
◼ Step 3. Instantiating the Transformers
Next, I’ll define a full Seq2SeqTransformer class using the encoder and decoder layers.
Both layers take one of the attention layer defined in Step 2:
1# standard encoder layer w/ attention module
2class EncoderLayer(nn.Module):
3 def __init__(self, attention_module: nn.Module, d_model: int):
4 super().__init__()
5 self.attn = attention_module # takes one of the four attention classes
6 self.ffn = nn.Sequential(
7 nn.Linear(d_model, d_model * 4),
8 nn.ReLU(),
9 nn.Linear(d_model * 4, d_model)
10 )
11 self.norm1 = nn.LayerNorm(d_model)
12 self.norm2 = nn.LayerNorm(d_model)
13 self.dropout1 = nn.Dropout(0.1)
14 self.dropout2 = nn.Dropout(0.1)
15
16 def forward(self, x: torch.Tensor) -> torch.Tensor:
17 # attention block
18 norm_x = self.norm1(x)
19
20 # encoder has the same q, k, v (norm_x, norm_x, norm_x)
21 attn_out = self.attn(norm_x, norm_x, norm_x)
22 x = x + self.dropout1(attn_out) # residual connection
23
24 # feed-forward block
25 norm_x = self.norm2(x)
26 ffn_out = self.ffn(norm_x)
27 output = x + self.dropout2(ffn_out) # residual connection
28 return output
29
30
31# standard decoder layer
32class DecoderLayer(nn.Module):
33 def __init__(self, attention_module: nn.Module, d_model: int):
34 super().__init__()
35 # decoder self-attention (same type as encoder for this simulation)
36 self.self_attn = attention_module
37 # encoder-decoder attention
38 self.cross_attn = StandardMHA(d_model) # apply standardmha
39
40 self.ffn = nn.Sequential(
41 nn.Linear(d_model, d_model * 4),
42 nn.ReLU(),
43 nn.Linear(d_model * 4, d_model)
44 )
45
46 self.norm1 = nn.LayerNorm(d_model) # self-attn norm
47 self.norm2 = nn.LayerNorm(d_model) # cross-attn norm
48 self.norm3 = nn.LayerNorm(d_model) # ffn norm
49 self.dropout1 = nn.Dropout(0.1)
50 self.dropout2 = nn.Dropout(0.1)
51 self.dropout3 = nn.Dropout(0.1)
52
53 # takes the decoder input (tgt) and the encoder output (mem)
54 def forward(self, tgt: torch.Tensor, mem: torch.Tensor) -> torch.Tensor:
55 # masked decoder self-attention (q = k = v = tgt)
56 norm_tgt = self.norm1(tgt)
57 self_attn_out = self.self_attn(norm_tgt, norm_tgt, norm_tgt)
58 tgt = tgt + self.dropout1(self_attn_out) # residual connection
59
60 # encoder-decoder cross-attention (q = tgt, k = v = mem)
61 norm_tgt = self.norm2(tgt)
62 cross_attn_out = self.cross_attn(norm_tgt, mem, mem)
63 tgt = tgt + self.dropout2(cross_attn_out) # residual connection
64
65 # feed-forward block
66 norm_tgt = self.norm3(tgt)
67 ffn_out = self.ffn(norm_tgt)
68 output = tgt + self.dropout3(ffn_out) # residual connection
69 return output
70
71
72# full model
73class TransformerSeq2Seq(nn.Module):
74 def __init__(self, attention_module: nn.Module, d_model: int, vocab_size: int, pad_token_id: int):
75 super().__init__()
76 # encoder layer
77 self.encoder = EncoderLayer(attention_module, d_model)
78
79 # decoder layer
80 self.decoder = DecoderLayer(attention_module, d_model)
81
82 # target token embedding
83 self.target_embedding = nn.Embedding(vocab_size, d_model)
84
85 # final linear/softmax head for logits
86 self.linear_head = nn.Linear(d_model, vocab_size)
87
88 # store pad id for mask generation
89 self.pad_token_id = pad_token_id
90
91
92 def forward(self, Q, K, V, Y_true):
93 # encoder (q, k, v are used in the encoder layer to create encoded memory)
94 mem = self.encoder(Q)
95 # the decoder input is the true target tokens shifted right
96 tgt_input_tokens = Y_true[:, :-1]
97 tgt = self.target_embedding(tgt_input_tokens)
98 decoder_output = self.decoder(tgt, mem)
99 logits = self.linear_head(decoder_output)
100 return logits
101
Then, I’ll define each transformer, inheriting the Seq2SeqTransformer class:
1from transformers import AutoTokenizer
2
3
4# pre-trained tokenizer
5TOKENIZER = AutoTokenizer.from_pretrained("t5-small", model_max_length=N)
6VOCAB_SIZE = TOKENIZER.vocab_size
7
8
9# full seq2seq transformers using diff attention mechanisms
10class TransformerStandard(Seq2SeqTransformer):
11 def __init__(self, d_model: int, vocab_size: int):
12 attn = StandardMHA(d_model)
13 pad_id = TOKENIZER.pad_token_id
14 super().__init__(attn, d_model, vocab_size, pad_token_id=pad_id)
15
16
17class TransformerLinformer(Seq2SeqTransformer):
18 def __init__(self, N_max: int, L: int, D_MODEL: int, vocab_size: int):
19 attn = Linformer(N_max, L, D_MODEL)
20 pad_id = TOKENIZER.pad_token_id
21 super().__init__(attn, D_MODEL, vocab_size, pad_token_id=pad_id)
22
23
24class TransformerPerformer(Seq2SeqTransformer):
25 def __init__(self, D_MODEL: int, m_PERFORMER: int, vocab_size: int):
26 attn = Performer(D_MODEL, m_PERFORMER)
27 pad_id = TOKENIZER.pad_token_id
28 super().__init__(attn, D_MODEL, vocab_size, pad_token_id=pad_id)
29
30class TransformerLongformer(Seq2SeqTransformer):
31 def __init__(self, w_LONGFORMER: int, D_MODEL: int, vocab_size: int):
32 attn = Longformer(w_LONGFORMER, D_MODEL)
33 pad_id = TOKENIZER.pad_token_id
34 super().__init__(attn, D_MODEL, vocab_size, pad_token_id=pad_id)
35
36
37
38# instantiate full seq2seq transformer models
39D_MODEL = 512
40L_LINFORMER = 128 # linformer compressed sequence length L << N
41m_PERFORMER = 256 # performer number of random features m < D
42w_LONGFORMER = 128 # longformer local window size w << N
43
44models = {
45 '1. Standard MHA': TransformerStandard(D_MODEL, VOCAB_SIZE).to(DEVICE),
46 '2. Linformer': TransformerLinformer(Q_base.shape[1], L_LINFORMER, D_MODEL, VOCAB_SIZE).to(DEVICE),
47 '3. Performer': TransformerPerformer(D_MODEL, m_PERFORMER, VOCAB_SIZE).to(DEVICE),
48 '4. LongFormer': TransformerLongformer(w_LONGFORMER, D_MODEL, VOCAB_SIZE).to(DEVICE),
49}
50
◼ Step 4. Training the Standard MHA Transformer
Next, I’ll train the standard MHA model to create a stabilized target for the rest of the models.
I’ll define the train function with early stopping at consecutive 20 epochs without improvement in loss history:
1EPOCHS = 3000
2LEARNING_RATE = 1e-5
3
4
5# training standard mha
6def train(
7 model: nn.Module,
8 Q: torch.Tensor,
9 K: torch.Tensor,
10 V: torch.Tensor,
11 Y_true: torch.Tensor,
12 num_steps: int = EPOCHS,
13 lr: float = LEARNING_RATE
14 ) -> torch.Tensor:
15
16 # target - the model's own random initial output logits
17 model.eval()
18 with torch.no_grad():
19 init_output = model(Q, K, V, Y_true).detach()
20
21 # optimize params
22 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
23
24 # loss & early stop
25 criterion = nn.CrossEntropyLoss(ignore_index=TOKENIZER.pad_token_id)
26
27 # training mode
28 model.train()
29 Y_target = Y_true[:, 1:].flatten()
30
31 # teacher target labels for mse/ce-proxy stabilization
32 target = init_output.view(-1, init_output.shape[-1]).detach()
33
34 # use default patience for stabilization
35 early_stopper = EarlyStopper(patience=20)
36
37 # training epochs
38 for step in range(num_steps):
39 optimizer.zero_grad()
40 output_logits = model(Q, K, V, Y_true)
41 output_for_loss = output_logits.view(-1, output_logits.shape[-1]) # teshape output_logits to [b*n_tokens, vocab_size]
42
43 # stabilization loss (mse proxy used originally, kept here for consistency with original logic)
44 loss = f.mse_loss(output_for_loss, target) # changed to f.mse_loss for clarity
45 loss.backward()
46 optimizer.step()
47
48 # calculate the true language loss for reporting and early stopping
49 with torch.no_grad():
50 true_loss = criterion(output_for_loss, Y_target)
51
52 # early stop
53 if early_stopper.early_stop(true_loss.item()):
54 print(f"... early stopping triggered at step {step + 1}. no improvement in true ce loss for {early_stopper.patience} steps.")
55 break
56
57 model.eval()
58
59 # return the trained output logits (used as the stable distillation target)
60 with torch.no_grad():
61 return model(Q, K, V, Y_true)
62
63
64
65# train the standard mha to get stable targets
66standard_output_target_logits = train(
67 models['1. Standard MHA'], Q_base, K_base, V_base, Y_true_base
68)
69
70# get the test target logits from the trained mha
71with torch.no_grad():
72 standard_output_test_target_logits = models['1. Standard MHA'](Q_test, K_test, V_test, Y_true_test)
73
◼ Step 5. Perform Knowledge Distillation (KD) for Approximation Transformers
Knowledge distillation (KD) is a machine learning technique used to transfer the knowledge from a large, complex model (the teacher, in this project, the standard MHA) to a smaller, more efficient model (the student).
This process compress the model, crucial for deploying powerful models in resource-constrained environments like mobile devices or edge hardware.
I’ll define the run_distillation_simulation function where the rest of the models are optimized based on the MSE loss against the standard MHA model’s performance in Step 4:
1import torch
2
3# kd for the attention approx models
4EPOCHS = 3000
5LEARNING_RATE = 1e-5
6DISTILLATION_TEMP = 3.0 # temperature for distillation
7ALPHA = 0.5
8
9def run_distillation_simulation(
10 model: nn.Module,
11 Q: torch.Tensor,
12 K: torch.Tensor,
13 V: torch.Tensor,
14 Y_true: torch.Tensor,
15 target_model_name: str,
16 stable_target_logits: torch.Tensor,
17 num_steps: int = EPOCHS,
18 lr: float = LEARNING_RATE):
19
20 print(f"\n... training {target_model_name} (runs distillation) ...")
21
22 # define optimizer and loss
23 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
24
25 # loss functions
26 ce_loss = nn.CrossEntropyLoss(ignore_index=TOKENIZER.pad_token_id)
27
28 # set training mode
29 model.train()
30
31 # reshape target for cross-entropy loss (hard labels)
32 Y_target = Y_true[:, 1:].flatten()
33
34 # reshape stable_target_logits (teacher's output)
35 stable_target_for_loss = stable_target_logits.view(-1, stable_target_logits.shape[-1]).detach()
36
37 # early stopper (uses increased patience)
38 early_stopper = EarlyStopper(is_distillation=True)
39
40 # a utility function for kl-divergence on soft targets (kd loss)
41 def distillation_loss(student_logits, teacher_logits, temperature, alpha, true_labels):
42 # soft targets
43 student_log_softmax = f.log_softmax(student_logits / temperature, dim=-1)
44 teacher_softmax = f.softmax(teacher_logits / temperature, dim=-1)
45
46 # kl divergence loss
47 kd_loss = f.kl_div(student_log_softmax, teacher_softmax, reduction='batchmean') * (temperature ** 2)
48
49 # hard targets - standard cross entropy loss on true labels
50 hard_loss = ce_loss(student_logits, true_labels)
51
52 # combined loss
53 return alpha * hard_loss + (1.0 - alpha) * kd_loss
54
55 for step in range(num_steps):
56 optimizer.zero_grad()
57 output_logits = model(Q, K, V, Y_true)
58 output_for_loss = output_logits.view(-1, output_logits.shape[-1])
59
60 # combined distillation loss (kl + ce)
61 loss = distillation_loss(
62 student_logits=output_for_loss,
63 teacher_logits=stable_target_for_loss,
64 temperature=DISTILLATION_TEMP,
65 alpha=ALPHA,
66 true_labels=Y_target
67 )
68
69 loss.backward()
70 optimizer.step()
71
72 # calculate the true language loss for reporting and early stopping
73 with torch.no_grad():
74 # use the standard ce loss for the true performance metric
75 true_loss = ce_loss(output_for_loss, Y_target)
76
77 if early_stopper.early_stop(true_loss.item()):
78 print(f"... early stopping triggered at step {step + 1}. no improvement in true ce loss for {early_stopper.patience} steps.")
79 break
80
81 model.eval()
82
83
84
85# loop the rest of the models
86for name, model in models.items():
87 if name != '1. Standard MHA':
88 run_distillation_simulation(
89 model=model,
90 Q=Q_base,
91 K=K_base,
92 V=V_base,
93 Y_true=Y_true_base,
94 target_model_name=name,
95 stable_target_logits=standard_output_target_logits,
96 num_steps=EPOCHS,
97 )
98
99 # ensure inputs are detached for timing and evaluation
100 Q_base.requires_grad_(False); K_base.requires_grad_(False); V_base.requires_grad_(False)
101 Q_test.requires_grad_(False); K_test.requires_grad_(False); V_test.requires_grad_(False)
102 Y_true_base.requires_grad_(False); Y_true_test.requires_grad_(False)
103
◼ Step 6. Perform Inference
Lastly, I’ll define the perform_inference function where all models are evaluated on the three metrics:
Quality of approximation (based on the loss),
Generalization capabilities, and
Training speed.
1import time
2from typing import Dict
3
4def perform_inference() -> Dict[str, Dict[str, float]]:
5 # recording the results
6 results: Dict[str, Dict[str, float]] = dict()
7
8 # define loss
9 final_criterion = nn.CrossEntropyLoss(ignore_index=TOKENIZER.pad_token_id, reduction='mean')
10
11 # reshape true token targets for loss calc
12 Y_target_base_flat = Y_true_base[:, 1:].flatten()
13 Y_target_test_flat = Y_true_test[:, 1:].flatten()
14
15 # perform inference
16 for name, model in models.items():
17 NUM_RUNS = 10
18
19 # timing setup
20 if DEVICE.type == 'cuda':
21 torch.cuda.synchronize()
22 start_event = torch.cuda.Event(enable_timing=True)
23 end_event = torch.cuda.Event(enable_timing=True)
24 start_event.record()
25 with torch.no_grad():
26 for _ in range(NUM_RUNS):
27 output_logits_base = model(Q_base, K_base, V_base, Y_true_base)
28 end_event.record()
29 torch.cuda.synchronize()
30 avg_time = start_event.elapsed_time(end_event) / 1000 / NUM_RUNS # compute avg. time
31 else:
32 # fallback to cpu
33 start_time = time.time()
34 with torch.no_grad():
35 for _ in range(NUM_RUNS):
36 output_logits_base = model(Q_base, K_base, V_base, Y_true_base)
37 avg_time = (time.time() - start_time) / NUM_RUNS
38
39
40 # analysis 1 - accuracy vs. sparsity / rank (approximation quality)
41 output_logits_base_flat = output_logits_base.view(-1, output_logits_base.shape[-1])
42 standard_output_target_logits_flat = standard_output_target_logits.view(-1, standard_output_target_logits.shape[-1])
43
44 if name == '1. Standard MHA':
45 # standard model's loss against true labels (ce loss)
46 mse_accuracy = final_criterion(output_logits_base_flat, Y_target_base_flat).item()
47 else:
48 # approximation model's loss (mse between logits of student and teacher)
49 mse_accuracy = f.mse_loss(output_logits_base_flat, standard_output_target_logits_flat).item()
50
51
52 # analysis 2 - generalization (approximation on n_test)
53 try:
54 with torch.no_grad():
55 output_logits_test = model(Q_test, K_test, V_test, Y_true_test)
56
57 output_logits_test_flat = output_logits_test.view(-1, output_logits_test.shape[-1])
58 standard_output_test_target_logits_flat = standard_output_test_target_logits.view(-1, standard_output_test_target_logits.shape[-1])
59
60 if name == '1. Standard MHA':
61 mse_test = final_criterion(output_logits_test_flat, Y_target_test_flat).item()
62 generalization_ratio = 1.0 # as target model
63 else:
64 mse_test = f.mse_loss(output_logits_test_flat, standard_output_test_target_logits_flat).item()
65 if mse_accuracy > 1e-9: generalization_ratio = mse_test / mse_accuracy
66 else: generalization_ratio = 1.0
67
68 except RuntimeError as e:
69 # generalization is allowed to fail only if seq len exceeds n_max
70 if 'exceeds fixed projection size' in str(e) and name == '2. Linformer':
71 print(f" [{name}]: generalization test failed (sequence length error: {e})")
72 else:
73 # for performer or longformer, a runtime error is a major failure
74 print(f" [{name}]: generalization test failed (unexpected error: {e})")
75
76 mse_test = 1000.0
77 generalization_ratio = 1000.0
78
79 # store results
80 results[name] = {
81 "Practical Speed (s)": avg_time,
82 "Approximation Quality (Metric)": mse_accuracy,
83 "Generalization Factor (Test Metric / Base Metric)": generalization_ratio
84 }
85
86 return results
87
88
89# perform inference
90results = perform_inference()
91
◼ Results - Analyzing the Tradeoff
The Performer model is the overall best performer, achieving:
The best accuracy proxy value of 0.051980.
Fastest speed among approximation methods.
Efficiency in training, stopping training much earlier than Linformer while achieving much lower true CE loss.
Linformer has the fastest average time, but its generalization test failed due to a sequence length mismatch, making its results invalid for comparison.
LongFormer had the highest average time and the worst generalization ratio. Standard MHA showed the weakest performance in both training loss and the final metric.

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure Z. Comparison of Seq2Seq Transformers (left: approximation quality, mid: generalization, right: training speed)
▫ Training Results (Left in Figure Z)
The Performer and LongFormer models achieved significantly lower final training loss (True CE Loss) compared to Standard MHA and Linformer
Performer and LongFormer demonstrated superior performance in minimizing the true Cross-Entropy (CE) loss during distillation.
Standard MHA stopped very early and maintained a high CE loss (11.5).
Linformer ran for the longest duration (870 steps) but the final true CE loss remained high (6.8).
▫ Generalization (Middle in Figure Z)
Compared to the standard generalization ratio of 1.0, Performer indicates strong generalization capabilities, achieving 1.02, while LongFormer underperform all with the low ratio of 0.68.
▫ Training Speed (Right in Figure Z)
Performer achieved 0.2381 seconds, while LongFormer had the longest time of 0.3168 seconds.
Overall, the Performer model shows the best balance, achieving the lowest accuracy proxy metric with very fast inference time.
Conclusion
Attention approximation methods are powerful tools for improving the efficiency and performance of Transformer models by replacing the quadratic complexity with more efficient mechanisms.
In the experiment, we observed that the Performer model offered the best combination of accuracy and speed among the successfully generalized methods, confirming its effectiveness as an attention approximation technique.
The development of scalable attention mechanisms is not just an optimization but a fundamental enabler for the next generation of AI models.
By selectively reducing complexity, the techniques unlock the potential of Transformers for tasks requiring vast context, pushing the boundaries of what is computationally feasible.
Continue Your Learning
If you enjoyed this blog, these related entries will complete the picture:
The Definitive Guide to LLM Fine-Tuning: Objectivee, Mechanisms, and Hardware
LLM Decoding Strategies: A Guide to Algorithms and Sampling Methods
Tokenization Strategies for LLM Applications
Regularizing LLMs with Kullback-Leibler Divergence
Grouped Query Attention (GQA): Balancing LLM Quality and Speed
Related Books for Further Understanding
These books cover the wide range of theories and practices; from fundamentals to PhD level.

Linear Algebra Done Right

Foundations of Machine Learning, second edition (Adaptive Computation and Machine Learning series)

Designing Machine Learning Systems: An Iterative Process for Production-Ready Applications

Machine Learning Design Patterns: Solutions to Common Challenges in Data Preparation, Model Building, and MLOps

Hands-On Generative AI with Transformers and Diffusion Models
Share What You Learned
Kuriko IWAI, "Implementing Attention Approximation: Transformer Efficiency & Trade-offsr" in Kernel Labs
https://kuriko-iwai.com/attention-approximation
Looking for Solutions?
- Deploying ML Systems 👉 Book a briefing session
- Hiring an ML Engineer 👉 Drop an email
- Learn by Doing 👉 Enroll AI Engineering Masterclass
Written by Kuriko IWAI. All images, unless otherwise noted, are by the author. All experimentations on this blog utilize synthetic or licensed data.




