Chapter 12. Monsters and Mixtures

In [ ]:
!pip install -q numpyro arviz
In [0]:
import math
import os

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import set_matplotlib_formats

import jax.numpy as jnp
from jax import lax, random
from jax.scipy.special import expit

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.distributions.transforms import OrderedTransform
from numpyro.infer import MCMC, NUTS, Predictive, SVI, Trace_ELBO, init_to_value
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
numpyro.enable_x64()
numpyro.set_host_device_count(4)

Code 12.1

In [1]:
pbar = 0.5
theta = 5
x = jnp.linspace(0, 1, 101)
plt.plot(x, jnp.exp(dist.Beta(pbar * theta, (1 - pbar) * theta).log_prob(x)))
plt.gca().set(xlabel="probability", ylabel="Density")
plt.show()
Out[1]:
No description has been provided for this image

Code 12.2

In [2]:
UCBadmit = pd.read_csv("../data/UCBadmit.csv", sep=";")
d = UCBadmit
d["gid"] = (d["applicant.gender"] != "male").astype(int)
dat = dict(A=d.admit.values, N=d.applications.values, gid=d.gid.values)


def model(gid, N, A=None):
    a = numpyro.sample("a", dist.Normal(0, 1.5).expand([2]))
    phi = numpyro.sample("phi", dist.Exponential(1))
    theta = numpyro.deterministic("theta", phi + 2)
    pbar = expit(a[gid])
    numpyro.sample("A", dist.BetaBinomial(pbar * theta, (1 - pbar) * theta, N), obs=A)


m12_1 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m12_1.run(random.PRNGKey(0), **dat)
Out[2]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[2]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[2]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[2]:
  0%|          | 0/1000 [00:00<?, ?it/s]

Code 12.3

In [3]:
post = m12_1.get_samples()
post["theta"] = Predictive(m12_1.sampler.model, post)(random.PRNGKey(1), **dat)["theta"]
post["da"] = post["a"][:, 0] - post["a"][:, 1]
print_summary(post, 0.89, False)
Out[3]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]     -0.44      0.41     -0.43     -1.08      0.23   1360.42      1.00
      a[1]     -0.34      0.40     -0.33     -0.94      0.33   1684.60      1.00
        da     -0.10      0.57     -0.10     -0.96      0.83   1592.54      1.00
       phi      1.03      0.82      0.85      0.00      2.13   1718.89      1.00
     theta      3.03      0.82      2.85      2.00      4.13   1718.89      1.00

Code 12.4

In [4]:
gid = 1
# draw posterior mean beta distribution
x = jnp.linspace(0, 1, 101)
pbar = jnp.mean(expit(post["a"][:, gid]))
theta = jnp.mean(post["theta"])
plt.plot(x, jnp.exp(dist.Beta(pbar * theta, (1 - pbar) * theta).log_prob(x)))
plt.gca().set(ylabel="Density", xlabel="probability admit", ylim=(0, 3))

# draw 50 beta distributions sampled from posterior
for i in range(50):
    p = expit(post["a"][i, gid])
    theta = post["theta"][i]
    plt.plot(
        x, jnp.exp(dist.Beta(p * theta, (1 - p) * theta).log_prob(x)), "k", alpha=0.2
    )
plt.title("distribution of female admission rates")
plt.show()
Out[4]:
No description has been provided for this image

Code 12.5

In [5]:
post = m12_1.get_samples()
admit_pred = Predictive(m12_1.sampler.model, post)(
    random.PRNGKey(1), gid=dat["gid"], N=dat["N"]
)["A"]
admit_rate = admit_pred / dat["N"]
plt.scatter(range(1, 13), dat["A"] / dat["N"])
plt.errorbar(
    range(1, 13),
    jnp.mean(admit_rate, 0),
    jnp.std(admit_rate, 0) / 2,
    fmt="o",
    c="k",
    mfc="none",
    ms=7,
    elinewidth=1,
)
plt.plot(range(1, 13), jnp.percentile(admit_rate, 5.5, 0), "k+")
plt.plot(range(1, 13), jnp.percentile(admit_rate, 94.5, 0), "k+")
plt.show()
Out[5]:
No description has been provided for this image

Code 12.6

In [6]:
Kline = pd.read_csv("../data/Kline.csv", sep=";")
d = Kline
d["P"] = d.population.apply(math.log).pipe(lambda x: (x - x.mean()) / x.std())
d["contact_id"] = (d.contact == "high").astype(int)

dat2 = dict(T=d.total_tools.values, P=d.population.values, cid=d.contact_id.values)


def model(cid, P, T):
    a = numpyro.sample("a", dist.Normal(1, 1).expand([2]))
    b = numpyro.sample("b", dist.Exponential(1).expand([2]))
    g = numpyro.sample("g", dist.Exponential(1))
    phi = numpyro.sample("phi", dist.Exponential(1))
    lambda_ = jnp.exp(a[cid]) * jnp.power(P, b[cid]) / g
    numpyro.sample("T", dist.GammaPoisson(lambda_ / phi, 1 / phi), obs=T)


m12_2 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m12_2.run(random.PRNGKey(0), **dat2)
Out[6]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[6]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[6]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[6]:
  0%|          | 0/1000 [00:00<?, ?it/s]

Note: The results here might be different from the book. There seems to have a bug in R's dgampois implementation back to the time the book is printed. According to this issue, the bug has been fixed upstream.

Code 12.7

In [7]:
# define parameters
prob_drink = 0.2  # 20% of days
rate_work = 1  # average 1 manuscript per day

# sample one year of production
N = 365

with numpyro.handlers.seed(rng_seed=365):
    # simulate days monks drink
    drink = numpyro.sample("drink", dist.Binomial(1, prob_drink).expand([N]))

    # simulate manuscripts completed
    y = (1 - drink) * numpyro.sample("work", dist.Poisson(rate_work).expand([N]))

Code 12.8

In [8]:
plt.hist(np.asarray(y), color="k", bins=jnp.arange(-0.5, 6), rwidth=0.1)
plt.gca().set(xlabel="manuscripts completed")
zeros_drink = jnp.sum(drink)
zeros_work = jnp.sum((y == 0) & (drink == 0))
zeros_total = jnp.sum(y == 0)
plt.plot([0, 0], [zeros_work, zeros_total], "royalblue", lw=8)
plt.show()
Out[8]:
No description has been provided for this image

Code 12.9

In [9]:
def model(y):
    ap = numpyro.sample("ap", dist.Normal(-1.5, 1))
    al = numpyro.sample("al", dist.Normal(1, 0.5))
    p = expit(ap)
    lambda_ = jnp.exp(al)
    numpyro.sample("y", dist.ZeroInflatedPoisson(p, lambda_), obs=y)


m12_3 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m12_3.run(random.PRNGKey(0), y=y)
m12_3.print_summary(0.89)
Out[9]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[9]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[9]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[9]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[9]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
        al     -0.06      0.09     -0.06     -0.19      0.08    781.22      1.00
        ap     -1.79      0.48     -1.71     -2.52     -1.07    756.19      1.00

Number of divergences: 0

Code 12.10

In [10]:
post = m12_3.get_samples()
print(jnp.mean(expit(post["ap"])))  # probability drink
print(jnp.mean(jnp.exp(post["al"])))  # rate finish manuscripts, when not drinking
Out[10]:
0.15255071713826623
0.9416965716188807

Code 12.11

In [11]:
def model(y):
    ap = numpyro.sample("ap", dist.Normal(-1.5, 1))
    al = numpyro.sample("al", dist.Normal(1, 0.5))
    p = expit(ap)
    lambda_ = jnp.exp(al)
    log_prob = jnp.log1p(-p) + dist.Poisson(lambda_).log_prob(y)
    numpyro.factor("y|y>0", log_prob[y > 0])
    numpyro.factor("y|y==0", jnp.logaddexp(jnp.log(p), log_prob[y == 0]))


m12_3_alt = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m12_3_alt.run(random.PRNGKey(0), y=np.asarray(y))
m12_3_alt.print_summary(0.89)
Out[11]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[11]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[11]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[11]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[11]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
        al     -0.06      0.09     -0.06     -0.21      0.07    713.86      1.00
        ap     -1.79      0.51     -1.69     -2.50     -1.00    637.00      1.01

Number of divergences: 0

Note: JAX 0.2 requires that array boolean indices must be concrete. So to make log_prob[y > 0] work, we need to use a concrete NumPy ndarray y (obtained by np.asarray(y)) instead of JAX's DeviceArray.

Code 12.12

In [12]:
Trolley = pd.read_csv("../data/Trolley.csv", sep=";")
d = Trolley

Code 12.13

In [13]:
plt.hist(d.response, bins=jnp.arange(0.5, 8), rwidth=0.1)
plt.gca().set(xlim=(0.7, 7.3), xlabel="response")
plt.show()
Out[13]:
No description has been provided for this image

Code 12.14

In [14]:
# discrete proportion of each response value
pr_k = d.response.value_counts().sort_index().values / d.shape[0]

# cumsum converts to cumulative proportions
cum_pr_k = jnp.cumsum(pr_k, -1)

# plot
plt.plot(range(1, 8), cum_pr_k, "--o")
plt.gca().set(xlabel="response", ylabel="cumulative proportion", ylim=(-0.1, 1.1))
plt.show()
Out[14]:
No description has been provided for this image

Code 12.15

In [15]:
logit = lambda x: jnp.log(x / (1 - x))  # convenience function
lco = logit(cum_pr_k)
lco
Out[15]:
DeviceArray([-1.91609116, -1.26660559, -0.718634  ,  0.24778573,
              0.88986365,  1.76938091,         inf], dtype=float64)

Code 12.16

In [16]:
def model(R):
    cutpoints = numpyro.sample(
        "cutpoints",
        dist.TransformedDistribution(
            dist.Normal(0, 1.5).expand([6]), OrderedTransform()
        ),
    )
    numpyro.sample("R", dist.OrderedLogistic(0, cutpoints), obs=R)


m12_4 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m12_4.run(random.PRNGKey(0), R=d.response.values - 1)
Out[16]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[16]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[16]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[16]:
  0%|          | 0/1000 [00:00<?, ?it/s]

Note: With single-precision (x32) computations, MCMC chains might get stuck when the initial values are badly generated. Changing the random seed can solve the issue but it is better to enable x64 mode at the beginning of our program (see numpyro.enable_x64() in the first cell).

Code 12.17

In [17]:
def model(response):
    cutpoints = numpyro.sample(
        "cutpoints",
        dist.TransformedDistribution(
            dist.Normal(0, 1.5).expand([6]), OrderedTransform()
        ),
    )
    numpyro.sample("response", dist.OrderedLogistic(0, cutpoints), obs=response)


m12_4q = AutoLaplaceApproximation(
    model,
    init_loc_fn=init_to_value(values={"cutpoints": jnp.array([-2, -1, 0, 1, 2, 2.5])}),
)
svi = SVI(model, m12_4q, optim.Adam(0.3), Trace_ELBO(), response=d.response.values - 1)
svi_result = svi.run(random.PRNGKey(0), 1000)
p12_4q = svi_result.params
Out[17]:
100%|██████████| 1000/1000 [00:01<00:00, 870.82it/s, init loss: 20457.6601, avg. loss [951-1000]: 18930.1558]

Code 12.18

In [18]:
m12_4.print_summary(0.89)
Out[18]:
                  mean       std    median      5.5%     94.5%     n_eff     r_hat
cutpoints[0]     -1.92      0.03     -1.92     -1.96     -1.86   1221.83      1.00
cutpoints[1]     -1.27      0.02     -1.27     -1.30     -1.23   1764.20      1.00
cutpoints[2]     -0.72      0.02     -0.72     -0.75     -0.68   2133.58      1.00
cutpoints[3]      0.25      0.02      0.25      0.22      0.28   2374.45      1.00
cutpoints[4]      0.89      0.02      0.89      0.85      0.93   2265.39      1.00
cutpoints[5]      1.77      0.03      1.77      1.72      1.81   1952.07      1.00

Number of divergences: 0

Code 12.19

In [19]:
expit(jnp.mean(m12_4.get_samples()["cutpoints"], 0))
Out[19]:
DeviceArray([0.12832982, 0.21992469, 0.32781562, 0.56182347, 0.7090021 ,
             0.8544513 ], dtype=float64)

Code 12.20

In [20]:
coef = jnp.mean(m12_4.get_samples()["cutpoints"], 0)
pk = jnp.exp(dist.OrderedLogistic(0, coef).log_prob(jnp.arange(7)))
pk
Out[20]:
DeviceArray([0.12832982, 0.09159486, 0.10789093, 0.23400785, 0.14717863,
             0.1454492 , 0.1455487 ], dtype=float64)

Code 12.21

In [21]:
jnp.sum(pk * jnp.arange(1, 8))
Out[21]:
DeviceArray(4.198653, dtype=float64)

Code 12.22

In [22]:
coef = jnp.mean(m12_4.get_samples()["cutpoints"], 0) - 0.5
pk = jnp.exp(dist.OrderedLogistic(0, coef).log_prob(jnp.arange(7)))
pk
Out[22]:
DeviceArray([0.08197522, 0.06405212, 0.08224693, 0.20919626, 0.15894274,
             0.1843207 , 0.21926602], dtype=float64)

Code 12.23

In [23]:
jnp.sum(pk * jnp.arange(1, 8))
Out[23]:
DeviceArray(4.72910539, dtype=float64)

Code 12.24

In [24]:
dat = dict(
    R=d.response.values - 1, A=d.action.values, I=d.intention.values, C=d.contact.values
)


def model(A, I, C, R=None):
    bA = numpyro.sample("bA", dist.Normal(0, 0.5))
    bI = numpyro.sample("bI", dist.Normal(0, 0.5))
    bC = numpyro.sample("bC", dist.Normal(0, 0.5))
    bIA = numpyro.sample("bIA", dist.Normal(0, 0.5))
    bIC = numpyro.sample("bIC", dist.Normal(0, 0.5))
    cutpoints = numpyro.sample(
        "cutpoints",
        dist.TransformedDistribution(
            dist.Normal(0, 1.5).expand([6]), OrderedTransform()
        ),
    )
    BI = bI + bIA * A + bIC * C
    phi = numpyro.deterministic("phi", bA * A + bC * C + BI * I)
    numpyro.sample("R", dist.OrderedLogistic(phi, cutpoints), obs=R)


m12_5 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m12_5.run(random.PRNGKey(0), **dat)
m12_5.print_summary(0.89)
Out[24]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[24]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[24]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[24]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[24]:
                  mean       std    median      5.5%     94.5%     n_eff     r_hat
          bA     -0.47      0.05     -0.47     -0.56     -0.39   1088.90      1.00
          bC     -0.35      0.07     -0.34     -0.46     -0.24   1056.64      1.00
          bI     -0.29      0.05     -0.29     -0.38     -0.21    845.90      1.00
         bIA     -0.43      0.08     -0.43     -0.55     -0.31   1034.40      1.00
         bIC     -1.23      0.10     -1.23     -1.39     -1.07   1031.78      1.00
cutpoints[0]     -2.63      0.05     -2.63     -2.70     -2.54    948.24      1.00
cutpoints[1]     -1.94      0.05     -1.94     -2.01     -1.87   1024.60      1.00
cutpoints[2]     -1.34      0.04     -1.34     -1.41     -1.27   1011.23      1.00
cutpoints[3]     -0.31      0.04     -0.31     -0.37     -0.24    947.12      1.00
cutpoints[4]      0.36      0.04      0.36      0.29      0.43    961.96      1.00
cutpoints[5]      1.27      0.05      1.27      1.19      1.34   1145.43      1.00

Number of divergences: 0

Code 12.25

In [25]:
post = m12_5.get_samples(group_by_chain=True)
az.plot_forest(
    post,
    var_names=["bIC", "bIA", "bC", "bI", "bA"],
    combined=True,
    hdi_prob=0.89,
)
plt.gca().set(xlim=(-1.42, 0.02))
plt.show()
Out[25]:
No description has been provided for this image

Code 12.26

In [26]:
ax = plt.subplot(xlabel="intention", ylabel="probability", xlim=(0, 1), ylim=(0, 1))
fig = plt.gcf()
Out[26]:
No description has been provided for this image

Code 12.27

In [27]:
kA = 0  # value for action
kC = 0  # value for contact
kI = jnp.arange(2)  # values of intention to calculate over
pdat = dict(A=kA, C=kC, I=kI)
post = m12_5.get_samples()
post.pop("phi")
phi = Predictive(m12_5.sampler.model, post)(random.PRNGKey(1), **pdat)[
    "phi"
]

Code 12.28

In [28]:
for s in range(50):
    pk = expit(post["cutpoints"][s] - phi[s][..., None])
    for i in range(6):
        ax.plot(kI, pk[:, i], c="k", alpha=0.2)
fig
Out[28]:
No description has been provided for this image

Code 12.29

In [29]:
kA = 0  # value for action
kC = 0  # value for contact
kI = jnp.arange(2)  # values of intention to calculate over
pdat = dict(A=kA, C=kC, I=kI)
s = (
    Predictive(m12_5.sampler.model, post)(random.PRNGKey(1), **pdat)["R"]
    + 1
)
plt.hist(s[:, 0], bins=jnp.arange(0.5, 8), rwidth=0.1)
plt.hist(s[:, 1], bins=jnp.arange(0.65, 8), rwidth=0.1)
plt.gca().set(xlabel="response")
plt.show()
Out[29]:
No description has been provided for this image

Code 12.30

In [30]:
Trolley = pd.read_csv("../data/Trolley.csv", sep=";")
d = Trolley
d.edu.unique()
Out[30]:
array(['Middle School', "Bachelor's Degree", 'Some College',
       "Master's Degree", 'High School Graduate', 'Graduate Degree',
       'Some High School', 'Elementary School'], dtype=object)

Code 12.31

In [31]:
edu_levels = [
    "Elementary School",
    "Middle School",
    "Some High School",
    "High School Graduate",
    "Some College",
    "Bachelor's Degree",
    "Master's Degree",
    "Graduate Degree",
]
cat_type = pd.api.types.CategoricalDtype(categories=edu_levels, ordered=True)
d["edu_new"] = d.edu.astype(cat_type).cat.codes

Code 12.32

In [32]:
delta = dist.Dirichlet(jnp.repeat(2, 7)).sample(random.PRNGKey(1805), (10,))
delta
Out[32]:
DeviceArray([[0.22049101, 0.1143812 , 0.14306989, 0.04105449, 0.20959843,
              0.12338147, 0.14802351],
             [0.09893053, 0.11666275, 0.07634449, 0.20023139, 0.03087877,
              0.16119573, 0.31575634],
             [0.04633454, 0.30351494, 0.08330251, 0.10624879, 0.04593036,
              0.20940956, 0.20525929],
             [0.02105808, 0.20165225, 0.24391652, 0.15172088, 0.12766377,
              0.18156365, 0.07242485],
             [0.01556295, 0.06088043, 0.22734239, 0.13104622, 0.16324558,
              0.1284364 , 0.27348604],
             [0.0589938 , 0.09383189, 0.16710615, 0.13034609, 0.10075376,
              0.21677285, 0.23219547],
             [0.18233871, 0.43998852, 0.10003894, 0.16439816, 0.03279143,
              0.02268928, 0.05775495],
             [0.36959052, 0.09905208, 0.13345888, 0.14184757, 0.03573069,
              0.11845324, 0.10186702],
             [0.08939418, 0.10191662, 0.15333418, 0.14231584, 0.12424018,
              0.1280424 , 0.2607566 ],
             [0.01904163, 0.22182973, 0.13638751, 0.15848971, 0.14371073,
              0.1162337 , 0.20430697]], dtype=float64)

Code 12.33

In [33]:
h = 3
plt.subplot(xlim=(0.9, 7.1), ylim=(-0.01, 0.41), xlabel="index", ylabel="probability")
for i in range(delta.shape[0]):
    if i + 1 == h:
        plt.plot(range(1, 8), delta[i], "ko-", ms=8, lw=4)
    else:
        plt.plot(range(1, 8), delta[i], "ko-", mfc="w", ms=8, lw=1.5, alpha=0.7)
Out[33]:
No description has been provided for this image

Code 12.34

In [34]:
dat = dict(
    R=d.response.values - 1,
    action=d.action.values,
    intention=d.intention.values,
    contact=d.contact.values,
    E=d.edu_new.values,  # edu_new as an index
    alpha=jnp.repeat(2, 7),
)  # delta prior


def model(action, intention, contact, E, alpha, R):
    bA = numpyro.sample("bA", dist.Normal(0, 1))
    bI = numpyro.sample("bI", dist.Normal(0, 1))
    bC = numpyro.sample("bC", dist.Normal(0, 1))
    bE = numpyro.sample("bE", dist.Normal(0, 1))
    delta = numpyro.sample("delta", dist.Dirichlet(alpha))
    kappa = numpyro.sample(
        "kappa",
        dist.TransformedDistribution(
            dist.Normal(0, 1.5).expand([6]), OrderedTransform()
        ),
    )
    delta_j = jnp.pad(delta, (1, 0))
    delta_E = jnp.sum(jnp.where(jnp.arange(8) <= E[..., None], delta_j, 0), -1)
    phi = bE * delta_E + bA * action + bI * intention + bC * contact
    numpyro.sample("R", dist.OrderedLogistic(phi, kappa), obs=R)


m12_6 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m12_6.run(random.PRNGKey(0), **dat)
Out[34]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[34]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[34]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[34]:
  0%|          | 0/1000 [00:00<?, ?it/s]

Code 12.35

In [35]:
m12_6.print_summary(0.89, exclude_deterministic=True)
Out[35]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
        bA     -0.70      0.04     -0.71     -0.77     -0.65   1544.94      1.00
        bC     -0.96      0.05     -0.96     -1.03     -0.87   1661.19      1.00
        bE     -0.37      0.19     -0.36     -0.65     -0.07    977.17      1.00
        bI     -0.72      0.03     -0.72     -0.77     -0.66   2043.79      1.00
  delta[0]      0.26      0.15      0.24      0.02      0.47   1254.68      1.00
  delta[1]      0.15      0.09      0.13      0.01      0.26   2314.24      1.00
  delta[2]      0.19      0.11      0.17      0.03      0.35   1960.75      1.00
  delta[3]      0.16      0.09      0.15      0.02      0.29   1838.66      1.00
  delta[4]      0.04      0.04      0.03      0.00      0.07    757.05      1.00
  delta[5]      0.09      0.06      0.08      0.00      0.17   2076.75      1.00
  delta[6]      0.12      0.07      0.11      0.01      0.21   2044.37      1.00
  kappa[0]     -3.12      0.17     -3.11     -3.36     -2.84    949.07      1.00
  kappa[1]     -2.44      0.17     -2.43     -2.70     -2.18    945.69      1.00
  kappa[2]     -1.86      0.17     -1.85     -2.10     -1.58    951.26      1.00
  kappa[3]     -0.84      0.17     -0.83     -1.08     -0.57    950.36      1.00
  kappa[4]     -0.17      0.17     -0.16     -0.41      0.09    949.31      1.00
  kappa[5]      0.74      0.17      0.75      0.50      1.00    946.09      1.00

Number of divergences: 0

Code 12.36

In [36]:
delta_labels = ["Elem", "MidSch", "SHS", "HSG", "SCol", "Bach", "Mast", "Grad"]
a12_6 = az.from_numpyro(
    m12_6, coords={"labels": delta_labels[:7]}, dims={"delta": ["labels"]}
)
az.plot_pair(a12_6, var_names="delta")
set_matplotlib_formats("png")
Out[36]:
No description has been provided for this image

Code 12.37

In [37]:
dat["edu_norm"] = d.edu_new.values / d.edu_new.max()


def model(edu_norm, action, intention, contact, y):
    bA = numpyro.sample("bA", dist.Normal(0, 1))
    bI = numpyro.sample("bI", dist.Normal(0, 1))
    bC = numpyro.sample("bC", dist.Normal(0, 1))
    bE = numpyro.sample("bE", dist.Normal(0, 1))
    cutpoints = numpyro.sample(
        "cutpoints",
        dist.TransformedDistribution(
            dist.Normal(0, 1.5).expand([6]), OrderedTransform()
        ),
    )
    mu = bE * edu_norm + bA * action + bI * intention + bC * contact
    numpyro.sample("y", dist.OrderedLogistic(mu, cutpoints), obs=y)


m12_7 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m12_7.run(
    random.PRNGKey(0),
    dat["edu_norm"],
    dat["action"],
    dat["intention"],
    dat["contact"],
    dat["R"],
)
m12_7.print_summary(0.89)
Out[37]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[37]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[37]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[37]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[37]:
                  mean       std    median      5.5%     94.5%     n_eff     r_hat
          bA     -0.71      0.04     -0.71     -0.77     -0.64   1587.96      1.00
          bC     -0.96      0.05     -0.96     -1.04     -0.88   1912.27      1.00
          bE     -0.11      0.09     -0.11     -0.25      0.05   1458.40      1.00
          bI     -0.72      0.04     -0.72     -0.77     -0.65   1939.67      1.00
cutpoints[0]     -2.91      0.08     -2.91     -3.04     -2.79   1134.36      1.00
cutpoints[1]     -2.23      0.08     -2.23     -2.35     -2.11   1162.34      1.00
cutpoints[2]     -1.64      0.07     -1.64     -1.75     -1.52   1185.14      1.00
cutpoints[3]     -0.62      0.07     -0.62     -0.74     -0.51   1250.90      1.00
cutpoints[4]      0.05      0.07      0.05     -0.06      0.17   1290.31      1.00
cutpoints[5]      0.95      0.07      0.95      0.83      1.06   1371.53      1.00

Number of divergences: 0

Code 12.38

In [38]:
Hurricanes = pd.read_csv("../data/Hurricanes.csv", sep=";")

Comments

Comments powered by Disqus