Chapter 8. Conditional Manatees
In [ ]:
!pip install -q numpyro arviz
In [0]:
import math
import os
import warnings
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO, log_likelihood
from numpyro.infer.autoguide import AutoLaplaceApproximation
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
Code 8.1¶
In [1]:
rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged
# make log version of outcome
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)
# extract countries with GDP data
dd = d[d["rgdppc_2000"].notnull()].copy()
# rescale variables
dd["log_gdp_std"] = dd.log_gdp / dd.log_gdp.mean()
dd["rugged_std"] = dd.rugged / dd.rugged.max()
Code 8.2¶
In [2]:
def model(rugged_std, log_gdp_std=None):
a = numpyro.sample("a", dist.Normal(1, 1))
b = numpyro.sample("b", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + b * (rugged_std - 0.215))
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m8_1 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_1,
optim.Adam(0.1),
Trace_ELBO(),
rugged_std=dd.rugged_std.values,
log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_1 = svi_result.params
Out[2]:
Code 8.3¶
In [3]:
predictive = Predictive(m8_1.model, num_samples=1000, return_sites=["a", "b", "sigma"])
prior = predictive(random.PRNGKey(7), rugged_std=0)
# set up the plot dimensions
plt.subplot(xlim=(0, 1), ylim=(0.5, 1.5), xlabel="ruggedness", ylabel="log GDP")
plt.gca().axhline(dd.log_gdp_std.min(), ls="--")
plt.gca().axhline(dd.log_gdp_std.max(), ls="--")
# draw 50 lines from the prior
rugged_seq = jnp.linspace(-0.1, 1.1, num=30)
mu = Predictive(m8_1.model, prior, return_sites=["mu"])(
random.PRNGKey(7), rugged_std=rugged_seq
)["mu"]
for i in range(50):
plt.plot(rugged_seq, mu[i], "k", alpha=0.3)
Out[3]:
Code 8.4¶
In [4]:
jnp.sum(jnp.abs(prior["b"]) > 0.6) / prior["b"].shape[0]
Out[4]:
Code 8.5¶
In [5]:
def model(rugged_std, log_gdp_std=None):
a = numpyro.sample("a", dist.Normal(1, 0.1))
b = numpyro.sample("b", dist.Normal(0, 0.3))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + b * (rugged_std - 0.215))
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m8_1 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_1,
optim.Adam(0.1),
Trace_ELBO(),
rugged_std=dd.rugged_std.values,
log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_1 = svi_result.params
Out[5]:
Code 8.6¶
In [6]:
post = m8_1.sample_posterior(random.PRNGKey(1), p8_1, sample_shape=(1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
Out[6]:
Code 8.7¶
In [7]:
# make variable to index Africa (0) or not (1)
dd["cid"] = jnp.where(dd.cont_africa.values == 1, 0, 1)
Code 8.8¶
In [8]:
def model(cid, rugged_std, log_gdp_std=None):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a[cid] + b * (rugged_std - 0.215))
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m8_2 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_2,
optim.Adam(0.1),
Trace_ELBO(),
cid=dd.cid.values,
rugged_std=dd.rugged_std.values,
log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_2 = svi_result.params
Out[8]:
Code 8.9¶
In [9]:
post = m8_1.sample_posterior(random.PRNGKey(2), p8_1, sample_shape=(1000,))
logprob = log_likelihood(
m8_1.model, post, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values
)
az8_1 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_2.sample_posterior(random.PRNGKey(2), p8_2, sample_shape=(1000,))
logprob = log_likelihood(
m8_2.model,
post,
rugged_std=dd.rugged_std.values,
cid=dd.cid.values,
log_gdp_std=dd.log_gdp_std.values,
)
az8_2 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
az.compare({"m8.1": az8_1, "m8.2": az8_2}, ic="waic", scale="deviance")
Out[9]:
Out[9]:
Code 8.10¶
In [10]:
post = m8_2.sample_posterior(random.PRNGKey(1), p8_2, sample_shape=(1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
Out[10]:
Code 8.11¶
In [11]:
post = m8_2.sample_posterior(random.PRNGKey(1), p8_2, sample_shape=(1000,))
diff_a1_a2 = post["a"][:, 0] - post["a"][:, 1]
jnp.percentile(diff_a1_a2, q=jnp.array([5.5, 94.5]))
Out[11]:
Code 8.12¶
In [12]:
rugged_seq = jnp.linspace(start=-1, stop=1.1, num=30)
# compute mu over samples, fixing cid=1
post.pop("mu")
predictive = Predictive(m8_2.model, post, return_sites=["mu"])
mu_NotAfrica = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]
# compute mu over samples, fixing cid=0
mu_Africa = predictive(random.PRNGKey(2), cid=0, rugged_std=rugged_seq)["mu"]
# summarize to means and intervals
mu_NotAfrica_mu = jnp.mean(mu_NotAfrica, 0)
mu_NotAfrica_ci = jnp.percentile(mu_NotAfrica, q=jnp.array([1.5, 98.5]), axis=0)
mu_Africa_mu = jnp.mean(mu_Africa, 0)
mu_Africa_ci = jnp.percentile(mu_Africa, q=jnp.array([1.5, 98.5]), axis=0)
Code 8.13¶
In [13]:
def model(cid, rugged_std, log_gdp_std=None):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a[cid] + b[cid] * (rugged_std - 0.215))
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m8_3 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_3,
optim.Adam(0.1),
Trace_ELBO(),
cid=dd.cid.values,
rugged_std=dd.rugged_std.values,
log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_3 = svi_result.params
Out[13]:
Code 8.14¶
In [14]:
post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, sample_shape=(1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
Out[14]:
Code 8.15¶
In [15]:
post = m8_1.sample_posterior(random.PRNGKey(2), p8_1, sample_shape=(1000,))
logprob = log_likelihood(
m8_1.model, post, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values
)
az8_1 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_2.sample_posterior(random.PRNGKey(2), p8_2, sample_shape=(1000,))
logprob = log_likelihood(
m8_2.model,
post,
rugged_std=dd.rugged_std.values,
cid=dd.cid.values,
log_gdp_std=dd.log_gdp_std.values,
)
az8_3 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_3.sample_posterior(random.PRNGKey(2), p8_3, sample_shape=(1000,))
logprob = log_likelihood(
m8_3.model,
post,
rugged_std=dd.rugged_std.values,
cid=dd.cid.values,
log_gdp_std=dd.log_gdp_std.values,
)
az8_3 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
az.compare({"m8.1": az8_1, "m8.2": az8_2, "m8.3": az8_3}, ic="waic", scale="deviance")
Out[15]:
Out[15]:
Code 8.16¶
In [16]:
waic_list = az.waic(az8_3, pointwise=True, scale="deviance").waic_i.values
Out[16]:
Code 8.17¶
In [17]:
# plot non-Africa - cid=1
d_A0 = dd[dd["cid"] == 1]
az.plot_pair(d_A0[["rugged_std", "log_gdp_std"]].to_dict(orient="list"))
plt.gca().set(
xlim=(-0.01, 1.01),
xlabel="ruggedness (standardized)",
ylabel="log GDP (as proportion of mean)",
)
mu = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]
mu_mean = jnp.mean(mu, 0)
mu_ci = jnp.percentile(mu, q=jnp.array([1.5, 98.5]), axis=0)
plt.plot(rugged_seq, mu_mean, "k")
plt.fill_between(rugged_seq, mu_ci[0], mu_ci[1], color="k", alpha=0.2)
plt.title("Non-African nations")
plt.show()
Out[17]:
Code 8.18¶
In [18]:
rugged_seq = jnp.linspace(start=-0.2, stop=1.2, num=30)
post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, sample_shape=(1000,))
post.pop("mu")
predictive = Predictive(m8_3.model, post, return_sites=["mu"])
muA = predictive(random.PRNGKey(2), cid=0, rugged_std=rugged_seq)["mu"]
muN = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]
delta = muA - muN
Code 8.19¶
In [19]:
tulips = pd.read_csv("../data/tulips.csv", sep=";")
d = tulips
d.info()
d.head()
Out[19]:
Out[19]:
Code 8.20¶
In [20]:
d["blooms_std"] = d.blooms / d.blooms.max()
d["water_cent"] = d.water - d.water.mean()
d["shade_cent"] = d.shade - d.shade.mean()
Code 8.21¶
In [21]:
a = dist.Normal(0.5, 1).sample(random.PRNGKey(0), (int(1e4),))
jnp.sum((a < 0) | (a > 1)) / a.shape[0]
Out[21]:
Code 8.22¶
In [22]:
a = dist.Normal(0.5, 0.25).sample(random.PRNGKey(0), (int(1e4),))
jnp.sum((a < 0) | (a > 1)) / a.shape[0]
Out[22]:
Code 8.23¶
In [23]:
def model(water_cent, shade_cent, blooms_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 0.25))
bw = numpyro.sample("bw", dist.Normal(0, 0.25))
bs = numpyro.sample("bs", dist.Normal(0, 0.25))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bw * water_cent + bs * shade_cent)
numpyro.sample("blooms_std", dist.Normal(mu, sigma), obs=blooms_std)
m8_4 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_4,
optim.Adam(1),
Trace_ELBO(),
shade_cent=d.shade_cent.values,
water_cent=d.water_cent.values,
blooms_std=d.blooms_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_4 = svi_result.params
Out[23]:
Code 8.24¶
In [24]:
def model(water_cent, shade_cent, blooms_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 0.25))
bw = numpyro.sample("bw", dist.Normal(0, 0.25))
bs = numpyro.sample("bs", dist.Normal(0, 0.25))
bws = numpyro.sample("bws", dist.Normal(0, 0.25))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bw * water_cent + bs * shade_cent + bws * water_cent * shade_cent
numpyro.sample("blooms_std", dist.Normal(mu, sigma), obs=blooms_std)
m8_5 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_5,
optim.Adam(1),
Trace_ELBO(),
shade_cent=d.shade_cent.values,
water_cent=d.water_cent.values,
blooms_std=d.blooms_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_5 = svi_result.params
Out[24]:
Code 8.25¶
In [25]:
_, axes = plt.subplots(1, 3, figsize=(9, 3), sharey=True) # 3 plots in 1 row
for ax, s in zip(axes, range(-1, 2)):
idx = d.shade_cent == s
ax.scatter(d.water_cent[idx], d.blooms_std[idx])
ax.set(xlim=(-1.1, 1.1), ylim=(-0.1, 1.1), xlabel="water", ylabel="blooms")
post = m8_4.sample_posterior(random.PRNGKey(1), p8_4, sample_shape=(1000,))
post.pop("mu")
mu = Predictive(m8_4.model, post, return_sites=["mu"])(
random.PRNGKey(2), shade_cent=s, water_cent=jnp.arange(-1, 2)
)["mu"]
for i in range(20):
ax.plot(range(-1, 2), mu[i], "k", alpha=0.3)
Out[25]:
Code 8.26¶
In [26]:
predictive = Predictive(
m8_5.model, num_samples=1000, return_sites=["a", "bw", "bs", "bws", "sigma"]
)
prior = predictive(random.PRNGKey(7), water_cent=0, shade_cent=0)
Code 8.27¶
In [27]:
nettle = pd.read_csv("../data/nettle.csv", sep=";")
d = nettle
d["lang.per.cap"] = d["num.lang"] / d["k.pop"]
Comments
Comments powered by Disqus