Chapter 12. Monsters and Mixtures

In [0]:
import math
import os

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import set_matplotlib_formats

import jax.numpy as jnp
from jax import lax, random
from jax.scipy.special import expit

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.distributions.transforms import OrderedTransform
from numpyro.infer import ELBO, MCMC, NUTS, SVI, Predictive, init_to_value
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
numpyro.set_host_device_count(4)
numpyro.enable_x64()

Code 12.1

In [1]:
pbar = 0.5
theta = 5
x = jnp.linspace(0, 1, 101)
plt.plot(x, jnp.exp(dist.Beta(pbar * theta, (1 - pbar) * theta).log_prob(x)))
plt.gca().set(xlabel="probability", ylabel="Density")
plt.show()
Out[1]:

Code 12.2

In [2]:
UCBadmit = pd.read_csv("../data/UCBadmit.csv", sep=";")
d = UCBadmit
d["gid"] = (d["applicant.gender"] != "male").astype(int)
dat = dict(A=d.admit.values, N=d.applications.values, gid=d.gid.values)


def model(gid, N, A=None):
    a = numpyro.sample("a", dist.Normal(0, 1.5).expand([2]))
    phi = numpyro.sample("phi", dist.Exponential(1))
    theta = numpyro.deterministic("theta", phi + 2)
    pbar = expit(a[gid])
    numpyro.sample("A", dist.BetaBinomial(pbar * theta, (1 - pbar) * theta, N), obs=A)


m12_1 = MCMC(NUTS(model), 500, 500, num_chains=4)
m12_1.run(random.PRNGKey(0), **dat)

Code 12.3

In [3]:
post = m12_1.get_samples()
post["theta"] = Predictive(m12_1.sampler.model, post)(random.PRNGKey(1), **dat)["theta"]
post["da"] = post["a"][:, 0] - post["a"][:, 1]
print_summary(post, 0.89, False)
Out[3]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]     -0.44      0.42     -0.44     -1.11      0.22   1835.21      1.00
      a[1]     -0.31      0.41     -0.31     -0.94      0.34   1805.35      1.00
        da     -0.13      0.59     -0.12     -1.09      0.80   1741.81      1.00
       phi      1.04      0.81      0.87      0.00      2.06   1760.76      1.00
     theta      3.04      0.81      2.87      2.00      4.06   1760.76      1.00

Code 12.4

In [4]:
gid = 1
# draw posterior mean beta distribution
x = jnp.linspace(0, 1, 101)
pbar = jnp.mean(expit(post["a"][:, gid]))
theta = jnp.mean(post["theta"])
plt.plot(x, jnp.exp(dist.Beta(pbar * theta, (1 - pbar) * theta).log_prob(x)))
plt.gca().set(ylabel="Density", xlabel="probability admit", ylim=(0, 3))

# draw 50 beta distributions sampled from posterior
for i in range(50):
    p = expit(post["a"][i, gid])
    theta = post["theta"][i]
    plt.plot(
        x, jnp.exp(dist.Beta(p * theta, (1 - p) * theta).log_prob(x)), "k", alpha=0.2
    )
plt.title("distribution of female admission rates")
plt.show()
Out[4]:

Code 12.5

In [5]:
post = m12_1.get_samples()
admit_pred = Predictive(m12_1.sampler.model, post)(
    random.PRNGKey(1), gid=dat["gid"], N=dat["N"]
)["A"]
admit_rate = admit_pred / dat["N"]
plt.scatter(range(1, 13), dat["A"] / dat["N"])
plt.errorbar(
    range(1, 13),
    jnp.mean(admit_rate, 0),
    jnp.std(admit_rate, 0) / 2,
    fmt="o",
    c="k",
    mfc="none",
    ms=7,
    elinewidth=1,
)
plt.plot(range(1, 13), jnp.percentile(admit_rate, 5.5, 0), "k+")
plt.plot(range(1, 13), jnp.percentile(admit_rate, 94.5, 0), "k+")
plt.show()
Out[5]:

Code 12.6

In [6]:
Kline = pd.read_csv("../data/Kline.csv", sep=";")
d = Kline
d["P"] = d.population.apply(math.log).pipe(lambda x: (x - x.mean()) / x.std())
d["contact_id"] = (d.contact == "high").astype(int)

dat2 = dict(T=d.total_tools.values, P=d.population.values, cid=d.contact_id.values)


def model(cid, P, T):
    a = numpyro.sample("a", dist.Normal(1, 1).expand([2]))
    b = numpyro.sample("b", dist.Exponential(1).expand([2]))
    g = numpyro.sample("g", dist.Exponential(1))
    phi = numpyro.sample("phi", dist.Exponential(1))
    lambda_ = jnp.exp(a[cid]) * jnp.power(P, b[cid]) / g
    numpyro.sample("T", dist.GammaPoisson(lambda_ / phi, 1 / phi), obs=T)


m12_2 = MCMC(NUTS(model), 500, 500, num_chains=4)
m12_2.run(random.PRNGKey(0), **dat2)

Code 12.7

In [7]:
# define parameters
prob_drink = 0.2  # 20% of days
rate_work = 1  # average 1 manuscript per day

# sample one year of production
N = 365

with numpyro.handlers.seed(rng_seed=365):
    # simulate days monks drink
    drink = numpyro.sample("drink", dist.Binomial(1, prob_drink).expand([N]))

    # simulate manuscripts completed
    y = (1 - drink) * numpyro.sample("work", dist.Poisson(rate_work).expand([N]))

Code 12.8

In [8]:
plt.hist(y, color="k", bins=jnp.arange(-0.5, 6), rwidth=0.1)
plt.gca().set(xlabel="manuscripts completed")
zeros_drink = jnp.sum(drink)
zeros_work = jnp.sum((y == 0) & (drink == 0))
zeros_total = jnp.sum(y == 0)
plt.plot([0, 0], [zeros_work, zeros_total], "royalblue", lw=8)
plt.show()
Out[8]:

Code 12.9

In [9]:
def model(y):
    ap = numpyro.sample("ap", dist.Normal(-1.5, 1))
    al = numpyro.sample("al", dist.Normal(1, 0.5))
    p = expit(ap)
    lambda_ = jnp.exp(al)
    numpyro.sample("y", dist.ZeroInflatedPoisson(p, lambda_), obs=y)


m12_3 = MCMC(NUTS(model), 500, 500, num_chains=4)
m12_3.run(random.PRNGKey(0), y=y)
m12_3.print_summary(0.89)
Out[9]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
        al     -0.07      0.08     -0.07     -0.20      0.07    537.05      1.00
        ap     -1.82      0.53     -1.73     -2.54     -1.01    517.88      1.00

Number of divergences: 0

Code 12.10

In [10]:
post = m12_3.get_samples()
print(jnp.mean(expit(post["ap"])))  # probability drink
print(jnp.mean(jnp.exp(post["al"])))  # rate finish manuscripts, when not drinking
Out[10]:
0.15079226386038988
0.936732765207624

Code 12.11

In [11]:
def model(y):
    ap = numpyro.sample("ap", dist.Normal(-1.5, 1))
    al = numpyro.sample("al", dist.Normal(1, 0.5))
    p = expit(ap)
    lambda_ = jnp.exp(al)
    log_prob = jnp.log1p(-p) + dist.Poisson(lambda_).log_prob(y)
    numpyro.factor("y|y>0", log_prob[y > 0])
    numpyro.factor("y|y==0", jnp.logaddexp(jnp.log(p), log_prob[y == 0]))


m12_3_alt = MCMC(NUTS(model), 500, 500, num_chains=4)
m12_3_alt.run(random.PRNGKey(0), y=y)
m12_3_alt.print_summary(0.89)
Out[11]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
        al     -0.07      0.09     -0.06     -0.20      0.08    673.65      1.01
        ap     -1.82      0.51     -1.73     -2.54     -1.07    613.02      1.00

Number of divergences: 0

Code 12.12

In [12]:
Trolley = pd.read_csv("../data/Trolley.csv", sep=";")
d = Trolley

Code 12.13

In [13]:
plt.hist(d.response, bins=jnp.arange(0.5, 8), rwidth=0.1)
plt.gca().set(xlim=(0.7, 7.3), xlabel="response")
plt.show()
Out[13]:

Code 12.14

In [14]:
# discrete proportion of each response value
pr_k = d.response.value_counts().sort_index().values / d.shape[0]

# cumsum converts to cumulative proportions
cum_pr_k = jnp.cumsum(pr_k, -1)

# plot
plt.plot(range(1, 8), cum_pr_k, "--o")
plt.gca().set(xlabel="response", ylabel="cumulative proportion", ylim=(-0.1, 1.1))
plt.show()
Out[14]:

Code 12.15

In [15]:
logit = lambda x: jnp.log(x / (1 - x))  # convenience function
lco = logit(cum_pr_k)
lco
Out[15]:
DeviceArray([-1.91609116, -1.26660559, -0.718634  ,  0.24778573,
              0.88986365,  1.76938091,         inf], dtype=float64)

Code 12.16

In [16]:
def model(R):
    cutpoints = numpyro.sample(
        "cutpoints",
        dist.TransformedDistribution(
            dist.Normal(0, 1.5).expand([6]), OrderedTransform()
        ),
    )
    numpyro.sample("R", dist.OrderedLogistic(0, cutpoints), obs=R)


m12_4 = MCMC(NUTS(model), 500, 500, num_chains=4)
m12_4.run(random.PRNGKey(0), R=d.response.values - 1)

Note: With single-precision (x32) computations, MCMC chains might get stuck when the initial values are badly generated. Changing the random seed can solve the issue but it is better to enable x64 mode at the beginning of our program (see numpyro.enable_x64() in the first cell).

Code 12.17

In [17]:
def model(response):
    cutpoints = numpyro.sample(
        "cutpoints",
        dist.TransformedDistribution(
            dist.Normal(0, 1.5).expand([6]), OrderedTransform()
        ),
    )
    numpyro.sample("response", dist.OrderedLogistic(0, cutpoints), obs=response)


m12_4q = AutoLaplaceApproximation(
    model,
    init_strategy=init_to_value(
        values={"cutpoints": jnp.array([-2, -1, 0, 1, 2, 2.5])}
    ),
)
svi = SVI(model, m12_4q, optim.Adam(0.3), ELBO(), response=d.response.values - 1)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p12_4q = svi.get_params(state)

Code 12.18

In [18]:
m12_4.print_summary(0.89)
Out[18]:
                  mean       std    median      5.5%     94.5%     n_eff     r_hat
cutpoints[0]     -1.92      0.03     -1.92     -1.96     -1.87   1372.38      1.00
cutpoints[1]     -1.27      0.02     -1.27     -1.31     -1.23   1852.31      1.00
cutpoints[2]     -0.72      0.02     -0.72     -0.75     -0.68   1995.63      1.00
cutpoints[3]      0.25      0.02      0.25      0.22      0.28   2170.70      1.00
cutpoints[4]      0.89      0.02      0.89      0.86      0.93   2149.10      1.00
cutpoints[5]      1.77      0.03      1.77      1.73      1.82   2155.08      1.00

Number of divergences: 0

Code 12.19

In [19]:
expit(jnp.mean(m12_4.get_samples()["cutpoints"], 0))
Out[19]:
DeviceArray([0.12832244, 0.21987312, 0.32772031, 0.56164761, 0.70886956,
             0.85443873], dtype=float64)

Code 12.20

In [20]:
coef = jnp.mean(m12_4.get_samples()["cutpoints"], 0)
pk = jnp.exp(dist.OrderedLogistic(0, coef).log_prob(jnp.arange(7)))
pk
Out[20]:
DeviceArray([0.12832244, 0.09155067, 0.1078472 , 0.2339273 , 0.14722195,
             0.14556917, 0.14556127], dtype=float64)

Code 12.21

In [21]:
jnp.sum(pk * jnp.arange(1, 8))
Out[21]:
DeviceArray(4.19912823, dtype=float64)

Code 12.22

In [22]:
coef = jnp.mean(m12_4.get_samples()["cutpoints"], 0) - 0.5
pk = jnp.exp(dist.OrderedLogistic(0, coef).log_prob(jnp.arange(7)))
pk
Out[22]:
DeviceArray([0.08197025, 0.0640196 , 0.08220823, 0.20909667, 0.1589639 ,
             0.18445802, 0.21928333], dtype=float64)

Code 12.23

In [23]:
jnp.sum(pk * jnp.arange(1, 8))
Out[23]:
DeviceArray(4.72957174, dtype=float64)

Code 12.24

In [24]:
dat = dict(
    R=d.response.values - 1, A=d.action.values, I=d.intention.values, C=d.contact.values
)


def model(A, I, C, R=None):
    bA = numpyro.sample("bA", dist.Normal(0, 0.5))
    bI = numpyro.sample("bI", dist.Normal(0, 0.5))
    bC = numpyro.sample("bC", dist.Normal(0, 0.5))
    bIA = numpyro.sample("bIA", dist.Normal(0, 0.5))
    bIC = numpyro.sample("bIC", dist.Normal(0, 0.5))
    cutpoints = numpyro.sample(
        "cutpoints",
        dist.TransformedDistribution(
            dist.Normal(0, 1.5).expand([6]), OrderedTransform()
        ),
    )
    BI = bI + bIA * A + bIC * C
    phi = numpyro.deterministic("phi", bA * A + bC * C + BI * I)
    numpyro.sample("R", dist.OrderedLogistic(phi, cutpoints), obs=R)


m12_5 = MCMC(NUTS(model), 500, 500, num_chains=4)
m12_5.run(random.PRNGKey(0), **dat)
m12_5.print_summary(0.89)
Out[24]:
                  mean       std    median      5.5%     94.5%     n_eff     r_hat
          bA     -0.47      0.05     -0.48     -0.56     -0.39    889.58      1.00
          bC     -0.34      0.07     -0.34     -0.46     -0.24    994.23      1.00
          bI     -0.29      0.06     -0.29     -0.39     -0.20    867.73      1.00
         bIA     -0.43      0.08     -0.44     -0.55     -0.28   1048.93      1.00
         bIC     -1.24      0.10     -1.23     -1.38     -1.08   1159.67      1.00
cutpoints[0]     -2.64      0.05     -2.63     -2.72     -2.55    858.78      1.00
cutpoints[1]     -1.94      0.05     -1.94     -2.01     -1.86    865.31      1.00
cutpoints[2]     -1.34      0.05     -1.35     -1.41     -1.27    879.29      1.00
cutpoints[3]     -0.31      0.04     -0.31     -0.38     -0.24    855.14      1.00
cutpoints[4]      0.36      0.04      0.36      0.29      0.43    942.08      1.00
cutpoints[5]      1.27      0.05      1.27      1.19      1.34   1009.99      1.00

Number of divergences: 0

Code 12.25

In [25]:
post = m12_5.get_samples(group_by_chain=True)
az.plot_forest(
    post, var_names=["bIC", "bIA", "bC", "bI", "bA"], combined=True, hdi_prob=0.89,
)
plt.gca().set(xlim=(-1.42, 0.02))
plt.show()
Out[25]:

Code 12.26

In [26]:
ax = plt.subplot(xlabel="intention", ylabel="probability", xlim=(0, 1), ylim=(0, 1))
fig = plt.gcf()
Out[26]:

Code 12.27

In [27]:
kA = 0  # value for action
kC = 0  # value for contact
kI = jnp.arange(2)  # values of intention to calculate over
pdat = dict(A=kA, C=kC, I=kI)
phi = Predictive(m12_5.sampler.model, m12_5.get_samples())(random.PRNGKey(1), **pdat)[
    "phi"
]

Code 12.28

In [28]:
post = m12_5.get_samples()
for s in range(50):
    pk = expit(post["cutpoints"][s] - phi[s][..., None])
    for i in range(6):
        ax.plot(kI, pk[:, i], c="k", alpha=0.2)
fig
Out[28]:

Code 12.29

In [29]:
kA = 0  # value for action
kC = 0  # value for contact
kI = jnp.arange(2)  # values of intention to calculate over
pdat = dict(A=kA, C=kC, I=kI)
s = (
    Predictive(m12_5.sampler.model, m12_5.get_samples())(random.PRNGKey(1), **pdat)["R"]
    + 1
)
plt.hist(s[:, 0], bins=jnp.arange(0.5, 8), rwidth=0.1)
plt.hist(s[:, 1], bins=jnp.arange(0.65, 8), rwidth=0.1)
plt.gca().set(xlabel="response")
plt.show()
Out[29]:

Code 12.30

In [30]:
Trolley = pd.read_csv("../data/Trolley.csv", sep=";")
d = Trolley
d.edu.unique()
Out[30]:
array(['Middle School', "Bachelor's Degree", 'Some College',
       "Master's Degree", 'High School Graduate', 'Graduate Degree',
       'Some High School', 'Elementary School'], dtype=object)

Code 12.31

In [31]:
edu_levels = [
    "Elementary School",
    "Middle School",
    "Some High School",
    "High School Graduate",
    "Some College",
    "Bachelor's Degree",
    "Master's Degree",
    "Graduate Degree",
]
cat_type = pd.api.types.CategoricalDtype(categories=edu_levels, ordered=True)
d["edu_new"] = d.edu.astype(cat_type).cat.codes

Code 12.32

In [32]:
delta = dist.Dirichlet(jnp.repeat(2, 7)).sample(random.PRNGKey(1805), (10,))
delta
Out[32]:
DeviceArray([[0.25509119, 0.03908582, 0.02062807, 0.07466964, 0.01564467,
              0.06088717, 0.53399344],
             [0.06188233, 0.16273207, 0.33820846, 0.14339588, 0.10220403,
              0.09596682, 0.09561042],
             [0.10517178, 0.26792887, 0.17099041, 0.0463058 , 0.09862729,
              0.19184056, 0.11913529],
             [0.13225867, 0.17854144, 0.27357335, 0.09591914, 0.20810338,
              0.08527108, 0.02633294],
             [0.06846251, 0.17774259, 0.13601505, 0.13269377, 0.23953351,
              0.01396916, 0.23158341],
             [0.10417249, 0.19923656, 0.10265471, 0.10296115, 0.30281302,
              0.08507872, 0.10308334],
             [0.17357477, 0.08437654, 0.29003704, 0.0621773 , 0.06377041,
              0.08843999, 0.23762396],
             [0.09415017, 0.13096043, 0.09720853, 0.02269312, 0.01563354,
              0.41263503, 0.22671918],
             [0.03496056, 0.02213823, 0.04275664, 0.17735326, 0.1310294 ,
              0.32329965, 0.26846226],
             [0.02851972, 0.03833038, 0.04343796, 0.3296052 , 0.24054727,
              0.29424004, 0.02531943]], dtype=float64)

Code 12.33

In [33]:
h = 3
plt.subplot(xlim=(0.9, 7.1), ylim=(-0.01, 0.41), xlabel="index", ylabel="probability")
for i in range(delta.shape[0]):
    if i + 1 == h:
        plt.plot(range(1, 8), delta[i], "ko-", ms=8, lw=4)
    else:
        plt.plot(range(1, 8), delta[i], "ko-", mfc="w", ms=8, lw=1.5, alpha=0.7)
Out[33]:

Code 12.34

In [34]:
dat = dict(
    R=d.response.values - 1,
    action=d.action.values,
    intention=d.intention.values,
    contact=d.contact.values,
    E=d.edu_new.values,  # edu_new as an index
    alpha=jnp.repeat(2, 7),
)  # delta prior


def model(action, intention, contact, E, alpha, R):
    bA = numpyro.sample("bA", dist.Normal(0, 1))
    bI = numpyro.sample("bI", dist.Normal(0, 1))
    bC = numpyro.sample("bC", dist.Normal(0, 1))
    bE = numpyro.sample("bE", dist.Normal(0, 1))
    delta = numpyro.sample("delta", dist.Dirichlet(alpha))
    kappa = numpyro.sample(
        "kappa",
        dist.TransformedDistribution(
            dist.Normal(0, 1.5).expand([6]), OrderedTransform()
        ),
    )
    delta_j = jnp.pad(delta, (1, 0))
    delta_E = jnp.sum(jnp.where(jnp.arange(8) <= E[..., None], delta_j, 0), -1)
    phi = bE * delta_E + bA * action + bI * intention + bC * contact
    numpyro.sample("R", dist.OrderedLogistic(phi, kappa), obs=R)


m12_6 = MCMC(NUTS(model), 500, 500, num_chains=3)
m12_6.run(random.PRNGKey(0), **dat)

Code 12.35

In [35]:
m12_6.print_summary(0.89)
Out[35]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
        bA     -0.71      0.04     -0.71     -0.77     -0.64   1550.58      1.00
        bC     -0.96      0.05     -0.96     -1.05     -0.89   1304.82      1.00
        bE     -0.38      0.18     -0.36     -0.66     -0.09    468.01      1.01
        bI     -0.72      0.04     -0.72     -0.78     -0.67   1467.93      1.00
  delta[0]      0.26      0.15      0.24      0.01      0.48    611.81      1.00
  delta[1]      0.14      0.09      0.12      0.01      0.26   1535.10      1.00
  delta[2]      0.19      0.11      0.17      0.02      0.34   1443.81      1.00
  delta[3]      0.17      0.10      0.15      0.02      0.31   1104.17      1.00
  delta[4]      0.03      0.04      0.02      0.00      0.07    762.27      1.00
  delta[5]      0.09      0.06      0.07      0.01      0.17   1328.63      1.00
  delta[6]      0.12      0.07      0.11      0.01      0.22   1700.77      1.00
  kappa[0]     -3.14      0.17     -3.12     -3.41     -2.88    418.98      1.01
  kappa[1]     -2.45      0.17     -2.43     -2.71     -2.19    422.58      1.01
  kappa[2]     -1.87      0.17     -1.85     -2.12     -1.60    420.42      1.01
  kappa[3]     -0.85      0.17     -0.83     -1.07     -0.55    429.93      1.01
  kappa[4]     -0.18      0.17     -0.16     -0.43      0.09    433.61      1.01
  kappa[5]      0.73      0.17      0.75      0.48      0.99    421.72      1.01

Number of divergences: 0

Code 12.36

In [36]:
delta_labels = ["Elem", "MidSch", "SHS", "HSG", "SCol", "Bach", "Mast", "Grad"]
a12_6 = az.from_numpyro(
    m12_6, coords={"labels": delta_labels[:7]}, dims={"delta": ["labels"]}
)
az.plot_pair(a12_6, var_names="delta")
set_matplotlib_formats("png")
Out[36]: