Regularizing LLMs with Kullback-Leibler Divergence
Master the exploration-exploitation trade-off in fine-tuning with KL divergence regularization
By Kuriko IWAI

Table of Contents
IntroductionWhat is Kullback-Leibler (KL) DivergenceIntroduction
Large language models (LLMs) have transformed machine learning, yet still hinges on balancing exploration by generating diverse outputs and exploitation by staying coherent to desired behavior.
Numerous solutions like reinforcement learning from human feedback (RLHF)1 or controlled decoding2 have developed to tackle this setback.
Kullback-Leibler (KL) Divergence can approximate those solutions to a reward maximization problem3 and find an optimal distribution which balances reward and proximity to the reference model4.
This helps models to avoid catastrophic forgetting and maintain consistency with the reference behavior, while still maximizing the objective reward.
In this article, I’ll explore its core mechanism and examine how it works on regularizing LLMs.
What is Kullback-Leibler (KL) Divergence
The Kullback-Leibler (KL) divergence, also known as relative entropy, is a non-symmetric measure of the difference between two probability distributions.
Below diagram illustrates how KL divergence between two distributions are shifted:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure A. Forward KL divergence (Created by Kuriko IWAI)
KL divergence measures information lost (called divergence) when an approximating (model) distribution Q (moving blue in Figure A) is used to approximate the true distribution P (red).
Although KL divergence is applicable to non-Gaussian, in Figure A, P and Q are both Gaussian such that P ∼ N(1.0, 4.0) and Q ∼ N(x, 1.0) for demonstration.
The value D_{KL}(P||Q) on the left top corner indicates the divergence between P and Q, whose value is minimized when Q ∼ N(1.0, 1.0)as Q overwrap P the most.
Formally, this calculation is generalized:
for continuous distributions where:
P is the true distribution,
Q is the approximating distribution, and
p(x) and q(x) is the Probability Density Function (PDF) of P and Q respectively.
For discrete distributions:
where
x represents a specific, discrete event drawn from the sample X, and
P(i) and Q(i) represents the Probability Mass Function (PMF) value for the event i that the distributions P and Q assigned, respectively.
In both cases, Eq. (1)’s aggregate the expected value of the log-likelihood where the expectation is taken with respect to the true distribution P.
For example, when the target mean of the approximation distribution Q is 3.0 (μ_Q = 3.0) in Figure A, Eq. (1.1) computes the KL divergence:
When the target mean is 5.0 (μ_Q = 5.0), Eq. (1.1) computes the KL divergence:
Choosing μ_Q = 3.0 instead of μ_Q = 5.0 results in a smaller KL divergence (2.8068 nats in Eq. (2.1) < 8.8068 nats in Eq. (2.2)), which correctly reflects that Q is a better approximation of the true distribution P.
This is also proved by the fact that μ_Q = 3.0 is closer to the mean of the true distribution (μ_P = 1.0) than μ_Q = 5.0.
It is important to note that KL divergence must be non-negative:
◼ Identity of Distributions
The divergence is zero only when the two distributions P and Q are identical such that:
because when these two are the same, the log-likelihood log(P(x)/Q(x)) is zero, making the total divergence zero.
◼ No Upper Bound
Eq. (3.1) indicates that there is no fixed upper limit to the KL divergence.
The divergence can grow infinitely large when the PDF (or PMF) of the approximating distribution is zero because it makes the log-likelihood log(P(x)/Q(x)) grow to infinity.
This indicates that the approximating distribution Q is now heavily penalized by assigning zero to the event x which the true distribution P says is possible (P(x) > 0).
◼ Asymmetric Traits - Forward vs. Reverse KL Divergence in Practice
KL Divergence is asymmetry, indicating that the D(P||Q) and D(Q||P) are different.
D(P||Q), as denoted in Eq. (1)’s, is called forward KL divergence, whereas D(Q||P) is called reverse KL divergence.
Similar to the forward KL divergence, the reverse KL divergence is generalized:
for continuous distributions and
for discrete distributions.
In both cases, Eq. (4)’s are weighted by the approximating distribution Q instead of the true distribution P.
The reverse KL divergence heavily penalizes the model Q when Q(x) is high and P(x) is low, meaning the model puts large probability mass where the true distribution P puts very little.
The difference between forward and reverse KL divergence are critical in ML:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure B. Forward vs. reverse KL divergence (Created by Kuriko IWAI)
The model Q attempts to cover as much P(x) as possible in forward KL divergence, while avoids placing any probability mass outside of P(x) in reverse KL divergence.
◼ The Divergence Family
A divergence family refers to the broad category of measures used in information theory and statistics to quantify the difference (distance) between two probability distributions.
They are essential tools in machine learning, particularly in areas like model comparison, clustering, and variational inference.
The below diagram illustrates the relationships between the main classes or families of these divergences5:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure C. Principal classes of distance and divergence (Source, Edited by Kuriko IWAI)
Figure C reports three major classes: Csiszár, Bregman, and Jensen divergence, in addition to conformal and projective divergences, and further describes how to interpret geometrically these Jensen/Bregman divergences from the chordal slope theorem5.
Among these classes, both forward and reverse KL divergences belong to the Bregman property.
The forward KL divergence is uniquely classified under two classes:
Csiszár f-divergence using the generator function f(x) = -log(x) and
Bregman divergence using the negative entropy function F(P) = \sum P log(P).
The reverse KL divergence, on the other hand, is classified under Csiszár f-divergence which uses the generator function f(x) = x logx.
KL Divergence in Optimization (SFT, RLHF, and VI)
Among many classes in the divergence family, the KL divergence is unique in measuring information lost when one probability distribution is used to approximate another, which makes it perfect for regularization and classification tasks.
In this context, the KL divergence serves as a loss function guided by an optimization algorithm like gradient descent to adjust hyperparameters of the approximating distribution Q.
Let us see some examples:
◼ Regularization in LLMs (SFT and RLHF)
In SFT regularization, the KL divergence constrains a supervised fine-tuned model Q relative to a base model P to retain general knowledge such that:
where
L is the total loss,
L_{Task} is the task loss generated by the standard loss function for the task that the model is trained on, and
β ∈ (0, 1] is the regularization coefficient (a hyperparameter) to control the strength of KL penalty.
High β gives more weight to the KL term, strongly forcing the new model Q to stick close to the reference P, prioritizing stability and knowledge retention over task performance.
Low β reduces the importance of the KL term, allowing the new model Q to diverge freely, while prioritizing the task performance.
In RLHF regularization, on the other hand, the KL divergence constrains a policy (model) Q relative to an SFT model P to prevent destructive policy updates:
where
L_{Reward} is the expected reward assigned to the model's generated response,
π_θ is the new, optimized policy (the LLM being trained),
π_ref is the reference policy (the SFT-trained model), and
β is a the KL penalty coefficient that controls the strength of the regularization.
The divergence term measures the difference between the new, optimized policy π_θ and the initial policy π_ref and penalizes too large shifts.
This prevents the model from generating low-quality text while still allowing it to improve based on the human-derived reward.
◼ Variational Inference
In Bayesian machine learning, variational inference is used to approximate complex, intractable posterior distributions where the KL divergence is the Evidence Lower Bound (ELBO) of the objective function.
The divergence measures the difference between the approximating distribution Q and the true posterior distribution P.
And during the training process, the algorithm selects the optimal Q that can minimize the divergence.
Models like Variational Autoencoders (VAEs) and Variational Transformers rely on this principle.
◼ Training Generative Models
The KL divergence measures the dissimilarity between the true data distribution and the model distribution (what the model actually generates).
Minimizing the divergence is equivalent to minimizing the cross-entropy loss when the target distribution P is fixed, making it a foundation for many classification and sequence modeling losses.
Experiment: Preventing Policy Collapse in SFT
Now, I’ll examine KL divergence as SFT regularization on a classification task and compare the model performance.
◼ SFT Regularization
I’ll train the model using the following loss functions:
sft SFT model without KL divergence regularization.
fkl: SFT model with forward KL regularization. Beta is set either [0.1, 0.3, 0.5, 0.7, 0.9].
rkl: SFT model with reverse KL regularization. Beta is set either [0.1, 0.3, 0.5, 0.7, 0.9].
Each loss is computed using its unique function:
1import torch
2import torch.nn.functional as F
3
4# task loss (sft)
5def calculate_cross_entropy_loss(logits, labels) -> torch.Tensor:
6 return F.cross_entropy(logits, labels)
7
8# loss w/ forward kl reg (fkl)
9def calculate_fkl_loss(new_logits, ref_logits) -> torch.Tensor:
10 p_ref = F.softmax(ref_logits.detach(), dim=-1)
11 log_p_new = F.log_softmax(new_logits, dim=-1)
12 kl_div = F.kl_div(log_p_new, p_ref, reduction='batchmean')
13 return kl_div
14
15# loss w/ reverse kl reg (rkl)
16def calculate_rkl_loss(new_logits, ref_logits) -> torch.Tensor:
17 p_new = F.softmax(new_logits, dim=-1)
18 log_p_ref = F.log_softmax(ref_logits.detach(), dim=-1)
19 log_p_new = F.log_softmax(new_logits, dim=-1)
20 rkl_term = torch.sum(p_new * (log_p_new - log_p_ref), dim=-1)
21 return torch.mean(rkl_term) # ave over batch size
22
◼ Base Model & Configuration
I’ll train BERT models using a learning rate of 1e-4 for 100 epochs, with the reference also being a BERT-based model:
1from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
3# base config
4MODEL_NAME = 'bert-base-uncased'
5MAX_LENGTH = 128
6BATCH_SIZE = 8
7LEARNING_RATE = 1e-4
8NUM_EPOCHS = 100
9
10
11# init models
12model_sft = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS).to(DEVICE)
13model_fkl = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS).to(DEVICE)
14model_rkl = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS).to(DEVICE)
15
16
17# reference model (fixed)
18model_ref = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS).to(DEVICE)
19for param in model_ref.parameters(): param.requires_grad = False # no training
20model_ref.eval()
21
◼ Datasets
I’ll use synthetic data for a classification task with three labels:
1from datasets import Dataset
2
3def create_mock_data() -> Dataset:
4 texts = [
5 "This movie was incredible and I loved every moment.",
6 "Terrible service, slow response, waste of money.",
7 "It was okay, not great but not bad either.",
8 "Absolutely brilliant, a must-watch experience.",
9 "I regret buying this. Zero stars if I could.",
10 "Highly recommended for a chill evening.",
11 "Mediocre at best, very disappointing performance.",
12 "A masterpiece of modern cinema. Truly outstanding.",
13 "Never again. Simply the worst."
14 ]
15 labels = [0, 1, 2, 0, 1, 0, 2, 0, 1] # 0: positive, 1: negative, 2: neutral
16 return Dataset.from_dict({"text": texts, "label": labels})
17
18raw_dataset = create_mock_data()
19
Then, split the tokenized data into training and test datasets:
1from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
3# tokenizer
4tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
5
6def tokenize(dataset: Dataset) -> AutoTokenizer:
7 return tokenizer(dataset['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH)
8
9tokenized_datasets = raw_dataset.map(tokenize, batched=True)
10tokenized_datasets = tokenized_datasets.train_test_split(test_size=0.2)
11tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
12
13train_dataloader = DataLoader(tokenized_datasets['train'], shuffle=True, batch_size=BATCH_SIZE)
14eval_dataloader = DataLoader(tokenized_datasets['test'], batch_size=BATCH_SIZE)
15
◼ Training
I’ll train each of the model using different loss functions:
1import torch
2
3optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
4loss_history = []
5
6model.train()
7for epoch in range(NUM_EPOCHS):
8 for step, batch in enumerate(dataloader):
9 batch = {k: v.to(DEVICE) for k, v in batch.items()}
10 labels = batch.pop("label")
11 optimizer.zero_grad()
12
13 # forward pass
14 outputs_new = model(**batch)
15 logits_new = outputs_new.logits
16
17 # task loss (L_CE)
18 ce_loss = calculate_cross_entropy_loss(logits_new, labels)
19 total_loss = ce_loss
20
21 # kl reg loss (D_KL)
22 if beta_kl > 0.0 and model_ref is not None:
23 with torch.no_grad():
24 # get logits from the frozen reference model
25 outputs_ref = model_ref(**batch)
26 logits_ref = outputs_ref.logits
27
28 # compute kl divergence by kl_type
29 match kl_type:
30 case 'FKL' | 'fkl':
31 kl_loss = calculate_fkl_loss(logits_new, logits_ref)
32 case 'RKL' | 'rkl':
33 kl_loss = calculate_rkl_loss(logits_new, logits_ref)
34 case _:
35 kl_loss = 0.0
36
37 # combine loss
38 total_loss = ce_loss + beta_kl * kl_loss
39
40 total_loss.backward()
41 optimizer.step()
42 loss_history.append(total_loss.item())
43
◼ Results & Analysis
▫ 1. Learning Capabilities
The KL losses are largely stable after the initial drops except for a couple of spikes suggesting occasional, transient large changes in hyperparameters:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure D. Comparison of the loss history. Left: SFT vs. Forward KL, right: SFT vs. Reverse KL
The SFT loss (black dotted lines in Figure D) starts high (1.10) and decreases steadily and significantly to very low values.
The FKL loss (left in Figure D) penalizes the new fine-tuned model for deviating too far from the reference model.
Low β (β = 0.1 (blue), β = 0.3 (orange)) generates lower loss than the others, allowing the model to deviate more easily from the reference model.
As β gets higher from β = 0.5 (green) to β = 0.9 (purple), the loss gets higher. This indicates stronger regularization to keep the fine-tuned model closer to the reference model.
The RKL loss is an alternative regularization that penalizes the reference model from deviating from the new fine-tuned model.
Although less common in practice, the loss is more stable with lower β (β = 0.1 (blue), β = 0.3 (orange)) than the FKL counterpart.
▫ 2. Logits Probability Distribution
Both KL regularizations mitigate the effect of the SFT’s policy collapse:

Kernel Labs | Kuriko IWAI | kuriko-iwai.com
Figure E. Logits probability distribution. Left: SFT vs. Forward KL, right: SFT vs. Reverse KL
The reference model (red dotted lines in Figure E) has a diffuse distribution across the three labels. The model is uncertain, preferring Negative (label = 1) slightly, but still giving significant weight to Positive (label = 0) and Neutral (label = 2).
The SFT model (blue dotted lines in Figure E) shows policy collapse where it shows a highly concentrated and sharp distribution toward Positive with 99.8% probability, while silencing all other possibilities.
Both forward and reverse KL are used as regularization terms to pull the model away from the extreme confidence of the SFT model back towards the reference distribution, preventing policy collapse while still incorporating the SFT learning.
The FKL model successfully re-introduced uncertainty, lowering the probabilities for Positive and retaining a higher combined probability for Negative and Neutral.
It created a smoother, more diffuse distribution that still favors Positive but keeps the policy closer to the reference model's uncertainty across all labels.
The RKL model also moderated the policy collapse but resulted in a generally sharper distribution compared to FKL. Positive confidence remained higher, and conversely, the combined probability for Negative and Neutral was lower than in the FKL case.
Conclusion
KL divergence is competitive in fine-tuning large language models, providing a robust method to manage the trade-off between policy improvement and stability.
In our experiment, we observed that both the Forward and Reverse KL regularization successfully stopped the severe policy collapse seen in pure SFT models by adding necessary uncertainty via the regularization terms to the loss function.
Moving forward, a deeper analysis of the optimal β parameter range is warranted to balance the sharp preference learned by SFT with the smooth prior knowledge retained by the reference model.
◼ References
[1]. Deep reinforcement learning from human preferences (Christiano et al., arXiv:1706.03741)
[2]. Controlled Decoding from Language Models (Mudgal et al., arXiv: 2310.17022)
[3]. Best-of-N through the Smoothing Lens: KL Divergence and Regret Analysis (Aminian et al., arXiv:2507.05913)
[4]. Asymptotics of Language Model Alignment (Yang et al., arXiv:2404.01730)
[5]. An Elementary Introduction to Information Geometry (Frank Nielsen)
Continue Your Learning
If you enjoyed this blog, these related entries will complete the picture:
Transformer Architecture: Self-Attention & MLOps Guide
The Definitive Guide to LLM Fine-Tuning: Objectivee, Mechanisms, and Hardware
DoLa Decoding: Mitigating LLM Hallucinations via Layer Contrast
Grouped Query Attention (GQA): Balancing LLM Quality and Speed
Related Books for Further Understanding
These books cover the wide range of theories and practices; from fundamentals to PhD level.

Linear Algebra Done Right

Foundations of Machine Learning, second edition (Adaptive Computation and Machine Learning series)

Designing Machine Learning Systems: An Iterative Process for Production-Ready Applications

Machine Learning Design Patterns: Solutions to Common Challenges in Data Preparation, Model Building, and MLOps

Hands-On Large Language Models: Language Understanding and Generation
Share What You Learned
Kuriko IWAI, "Regularizing LLMs with Kullback-Leibler Divergence" in Kernel Labs
https://kuriko-iwai.com/kullback-leibler-divergence-for-llms
Looking for Solutions?
- Deploying ML Systems 👉 Book a briefing session
- Hiring an ML Engineer 👉 Drop an email
- Learn by Doing 👉 Enroll AI Engineering Masterclass
Written by Kuriko IWAI. All images, unless otherwise noted, are by the author. All experimentations on this blog utilize synthetic or licensed data.



