Chapter 13. Models With Memory

In [0]:
import os
import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd

import jax.numpy as jnp
from jax import lax, random
from jax.scipy.special import expit

import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import effective_sample_size
from numpyro.infer import MCMC, NUTS, Predictive

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
    category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_host_device_count(4)

Code 13.1

In [1]:
reedfrogs = pd.read_csv("../data/reedfrogs.csv", sep=";")
d = reedfrogs
d.head()
Out[1]:
density pred size surv propsurv
0 10 no big 9 0.9
1 10 no big 10 1.0
2 10 no big 7 0.7
3 10 no big 10 1.0
4 10 no small 9 0.9

Code 13.2

In [2]:
# make the tank cluster variable
d["tank"] = jnp.arange(d.shape[0])

dat = dict(S=d.surv.values, N=d.density.values, tank=d.tank.values)

# approximate posterior
def model(tank, N, S):
    a = numpyro.sample("a", dist.Normal(0, 1.5), sample_shape=tank.shape)
    logit_p = a[tank]
    numpyro.sample("S", dist.Binomial(N, logits=logit_p), obs=S)


m13_1 = MCMC(NUTS(model), 500, 500, num_chains=4)
m13_1.run(random.PRNGKey(0), **dat)

Code 13.3

In [3]:
def model(tank, N, S):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    a = numpyro.sample("a", dist.Normal(a_bar, sigma), sample_shape=tank.shape)
    logit_p = a[tank]
    numpyro.sample("S", dist.Binomial(N, logits=logit_p), obs=S)


m13_2 = MCMC(NUTS(model), 500, 500, num_chains=4)
m13_2.run(random.PRNGKey(0), **dat)

Code 13.4

In [4]:
az.compare(
    {"m13.1": az.from_numpyro(m13_1), "m13.2": az.from_numpyro(m13_2)},
    ic="waic",
    scale="deviance",
)
Out[4]:
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
Out[4]:
rank waic p_waic d_waic weight se dse warning waic_scale
m13.2 0 201.392 21.4062 0 0.992273 4.54082 0 True deviance
m13.1 1 214.284 25.3695 12.8923 0.00772698 6.98608 3.95414 True deviance

Code 13.5

In [5]:
# extract NumPyro samples
post = m13_2.get_samples()

# compute median intercept for each tank
# also transform to probability with logistic
d["propsurv.est"] = expit(jnp.mean(post["a"], 0))

# display raw proportions surviving in each tank
plt.plot(jnp.arange(1, 49), d.propsurv, "o", alpha=0.5, zorder=3)
plt.gca().set(ylim=(-0.05, 1.05), xlabel="tank", ylabel="proportion survival")
plt.gca().set(xticks=[1, 16, 32, 48], xticklabels=[1, 16, 32, 48])

# overlay posterior means
plt.plot(jnp.arange(1, 49), d["propsurv.est"], "ko", mfc="w")

# mark posterior mean probability across tanks
plt.gca().axhline(y=jnp.mean(expit(post["a_bar"])), c="k", ls="--", lw=1)

# draw vertical dividers between tank densities
plt.gca().axvline(x=16.5, c="k", lw=0.5)
plt.gca().axvline(x=32.5, c="k", lw=0.5)
plt.annotate("small tanks", (8, 0), ha="center")
plt.annotate("medium tanks", (16 + 8, 0), ha="center")
plt.annotate("large tanks", (32 + 8, 0), ha="center")
plt.show()
Out[5]:

Code 13.6

In [6]:
# show first 100 populations in the posterior
plt.subplot(xlim=(-3, 4), ylim=(0, 0.35), xlabel="log-odds survive", ylabel="Density")
for i in range(100):
    x = jnp.linspace(-3, 4, 101)
    plt.plot(
        x,
        jnp.exp(dist.Normal(post["a_bar"][i], post["sigma"][i]).log_prob(x)),
        "k",
        alpha=0.2,
    )
plt.show()

# sample 8000 imaginary tanks from the posterior distribution
idxs = random.randint(random.PRNGKey(1), (8000,), minval=0, maxval=1999)
sim_tanks = dist.Normal(post["a_bar"][idxs], post["sigma"][idxs]).sample(
    random.PRNGKey(2)
)

# transform to probability and visualize
az.plot_kde(expit(sim_tanks), bw=0.3)
plt.show()
Out[6]:
Out[6]:

Code 13.7

In [7]:
a_bar = 1.5
sigma = 1.5
nponds = 60
Ni = jnp.repeat(jnp.array([5, 10, 25, 35]), repeats=15)

Code 13.8

In [8]:
a_pond = dist.Normal(a_bar, sigma).sample(random.PRNGKey(5005), (nponds,))

Code 13.9

In [9]:
dsim = pd.DataFrame(dict(pond=range(1, nponds + 1), Ni=Ni, true_a=a_pond))

Code 13.10

In [10]:
print(type(range(3)))
print(type(jnp.arange(3)))
Out[10]:
<class 'range'>
<class 'jax.interpreters.xla.DeviceArray'>

Code 13.11

In [11]:
dsim["Si"] = dist.Binomial(dsim.Ni.values, logits=dsim.true_a.values).sample(
    random.PRNGKey(0)
)

Code 13.12

In [12]:
dsim["p_nopool"] = dsim.Si / dsim.Ni

Code 13.13

In [13]:
dat = dict(Si=dsim.Si.values, Ni=dsim.Ni.values, pond=dsim.pond.values - 1)


def model(pond, Ni, Si):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    a_pond = numpyro.sample(
        "a_pond", dist.Normal(a_bar, sigma), sample_shape=pond.shape
    )
    logit_p = a_pond[pond]
    numpyro.sample("Si", dist.Binomial(Ni, logits=logit_p), obs=Si)


m13_3 = MCMC(NUTS(model), 500, 500, num_chains=4)
m13_3.run(random.PRNGKey(0), **dat)

Code 13.14

In [14]:
m13_3.print_summary(0.89)
Out[14]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
     a_bar      1.54      0.28      1.54      1.06      1.95   2281.23      1.00
 a_pond[0]      2.98      1.36      2.84      0.90      5.13   2407.70      1.00
 a_pond[1]      2.98      1.36      2.83      0.89      5.12   2097.79      1.00
 a_pond[2]      0.73      0.90      0.69     -0.68      2.07   2404.93      1.00
 a_pond[3]      3.03      1.40      2.90      0.89      5.08   2713.74      1.00
 a_pond[4]      2.96      1.36      2.84      0.85      4.96   2800.48      1.00
 a_pond[5]      2.98      1.44      2.81      0.95      5.38   2431.15      1.00
 a_pond[6]     -0.05      0.87     -0.04     -1.49      1.27   3150.06      1.00
 a_pond[7]      3.04      1.38      2.93      0.95      5.15   2472.86      1.00
 a_pond[8]      1.62      1.01      1.54      0.06      3.17   2933.07      1.00
 a_pond[9]      1.67      1.05      1.61     -0.10      3.16   2438.75      1.00
a_pond[10]     -1.72      1.07     -1.60     -3.37     -0.08   2144.05      1.00
a_pond[11]     -1.71      1.11     -1.59     -3.39     -0.02   2004.38      1.00
a_pond[12]      2.99      1.36      2.89      0.86      5.00   2533.61      1.00
a_pond[13]      2.98      1.30      2.86      0.80      4.78   2383.16      1.00
a_pond[14]     -0.00      0.87     -0.01     -1.49      1.27   2081.98      1.00
a_pond[15]      1.56      0.78      1.50      0.30      2.69   2141.73      1.00
a_pond[16]      3.40      1.27      3.25      1.60      5.46   2136.25      1.00
a_pond[17]      2.27      0.93      2.18      0.81      3.65   2064.09      1.00
a_pond[18]      3.39      1.29      3.23      1.49      5.34   2448.16      1.00
a_pond[19]      0.59      0.64      0.58     -0.42      1.59   3363.59      1.00
a_pond[20]      1.54      0.76      1.48      0.27      2.73   2179.81      1.00
a_pond[21]      2.29      0.97      2.21      0.76      3.81   2483.35      1.00
a_pond[22]      0.57      0.65      0.55     -0.46      1.58   3108.46      1.00
a_pond[23]     -1.05      0.73     -1.03     -2.16      0.13   2560.64      1.00
a_pond[24]      3.36      1.25      3.22      1.53      5.34   2001.26      1.00
a_pond[25]      1.54      0.77      1.50      0.33      2.69   2960.84      1.00
a_pond[26]      3.41      1.31      3.28      1.44      5.43   1735.02      1.00
a_pond[27]      1.02      0.68      0.99     -0.07      2.03   3213.64      1.00
a_pond[28]      2.25      0.95      2.15      0.63      3.62   2589.98      1.00
a_pond[29]     -0.20      0.65     -0.21     -1.23      0.78   2735.11      1.00
a_pond[30]     -3.12      0.92     -3.02     -4.63     -1.82   1863.65      1.00
a_pond[31]     -0.15      0.38     -0.14     -0.77      0.45   2715.25      1.00
a_pond[32]      0.83      0.42      0.82      0.16      1.50   3916.83      1.00
a_pond[33]      0.66      0.40      0.65     -0.06      1.23   3212.74      1.00
a_pond[34]     -1.76      0.58     -1.72     -2.59     -0.79   2519.79      1.00
a_pond[35]      1.72      0.53      1.70      0.88      2.49   3092.25      1.00
a_pond[36]      0.83      0.45      0.82      0.16      1.57   3138.79      1.00
a_pond[37]      3.97      1.14      3.84      2.11      5.57   2225.12      1.00
a_pond[38]      4.00      1.17      3.85      2.32      5.84   1794.19      1.00
a_pond[39]      3.04      0.81      2.95      1.82      4.34   2007.28      1.00
a_pond[40]      3.05      0.85      2.99      1.79      4.44   2185.64      1.00
a_pond[41]      1.72      0.55      1.70      0.83      2.52   3055.49      1.00
a_pond[42]      3.05      0.85      2.96      1.73      4.34   2728.46      1.00
a_pond[43]      2.47      0.69      2.40      1.34      3.40   1884.27      1.00
a_pond[44]      3.04      0.81      2.97      1.79      4.32   2547.27      1.00
a_pond[45]     -1.27      0.41     -1.26     -1.96     -0.68   3249.49      1.00
a_pond[46]      1.43      0.41      1.42      0.78      2.09   3470.21      1.00
a_pond[47]      0.22      0.34      0.21     -0.33      0.75   3125.85      1.00
a_pond[48]      0.58      0.36      0.58     -0.01      1.14   3083.52      1.00
a_pond[49]      3.32      0.82      3.25      2.12      4.55   2114.18      1.00
a_pond[50]      1.27      0.42      1.25      0.64      1.96   3411.26      1.00
a_pond[51]      3.35      0.82      3.28      2.08      4.56   1814.55      1.00
a_pond[52]      2.40      0.56      2.36      1.47      3.23   2680.32      1.00
a_pond[53]      0.71      0.36      0.70      0.16      1.35   3768.80      1.00
a_pond[54]      0.33      0.34      0.32     -0.22      0.87   2863.52      1.00
a_pond[55]      2.79      0.67      2.73      1.80      3.88   2533.37      1.00
a_pond[56]      2.10      0.54      2.06      1.28      3.00   3273.72      1.00
a_pond[57]      0.58      0.35      0.58     -0.00      1.12   2897.91      1.00
a_pond[58]      0.46      0.35      0.46     -0.10      0.99   3336.77      1.00
a_pond[59]      3.33      0.83      3.26      2.05      4.63   2378.74      1.00
     sigma      1.86      0.27      1.84      1.45      2.29    734.34      1.00

Number of divergences: 0

Code 13.15

In [15]:
post = m13_3.get_samples()
dsim["p_partpool"] = jnp.mean(expit(post["a_pond"]), 0)

Code 13.16

In [16]:
dsim["p_true"] = expit(dsim.true_a.values)

Code 13.17

In [17]:
nopool_error = (dsim.p_nopool - dsim.p_true).abs()
partpool_error = (dsim.p_partpool - dsim.p_true).abs()

Code 13.18

In [18]:
plt.scatter(range(1, 61), nopool_error, label="nopool", alpha=0.8)
plt.gca().set(xlabel="pond", ylabel="absolute error")
plt.scatter(
    range(1, 61),
    partpool_error,
    label="partpool",
    s=50,
    edgecolor="black",
    facecolor="none",
)
plt.legend()
plt.show()
Out[18]:

Code 13.19

In [19]:
dsim["nopool_error"] = nopool_error
dsim["partpool_error"] = partpool_error
nopool_avg = dsim.groupby("Ni")["nopool_error"].mean()
partpool_avg = dsim.groupby("Ni")["partpool_error"].mean()

Code 13.20

In [20]:
a_bar = 1.5
sigma = 1.5
nponds = 60
Ni = jnp.repeat(jnp.array([5, 10, 25, 35]), repeats=15)
a_pond = dist.Normal(a_bar, sigma).sample(random.PRNGKey(5006), (nponds,))
dsim = pd.DataFrame(dict(pond=range(1, nponds + 1), Ni=Ni, true_a=a_pond))
dsim["Si"] = dist.Binomial(dsim.Ni.values, logits=dsim.true_a.values).sample(
    random.PRNGKey(0)
)
dsim["p_nopool"] = dsim.Si / dsim.Ni
newdat = dict(Si=dsim.Si.values, Ni=dsim.Ni.values, pond=dsim.pond.values - 1)
m13_3new = MCMC(NUTS(m13_3.sampler.model), 1000, 1000, num_chains=4)
m13_3new.run(random.PRNGKey(0), **newdat)

post = m13_3new.get_samples()
dsim["p_partpool"] = jnp.mean(expit(post["a_pond"]), 0)
dsim["p_true"] = expit(dsim.true_a.values)
nopool_error = (dsim.p_nopool - dsim.p_true).abs()
partpool_error = (dsim.p_partpool - dsim.p_true).abs()
plt.scatter(range(1, 61), nopool_error, label="nopool", alpha=0.8)
plt.gca().set(xlabel="pond", ylabel="absolute error")
plt.scatter(
    range(1, 61),
    partpool_error,
    label="partpool",
    s=50,
    edgecolor="black",
    facecolor="none",
)
plt.legend()
plt.show()
Out[20]:

Code 13.21

In [21]:
chimpanzees = pd.read_csv("../data/chimpanzees.csv", sep=";")
d = chimpanzees
d["treatment"] = 1 + d.prosoc_left + 2 * d.condition

dat_list = dict(
    pulled_left=d.pulled_left.values,
    actor=d.actor.values - 1,
    block_id=d.block.values - 1,
    treatment=d.treatment.values - 1,
)


def model(actor, block_id, treatment, pulled_left=None, link=False):
    # hyper-priors
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
    sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
    # adaptive priors
    a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
    g = numpyro.sample("g", dist.Normal(0, sigma_g), sample_shape=(6,))
    b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
    logit_p = a[actor] + g[block_id] + b[treatment]
    if link:
        numpyro.deterministic("p", expit(logit_p))
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m13_4 = MCMC(NUTS(model), 500, 500, num_chains=4)
m13_4.run(random.PRNGKey(0), **dat_list)
print("Number of divergences:", m13_4.get_extra_fields()["diverging"].sum())
Out[21]:
Number of divergences: 13

Code 13.22

In [22]:
m13_4.print_summary()
post = m13_4.get_samples(group_by_chain=True)
az.plot_forest(post, combined=True, hdi_prob=0.89)  # also plot
plt.show()
Out[22]:
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      a[0]     -0.36      0.36     -0.37     -0.95      0.22    486.62      1.00
      a[1]      4.62      1.20      4.42      2.73      6.42    878.09      1.00
      a[2]     -0.66      0.37     -0.68     -1.26     -0.05    492.32      1.01
      a[3]     -0.67      0.38     -0.68     -1.26     -0.03    511.05      1.00
      a[4]     -0.36      0.37     -0.38     -0.94      0.27    483.27      1.00
      a[5]      0.58      0.37      0.57     -0.01      1.19    493.90      1.00
      a[6]      2.10      0.47      2.07      1.33      2.85    704.89      1.00
     a_bar      0.59      0.73      0.58     -0.64      1.70   1060.45      1.00
      b[0]     -0.14      0.31     -0.12     -0.62      0.41    534.65      1.00
      b[1]      0.39      0.30      0.40     -0.10      0.88    517.28      1.01
      b[2]     -0.48      0.30     -0.47     -0.96      0.02    504.48      1.00
      b[3]      0.27      0.30      0.27     -0.22      0.76    546.26      1.00
      g[0]     -0.18      0.22     -0.14     -0.52      0.14    473.87      1.00
      g[1]      0.06      0.20      0.03     -0.31      0.35    658.15      1.00
      g[2]      0.07      0.19      0.05     -0.20      0.43    559.06      1.00
      g[3]      0.02      0.19      0.01     -0.29      0.35    677.37      1.00
      g[4]     -0.03      0.20     -0.01     -0.37      0.28   1006.22      1.00
      g[5]      0.13      0.22      0.09     -0.21      0.48    572.32      1.00
   sigma_a      1.99      0.63      1.89      1.07      2.94    818.82      1.00
   sigma_g      0.24      0.18      0.20      0.02      0.47    262.91      1.00

Number of divergences: 13
Out[22]:

Code 13.23

In [23]:
def model(actor, treatment, pulled_left):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
    a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
    b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
    logit_p = a[actor] + b[treatment]
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m13_5 = MCMC(NUTS(model), 500, 500, num_chains=4)
m13_5.run(
    random.PRNGKey(14),
    dat_list["actor"],
    dat_list["treatment"],
    dat_list["pulled_left"],
)

Code 13.24

In [24]:
az.compare(
    {"m13.4": az.from_numpyro(m13_4), "m13.5": az.from_numpyro(m13_5)},
    ic="waic",
    scale="deviance",
)
Out[24]:
rank waic p_waic d_waic weight se dse warning waic_scale
m13.5 0 531.171 8.57096 0 0.68637 19.2942 0 False deviance
m13.4 1 533.089 11.219 1.9172 0.31363 19.0931 1.92803 False deviance

Code 13.25

In [25]:
def model(actor, block_id, treatment, pulled_left):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
    sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
    sigma_b = numpyro.sample("sigma_b", dist.Exponential(1))
    a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
    g = numpyro.sample("g", dist.Normal(0, sigma_g), sample_shape=(6,))
    b = numpyro.sample("b", dist.Normal(0, sigma_b), sample_shape=(4,))
    logit_p = a[actor] + g[block_id] + b[treatment]
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m13_6 = MCMC(NUTS(model), 500, 500, num_chains=4)
m13_6.run(random.PRNGKey(16), **dat_list)
print("Number of divergences:", m13_6.get_extra_fields()["diverging"].sum())
{
    "m13.4": jnp.mean(m13_4.get_samples()["b"], 0),
    "m13.6": jnp.mean(m13_6.get_samples()["b"], 0),
}
Out[25]:
Number of divergences: 22
Out[25]:
{'m13.4': DeviceArray([-0.13909166,  0.39071533, -0.47855413,  0.27035657], dtype=float32),
 'm13.6': DeviceArray([-0.17561208,  0.32285807, -0.5131155 ,  0.21562107], dtype=float32)}

Code 13.26

In [26]:
def model():
    v = numpyro.sample("v", dist.Normal(0, 3))
    x = numpyro.sample("x", dist.Normal(0, jnp.exp(v)))


m13_7 = MCMC(NUTS(model), 500, 500, num_chains=4)
m13_7.run(random.PRNGKey(0))
m13_7.print_summary()
Out[26]:
                mean       std    median      5.0%     95.0%     n_eff     r_hat
         v      2.50      1.71      2.40     -0.46      4.97      8.39      1.32
         x      2.73    170.57      0.10    -65.42     74.32    247.32      1.02

Number of divergences: 557

Code 13.27

In [27]:
def model():
    v = numpyro.sample("v", dist.Normal(0, 3))
    z = numpyro.sample("z", dist.Normal(0, 1))
    numpyro.deterministic("x", z * jnp.exp(v))


m13_7nc = MCMC(NUTS(model), 500, 500, num_chains=4)
m13_7nc.run(random.PRNGKey(0))
m13_7nc.print_summary(exclude_deterministic=False)
Out[27]:
                mean       std    median      5.0%     95.0%     n_eff     r_hat
         v      0.03      2.93      0.08     -4.64      4.69   1774.07      1.00
         x     -7.80    358.56      0.00    -24.60     28.71   2048.67      1.00
         z     -0.00      0.99      0.01     -1.64      1.58   1847.30      1.00

Number of divergences: 0

Code 13.28

In [28]:
m13_4b = MCMC(
    NUTS(m13_4.sampler.model, target_accept_prob=0.99), 500, 500, num_chains=4
)
m13_4b.run(random.PRNGKey(13), **dat_list)
jnp.sum(m13_4b.get_extra_fields()["diverging"])
Out[28]:
DeviceArray(0, dtype=int32)

Code 13.29

In [29]:
def model(actor, block_id, treatment, pulled_left):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
    sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
    z = numpyro.sample("z", dist.Normal(0, 1), sample_shape=(7,))
    x = numpyro.sample("x", dist.Normal(0, 1), sample_shape=(6,))
    b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
    logit_p = a_bar + z[actor] * sigma_a + x[block_id] * sigma_g + b[treatment]
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m13_4nc = MCMC(NUTS(model), 500, 500, num_chains=4)
m13_4nc.run(random.PRNGKey(16), **dat_list)
print("Number of divergences:", m13_4nc.get_extra_fields()["diverging"].sum())
Out[29]:
Number of divergences: 0

Code 13.30

In [30]:
neff_c = {
    k: effective_sample_size(v.copy())
    for k, v in m13_4.get_samples(group_by_chain=True).items()
}
neff_nc = {
    k: effective_sample_size(v.copy())
    for k, v in m13_4nc.get_samples(group_by_chain=True).items()
}
par_names = []
keys_c = ["b", "a", "g", "a_bar", "sigma_a", "sigma_g"]
keys_nc = ["b", "z", "x", "a_bar", "sigma_a", "sigma_g"]
for k in keys_c:
    if jnp.ndim(neff_c[k]) == 0:
        par_names += [k]
    else:
        par_names += [k + "[{}]".format(i) for i in range(neff_c[k].size)]
neff_c = jnp.concatenate([neff_c[k].reshape(-1) for k in keys_c])
neff_nc = jnp.concatenate([neff_nc[k].reshape(-1) for k in keys_nc])
neff_table = pd.DataFrame(dict(neff_c=neff_c, neff_nc=neff_nc))
neff_table.index = par_names
neff_table.round()
Out[30]:
neff_c neff_nc
b[0] 535.0 1178.0
b[1] 517.0 1096.0
b[2] 504.0 1167.0
b[3] 546.0 1159.0
a[0] 487.0 452.0
a[1] 878.0 1169.0
a[2] 492.0 454.0
a[3] 511.0 444.0
a[4] 483.0 444.0
a[5] 494.0 519.0
a[6] 705.0 738.0
g[0] 474.0 2048.0
g[1] 658.0 2204.0
g[2] 559.0 2103.0
g[3] 677.0 1994.0
g[4] 1006.0 2131.0
g[5] 572.0 2221.0
a_bar 1060.0 517.0
sigma_a 819.0 934.0
sigma_g 263.0 1004.0

Code 13.31

In [31]:
chimp = 2
d_pred = dict(
    actor=jnp.repeat(chimp, 4) - 1,
    treatment=jnp.arange(4),
    block_id=jnp.repeat(1, 4) - 1,
)
p = Predictive(m13_4.sampler.model, m13_4.get_samples())(
    random.PRNGKey(0), link=True, **d_pred
)["p"]
p_mu = jnp.mean(p, 0)
p_ci = jnp.percentile(p, q=(5.5, 94.5), axis=0)

Code 13.32

In [32]:
post = m13_4.get_samples()
{k: v.reshape(-1)[:5] for k, v in post.items()}
Out[32]:
{'a': DeviceArray([-0.5197992 ,  3.056869  , -0.7414023 , -0.9056391 ,
              -0.09762876], dtype=float32),
 'a_bar': DeviceArray([0.32702658, 0.78957605, 0.3658363 , 0.19608618, 0.05800754],            dtype=float32),
 'b': DeviceArray([-0.15199736,  0.24597855, -0.6075848 ,  0.29595786,
              -0.1312438 ], dtype=float32),
 'g': DeviceArray([-0.2127662 ,  0.12210437,  0.6095109 ,  0.2989592 ,
              -0.30063742], dtype=float32),
 'sigma_a': DeviceArray([2.1341236, 1.8951536, 2.0155072, 2.2705545, 1.2400742], dtype=float32),
 'sigma_g': DeviceArray([0.7917654 , 0.29069167, 0.85146064, 0.99559826, 0.33158094],            dtype=float32)}

Code 13.33

In [33]:
az.plot_kde(post["a"][:, 4])
plt.show()
Out[33]:

Code 13.34

In [34]:
def p_link(treatment, actor=0, block_id=0):
    a, g, b = post["a"], post["g"], post["b"]
    logodds = a[:, actor] + g[:, block_id] + b[:, treatment]
    return expit(logodds)

Code 13.35

In [35]:
p_raw = lax.map(lambda i: p_link(i, actor=1, block_id=0), jnp.arange(4))
p_mu = jnp.mean(p_raw, 0)
p_ci = jnp.percentile(p_raw, (5.5, 94.5), 0)

Code 13.36

In [36]:
def p_link_abar(treatment):
    logodds = post["a_bar"] + post["b"][:, treatment]
    return expit(logodds)

Code 13.37

In [37]:
p_raw = lax.map(p_link_abar, jnp.arange(4))
p_mu = jnp.mean(p_raw, 1)
p_ci = jnp.percentile(p_raw, (5.5, 94.5), 1)

plt.subplot(
    xlabel="treatment", ylabel="proportion pulled left", ylim=(0, 1), xlim=(0.9, 4.1)
)
plt.gca().set(xticks=range(1, 5), xticklabels=["R/N", "L/N", "R/P", "L/P"])
plt.plot(range(1, 5), p_mu)
plt.fill_between(range(1, 5), p_ci[0], p_ci[1], color="k", alpha=0.2)
plt.show()
Out[37]:

Code 13.38

In [38]:
a_sim = dist.Normal(post["a_bar"], post["sigma_a"]).sample(random.PRNGKey(0))


def p_link_asim(treatment):
    logodds = a_sim + post["b"][:, treatment]
    return expit(logodds)


p_raw_asim = lax.map(p_link_asim, jnp.arange(4))

Code 13.39

In [39]:
plt.subplot(
    xlabel="treatment", ylabel="proportion pulled left", ylim=(0, 1), xlim=(0.9, 4.1)
)
plt.gca().set(xticks=range(1, 5), xticklabels=["R/N", "L/N", "R/P", "L/P"])
for i in range(100):
    plt.plot(range(1, 5), p_raw_asim[:, i], color="k", alpha=0.25)
Out[39]:

Code 13.40

In [40]:
bangladesh = pd.read_csv("../data/bangladesh.csv", sep=";")
d = bangladesh
jnp.sort(d.district.unique())
Out[40]:
DeviceArray([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
             16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
             31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
             46, 47, 48, 49, 50, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61],            dtype=int32)

Code 13.41

In [41]:
d["district_id"] = d.district.astype("category").cat.codes
jnp.sort(d.district_id.unique())
Out[41]:
DeviceArray([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
             15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
             30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
             45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59],            dtype=int8)

Comments

Comments powered by Disqus