Implementing Attention Approximation: Transformer Efficiency & Trade-offsr

Self-attention mechanisms and attention approximation techniques in Transformer

Deep LearningData SciencePythonLLM

By Kuriko IWAI

Kuriko IWAI

Table of Contents

IntroductionWhat is Self-AttentionHow Self-Attention Works - The Q-K-V Mechanism
Step 1: Compute Raw Attention Scores
Step 2: Compute Attention Weights
Step 3: Compute the Context Vector
Step 4: Output
Standard Components of Self-Attention
Computation Complexity of the Self-Attention Mechanism
What is Attention Approximation
Types of Attention Approximation Techniques
Analysis in Action
The Task & Data
The Transformers
Step 1. Computing K, V, and Q Vectors
Step 2. Defining Attention Layers
Step 3. Instantiating the Transformers
Step 4. Training the Standard MHA Transformer
Step 5. Perform Knowledge Distillation (KD) for Approximation Transformers
Step 6. Perform Inference
Results - Analyzing the Tradeoff
Conclusion

Introduction

The Transformer architecture, introduced in the “Attention Is All You Need” paper, has revolutionized Natural Language Processing (NLP).

Its core innovation, the self-attention mechanism, allows models to weigh the importance of different parts of the input sequence.

However, the standard self-attention mechanism suffers from its computational complexity which scales quadratically (O(N²)) as the length of the input sequence N grows, creating a bottleneck especially in tasks with long N such as document summarization or high-resolution image processing.

Attention approximation solve this challenge by reducing the complexity using various techniques.

In this article, I’ll demonstrate how self-attention works and implement major attention approximation techniques, analyzing efficiency and accuracy trade-offs.

What is Self-Attention

Self-attention is a mechanism that allows a model to weigh the importance of different elements like words in the input sequence when processing a single element.

The below diagram illustrates the architecture and computational process of a standard transformer model:

Figure A. Standard architecture and computation process of an MHA transformer model (Created by Kuriko IWAI based on the paper “Attention Is All You Need”)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure A. Standard architecture and computation process of an MHA transformer model (Created by Kuriko IWAI based on the paper “Attention Is All You Need”)

Steps 3 to 5 in Figure A leverages the self-attention mechanism where the network decides the relevance of each element in the input sequence to the current element being processed.

This achieves two distinct advantages:

  • Enhances contextual understanding of the input regardless of the distance between the elements, and

  • Leverages parallelization to improve training speed of long-sequential tasks.

How Self-Attention Works - The Q-K-V Mechanism

Transformers use three vectors to handle the self-attention:

  • Query (Q)

  • Key (K), and

  • Value (V).

These vectors are created by multiplying the input (X) by three learnerable model parameters, weight matrices (W_Q, W_K, W_V):

Figure B. How Q, K, V vectors are computed (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure B. How Q, K, V vectors are computed (Created by Kuriko IWAI)

This mechanism proceeds in four distinct steps:

Step 1: Compute Raw Attention Scores

For a specific query token (i.e., the i-th word), its Query vector (Q_i) is multiplied by the Key vector (K_j) of every token from j=1 to N in the sequence to compute a raw attention score:

Si,j=QiKjS_{i, j} = Q_i \cdot K_j

where:

  • S_{i, j}:

  • Q_i: The Query vector for the i-th token, and

  • K_j: The Key vector for the j-th token.

The raw attention score indicates how similar the query token i is to the key token j, and thus how much attention i should pay to j.

The set of raw attention scores of the entire sequence is generalized:

S=QKTS = QK^T

where:

  • S: A set of raw attention scores (NxN matrix),

  • Q: The Query vector, and

  • K^T: The transposed Key vector.

Step 2: Compute Attention Weights

Then, the attention score matrix S is scaled to stabilize the gradients:

S=QKTdkS' = \frac{Q K^T}{\sqrt{d_k}}

where:

  • S’: The scaled raw attention score matrix,

  • Q: The Query vector,

  • K: The Key vector,

  • d_k: The dimension of the Key vector, and

  • \sqrt{d_k}: The scaling factor.

Then, applies Softmax function to the scaled score to generates a set of attention weights A:

A=Softmax(QKTdk)A = \text{Softmax} \left( \frac{Q K^T}{\sqrt{d_k}} \right)

The attention weights are summed up to 1, ensuring all scores are normalized.

Step 3: Compute the Context Vector

Each attention weight calculated in Step 2 is multiplied by the corresponding Value vector (V_j) and summed up to create the final context vector:

Zi=j=1N(Ai,jVj)Z_i = \sum_{j=1}^{N} (A_{i, j} \cdot V_j)

where:

  • Z_i: The Context vector for the i-th token,

  • N: Total number of the tokens in the sequence,

  • A_{i,j}: The attention weight for the i-th token and j-th token, and

  • V_j: The value vector for the j-th token.

This Context vector Z_i is a rich representation of the i-th token, containing information from all other tokens in the sequence, weighted by their calculated importance.

Step 4: Output

Lastly, the resulting matrix Z (Z = AV), which contains all context vectors from Z_1 to Z_N, is passed forward through a linear layer and a residual connection for the next layer of the Transformer.

Standard Components of Self-Attention

The core Q-K-V mechanism is implemented in two standard forms in the Transformer architecture:

  1. The scaled dot-product attention (SDPA), and

  2. Multi-head attention (MHA).

1. The Scaled Dot-Product Attention (SDPA)

Steps 1 and 2 use the scaled dot-production attention (SDPA), a mathematical formula where the Transformer applies the dot product between Q and K (Step 1), followed by scaling and normalizing the raw attention scores S (Step 2).

The scaling factor in Step 2 \sqrt{d_k} plays a critical role because large scores can push the Softmax function into regions with very small gradients (flat areas), hindering effective learning.

2. Multi-Head Attention (MHA)

Standard Transformers apply the multi-head attention (MHA) mechanism where the Q, K, V vectors are split into smaller chunks called heads.

The below diagram illustrates how MHA works when total heads H = 8:

Figure C. Multi-head attention with H=8 (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure C. Multi-head attention with H=8 (Created by Kuriko IWAI)

It splits all tokens into eight heads, applies the Q-K-V mechanism to each head, and then concatenates the attention outputs Z’s.

Although optional, the process enhances contextual understanding as each head learns to attend to different aspects of the input like subject-verb or object-adjective relationships (hence, they output unique Z’s).

Mathematically, this process is generalized with a random head h:

S(h)=Q(h)(K(h))TdhS^{(h)} = \frac{Q^{(h)} (K^{(h)})^T}{\sqrt{d_h}}

where:

  • h: h-th head out of total H heads,

  • S^{(h)}: The score matrix for head h,

  • Q^{(h)}, K^{(h)}: The matrices containing all the query and key vectors for head h only, and

  • d_h: The dimension of the head vectors.

Then, the model computes the context vector for the head h:

Z(h)=Softmax(S(h))V(h)=A(h)V(h)Z^{(h)} = \text{Softmax} (S^{(h)})\cdot V^{(h)} = A^{(h)} V^{(h)}

where:

  • Z^{(h)}: The context vector for the head h,

  • V^{(h)}: The value vector for the head h, and

  • A^{(h)}: The attention score matrix for the head h, computed by applying Softmax function to the scaled score.

Lastly, the context vectors Z^{(h)} from all H heads are concatenated along the last dimension:

Zconcat=concat(Z(1),Z(2),,Z(H))Z_{concat} = \text{concat}(Z^{(1)}, Z^{(2)}, \dots, Z^{(H)})

Then, the concatenated output is multiplied by a learned weight matrix of the output layer W^O to project the result back into the desired model dimension:

OMHA=ZconcatWOO_{MHA} = Z_{concat} W^O

where:

  • O_{MHA}: The final output of the Multi-Head Attention block (NxD dimension), and

  • W^O: The learned output weight matrix (DxD dimension).

This step is critical for mixing the information learned by all heads.

Computation Complexity of the Self-Attention Mechanism

With or without MHA, the complexity of computing the attention matrix is O(N²) because the core operation requires comparing every element in the Query vector with every element in the Key vector.

The below diagram illustrates the operation:

Figure D. Time complexity of a standard self-attention mechanism (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure D. Time complexity of a standard self-attention mechanism (Created by Kuriko IWAI)

When the length of the sequence is N, the computation results in an NxN matrix calculation, generating O(N²) complexity (based on the assumption that both d_v and d_k are much smaller than N).

What is Attention Approximation

Attention Approximation refers to techniques to reduce the computational demands of the standard attention mechanism while preserving its high performance.

The goal is to reduce the complexity to either linear (O(N)) or near-linear (O(N log N)) instead of O(N²).

Types of Attention Approximation Techniques

Major attention approximation techniques fall into the following categories:

  • Low-Rank Approximation,

  • Kernel-Based Approximation, and

  • Sparse Attention.

Let us take a look.

Low-Rank Approximation

Low-rank approximation represents the attention by a product of lower-dimensional matrices based on the assumption where the full attention matrix has high redundancy in its elements:

Figure E. Time complexity of low-rank approximation (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure E. Time complexity of low-rank approximation (Created by Kuriko IWAI)

The Key vector (K, orange box in Figure E) and Value vector (V, blue box in Figure E) have k x d_k and d_v x k dimensions respectively.

k is much smaller than the sequence length N because it represents a key essence from the entire sequence, reducing the time complexity to O(N).

Major techniques include:

  • Linformer: Projects the Key (K) and Value (V) matrices onto a smaller, fixed-size dimension (L « N) using learned linear projection matrices. Reduces the complexity to O(NL).

  • Nyströmformer: Uses the Nyström method where a matrix is approximated with a fixed number of landmark tokens (m « N) from the sequence. Reduces the complexity to O(Nm).

  • Low-Rank Factorization: Directly approximates the attention matrix A as the product of two lower-dimensional matrices U and V - such that S ≈ U V^T instead of S = Q K^T.

Kernel-Based Approximation

Kernel-based methods omit the quadratic computation of the Softmax function by utilizing the kernel trick and Random Feature Maps (RFM), leading to O(N) complexity:

Figure F. Time complexity of kernel approximation vs standard self-attention (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure F. Time complexity of kernel approximation vs standard self-attention (Created by Kuriko IWAI)

The method skips computing the attention matrix, which requires O(N²) complexity.

Instead, it applies the Value vector (blue box in Figure F) directly to the kernels (purple and orange boxes in Figure F), achieving O(N).

Major techniques include:

  • Performer (FAVOR+): Finds a positive random feature map ϕ (phi) such that the Softmax kernel exp(x) is approximated by the inner product of features: exp(Q K^T) ≈ ϕ(Q) ϕ(K)^T.

  • Kernelized Attention: Replaces the Softmax-based attention mechanism with a generalized kernel function k(Q_i, K_j) where the kernel k is an inner product of explicit feature maps (i.e., k(x, y) = ϕ(x)^T ϕ(y)).

Sparse Attention

Sparse attention also avoids calculating the full NxN attention matrix A just like kernel approximation.

But its approach is unique by explicitly restricting the allowed connections between tokens using a mask M:

Figure G. Time complexity of sparse attention (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure G. Time complexity of sparse attention (Created by Kuriko IWAI)

Although the attention vector (pink box in Figure G) has NxN dimensions, the mask M (grey box in Figure G) omits computing only L elements in the attention vector.

Assuming that L is a linear function of N such that L = cN, sparse attention can reduce time complexity to O(N).

Major techniques include:

  • LongFormer: Combines two types of attention for efficiency, achieving O(N) complexity:

    1. Local Attention: A fixed-size window around each token.

    2. Global Attention: A few pre-selected tokens like the [CLS] token attend to all tokens, and all tokens attend back to them, ensuring global context flow.

  • Reformer: Uses Locality-Sensitive Hashing (LSH) to group similar queries Q and keys K and only computes attention within these small, relevant clusters. Reduces complexity to O(NlogN).

  • Sparse Transformer: Uses fixed strided and fixed-pattern sparsity (e.g., attention to every k-th token, or only the k previous tokens) to define the sparse attention matrix. Reduces the complexity to O(N \sqrt{N}) or O(NlogN).

Analysis in Action

In this section, I’ll demonstrate how to implement major approximation methods and analyze the trade-off between computational efficiency and model expressiveness.

The Task & Data

The transformers are trained to perform a English-to-French translation task using the dataset:

Figure H. Sample dataset (source)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure H. Sample dataset (source)

The Transformers

I’ll instantiate the four Transformers with different attention layer:

  • Standard MHA (O(N²)): The baseline. Expects the maximum context capture.

  • Linformer (O(NL)): Represents a simple and efficient approach of learning a compressed representation with a fixed dimension (L).

  • Performer (O(N)): Rigorous approach avoiding the attention matrix through the kernel trick.

  • LongFormer (O(N)): Solves the local context limitation by ingeniously combining local windowed attention with global attention.

All transformers have the encoder-decoder architecture, and the same attention layer is applied to both encoder and decoder.

Step 1. Computing K, V, and Q Vectors

The first step is to extract the raw data and compute the K, V, Q vectors:

1import pandas as pd
2import torch
3from typing import Tuple, List
4from transformers import AutoTokenizer
5
6
7CSV_PATH = filepath
8DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
10# pre-trained tokenizer
11TOKENIZER = AutoTokenizer.from_pretrained("t5-small", model_max_length=N)
12
13
14# extract the raw data and structure Q, K, V, Y_true tensors
15def extract_data(
16        batch_size: int,
17        seq_len: int,
18        d_model: int,
19        device: torch.device = DEVICE,
20        csv_path: str = CSV_PATH
21    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # q, k, v, y_true
22
23    chunk_iterator = pd.read_csv(
24        csv_path, 
25        chunksize=batch_size * 2,
26        header=None, # ignore file header
27        names=['en', 'fr'],
28        usecols=['en', 'fr'], 
29        on_bad_lines='skip'
30    )
31    first_chunk = next(chunk_iterator)
32
33    # tokenization and padding for input data (en)
34    raw_english_sentences: List[str] = first_chunk['en'].astype(str).tolist()[:batch_size]
35    tokenized_inputs = TOKENIZER(raw_english_sentences, padding='max_length', truncation=True, max_length=seq_len, return_tensors='pt')
36
37    # tokenization and padding for the target data (fr)
38    raw_french_sentences: List[str] = first_chunk['fr'].astype(str).tolist()[:batch_size]
39    tokenized_targets = TOKENIZER(raw_french_sentences, padding='max_length', truncation=True, max_length=seq_len, return_tensors='pt')
40
41    # total tokens 
42    N_tokens = tokenized_inputs['input_ids'].shape[1]
43
44    # simulate Q, K, V vectors (used by the encoder)
45    Q = torch.randn(batch_size, N_tokens, d_model, device=device)
46    K = torch.randn(batch_size, N_tokens, d_model, device=device)
47    V = torch.randn(batch_size, N_tokens, d_model, device=device)
48
49    # true fr target token ids to compute loss
50    Y_true = tokenized_targets['input_ids'].to(device)
51
52    return Q, K, V, Y_true
53
54
55# extract train and test data with different sequence length
56N = 2048
57N_TEST = 4096
58D_MODEL = 512
59BATCH_SIZE = 4
60CSV_PATH = filepath
61
62Q_base, K_base, V_base, Y_true_base = extract_data(BATCH_SIZE, N, D_MODEL, DEVICE)
63Q_test, K_test, V_test, Y_true_test = extract_data(BATCH_SIZE, N_TEST, D_MODEL, DEVICE)
64

Step 2. Defining Attention Layers

Next, I’ll define each attention layer using the PyTorch library:

1import torch
2import torch.nn as nn
3import torch.nn.functional as f
4
5# standard mha self-attention (o(n^2))
6class StandardMHA(nn.Module):
7    def __init__(self, d_k: int):
8        super().__init__()
9        self.scaling_factor = 1 / (d_k ** 0.5)
10
11    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
12        S = torch.matmul(Q, K.transpose(-2, -1)) * self.scaling_factor # scaled score s
13        A = torch.softmax(S, dim=-1) # attention vector a
14        Z = torch.matmul(A, V) # context vector z - [b, n, d]
15        return Z
16
17
18# linformer attention (o(nl))
19class Linformer(nn.Module):
20    def __init__(self, N_max: int, L: int, d_k: int):
21        super().__init__()
22        # use n_max (which is n from global) to size the projection matrix
23        self.E = nn.Parameter(torch.randn(N_max, L, device=DEVICE), requires_grad=True) # projection vector (learned parameter unique to linformer)
24        self.scaling_factor = 1 / (d_k ** 0.5)
25
26    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
27        N_curr = K.shape[1]
28
29        # fix: allow the projection matrix to handle sequences up to its max size (n)
30        if N_curr > self.E.shape[0]:
31             raise RuntimeError(f"... linformer: sequence length {N_curr} exceeds fixed projection size {self.E.shape[0]}. failed to generalize ...")
32
33        # limited elements l (use the necessary part of the learned projection)
34        L_proj = self.E[:N_curr, :] # use l_proj to avoid name conflict with l in self.e
35
36        # compress key and value vectors using l
37        K_compressed = torch.matmul(K.transpose(-2, -1), L_proj).transpose(-2, -1)
38        V_compressed = torch.matmul(V.transpose(-2, -1), L_proj).transpose(-2, -1)
39
40        # compute context vector z
41        S = torch.matmul(Q, K_compressed.transpose(-2, -1)) * self.scaling_factor
42        A = torch.softmax(S, dim=-1) 
43        Z = torch.matmul(A, V_compressed)
44        return Z
45
46
47# performer attention (o(n))
48class Performer(nn.Module):
49    def __init__(self, d_model: int, m_features: int):
50        super().__init__()
51        self.m = m_features
52        self.d_model = d_model
53        # random features normalized by 1/sqrt(m)
54        self.random_features = nn.Parameter(torch.randn(d_model, m_features, device=DEVICE) / (m_features ** 0.5), requires_grad=False) # fixed random features
55
56    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
57        # simplified rfa approximation: exp(q' * k') * v / exp(q' * sum(k'))
58
59        # apply elu + 1 to the feature map for positive approximation of exp(x)
60        Q_prime = f.elu(torch.matmul(Q, self.random_features)) + 1
61        K_prime = f.elu(torch.matmul(K, self.random_features)) + 1
62
63        # term 1: k_prime.t @ v -> [b, m, d_model]  - numerator: q_prime @ term 1 -> [b, n, d_model]
64        numerator_term_1 = torch.matmul(K_prime.transpose(-2, -1), V)
65        numerator = torch.matmul(Q_prime, numerator_term_1)
66
67        # denominator term 1: sum(k_prime) -> [b, 1, m]
68        denominator_term_1 = torch.sum(K_prime, dim=1).unsqueeze(1)
69        denominator = torch.matmul(Q_prime, denominator_term_1.transpose(-2, -1)) 
70
71        epsilon = 1e-6 # increased epsilon slightly for stability
72        Z = numerator / (denominator + epsilon)
73        return Z
74
75
76# longformer (o(nw))
77class Longformer(nn.Module):
78    def __init__(self, window_size: int, d_k: int):
79        super().__init__()
80        self.w = window_size
81        self.scaling_factor = 1 / (d_k ** 0.5)
82        self.neg_inf = -1e9
83
84    def _create_local_mask(self, N_curr: int, device: torch.device) -> torch.Tensor:
85        indices = torch.arange(N_curr, device=device).unsqueeze(-1)
86        diff_matrix = torch.abs(indices - indices.transpose(0, 1))
87        mask = (diff_matrix <= self.w)
88        return mask.bool()
89
90    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
91        B, N_curr, _ = Q.shape
92
93        S = torch.matmul(Q, K.transpose(-2, -1)) * self.scaling_factor # attention scores 
94
95        # create and apply a mask m
96        local_mask = self._create_local_mask(N_curr, Q.device)
97        M = local_mask.unsqueeze(0).expand(B, -1, -1)
98        masked_scores = S.masked_fill(~M, self.neg_inf)
99
100        # compute z
101        A = torch.softmax(masked_scores, dim=-1)
102        Z = torch.matmul(A, V)
103        return Z
104

Step 3. Instantiating the Transformers

Next, I’ll define a full Seq2SeqTransformer class using the encoder and decoder layers.

Both layers take one of the attention layer defined in Step 2:

1# standard encoder layer w/ attention module
2class EncoderLayer(nn.Module):
3    def __init__(self, attention_module: nn.Module, d_model: int):
4        super().__init__()
5        self.attn = attention_module # takes one of the four attention classes
6        self.ffn = nn.Sequential(
7            nn.Linear(d_model, d_model * 4),
8            nn.ReLU(),
9            nn.Linear(d_model * 4, d_model)
10        )
11        self.norm1 = nn.LayerNorm(d_model)
12        self.norm2 = nn.LayerNorm(d_model)
13        self.dropout1 = nn.Dropout(0.1)
14        self.dropout2 = nn.Dropout(0.1)
15
16    def forward(self, x: torch.Tensor) -> torch.Tensor:
17        # attention block
18        norm_x = self.norm1(x)
19
20        # encoder has the same q, k, v (norm_x, norm_x, norm_x)
21        attn_out = self.attn(norm_x, norm_x, norm_x)
22        x = x + self.dropout1(attn_out) # residual connection
23
24        # feed-forward block
25        norm_x = self.norm2(x)
26        ffn_out = self.ffn(norm_x)
27        output = x + self.dropout2(ffn_out) # residual connection
28        return output
29
30
31# standard decoder layer
32class DecoderLayer(nn.Module):
33    def __init__(self, attention_module: nn.Module, d_model: int):
34        super().__init__()
35        # decoder self-attention (same type as encoder for this simulation)
36        self.self_attn = attention_module 
37        # encoder-decoder attention
38        self.cross_attn = StandardMHA(d_model) # apply standardmha
39
40        self.ffn = nn.Sequential(
41            nn.Linear(d_model, d_model * 4),
42            nn.ReLU(),
43            nn.Linear(d_model * 4, d_model)
44        )
45
46        self.norm1 = nn.LayerNorm(d_model) # self-attn norm
47        self.norm2 = nn.LayerNorm(d_model) # cross-attn norm
48        self.norm3 = nn.LayerNorm(d_model) # ffn norm
49        self.dropout1 = nn.Dropout(0.1)
50        self.dropout2 = nn.Dropout(0.1)
51        self.dropout3 = nn.Dropout(0.1)
52
53    # takes the decoder input (tgt) and the encoder output (mem)
54    def forward(self, tgt: torch.Tensor, mem: torch.Tensor) -> torch.Tensor:     
55        # masked decoder self-attention (q = k = v = tgt)
56        norm_tgt = self.norm1(tgt)
57        self_attn_out = self.self_attn(norm_tgt, norm_tgt, norm_tgt) 
58        tgt = tgt + self.dropout1(self_attn_out) # residual connection
59
60        # encoder-decoder cross-attention (q = tgt, k = v = mem)
61        norm_tgt = self.norm2(tgt)
62        cross_attn_out = self.cross_attn(norm_tgt, mem, mem) 
63        tgt = tgt + self.dropout2(cross_attn_out) # residual connection
64
65        # feed-forward block
66        norm_tgt = self.norm3(tgt)
67        ffn_out = self.ffn(norm_tgt)
68        output = tgt + self.dropout3(ffn_out) # residual connection
69        return output
70
71
72# full model
73class TransformerSeq2Seq(nn.Module):
74    def __init__(self, attention_module: nn.Module, d_model: int, vocab_size: int, pad_token_id: int):
75        super().__init__()
76        # encoder layer
77        self.encoder = EncoderLayer(attention_module, d_model)
78
79        # decoder layer
80        self.decoder = DecoderLayer(attention_module, d_model)
81
82        # target token embedding
83        self.target_embedding = nn.Embedding(vocab_size, d_model)
84
85        # final linear/softmax head for logits
86        self.linear_head = nn.Linear(d_model, vocab_size)
87
88        # store pad id for mask generation
89        self.pad_token_id = pad_token_id
90
91
92    def forward(self, Q, K, V, Y_true):
93        # encoder (q, k, v are used in the encoder layer to create encoded memory)
94        mem = self.encoder(Q)
95        # the decoder input is the true target tokens shifted right
96        tgt_input_tokens = Y_true[:, :-1]
97        tgt = self.target_embedding(tgt_input_tokens)
98        decoder_output = self.decoder(tgt, mem)
99        logits = self.linear_head(decoder_output)
100        return logits
101

Then, I’ll define each transformer, inheriting the Seq2SeqTransformer class:

1from transformers import AutoTokenizer
2
3
4# pre-trained tokenizer
5TOKENIZER = AutoTokenizer.from_pretrained("t5-small", model_max_length=N)
6VOCAB_SIZE = TOKENIZER.vocab_size
7
8
9# full seq2seq transformers using diff attention mechanisms
10class TransformerStandard(Seq2SeqTransformer):
11    def __init__(self, d_model: int, vocab_size: int):
12        attn = StandardMHA(d_model)
13        pad_id = TOKENIZER.pad_token_id
14        super().__init__(attn, d_model, vocab_size, pad_token_id=pad_id)
15
16
17class TransformerLinformer(Seq2SeqTransformer):
18    def __init__(self, N_max: int, L: int, D_MODEL: int, vocab_size: int):
19        attn = Linformer(N_max, L, D_MODEL)
20        pad_id = TOKENIZER.pad_token_id
21        super().__init__(attn, D_MODEL, vocab_size, pad_token_id=pad_id)
22
23
24class TransformerPerformer(Seq2SeqTransformer):
25    def __init__(self, D_MODEL: int, m_PERFORMER: int, vocab_size: int):
26        attn = Performer(D_MODEL, m_PERFORMER)
27        pad_id = TOKENIZER.pad_token_id
28        super().__init__(attn, D_MODEL, vocab_size, pad_token_id=pad_id)
29
30class TransformerLongformer(Seq2SeqTransformer):
31    def __init__(self, w_LONGFORMER: int, D_MODEL: int, vocab_size: int):
32        attn = Longformer(w_LONGFORMER, D_MODEL)
33        pad_id = TOKENIZER.pad_token_id
34        super().__init__(attn, D_MODEL, vocab_size, pad_token_id=pad_id)
35
36
37
38# instantiate full seq2seq transformer models
39D_MODEL = 512
40L_LINFORMER = 128   # linformer compressed sequence length L << N
41m_PERFORMER = 256   # performer number of random features m < D
42w_LONGFORMER = 128  # longformer local window size w << N
43
44models = {
45    '1. Standard MHA': TransformerStandard(D_MODEL, VOCAB_SIZE).to(DEVICE),
46    '2. Linformer': TransformerLinformer(Q_base.shape[1], L_LINFORMER, D_MODEL, VOCAB_SIZE).to(DEVICE), 
47    '3. Performer': TransformerPerformer(D_MODEL, m_PERFORMER, VOCAB_SIZE).to(DEVICE),
48    '4. LongFormer': TransformerLongformer(w_LONGFORMER, D_MODEL, VOCAB_SIZE).to(DEVICE),
49}
50

Step 4. Training the Standard MHA Transformer

Next, I’ll train the standard MHA model to create a stabilized target for the rest of the models.

I’ll define the train function with early stopping at consecutive 20 epochs without improvement in loss history:

1EPOCHS = 3000
2LEARNING_RATE = 1e-5
3
4
5# training standard mha
6def train(
7        model: nn.Module,
8        Q: torch.Tensor,
9        K: torch.Tensor,
10        V: torch.Tensor,
11        Y_true: torch.Tensor,
12        num_steps: int = EPOCHS,
13        lr: float = LEARNING_RATE
14        ) -> torch.Tensor:
15
16    # target - the model's own random initial output logits
17    model.eval()
18    with torch.no_grad():
19        init_output = model(Q, K, V, Y_true).detach()
20
21    # optimize params
22    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
23
24    # loss & early stop
25    criterion = nn.CrossEntropyLoss(ignore_index=TOKENIZER.pad_token_id)
26
27    # training mode
28    model.train()
29    Y_target = Y_true[:, 1:].flatten()
30
31    # teacher target labels for mse/ce-proxy stabilization
32    target = init_output.view(-1, init_output.shape[-1]).detach()
33
34    # use default patience for stabilization
35    early_stopper = EarlyStopper(patience=20)
36
37    # training epochs
38    for step in range(num_steps):
39        optimizer.zero_grad()
40        output_logits = model(Q, K, V, Y_true)
41        output_for_loss = output_logits.view(-1, output_logits.shape[-1]) # teshape output_logits to [b*n_tokens, vocab_size]
42
43        # stabilization loss (mse proxy used originally, kept here for consistency with original logic)
44        loss = f.mse_loss(output_for_loss, target) # changed to f.mse_loss for clarity
45        loss.backward() 
46        optimizer.step()
47
48        # calculate the true language loss for reporting and early stopping
49        with torch.no_grad():
50            true_loss = criterion(output_for_loss, Y_target) 
51
52        # early stop
53        if early_stopper.early_stop(true_loss.item()):
54            print(f"... early stopping triggered at step {step + 1}. no improvement in true ce loss for {early_stopper.patience} steps.")
55            break
56
57    model.eval()
58
59    # return the trained output logits (used as the stable distillation target)
60    with torch.no_grad():
61        return model(Q, K, V, Y_true)
62
63
64
65# train the standard mha to get stable targets
66standard_output_target_logits = train(
67    models['1. Standard MHA'], Q_base, K_base, V_base, Y_true_base
68)
69
70# get the test target logits from the trained mha
71with torch.no_grad():
72    standard_output_test_target_logits = models['1. Standard MHA'](Q_test, K_test, V_test, Y_true_test)
73

Step 5. Perform Knowledge Distillation (KD) for Approximation Transformers

Knowledge distillation (KD) is a machine learning technique used to transfer the knowledge from a large, complex model (the teacher, in this project, the standard MHA) to a smaller, more efficient model (the student).

This process compress the model, crucial for deploying powerful models in resource-constrained environments like mobile devices or edge hardware.

I’ll define the run_distillation_simulation function where the rest of the models are optimized based on the MSE loss against the standard MHA model’s performance in Step 4:

1import torch
2
3# kd for the attention approx models
4EPOCHS = 3000
5LEARNING_RATE = 1e-5
6DISTILLATION_TEMP = 3.0 # temperature for distillation
7ALPHA = 0.5
8
9def run_distillation_simulation(
10        model: nn.Module,
11        Q: torch.Tensor,
12        K: torch.Tensor,
13        V: torch.Tensor,
14        Y_true: torch.Tensor,
15        target_model_name: str,
16        stable_target_logits: torch.Tensor,
17        num_steps: int = EPOCHS,
18        lr: float = LEARNING_RATE):
19
20    print(f"\n... training {target_model_name} (runs distillation) ...")
21
22    # define optimizer and loss
23    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
24
25    # loss functions
26    ce_loss = nn.CrossEntropyLoss(ignore_index=TOKENIZER.pad_token_id)
27
28    # set training mode
29    model.train()
30
31    # reshape target for cross-entropy loss (hard labels)
32    Y_target = Y_true[:, 1:].flatten()
33
34    # reshape stable_target_logits (teacher's output)
35    stable_target_for_loss = stable_target_logits.view(-1, stable_target_logits.shape[-1]).detach()
36
37    # early stopper (uses increased patience)
38    early_stopper = EarlyStopper(is_distillation=True)
39
40    # a utility function for kl-divergence on soft targets (kd loss)
41    def distillation_loss(student_logits, teacher_logits, temperature, alpha, true_labels):
42        # soft targets
43        student_log_softmax = f.log_softmax(student_logits / temperature, dim=-1)
44        teacher_softmax = f.softmax(teacher_logits / temperature, dim=-1)
45
46        # kl divergence loss
47        kd_loss = f.kl_div(student_log_softmax, teacher_softmax, reduction='batchmean') * (temperature ** 2)
48
49        # hard targets - standard cross entropy loss on true labels
50        hard_loss = ce_loss(student_logits, true_labels)
51
52        # combined loss
53        return alpha * hard_loss + (1.0 - alpha) * kd_loss
54
55    for step in range(num_steps):
56        optimizer.zero_grad()
57        output_logits = model(Q, K, V, Y_true)
58        output_for_loss = output_logits.view(-1, output_logits.shape[-1])   
59
60        # combined distillation loss (kl + ce)
61        loss = distillation_loss(
62            student_logits=output_for_loss, 
63            teacher_logits=stable_target_for_loss, 
64            temperature=DISTILLATION_TEMP, 
65            alpha=ALPHA, 
66            true_labels=Y_target
67        )
68
69        loss.backward()
70        optimizer.step()
71
72        # calculate the true language loss for reporting and early stopping
73        with torch.no_grad():
74            # use the standard ce loss for the true performance metric
75            true_loss = ce_loss(output_for_loss, Y_target)
76
77        if early_stopper.early_stop(true_loss.item()):
78            print(f"... early stopping triggered at step {step + 1}. no improvement in true ce loss for {early_stopper.patience} steps.")
79            break
80
81    model.eval()
82
83
84
85# loop the rest of the models
86for name, model in models.items():
87    if name != '1. Standard MHA':
88        run_distillation_simulation(
89            model=model,
90            Q=Q_base,
91            K=K_base,
92            V=V_base,
93            Y_true=Y_true_base,
94            target_model_name=name, 
95            stable_target_logits=standard_output_target_logits,
96            num_steps=EPOCHS,
97        )
98
99    # ensure inputs are detached for timing and evaluation
100    Q_base.requires_grad_(False); K_base.requires_grad_(False); V_base.requires_grad_(False)
101    Q_test.requires_grad_(False); K_test.requires_grad_(False); V_test.requires_grad_(False)
102    Y_true_base.requires_grad_(False); Y_true_test.requires_grad_(False)
103

Step 6. Perform Inference

Lastly, I’ll define the perform_inference function where all models are evaluated on the three metrics:

  • Quality of approximation (based on the loss),

  • Generalization capabilities, and

  • Training speed.

1import time
2from typing import Dict
3
4def perform_inference() -> Dict[str, Dict[str, float]]: 
5    # recording the results
6    results: Dict[str, Dict[str, float]] = dict()
7
8    # define loss
9    final_criterion = nn.CrossEntropyLoss(ignore_index=TOKENIZER.pad_token_id, reduction='mean') 
10
11    # reshape true token targets for loss calc
12    Y_target_base_flat = Y_true_base[:, 1:].flatten()
13    Y_target_test_flat = Y_true_test[:, 1:].flatten()
14
15    # perform inference
16    for name, model in models.items():
17        NUM_RUNS = 10
18
19        # timing setup
20        if DEVICE.type == 'cuda':
21            torch.cuda.synchronize()
22            start_event = torch.cuda.Event(enable_timing=True)
23            end_event = torch.cuda.Event(enable_timing=True)
24            start_event.record()
25            with torch.no_grad():
26                for _ in range(NUM_RUNS):
27                    output_logits_base = model(Q_base, K_base, V_base, Y_true_base)
28            end_event.record() 
29            torch.cuda.synchronize()
30            avg_time = start_event.elapsed_time(end_event) / 1000 / NUM_RUNS # compute avg. time
31        else:
32            # fallback to cpu
33            start_time = time.time()
34            with torch.no_grad():
35                for _ in range(NUM_RUNS):
36                     output_logits_base = model(Q_base, K_base, V_base, Y_true_base)
37            avg_time = (time.time() - start_time) / NUM_RUNS
38
39
40        # analysis 1 - accuracy vs. sparsity / rank (approximation quality)
41        output_logits_base_flat = output_logits_base.view(-1, output_logits_base.shape[-1])
42        standard_output_target_logits_flat = standard_output_target_logits.view(-1, standard_output_target_logits.shape[-1])
43
44        if name == '1. Standard MHA':
45            # standard model's loss against true labels (ce loss)
46            mse_accuracy = final_criterion(output_logits_base_flat, Y_target_base_flat).item() 
47        else:
48            # approximation model's loss (mse between logits of student and teacher)
49            mse_accuracy = f.mse_loss(output_logits_base_flat, standard_output_target_logits_flat).item() 
50
51
52        # analysis 2 - generalization (approximation on n_test)
53        try:
54            with torch.no_grad():
55                output_logits_test = model(Q_test, K_test, V_test, Y_true_test)
56
57            output_logits_test_flat = output_logits_test.view(-1, output_logits_test.shape[-1])
58            standard_output_test_target_logits_flat = standard_output_test_target_logits.view(-1, standard_output_test_target_logits.shape[-1])
59
60            if name == '1. Standard MHA':
61                 mse_test = final_criterion(output_logits_test_flat, Y_target_test_flat).item()
62                 generalization_ratio = 1.0  # as target model 
63            else:
64                mse_test = f.mse_loss(output_logits_test_flat, standard_output_test_target_logits_flat).item()
65                if mse_accuracy > 1e-9: generalization_ratio = mse_test / mse_accuracy
66                else: generalization_ratio = 1.0
67
68        except RuntimeError as e:
69            # generalization is allowed to fail only if seq len exceeds n_max
70            if 'exceeds fixed projection size' in str(e) and name == '2. Linformer':
71                print(f"  [{name}]: generalization test failed (sequence length error: {e})")
72            else:
73                # for performer or longformer, a runtime error is a major failure
74                print(f"  [{name}]: generalization test failed (unexpected error: {e})")
75
76            mse_test = 1000.0 
77            generalization_ratio = 1000.0 
78
79        # store results
80        results[name] = {
81            "Practical Speed (s)": avg_time,
82            "Approximation Quality (Metric)": mse_accuracy, 
83            "Generalization Factor (Test Metric / Base Metric)": generalization_ratio
84        }
85
86    return results
87
88
89# perform inference
90results = perform_inference()
91

Results - Analyzing the Tradeoff

The Performer model is the overall best performer, achieving:

  • The best accuracy proxy value of 0.051980.

  • Fastest speed among approximation methods.

  • Efficiency in training, stopping training much earlier than Linformer while achieving much lower true CE loss.

Linformer has the fastest average time, but its generalization test failed due to a sequence length mismatch, making its results invalid for comparison.

LongFormer had the highest average time and the worst generalization ratio. Standard MHA showed the weakest performance in both training loss and the final metric.

Figure Z. Comparison of Seq2Seq Transformers (left: approximation quality, mid: generalization, right: training speed) (Created by Kuriko IWAI)

Kernel Labs | Kuriko IWAI | kuriko-iwai.com

Figure Z. Comparison of Seq2Seq Transformers (left: approximation quality, mid: generalization, right: training speed)

Training Results (Left in Figure Z)

The Performer and LongFormer models achieved significantly lower final training loss (True CE Loss) compared to Standard MHA and Linformer

  • Performer and LongFormer demonstrated superior performance in minimizing the true Cross-Entropy (CE) loss during distillation.

  • Standard MHA stopped very early and maintained a high CE loss (11.5).

  • Linformer ran for the longest duration (870 steps) but the final true CE loss remained high (6.8).

Generalization (Middle in Figure Z)

Compared to the standard generalization ratio of 1.0, Performer indicates strong generalization capabilities, achieving 1.02, while LongFormer underperform all with the low ratio of 0.68.

Training Speed (Right in Figure Z)

Performer achieved 0.2381 seconds, while LongFormer had the longest time of 0.3168 seconds.

Overall, the Performer model shows the best balance, achieving the lowest accuracy proxy metric with very fast inference time.

Conclusion

Attention approximation methods are powerful tools for improving the efficiency and performance of Transformer models by replacing the quadratic complexity with more efficient mechanisms.

In the experiment, we observed that the Performer model offered the best combination of accuracy and speed among the successfully generalized methods, confirming its effectiveness as an attention approximation technique.

The development of scalable attention mechanisms is not just an optimization but a fundamental enabler for the next generation of AI models.

By selectively reducing complexity, the techniques unlock the potential of Transformers for tasks requiring vast context, pushing the boundaries of what is computationally feasible.

Continue Your Learning

If you enjoyed this blog, these related entries will complete the picture:

Related Books for Further Understanding

These books cover the wide range of theories and practices; from fundamentals to PhD level.

Linear Algebra Done Right

Linear Algebra Done Right

Foundations of Machine Learning, second edition (Adaptive Computation and Machine Learning series)

Foundations of Machine Learning, second edition (Adaptive Computation and Machine Learning series)

Designing Machine Learning Systems: An Iterative Process for Production-Ready Applications

Designing Machine Learning Systems: An Iterative Process for Production-Ready Applications

Machine Learning Design Patterns: Solutions to Common Challenges in Data Preparation, Model Building, and MLOps

Machine Learning Design Patterns: Solutions to Common Challenges in Data Preparation, Model Building, and MLOps

Hands-On Generative AI with Transformers and Diffusion Models

Hands-On Generative AI with Transformers and Diffusion Models

Share What You Learned

Kuriko IWAI, "Implementing Attention Approximation: Transformer Efficiency & Trade-offsr" in Kernel Labs

https://kuriko-iwai.com/attention-approximation

Looking for Solutions?

Written by Kuriko IWAI. All images, unless otherwise noted, are by the author. All experimentations on this blog utilize synthetic or licensed data.