2025-04-25
This article summarizes and connects three essential concepts in statistical inference and machine learning: KL Divergence, Cross Entropy and Maximum Likelihood Estimation, focusing on the discrete case.
Given two probability distributions, \(p\) and \(q\), the Kullback-Leibler Divergence of \(p\) given \(q\), denoted \(\mathop{\mathrm{\operatorname{D_{\text{KL}}}}}(p \mid \mid q)\), measures how likely \(q\) is to generate samples drawn from \(p\). This gives us a measure of how similar the two distributions are. In ML we often have observations in the form of a training dataset. When we learn a model on this data, we want to minimize the divergence between our learned model and the true distriubtion that generated the training data. The \(\mathop{\mathrm{\operatorname{D_{\text{KL}}}}}\) captures this divergence.
Assume that we want to measure the similarity between the probabilities of getting heads or tails for two different coins: a fair coin and a heavily biased coin:
\[\text{Fair coin} = \begin{cases} 0.5 & \text{Heads}\\ 0.5 & \text{Tails}\\ \end{cases}\]
\[\text{Biased coin} = \begin{cases} 0.7 & \text{Heads}\\ 0.3 & \text{Tails}\\ \end{cases}\]
A natural way of quantifying how different the fair and biased coins are is to measure the ratio between the two coins when their distributions are applied to a series of observations.
If we let \(O= \{\text{Tail}, \text{Head}, \text{Head}, \text{Tail}, \text{Head}\}\), the difference between the two coins given these observations is:
\[\frac{P(O \mid \text{Fair})}{P(O \mid \text{Biased})}\]
Because coin tosses are independent events, each terms is simply the product of the probability of each event, so \(P(O \mid \text{Fair})=0.5^3 \times 0.5^2\), and \(P(O \mid \text{Biased})=0.7^3 \times 0.3^2\). If we now generalize the probability of each event for both coins:
\[\text{Fair coin} = \begin{cases} p_1 & \text{Heads}\\ p_2 & \text{Tails}\\ \end{cases}\]
\[\text{Biased coin} = \begin{cases} q_1 & \text{Heads}\\ q_2 & \text{Tails}\\ \end{cases}\]
We get that our ratio is:
\[\frac{P(O \mid \text{Fair})}{P(O \mid \text{Biased})} = \frac{p_1^{N_{H}} \times p_2^{N_{T}}}{q_1^{N_{H}} \times q_2^{N_{T}}},\]
where \(N_H\) and \(N_T\) is the number of heads and tails, respectively. If we normalize by averaging the number of observations and applying a log (to not get numerical problems in the case where we have many observations), we get:
\[\begin{aligned} \frac{P(O \mid \text{Fair})}{P(O \mid \text{Biased})} & = \log(\frac{p_1^{N_{H}} \times p_2^{N_{T}}}{q_1^{N_{H}} \times q_2^{N_{T}}})^\frac{1}{N} \\ & = \frac{1}{N} \log(\frac{p_1^{N_{H}} \times p_2^{N_{T}}}{q_1^{N_{H}} \times q_2^{N_{T}}}) \\ & = \frac{1}{N} \log p_1^{N_{H}} + \frac{1}{N} \log p_2^{N_{T}} - \frac{1}{N} \log q_1^{N_{H}} - \frac{1}{N} \log q_2^{N_{T}} \\ & = \frac{N_H}{N} \log p_1 + \frac{N_T}{N} \log p_2 - \frac{N_H}{N} \log q_1 - \frac{N_T}{N} \log q_2 \end{aligned}\]
When \(N\) increases we get closer to the true distriubtion, which means that in the limit, \(\frac{N_H}{N}=p_1\) and \(\frac{N_T}{N}=p_2\): \[\begin{aligned} \frac{P(O \mid \text{Fair})}{P(O \mid \text{Biased})} & = p_1 \log p_1 + p_2 \log p_2 - p_1 \log q_1 - p_2 \log q_2 \\ & = p_1 \log p_1 - p_1 \log q_1 + p_2 \log p_2 - p_2 \log q_2 \\ & = p_1 \log(\frac{p_1}{q_1}) + p_2 \log(\frac{p_2}{q_2}) \\ & = \sum_i p_i \log(\frac{p_i}{q_i}) \\ & = \mathop{\mathrm{\operatorname{D_{\text{KL}}}}}(p \mid \mid q) \end{aligned}\]
Although the \(\mathop{\mathrm{\operatorname{D_{\text{KL}}}}}\) measures the difference between two distributions, it is not a proper distance metric. Importantly, we have that \(\mathop{\mathrm{\operatorname{D_{\text{KL}}}}}(p \mid \mid q) \neq \mathop{\mathrm{\operatorname{D_{\text{KL}}}}}(q \mid \mid p)\), so the divergence is not symmetric. Another important feature of the \(\mathop{\mathrm{\operatorname{D_{\text{KL}}}}}\) is that it is always positive. A \(\mathop{\mathrm{\operatorname{D_{\text{KL}}}}}\) of 0 indicates that the two distributions are identical.
If we want to learn a set of parameters \(\theta\) for a model \(q\) that approximates a distribution \(p\), we essentially want to find values for \(\theta\) that minimizes the divergence between \(q_\theta\) and \(p\).
This can be modelled using \(\mathop{\mathrm{\operatorname{D_{\text{KL}}}}}\). Given class labels \(y\) and an observation \(x_i\), we want to minimize \(\mathop{\mathrm{\operatorname{D_{\text{KL}}}}}(p(y \mid x_i) \mid \mid q_\theta (y \mid x_i))\). If we rewrite this, we get:
\[\begin{aligned} \mathop{\mathrm{\operatorname{D_{\text{KL}}}}}(p(y \mid x_i) \mid \mid q_\theta (y \mid x_i)) & = \sum_y p(y \mid x_i) \log \frac{p(y \mid x_i)}{q_\theta (y \mid x_i)} \\ & = \sum_y p(y \mid x_i) (\log p(y \mid x_i) - \log q_\theta(y \mid x_i)) \\ & = \sum_y p(y \mid x_i) \log p(y \mid x_i) - \sum_y p(y \mid x_i) q_\theta(y \mid x_i) \end{aligned}\]
Since the first term does not depend on \(\theta\), we only need to minimize the second term in order to minimize the overall divergence:
\[\mathop{\mathrm{argmin}}_\theta \mathop{\mathrm{\operatorname{D_{\text{KL}}}}}(p \mid \mid q_\theta) = \mathop{\mathrm{argmin}}_\theta - \sum_y p(y \mid x_i) \log q_\theta(y \mid x_i)\]
This is the formulation of the Cross Entropy between \(p\) and \(q_\theta\). The Cross Entropy of \(p\) and \(q\) is denoted as \(\operatorname{H}(p,q)\).
In statistical inference, Maximum Likelihood Estimation (MLE) is a general method for estimating the parameters of probability distributions given some data. This is achieved by maximizing the probability of the data, the so-called likelihood function.
Given a set of observations \(x_1, \dots, x_n\) generated from a distribution \(P\), the likelihood of the observations is defined as:
\[\mathcal{L} = \prod_{i=1}^n P(x_i)\]
However, if we do not know \(P\), we have to approximate it using another distribution \(Q\), parameterized by \(\theta\):
\[\mathcal{L(\theta)} = \prod_{i=1}^n Q_\theta(x_i)\]
The idea behind MLE is to find the optimal set of parameters \(\hat{\theta}\) such that the likelihood is maximized:
\[\hat{\theta} = \mathop{\mathrm{argmax}}_\theta \mathcal{L(\theta)} = \mathop{\mathrm{argmax}}_\theta \prod_{i=1}^n Q_\theta(x_i)\]
In log space, this is:
\[\hat{\theta} = \mathop{\mathrm{argmax}}_\theta \log \sum_{i=1}^n Q_\theta(x_i)\]
If we now average over the \(n\) samples:
\[\hat{\theta} = \mathop{\mathrm{argmax}}_\theta \frac{1}{n} \log \sum_{i=1}^n Q_\theta(x_i),\]
we can use the Law of large numbers to see that as \(n\) increases, the closer it will get to the true probability distribution,1 so we will get:
\[\hat{\theta} = \mathop{\mathrm{argmax}}_\theta \sum_{i=1}^n P(x_i) \log Q_\theta(x_i),\]
If we negate this and minimize instead of optimizing, we get the Cross Entropy:
\[\hat{\theta} = \mathop{\mathrm{argmin}}_\theta - \sum_{i=1}^n P(x_i) \log Q_\theta(x_i),\]
This means that minimizing the Cross Entropy is equivalent to maximizing the likelihood.
In summary, training a neural network using Cross Entropy is the same as minimizing the \(\mathop{\mathrm{\operatorname{D_{\text{KL}}}}}\) between your model and the unknown distribution that generated your training data. This process is a form of maximum likelihood estimation, as minimizing the Cross Entropy (for example through gradient descent) is the same as maximizing the probability your model assigns to each datapoint.