Chapter 10. Big Entropy and the Generalized Linear Model
In [ ]:
!pip install -q numpyro arviz
In [0]:
import os
import arviz as az
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import random, tree_map, vmap
import numpyro
import numpyro.distributions as dist
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
Code 10.1¶
In [1]:
p = {}
p["A"] = jnp.array([0, 0, 10, 0, 0])
p["B"] = jnp.array([0, 1, 8, 1, 0])
p["C"] = jnp.array([0, 2, 6, 2, 0])
p["D"] = jnp.array([1, 2, 4, 2, 1])
p["E"] = jnp.array([2, 2, 2, 2, 2])
Code 10.2¶
In [2]:
p_norm = tree_map(lambda q: q / jnp.sum(q), p)
Code 10.3¶
In [3]:
H = tree_map(lambda q: -jnp.sum(jnp.where(q == 0, 0, q * jnp.log(q))), p_norm)
H
Out[3]:
Code 10.4¶
In [4]:
ways = jnp.array([1, 90, 1260, 37800, 113400])
logwayspp = jnp.log(ways) / 10
Code 10.5¶
In [5]:
# build list of the candidate distributions
p = {}
p[1] = jnp.array([1 / 4, 1 / 4, 1 / 4, 1 / 4])
p[2] = jnp.array([2 / 6, 1 / 6, 1 / 6, 2 / 6])
p[3] = jnp.array([1 / 6, 2 / 6, 2 / 6, 1 / 6])
p[4] = jnp.array([1 / 8, 4 / 8, 2 / 8, 1 / 8])
# compute expected value of each
tree_map(lambda p: jnp.sum(p * jnp.array([0, 1, 1, 2])), p)
Out[5]:
Code 10.6¶
In [6]:
# compute entropy of each distribution
tree_map(lambda p: -jnp.sum(p * jnp.log(p)), p)
Out[6]:
Code 10.7¶
In [7]:
p = 0.7
A = jnp.array([(1 - p) ** 2, p * (1 - p), (1 - p) * p, p**2])
A
Out[7]:
Code 10.8¶
In [8]:
-jnp.sum(A * jnp.log(A))
Out[8]:
Code 10.9¶
In [9]:
def sim_p(i, G=1.4):
x123 = dist.Uniform().sample(random.PRNGKey(i), (3,))
x4 = (G * jnp.sum(x123, keepdims=True) - x123[1] - x123[2]) / (2 - G)
z = jnp.sum(jnp.concatenate([x123, x4]))
p = jnp.concatenate([x123, x4]) / z
return {"H": -jnp.sum(p * jnp.log(p)), "p": p}
Code 10.10¶
In [10]:
H = vmap(lambda i: sim_p(i, G=1.4))(jnp.arange(int(1e5)))
az.plot_kde(H["H"], bw=0.25)
plt.show()
Out[10]:
Code 10.11¶
In [11]:
entropies = H["H"]
distributions = H["p"]
Code 10.12¶
In [12]:
jnp.max(entropies)
Out[12]:
Code 10.13¶
In [13]:
distributions[jnp.argmax(entropies)]
Out[13]:
Comments
Comments powered by Disqus