Chapter 13. Models With Memory
In [ ]:
!pip install -q numpyro arviz
In [0]:
import os
import warnings
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import lax, random
from jax.scipy.special import expit
import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import effective_sample_size
from numpyro.infer import MCMC, NUTS, Predictive
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
Code 13.1¶
In [1]:
reedfrogs = pd.read_csv("../data/reedfrogs.csv", sep=";")
d = reedfrogs
d.head()
Out[1]:
Code 13.2¶
In [2]:
# make the tank cluster variable
d["tank"] = jnp.arange(d.shape[0])
dat = dict(S=d.surv.values, N=d.density.values, tank=d.tank.values)
# approximate posterior
def model(tank, N, S):
a = numpyro.sample("a", dist.Normal(0, 1.5), sample_shape=tank.shape)
logit_p = a[tank]
numpyro.sample("S", dist.Binomial(N, logits=logit_p), obs=S)
m13_1 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_1.run(random.PRNGKey(0), **dat)
Out[2]:
Out[2]:
Out[2]:
Out[2]:
Code 13.3¶
In [3]:
def model(tank, N, S):
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
a = numpyro.sample("a", dist.Normal(a_bar, sigma), sample_shape=tank.shape)
logit_p = a[tank]
numpyro.sample("S", dist.Binomial(N, logits=logit_p), obs=S)
m13_2 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_2.run(random.PRNGKey(0), **dat)
Out[3]:
Out[3]:
Out[3]:
Out[3]:
Code 13.4¶
In [4]:
az.compare(
{"m13.1": az.from_numpyro(m13_1), "m13.2": az.from_numpyro(m13_2)},
ic="waic",
scale="deviance",
)
Out[4]:
Out[4]:
Code 13.5¶
In [5]:
# extract NumPyro samples
post = m13_2.get_samples()
# compute median intercept for each tank
# also transform to probability with logistic
d["propsurv.est"] = expit(jnp.mean(post["a"], 0))
# display raw proportions surviving in each tank
plt.plot(jnp.arange(1, 49), d.propsurv, "o", alpha=0.5, zorder=3)
plt.gca().set(ylim=(-0.05, 1.05), xlabel="tank", ylabel="proportion survival")
plt.gca().set(xticks=[1, 16, 32, 48], xticklabels=[1, 16, 32, 48])
# overlay posterior means
plt.plot(jnp.arange(1, 49), d["propsurv.est"], "ko", mfc="w")
# mark posterior mean probability across tanks
plt.gca().axhline(y=jnp.mean(expit(post["a_bar"])), c="k", ls="--", lw=1)
# draw vertical dividers between tank densities
plt.gca().axvline(x=16.5, c="k", lw=0.5)
plt.gca().axvline(x=32.5, c="k", lw=0.5)
plt.annotate("small tanks", (8, 0), ha="center")
plt.annotate("medium tanks", (16 + 8, 0), ha="center")
plt.annotate("large tanks", (32 + 8, 0), ha="center")
plt.show()
Out[5]:
Code 13.6¶
In [6]:
# show first 100 populations in the posterior
plt.subplot(xlim=(-3, 4), ylim=(0, 0.35), xlabel="log-odds survive", ylabel="Density")
for i in range(100):
x = jnp.linspace(-3, 4, 101)
plt.plot(
x,
jnp.exp(dist.Normal(post["a_bar"][i], post["sigma"][i]).log_prob(x)),
"k",
alpha=0.2,
)
plt.show()
# sample 8000 imaginary tanks from the posterior distribution
idxs = random.randint(random.PRNGKey(1), (8000,), minval=0, maxval=1999)
sim_tanks = dist.Normal(post["a_bar"][idxs], post["sigma"][idxs]).sample(
random.PRNGKey(2)
)
# transform to probability and visualize
az.plot_kde(expit(sim_tanks), bw=0.3)
plt.show()
Out[6]:
Out[6]:
Code 13.7¶
In [7]:
a_bar = 1.5
sigma = 1.5
nponds = 60
Ni = jnp.repeat(jnp.array([5, 10, 25, 35]), repeats=15)
Code 13.8¶
In [8]:
a_pond = dist.Normal(a_bar, sigma).sample(random.PRNGKey(5005), (nponds,))
Code 13.9¶
In [9]:
dsim = pd.DataFrame(dict(pond=range(1, nponds + 1), Ni=Ni, true_a=a_pond))
Code 13.10¶
In [10]:
print(type(range(3)))
print(type(jnp.arange(3)))
Out[10]:
Code 13.11¶
In [11]:
dsim["Si"] = dist.Binomial(dsim.Ni.values, logits=dsim.true_a.values).sample(
random.PRNGKey(0)
)
Code 13.12¶
In [12]:
dsim["p_nopool"] = dsim.Si / dsim.Ni
Code 13.13¶
In [13]:
dat = dict(Si=dsim.Si.values, Ni=dsim.Ni.values, pond=dsim.pond.values - 1)
def model(pond, Ni, Si):
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
a_pond = numpyro.sample(
"a_pond", dist.Normal(a_bar, sigma), sample_shape=pond.shape
)
logit_p = a_pond[pond]
numpyro.sample("Si", dist.Binomial(Ni, logits=logit_p), obs=Si)
m13_3 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_3.run(random.PRNGKey(0), **dat)
Out[13]:
Out[13]:
Out[13]:
Out[13]:
Code 13.14¶
In [14]:
m13_3.print_summary(0.89)
Out[14]:
Code 13.15¶
In [15]:
post = m13_3.get_samples()
dsim["p_partpool"] = jnp.mean(expit(post["a_pond"]), 0)
Code 13.16¶
In [16]:
dsim["p_true"] = expit(dsim.true_a.values)
Code 13.17¶
In [17]:
nopool_error = (dsim.p_nopool - dsim.p_true).abs()
partpool_error = (dsim.p_partpool - dsim.p_true).abs()
Code 13.18¶
In [18]:
plt.scatter(range(1, 61), nopool_error, label="nopool", alpha=0.8)
plt.gca().set(xlabel="pond", ylabel="absolute error")
plt.scatter(
range(1, 61),
partpool_error,
label="partpool",
s=50,
edgecolor="black",
facecolor="none",
)
plt.legend()
plt.show()
Out[18]:
Code 13.19¶
In [19]:
dsim["nopool_error"] = nopool_error
dsim["partpool_error"] = partpool_error
nopool_avg = dsim.groupby("Ni")["nopool_error"].mean()
partpool_avg = dsim.groupby("Ni")["partpool_error"].mean()
Code 13.20¶
In [20]:
a_bar = 1.5
sigma = 1.5
nponds = 60
Ni = jnp.repeat(jnp.array([5, 10, 25, 35]), repeats=15)
a_pond = dist.Normal(a_bar, sigma).sample(random.PRNGKey(5006), (nponds,))
dsim = pd.DataFrame(dict(pond=range(1, nponds + 1), Ni=Ni, true_a=a_pond))
dsim["Si"] = dist.Binomial(dsim.Ni.values, logits=dsim.true_a.values).sample(
random.PRNGKey(0)
)
dsim["p_nopool"] = dsim.Si / dsim.Ni
newdat = dict(Si=dsim.Si.values, Ni=dsim.Ni.values, pond=dsim.pond.values - 1)
m13_3new = MCMC(
NUTS(m13_3.sampler.model), num_warmup=1000, num_samples=1000, num_chains=4
)
m13_3new.run(random.PRNGKey(0), **newdat)
post = m13_3new.get_samples()
dsim["p_partpool"] = jnp.mean(expit(post["a_pond"]), 0)
dsim["p_true"] = expit(dsim.true_a.values)
nopool_error = (dsim.p_nopool - dsim.p_true).abs()
partpool_error = (dsim.p_partpool - dsim.p_true).abs()
plt.scatter(range(1, 61), nopool_error, label="nopool", alpha=0.8)
plt.gca().set(xlabel="pond", ylabel="absolute error")
plt.scatter(
range(1, 61),
partpool_error,
label="partpool",
s=50,
edgecolor="black",
facecolor="none",
)
plt.legend()
plt.show()
Out[20]:
Out[20]:
Out[20]:
Out[20]:
Out[20]:
Code 13.21¶
In [21]:
chimpanzees = pd.read_csv("../data/chimpanzees.csv", sep=";")
d = chimpanzees
d["treatment"] = 1 + d.prosoc_left + 2 * d.condition
dat_list = dict(
pulled_left=d.pulled_left.values,
actor=d.actor.values - 1,
block_id=d.block.values - 1,
treatment=d.treatment.values - 1,
)
def model(actor, block_id, treatment, pulled_left=None, link=False):
# hyper-priors
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
# adaptive priors
a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
g = numpyro.sample("g", dist.Normal(0, sigma_g), sample_shape=(6,))
b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
logit_p = a[actor] + g[block_id] + b[treatment]
if link:
numpyro.deterministic("p", expit(logit_p))
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m13_4 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_4.run(random.PRNGKey(0), **dat_list)
print("Number of divergences:", m13_4.get_extra_fields()["diverging"].sum())
Out[21]:
Out[21]:
Out[21]:
Out[21]:
Out[21]:
Code 13.22¶
In [22]:
m13_4.print_summary()
post = m13_4.get_samples(group_by_chain=True)
az.plot_forest(post, combined=True, hdi_prob=0.89) # also plot
plt.show()
Out[22]:
Out[22]:
Code 13.23¶
In [23]:
def model(actor, treatment, pulled_left):
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
logit_p = a[actor] + b[treatment]
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m13_5 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_5.run(
random.PRNGKey(14),
dat_list["actor"],
dat_list["treatment"],
dat_list["pulled_left"],
)
Out[23]:
Out[23]:
Out[23]:
Out[23]:
Code 13.24¶
In [24]:
az.compare(
{"m13.4": az.from_numpyro(m13_4), "m13.5": az.from_numpyro(m13_5)},
ic="waic",
scale="deviance",
)
Out[24]:
Out[24]:
Code 13.25¶
In [25]:
def model(actor, block_id, treatment, pulled_left):
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
sigma_b = numpyro.sample("sigma_b", dist.Exponential(1))
a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
g = numpyro.sample("g", dist.Normal(0, sigma_g), sample_shape=(6,))
b = numpyro.sample("b", dist.Normal(0, sigma_b), sample_shape=(4,))
logit_p = a[actor] + g[block_id] + b[treatment]
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m13_6 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_6.run(random.PRNGKey(16), **dat_list)
print("Number of divergences:", m13_6.get_extra_fields()["diverging"].sum())
{
"m13.4": jnp.mean(m13_4.get_samples()["b"], 0),
"m13.6": jnp.mean(m13_6.get_samples()["b"], 0),
}
Out[25]:
Out[25]:
Out[25]:
Out[25]:
Out[25]:
Out[25]:
Code 13.26¶
In [26]:
def model():
v = numpyro.sample("v", dist.Normal(0, 3))
x = numpyro.sample("x", dist.Normal(0, jnp.exp(v)))
m13_7 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_7.run(random.PRNGKey(0))
m13_7.print_summary()
Out[26]:
Out[26]:
Out[26]:
Out[26]:
Out[26]:
Code 13.27¶
In [27]:
def model():
v = numpyro.sample("v", dist.Normal(0, 3))
z = numpyro.sample("z", dist.Normal(0, 1))
numpyro.deterministic("x", z * jnp.exp(v))
m13_7nc = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_7nc.run(random.PRNGKey(0))
m13_7nc.print_summary(exclude_deterministic=False)
Out[27]:
Out[27]:
Out[27]:
Out[27]:
Out[27]:
Code 13.28¶
In [28]:
m13_4b = MCMC(
NUTS(m13_4.sampler.model, target_accept_prob=0.99),
num_warmup=500,
num_samples=500,
num_chains=4,
)
m13_4b.run(random.PRNGKey(13), **dat_list)
jnp.sum(m13_4b.get_extra_fields()["diverging"])
Out[28]:
Out[28]:
Out[28]:
Out[28]:
Out[28]:
Code 13.29¶
In [29]:
def model(actor, block_id, treatment, pulled_left):
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
z = numpyro.sample("z", dist.Normal(0, 1), sample_shape=(7,))
x = numpyro.sample("x", dist.Normal(0, 1), sample_shape=(6,))
b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
logit_p = a_bar + z[actor] * sigma_a + x[block_id] * sigma_g + b[treatment]
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m13_4nc = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_4nc.run(random.PRNGKey(16), **dat_list)
print("Number of divergences:", m13_4nc.get_extra_fields()["diverging"].sum())
Out[29]:
Out[29]:
Out[29]:
Out[29]:
Out[29]:
Code 13.30¶
In [30]:
neff_c = {
k: effective_sample_size(v)
for k, v in m13_4.get_samples(group_by_chain=True).items()
}
neff_nc = {
k: effective_sample_size(v)
for k, v in m13_4nc.get_samples(group_by_chain=True).items()
}
par_names = []
keys_c = ["b", "a", "g", "a_bar", "sigma_a", "sigma_g"]
keys_nc = ["b", "z", "x", "a_bar", "sigma_a", "sigma_g"]
for k in keys_c:
if jnp.ndim(neff_c[k]) == 0:
par_names += [k]
else:
par_names += [k + "[{}]".format(i) for i in range(neff_c[k].size)]
neff_c = jnp.concatenate([neff_c[k].reshape(-1) for k in keys_c])
neff_nc = jnp.concatenate([neff_nc[k].reshape(-1) for k in keys_nc])
neff_table = pd.DataFrame(dict(neff_c=neff_c, neff_nc=neff_nc))
neff_table.index = par_names
neff_table.round()
Out[30]:
Code 13.31¶
In [31]:
chimp = 2
d_pred = dict(
actor=jnp.repeat(chimp, 4) - 1,
treatment=jnp.arange(4),
block_id=jnp.repeat(1, 4) - 1,
)
p = Predictive(m13_4.sampler.model, m13_4.get_samples())(
random.PRNGKey(0), link=True, **d_pred
)["p"]
p_mu = jnp.mean(p, 0)
p_ci = jnp.percentile(p, q=jnp.array([5.5, 94.5]), axis=0)
Code 13.32¶
In [32]:
post = m13_4.get_samples()
{k: v.reshape(-1)[:5] for k, v in post.items()}
Out[32]:
Code 13.33¶
In [33]:
az.plot_kde(post["a"][:, 4])
plt.show()
Out[33]:
Code 13.34¶
In [34]:
def p_link(treatment, actor=0, block_id=0):
a, g, b = post["a"], post["g"], post["b"]
logodds = a[:, actor] + g[:, block_id] + b[:, treatment]
return expit(logodds)
Code 13.35¶
In [35]:
p_raw = lax.map(lambda i: p_link(i, actor=1, block_id=0), jnp.arange(4))
p_mu = jnp.mean(p_raw, 0)
p_ci = jnp.percentile(p_raw, jnp.array([5.5, 94.5]), 0)
Code 13.36¶
In [36]:
def p_link_abar(treatment):
logodds = post["a_bar"] + post["b"][:, treatment]
return expit(logodds)
Code 13.37¶
In [37]:
p_raw = lax.map(p_link_abar, jnp.arange(4))
p_mu = jnp.mean(p_raw, 1)
p_ci = jnp.percentile(p_raw, jnp.array([5.5, 94.5]), 1)
plt.subplot(
xlabel="treatment", ylabel="proportion pulled left", ylim=(0, 1), xlim=(0.9, 4.1)
)
plt.gca().set(xticks=range(1, 5), xticklabels=["R/N", "L/N", "R/P", "L/P"])
plt.plot(range(1, 5), p_mu)
plt.fill_between(range(1, 5), p_ci[0], p_ci[1], color="k", alpha=0.2)
plt.show()
Out[37]:
Code 13.38¶
In [38]:
a_sim = dist.Normal(post["a_bar"], post["sigma_a"]).sample(random.PRNGKey(0))
def p_link_asim(treatment):
logodds = a_sim + post["b"][:, treatment]
return expit(logodds)
p_raw_asim = lax.map(p_link_asim, jnp.arange(4))
Code 13.39¶
In [39]:
plt.subplot(
xlabel="treatment", ylabel="proportion pulled left", ylim=(0, 1), xlim=(0.9, 4.1)
)
plt.gca().set(xticks=range(1, 5), xticklabels=["R/N", "L/N", "R/P", "L/P"])
for i in range(100):
plt.plot(range(1, 5), p_raw_asim[:, i], color="k", alpha=0.25)
Out[39]:
Code 13.40¶
In [40]:
bangladesh = pd.read_csv("../data/bangladesh.csv", sep=";")
d = bangladesh
jnp.sort(d.district.unique())
Out[40]:
Code 13.41¶
In [41]:
d["district_id"] = d.district.astype("category").cat.codes
jnp.sort(d.district_id.unique())
Out[41]:
Comments
Comments powered by Disqus