Chapter 13. Models With Memory

In [ ]:
!pip install -q numpyro arviz
In [0]:
import os
import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd

import jax.numpy as jnp
from jax import lax, random
from jax.scipy.special import expit

import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import effective_sample_size
from numpyro.infer import MCMC, NUTS, Predictive

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
    category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)

Code 13.1

In [1]:
reedfrogs = pd.read_csv("../data/reedfrogs.csv", sep=";")
d = reedfrogs
d.head()
Out[1]:
density pred size surv propsurv
0 10 no big 9 0.9
1 10 no big 10 1.0
2 10 no big 7 0.7
3 10 no big 10 1.0
4 10 no small 9 0.9

Code 13.2

In [2]:
# make the tank cluster variable
d["tank"] = jnp.arange(d.shape[0])

dat = dict(S=d.surv.values, N=d.density.values, tank=d.tank.values)


# approximate posterior
def model(tank, N, S):
    a = numpyro.sample("a", dist.Normal(0, 1.5), sample_shape=tank.shape)
    logit_p = a[tank]
    numpyro.sample("S", dist.Binomial(N, logits=logit_p), obs=S)


m13_1 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_1.run(random.PRNGKey(0), **dat)
Out[2]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[2]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[2]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[2]:
  0%|          | 0/1000 [00:00<?, ?it/s]

Code 13.3

In [3]:
def model(tank, N, S):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    a = numpyro.sample("a", dist.Normal(a_bar, sigma), sample_shape=tank.shape)
    logit_p = a[tank]
    numpyro.sample("S", dist.Binomial(N, logits=logit_p), obs=S)


m13_2 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_2.run(random.PRNGKey(0), **dat)
Out[3]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[3]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[3]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[3]:
  0%|          | 0/1000 [00:00<?, ?it/s]

Code 13.4

In [4]:
az.compare(
    {"m13.1": az.from_numpyro(m13_1), "m13.2": az.from_numpyro(m13_2)},
    ic="waic",
    scale="deviance",
)
Out[4]:
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
Out[4]:
rank waic p_waic d_waic weight se dse warning waic_scale
m13.2 0 200.729635 21.306688 0.000000 1.000000e+00 7.149090 0.000000 True deviance
m13.1 1 215.670229 26.160694 14.940594 1.154632e-13 4.379542 3.857661 True deviance

Code 13.5

In [5]:
# extract NumPyro samples
post = m13_2.get_samples()

# compute median intercept for each tank
# also transform to probability with logistic
d["propsurv.est"] = expit(jnp.mean(post["a"], 0))

# display raw proportions surviving in each tank
plt.plot(jnp.arange(1, 49), d.propsurv, "o", alpha=0.5, zorder=3)
plt.gca().set(ylim=(-0.05, 1.05), xlabel="tank", ylabel="proportion survival")
plt.gca().set(xticks=[1, 16, 32, 48], xticklabels=[1, 16, 32, 48])

# overlay posterior means
plt.plot(jnp.arange(1, 49), d["propsurv.est"], "ko", mfc="w")

# mark posterior mean probability across tanks
plt.gca().axhline(y=jnp.mean(expit(post["a_bar"])), c="k", ls="--", lw=1)

# draw vertical dividers between tank densities
plt.gca().axvline(x=16.5, c="k", lw=0.5)
plt.gca().axvline(x=32.5, c="k", lw=0.5)
plt.annotate("small tanks", (8, 0), ha="center")
plt.annotate("medium tanks", (16 + 8, 0), ha="center")
plt.annotate("large tanks", (32 + 8, 0), ha="center")
plt.show()
Out[5]:
No description has been provided for this image

Code 13.6

In [6]:
# show first 100 populations in the posterior
plt.subplot(xlim=(-3, 4), ylim=(0, 0.35), xlabel="log-odds survive", ylabel="Density")
for i in range(100):
    x = jnp.linspace(-3, 4, 101)
    plt.plot(
        x,
        jnp.exp(dist.Normal(post["a_bar"][i], post["sigma"][i]).log_prob(x)),
        "k",
        alpha=0.2,
    )
plt.show()

# sample 8000 imaginary tanks from the posterior distribution
idxs = random.randint(random.PRNGKey(1), (8000,), minval=0, maxval=1999)
sim_tanks = dist.Normal(post["a_bar"][idxs], post["sigma"][idxs]).sample(
    random.PRNGKey(2)
)

# transform to probability and visualize
az.plot_kde(expit(sim_tanks), bw=0.3)
plt.show()
Out[6]:
No description has been provided for this image
Out[6]:
No description has been provided for this image

Code 13.7

In [7]:
a_bar = 1.5
sigma = 1.5
nponds = 60
Ni = jnp.repeat(jnp.array([5, 10, 25, 35]), repeats=15)

Code 13.8

In [8]:
a_pond = dist.Normal(a_bar, sigma).sample(random.PRNGKey(5005), (nponds,))

Code 13.9

In [9]:
dsim = pd.DataFrame(dict(pond=range(1, nponds + 1), Ni=Ni, true_a=a_pond))

Code 13.10

In [10]:
print(type(range(3)))
print(type(jnp.arange(3)))
Out[10]:
<class 'range'>
<class 'jaxlib.xla_extension.DeviceArray'>

Code 13.11

In [11]:
dsim["Si"] = dist.Binomial(dsim.Ni.values, logits=dsim.true_a.values).sample(
    random.PRNGKey(0)
)

Code 13.12

In [12]:
dsim["p_nopool"] = dsim.Si / dsim.Ni

Code 13.13

In [13]:
dat = dict(Si=dsim.Si.values, Ni=dsim.Ni.values, pond=dsim.pond.values - 1)


def model(pond, Ni, Si):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    a_pond = numpyro.sample(
        "a_pond", dist.Normal(a_bar, sigma), sample_shape=pond.shape
    )
    logit_p = a_pond[pond]
    numpyro.sample("Si", dist.Binomial(Ni, logits=logit_p), obs=Si)


m13_3 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_3.run(random.PRNGKey(0), **dat)
Out[13]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[13]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[13]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[13]:
  0%|          | 0/1000 [00:00<?, ?it/s]

Code 13.14

In [14]:
m13_3.print_summary(0.89)
Out[14]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
     a_bar      1.55      0.28      1.54      1.09      1.97   2640.58      1.00
 a_pond[0]      3.01      1.36      2.89      0.89      5.11   2146.07      1.00
 a_pond[1]      3.03      1.39      2.91      0.76      5.06   2518.07      1.00
 a_pond[2]      0.74      0.89      0.71     -0.65      2.18   3881.46      1.00
 a_pond[3]      3.01      1.31      2.90      1.14      5.21   2339.89      1.00
 a_pond[4]      2.99      1.35      2.85      0.83      5.02   2502.75      1.00
 a_pond[5]      2.97      1.35      2.82      0.85      5.04   2394.69      1.00
 a_pond[6]     -0.04      0.83     -0.03     -1.27      1.35   2422.02      1.00
 a_pond[7]      3.01      1.36      2.89      0.88      5.06   1866.54      1.00
 a_pond[8]      1.66      1.07      1.59     -0.10      3.25   3248.79      1.00
 a_pond[9]      1.62      1.03      1.52     -0.08      3.12   2579.92      1.00
a_pond[10]     -1.72      1.09     -1.64     -3.41     -0.03   3104.02      1.00
a_pond[11]     -1.74      1.13     -1.63     -3.57     -0.09   2501.22      1.00
a_pond[12]      3.00      1.35      2.87      0.68      4.90   2192.21      1.00
a_pond[13]      3.05      1.39      2.88      0.85      5.25   2250.77      1.00
a_pond[14]     -0.04      0.84     -0.03     -1.40      1.29   3785.95      1.00
a_pond[15]      1.51      0.75      1.46      0.33      2.67   3063.76      1.00
a_pond[16]      3.40      1.25      3.25      1.43      5.30   2503.37      1.00
a_pond[17]      2.24      0.94      2.16      0.72      3.66   3027.82      1.00
a_pond[18]      3.44      1.31      3.30      1.27      5.35   2444.29      1.00
a_pond[19]      0.59      0.69      0.57     -0.55      1.65   3086.36      1.00
a_pond[20]      1.55      0.80      1.50      0.26      2.74   2602.71      1.00
a_pond[21]      2.28      0.94      2.23      0.80      3.73   2771.97      1.00
a_pond[22]      0.58      0.64      0.57     -0.48      1.57   2846.15      1.00
a_pond[23]     -1.07      0.70     -1.03     -2.23     -0.01   2388.72      1.00
a_pond[24]      3.43      1.29      3.29      1.38      5.28   1926.15      1.00
a_pond[25]      1.54      0.76      1.49      0.39      2.79   3489.25      1.00
a_pond[26]      3.43      1.30      3.26      1.43      5.53   1902.42      1.00
a_pond[27]      1.06      0.71      1.03     -0.10      2.12   2640.76      1.00
a_pond[28]      2.25      0.94      2.17      0.82      3.71   3871.51      1.00
a_pond[29]     -0.21      0.63     -0.20     -1.24      0.77   3303.80      1.00
a_pond[30]     -3.19      0.99     -3.06     -4.73     -1.70   1657.59      1.00
a_pond[31]     -0.17      0.40     -0.16     -0.77      0.50   3286.27      1.00
a_pond[32]      0.82      0.43      0.81      0.10      1.47   3592.00      1.00
a_pond[33]      0.64      0.41      0.64     -0.05      1.25   2996.21      1.00
a_pond[34]     -1.75      0.56     -1.70     -2.61     -0.84   3351.40      1.00
a_pond[35]      1.72      0.54      1.70      0.85      2.61   3152.36      1.00
a_pond[36]      0.82      0.43      0.81      0.22      1.60   2752.63      1.00
a_pond[37]      4.02      1.20      3.87      2.22      5.90   1909.16      1.00
a_pond[38]      4.01      1.16      3.88      2.18      5.62   2083.64      1.00
a_pond[39]      3.07      0.87      2.98      1.73      4.44   2040.58      1.00
a_pond[40]      3.06      0.88      2.98      1.65      4.35   1979.37      1.00
a_pond[41]      1.73      0.56      1.70      0.89      2.65   3205.70      1.00
a_pond[42]      3.06      0.87      2.97      1.70      4.40   2023.91      1.00
a_pond[43]      2.47      0.67      2.41      1.35      3.46   2418.37      1.00
a_pond[44]      3.07      0.85      2.97      1.74      4.36   2668.19      1.00
a_pond[45]     -1.29      0.41     -1.26     -1.88     -0.62   3217.64      1.00
a_pond[46]      1.44      0.42      1.42      0.74      2.08   2950.12      1.00
a_pond[47]      0.23      0.33      0.22     -0.29      0.74   3520.28      1.00
a_pond[48]      0.58      0.36      0.57     -0.05      1.10   3651.30      1.00
a_pond[49]      3.33      0.81      3.26      2.05      4.60   2688.02      1.00
a_pond[50]      1.27      0.39      1.26      0.64      1.90   3220.04      1.00
a_pond[51]      3.35      0.83      3.30      2.10      4.65   2192.53      1.00
a_pond[52]      2.39      0.59      2.37      1.40      3.22   2908.23      1.00
a_pond[53]      0.71      0.35      0.70      0.14      1.28   3976.04      1.00
a_pond[54]      0.34      0.34      0.35     -0.18      0.91   3431.03      1.00
a_pond[55]      2.80      0.70      2.76      1.71      3.82   2395.95      1.00
a_pond[56]      2.10      0.54      2.06      1.28      2.97   2852.21      1.00
a_pond[57]      0.58      0.35      0.58     -0.00      1.13   3492.97      1.00
a_pond[58]      0.46      0.35      0.47     -0.02      1.07   3989.22      1.00
a_pond[59]      3.34      0.81      3.26      2.09      4.57   2403.33      1.00
     sigma      1.87      0.26      1.86      1.46      2.25    877.63      1.00

Number of divergences: 0

Code 13.15

In [15]:
post = m13_3.get_samples()
dsim["p_partpool"] = jnp.mean(expit(post["a_pond"]), 0)

Code 13.16

In [16]:
dsim["p_true"] = expit(dsim.true_a.values)

Code 13.17

In [17]:
nopool_error = (dsim.p_nopool - dsim.p_true).abs()
partpool_error = (dsim.p_partpool - dsim.p_true).abs()

Code 13.18

In [18]:
plt.scatter(range(1, 61), nopool_error, label="nopool", alpha=0.8)
plt.gca().set(xlabel="pond", ylabel="absolute error")
plt.scatter(
    range(1, 61),
    partpool_error,
    label="partpool",
    s=50,
    edgecolor="black",
    facecolor="none",
)
plt.legend()
plt.show()
Out[18]:
No description has been provided for this image

Code 13.19

In [19]:
dsim["nopool_error"] = nopool_error
dsim["partpool_error"] = partpool_error
nopool_avg = dsim.groupby("Ni")["nopool_error"].mean()
partpool_avg = dsim.groupby("Ni")["partpool_error"].mean()

Code 13.20

In [20]:
a_bar = 1.5
sigma = 1.5
nponds = 60
Ni = jnp.repeat(jnp.array([5, 10, 25, 35]), repeats=15)
a_pond = dist.Normal(a_bar, sigma).sample(random.PRNGKey(5006), (nponds,))
dsim = pd.DataFrame(dict(pond=range(1, nponds + 1), Ni=Ni, true_a=a_pond))
dsim["Si"] = dist.Binomial(dsim.Ni.values, logits=dsim.true_a.values).sample(
    random.PRNGKey(0)
)
dsim["p_nopool"] = dsim.Si / dsim.Ni
newdat = dict(Si=dsim.Si.values, Ni=dsim.Ni.values, pond=dsim.pond.values - 1)
m13_3new = MCMC(
    NUTS(m13_3.sampler.model), num_warmup=1000, num_samples=1000, num_chains=4
)
m13_3new.run(random.PRNGKey(0), **newdat)

post = m13_3new.get_samples()
dsim["p_partpool"] = jnp.mean(expit(post["a_pond"]), 0)
dsim["p_true"] = expit(dsim.true_a.values)
nopool_error = (dsim.p_nopool - dsim.p_true).abs()
partpool_error = (dsim.p_partpool - dsim.p_true).abs()
plt.scatter(range(1, 61), nopool_error, label="nopool", alpha=0.8)
plt.gca().set(xlabel="pond", ylabel="absolute error")
plt.scatter(
    range(1, 61),
    partpool_error,
    label="partpool",
    s=50,
    edgecolor="black",
    facecolor="none",
)
plt.legend()
plt.show()
Out[20]:
  0%|          | 0/2000 [00:00<?, ?it/s]
Out[20]:
  0%|          | 0/2000 [00:00<?, ?it/s]
Out[20]:
  0%|          | 0/2000 [00:00<?, ?it/s]
Out[20]:
  0%|          | 0/2000 [00:00<?, ?it/s]
Out[20]:
No description has been provided for this image

Code 13.21

In [21]:
chimpanzees = pd.read_csv("../data/chimpanzees.csv", sep=";")
d = chimpanzees
d["treatment"] = 1 + d.prosoc_left + 2 * d.condition

dat_list = dict(
    pulled_left=d.pulled_left.values,
    actor=d.actor.values - 1,
    block_id=d.block.values - 1,
    treatment=d.treatment.values - 1,
)


def model(actor, block_id, treatment, pulled_left=None, link=False):
    # hyper-priors
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
    sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
    # adaptive priors
    a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
    g = numpyro.sample("g", dist.Normal(0, sigma_g), sample_shape=(6,))
    b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
    logit_p = a[actor] + g[block_id] + b[treatment]
    if link:
        numpyro.deterministic("p", expit(logit_p))
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m13_4 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_4.run(random.PRNGKey(0), **dat_list)
print("Number of divergences:", m13_4.get_extra_fields()["diverging"].sum())
Out[21]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[21]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[21]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[21]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[21]:
Number of divergences: 7

Code 13.22

In [22]:
m13_4.print_summary()
post = m13_4.get_samples(group_by_chain=True)
az.plot_forest(post, combined=True, hdi_prob=0.89)  # also plot
plt.show()
Out[22]:
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      a[0]     -0.38      0.37     -0.38     -0.98      0.22    597.32      1.00
      a[1]      4.67      1.29      4.49      2.78      6.67    948.84      1.00
      a[2]     -0.69      0.37     -0.70     -1.30     -0.07    566.81      1.00
      a[3]     -0.70      0.39     -0.69     -1.36     -0.09    512.00      1.00
      a[4]     -0.39      0.37     -0.39     -1.01      0.19    546.34      1.00
      a[5]      0.56      0.38      0.56     -0.08      1.15    607.70      1.00
      a[6]      2.09      0.47      2.10      1.26      2.81    757.08      1.00
     a_bar      0.57      0.72      0.57     -0.61      1.74   1482.01      1.00
      b[0]     -0.11      0.30     -0.11     -0.59      0.37    614.87      1.00
      b[1]      0.41      0.30      0.41     -0.07      0.93    623.30      1.00
      b[2]     -0.46      0.31     -0.45     -0.99     -0.01    574.28      1.00
      b[3]      0.30      0.30      0.31     -0.18      0.81    620.34      1.00
      g[0]     -0.18      0.23     -0.13     -0.56      0.10    556.08      1.01
      g[1]      0.05      0.20      0.03     -0.24      0.41    613.18      1.01
      g[2]      0.06      0.19      0.03     -0.26      0.38    596.79      1.00
      g[3]      0.02      0.19      0.01     -0.28      0.34    947.51      1.00
      g[4]     -0.03      0.19     -0.01     -0.35      0.27    933.78      1.00
      g[5]      0.13      0.22      0.09     -0.21      0.47    485.27      1.00
   sigma_a      2.03      0.65      1.93      1.06      2.96   1194.46      1.00
   sigma_g      0.23      0.18      0.19      0.00      0.47    309.51      1.01

Number of divergences: 7
Out[22]:
No description has been provided for this image

Code 13.23

In [23]:
def model(actor, treatment, pulled_left):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
    a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
    b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
    logit_p = a[actor] + b[treatment]
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m13_5 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_5.run(
    random.PRNGKey(14),
    dat_list["actor"],
    dat_list["treatment"],
    dat_list["pulled_left"],
)
Out[23]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[23]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[23]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[23]:
  0%|          | 0/1000 [00:00<?, ?it/s]

Code 13.24

In [24]:
az.compare(
    {"m13.4": az.from_numpyro(m13_4), "m13.5": az.from_numpyro(m13_5)},
    ic="waic",
    scale="deviance",
)
Out[24]:
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking
Out[24]:
rank waic p_waic d_waic weight se dse warning waic_scale
m13.5 0 531.103967 8.554003 0.000000 1.0 19.200204 0.000000 False deviance
m13.4 1 532.297533 10.828377 1.193565 0.0 19.411863 1.873273 False deviance

Code 13.25

In [25]:
def model(actor, block_id, treatment, pulled_left):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
    sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
    sigma_b = numpyro.sample("sigma_b", dist.Exponential(1))
    a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
    g = numpyro.sample("g", dist.Normal(0, sigma_g), sample_shape=(6,))
    b = numpyro.sample("b", dist.Normal(0, sigma_b), sample_shape=(4,))
    logit_p = a[actor] + g[block_id] + b[treatment]
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m13_6 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_6.run(random.PRNGKey(16), **dat_list)
print("Number of divergences:", m13_6.get_extra_fields()["diverging"].sum())
{
    "m13.4": jnp.mean(m13_4.get_samples()["b"], 0),
    "m13.6": jnp.mean(m13_6.get_samples()["b"], 0),
}
Out[25]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[25]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[25]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[25]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[25]:
Number of divergences: 10
Out[25]:
{'m13.4': DeviceArray([-0.11273014,  0.41450262, -0.45547444,  0.30242366], dtype=float32),
 'm13.6': DeviceArray([-0.11401905,  0.3812172 , -0.44742095,  0.27241236], dtype=float32)}

Code 13.26

In [26]:
def model():
    v = numpyro.sample("v", dist.Normal(0, 3))
    x = numpyro.sample("x", dist.Normal(0, jnp.exp(v)))


m13_7 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_7.run(random.PRNGKey(0))
m13_7.print_summary()
Out[26]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[26]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[26]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[26]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[26]:
                mean       std    median      5.0%     95.0%     n_eff     r_hat
         v      1.24      2.58      1.22     -2.75      4.49      4.12      1.59
         x     43.93    349.99      0.01    -43.64     51.78     49.60      1.06

Number of divergences: 201

Code 13.27

In [27]:
def model():
    v = numpyro.sample("v", dist.Normal(0, 3))
    z = numpyro.sample("z", dist.Normal(0, 1))
    numpyro.deterministic("x", z * jnp.exp(v))


m13_7nc = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_7nc.run(random.PRNGKey(0))
m13_7nc.print_summary(exclude_deterministic=False)
Out[27]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[27]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[27]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[27]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[27]:
                mean       std    median      5.0%     95.0%     n_eff     r_hat
         v     -0.04      3.00     -0.13     -4.85      5.01   2019.83      1.00
         x      2.92    149.31     -0.00    -26.96     30.74   1719.69      1.00
         z     -0.02      0.98     -0.01     -1.66      1.57   1963.93      1.00

Number of divergences: 0

Code 13.28

In [28]:
m13_4b = MCMC(
    NUTS(m13_4.sampler.model, target_accept_prob=0.99),
    num_warmup=500,
    num_samples=500,
    num_chains=4,
)
m13_4b.run(random.PRNGKey(13), **dat_list)
jnp.sum(m13_4b.get_extra_fields()["diverging"])
Out[28]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[28]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[28]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[28]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[28]:
DeviceArray(11, dtype=int32)

Code 13.29

In [29]:
def model(actor, block_id, treatment, pulled_left):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
    sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
    z = numpyro.sample("z", dist.Normal(0, 1), sample_shape=(7,))
    x = numpyro.sample("x", dist.Normal(0, 1), sample_shape=(6,))
    b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
    logit_p = a_bar + z[actor] * sigma_a + x[block_id] * sigma_g + b[treatment]
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m13_4nc = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_4nc.run(random.PRNGKey(16), **dat_list)
print("Number of divergences:", m13_4nc.get_extra_fields()["diverging"].sum())
Out[29]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[29]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[29]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[29]:
  0%|          | 0/1000 [00:00<?, ?it/s]
Out[29]:
Number of divergences: 0

Code 13.30

In [30]:
neff_c = {
    k: effective_sample_size(v)
    for k, v in m13_4.get_samples(group_by_chain=True).items()
}
neff_nc = {
    k: effective_sample_size(v)
    for k, v in m13_4nc.get_samples(group_by_chain=True).items()
}
par_names = []
keys_c = ["b", "a", "g", "a_bar", "sigma_a", "sigma_g"]
keys_nc = ["b", "z", "x", "a_bar", "sigma_a", "sigma_g"]
for k in keys_c:
    if jnp.ndim(neff_c[k]) == 0:
        par_names += [k]
    else:
        par_names += [k + "[{}]".format(i) for i in range(neff_c[k].size)]
neff_c = jnp.concatenate([neff_c[k].reshape(-1) for k in keys_c])
neff_nc = jnp.concatenate([neff_nc[k].reshape(-1) for k in keys_nc])
neff_table = pd.DataFrame(dict(neff_c=neff_c, neff_nc=neff_nc))
neff_table.index = par_names
neff_table.round()
Out[30]:
neff_c neff_nc
b[0] 615.0 1095.0
b[1] 623.0 1072.0
b[2] 574.0 904.0
b[3] 620.0 987.0
a[0] 597.0 382.0
a[1] 949.0 1045.0
a[2] 567.0 379.0
a[3] 512.0 383.0
a[4] 546.0 369.0
a[5] 608.0 367.0
a[6] 757.0 640.0
g[0] 556.0 1675.0
g[1] 613.0 2287.0
g[2] 597.0 1947.0
g[3] 948.0 2062.0
g[4] 934.0 2232.0
g[5] 485.0 1443.0
a_bar 1482.0 429.0
sigma_a 1194.0 647.0
sigma_g 310.0 849.0

Code 13.31

In [31]:
chimp = 2
d_pred = dict(
    actor=jnp.repeat(chimp, 4) - 1,
    treatment=jnp.arange(4),
    block_id=jnp.repeat(1, 4) - 1,
)
p = Predictive(m13_4.sampler.model, m13_4.get_samples())(
    random.PRNGKey(0), link=True, **d_pred
)["p"]
p_mu = jnp.mean(p, 0)
p_ci = jnp.percentile(p, q=jnp.array([5.5, 94.5]), axis=0)

Code 13.32

In [32]:
post = m13_4.get_samples()
{k: v.reshape(-1)[:5] for k, v in post.items()}
Out[32]:
{'a': DeviceArray([ 0.17939359,  3.2472124 , -0.42817664, -0.3104607 ,
               0.09759563], dtype=float32),
 'a_bar': DeviceArray([ 0.57802844,  0.5723409 ,  0.24400227, -0.14573385,
               0.60223407], dtype=float32),
 'b': DeviceArray([-0.35315567,  0.07911116, -0.7404461 , -0.29147664,
              -0.1611932 ], dtype=float32),
 'g': DeviceArray([-0.06372303, -0.07348639, -0.07706347, -0.00371453,
               0.11766291], dtype=float32),
 'sigma_a': DeviceArray([2.015479 , 2.4767263, 3.077415 , 1.6289241, 1.7870352], dtype=float32),
 'sigma_g': DeviceArray([0.05992268, 0.21187922, 0.09104711, 0.2905614 , 0.26771992],            dtype=float32)}

Code 13.33

In [33]:
az.plot_kde(post["a"][:, 4])
plt.show()
Out[33]:
No description has been provided for this image

Code 13.34

In [34]:
def p_link(treatment, actor=0, block_id=0):
    a, g, b = post["a"], post["g"], post["b"]
    logodds = a[:, actor] + g[:, block_id] + b[:, treatment]
    return expit(logodds)

Code 13.35

In [35]:
p_raw = lax.map(lambda i: p_link(i, actor=1, block_id=0), jnp.arange(4))
p_mu = jnp.mean(p_raw, 0)
p_ci = jnp.percentile(p_raw, jnp.array([5.5, 94.5]), 0)

Code 13.36

In [36]:
def p_link_abar(treatment):
    logodds = post["a_bar"] + post["b"][:, treatment]
    return expit(logodds)

Code 13.37

In [37]:
p_raw = lax.map(p_link_abar, jnp.arange(4))
p_mu = jnp.mean(p_raw, 1)
p_ci = jnp.percentile(p_raw, jnp.array([5.5, 94.5]), 1)

plt.subplot(
    xlabel="treatment", ylabel="proportion pulled left", ylim=(0, 1), xlim=(0.9, 4.1)
)
plt.gca().set(xticks=range(1, 5), xticklabels=["R/N", "L/N", "R/P", "L/P"])
plt.plot(range(1, 5), p_mu)
plt.fill_between(range(1, 5), p_ci[0], p_ci[1], color="k", alpha=0.2)
plt.show()
Out[37]:
No description has been provided for this image

Code 13.38

In [38]:
a_sim = dist.Normal(post["a_bar"], post["sigma_a"]).sample(random.PRNGKey(0))


def p_link_asim(treatment):
    logodds = a_sim + post["b"][:, treatment]
    return expit(logodds)


p_raw_asim = lax.map(p_link_asim, jnp.arange(4))

Code 13.39

In [39]:
plt.subplot(
    xlabel="treatment", ylabel="proportion pulled left", ylim=(0, 1), xlim=(0.9, 4.1)
)
plt.gca().set(xticks=range(1, 5), xticklabels=["R/N", "L/N", "R/P", "L/P"])
for i in range(100):
    plt.plot(range(1, 5), p_raw_asim[:, i], color="k", alpha=0.25)
Out[39]:
No description has been provided for this image

Code 13.40

In [40]:
bangladesh = pd.read_csv("../data/bangladesh.csv", sep=";")
d = bangladesh
jnp.sort(d.district.unique())
Out[40]:
DeviceArray([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
             16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
             31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
             46, 47, 48, 49, 50, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61],            dtype=int32)

Code 13.41

In [41]:
d["district_id"] = d.district.astype("category").cat.codes
jnp.sort(d.district_id.unique())
Out[41]:
DeviceArray([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
             15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
             30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
             45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59],            dtype=int8)

Comments

Comments powered by Disqus