KL Divergence
Contents
43. KL Divergence#
Kullback–Leibler divergence
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import sympy as sp
43.1. Gaussians#
Generate Gaussian data
stats.norm.ppf(0.01)
-2.3263478740408408
stats.norm.ppf(0.99)
2.3263478740408408
x = np.linspace(
-10,
10,
10000
)
y = stats.norm.pdf(x)
y_a = stats.norm.pdf(x, 3)
fig, ax = plt.subplots(figsize=(10,6.18))
ax.plot(x, y, color="tab:red")
ax.plot(x, y_a, color="tab:blue")
ax.fill_between(x, y, color="tab:red")
ax.fill_between(x, y_a, color="tab:blue")
ax.set_xlim([-4,6])
ax.set_xlabel("$x$")
ax.set_title("Two Gaussians")
ax.annotate('$p(x)$', xy=(-1, 0.3), xytext=(-2, 0.3),
arrowprops=dict(color='tab:red', shrink=0.001, width=1, headwidth=4),
color="tab:red"
)
ax.annotate('$q(x)$', xy=(4, 0.3), xytext=(5, 0.3),
arrowprops=dict(color='tab:blue', shrink=0.001, width=1, headwidth=4),
color="tab:blue"
)
ax.text(-2.5, 0.39, r"$D_\mathrm{KL}(p\parallel q)\sim 4.5$")
ax.text(-2.5, 0.35, r"$D_\mathrm{KL}(q\parallel p)\sim 4.5$")
plt.savefig("assets/kl-divergence/two-gaussians.png")
stats.entropy(y, y_a)
4.499999999998729
stats.entropy(y_a, y)
4.499999999974058
Investigate \(\log(\frac{y}{y_a})\)
np.min(y/y_a)
8.423463754468646e-12
np.max(y/y_a)
961965785544776.6
fig, ax = plt.subplots( nrows=2, ncols=1, figsize=(10,2*6.18))
ax[0].plot(x, np.log(y/y_a), color="k")
ax[0].fill_between(x, np.log(y/y_a), color="k")
ax[0].set_xlim([-4,6])
ax[0].set_xlabel("$x$")
ax[0].set_title(r"Two Gaussians: $\log\left(p/q\right)$")
ax[1].plot(x, y*np.log(y/y_a), color="k")
ax[1].fill_between(x, y*np.log(y/y_a), color="k")
ax[1].set_xlim([-4,6])
ax[1].set_xlabel("$x$")
ax[1].set_title("Two Gaussians: $p \log(p/q)$")
ax[1].annotate(r'Area under curve $D_\mathrm{KL} \sim 4.5$', xy=(1, 0.5), xytext=(3, 1),
arrowprops=dict(color='k', shrink=0.001, width=1, headwidth=4),
color="k"
)
plt.savefig("assets/kl-divergence/integrants.png")
np.sum(
((x.max() - x.min())/len(x))*y*np.log(y/y_a)
)
4.49955
44. A Guassian mixture#
y_b = stats.norm.pdf(x)/2 + stats.norm.pdf(x, 3)/2
fig, ax = plt.subplots(figsize=(10,6.18))
ax.plot(x, y, color="tab:red")
ax.plot(x, y_b, color="tab:blue")
ax.fill_between(x, y, color="tab:red")
ax.fill_between(x, y_b, color="tab:blue")
ax.set_xlim([-4,7])
ax.set_xlabel("$x$")
ax.set_title("Guassian and Guassian Mixture")
ax.annotate('$p(x)$', xy=(-1, 0.3), xytext=(-2, 0.3),
arrowprops=dict(color='tab:red', shrink=0.001, width=1, headwidth=4),
color="tab:red"
)
ax.annotate('$q(x)$', xy=(4, 0.15), xytext=(5, 0.3),
arrowprops=dict(color='tab:blue', shrink=0.001, width=1, headwidth=4),
color="tab:blue"
)
ax.text(-2.5, 0.39, r"$D_{KL}(p\parallel q)\sim 0.53$")
ax.text(-2.5, 0.35, r"$D_{KL}(q\parallel p)\sim 1.72$")
plt.savefig("assets/kl-divergence/guassian-mixture.png")
stats.entropy(y, y_b)
0.526777306520738
stats.entropy(y_b, y)
1.723222693464332
fig, ax = plt.subplots( nrows=2, ncols=1, figsize=(10,2*6.18))
ax[0].plot(x, np.log(y/y_b), color="k")
ax[0].fill_between(x, np.log(y/y_b), color="k")
ax[0].set_xlim([-4,7])
ax[0].set_xlabel("$x$")
ax[0].set_title(r"Two Gaussians: $\log\left(p/q\right)$")
ax[1].plot(x, y*np.log(y/y_b), color="k")
ax[1].fill_between(x, y*np.log(y/y_b), color="k")
ax[1].set_xlim([-4,7])
ax[1].set_xlabel("$x$")
ax[1].set_title("Two Gaussians: $p \log(p/q)$")
ax[1].annotate(r'Area under curve $D_\mathrm{KL} \sim 0.53$', xy=(1, 0.2), xytext=(3, 0.2),
arrowprops=dict(color='k', shrink=0.001, width=1, headwidth=4),
color="k"
)
plt.savefig("assets/kl-divergence/guassian-mixture-integrants.png")
np.sum(
((x.max() - x.min())/len(x))*y*np.log(y/y_b)
)
0.5267246287907215
fig, ax = plt.subplots( nrows=2, ncols=1, figsize=(10,2*6.18))
ax[0].plot(x, np.log(y_b/y), color="k")
ax[0].fill_between(x, np.log(y_b/y), color="k")
ax[0].set_xlim([-4,7])
ax[0].set_xlabel("$x$")
ax[0].set_title(r"Two Gaussians: $\log\left(p/q\right)$")
ax[1].plot(x, y_b*np.log(y_b/y), color="k")
ax[1].fill_between(x, y_b*np.log(y_b/y), color="k")
ax[1].set_xlim([-4,7])
ax[1].set_xlabel("$x$")
ax[1].set_title("Two Gaussians: $p \log(p/q)$")
ax[1].annotate(r'Area under curve $D_\mathrm{KL} \sim 1.72$', xy=(2, 0.2), xytext=(-2, 0.2),
arrowprops=dict(color='k', shrink=0.001, width=1, headwidth=4),
color="k"
)
plt.savefig("assets/kl-divergence/guassian-mixture-integrants-d-q-p.png")
np.sum(
((x.max() - x.min())/len(x))*y_b*np.log(y_b/y)
)
1.7230503711932554
Move the mixture leftward
y_c = stats.norm.pdf(x)/2 + stats.norm.pdf(x, -3)/2
fig, ax = plt.subplots(figsize=(10,6.18))
ax.plot(x, y)
ax.plot(x, y_c)
ax.fill_between(x, y)
ax.fill_between(x, y_c)
ax.set_xlim([-6,3])
(-6.0, 3.0)
stats.entropy(y, y_c)
0.5267773065207384
stats.entropy(y_c, y)
1.723222693464331