Chapter 8. Conditional Manatees

In [0]:
import math
import os
import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd

import jax.numpy as jnp
from jax import lax, random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import ELBO, SVI, Predictive, log_likelihood
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
    category.__name__, message
)
az.style.use("arviz-darkgrid")

Code 8.1

In [1]:
rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged

# make log version of outcome
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)

# extract countries with GDP data
dd = d[d["rgdppc_2000"].notnull()].copy()

# rescale variables
dd["log_gdp_std"] = dd.log_gdp / dd.log_gdp.mean()
dd["rugged_std"] = dd.rugged / dd.rugged.max()

Code 8.2

In [2]:
def model(rugged_std, log_gdp_std=None):
    a = numpyro.sample("a", dist.Normal(1, 1))
    b = numpyro.sample("b", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + b * (rugged_std - 0.215))
    numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)


m8_1 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_1,
    optim.Adam(0.1),
    ELBO(),
    rugged_std=dd.rugged_std.values,
    log_gdp_std=dd.log_gdp_std.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p8_1 = svi.get_params(state)

Code 8.3

In [3]:
predictive = Predictive(m8_1.model, num_samples=1000, return_sites=["a", "b", "sigma"])
prior = predictive(random.PRNGKey(7), rugged_std=0)

# set up the plot dimensions
plt.subplot(xlim=(0, 1), ylim=(0.5, 1.5), xlabel="ruggedness", ylabel="log GDP")
plt.gca().axhline(dd.log_gdp_std.min(), ls="--")
plt.gca().axhline(dd.log_gdp_std.max(), ls="--")

# draw 50 lines from the prior
rugged_seq = jnp.linspace(-0.1, 1.1, num=30)
mu = Predictive(m8_1.model, prior, return_sites=["mu"])(
    random.PRNGKey(7), rugged_std=rugged_seq
)["mu"]
for i in range(50):
    plt.plot(rugged_seq, mu[i], "k", alpha=0.3)
Out[3]:

Code 8.4

In [4]:
jnp.sum(jnp.abs(prior["b"]) > 0.6) / prior["b"].shape[0]
Out[4]:
DeviceArray(0.564, dtype=float32)

Code 8.5

In [5]:
def model(rugged_std, log_gdp_std=None):
    a = numpyro.sample("a", dist.Normal(1, 0.1))
    b = numpyro.sample("b", dist.Normal(0, 0.3))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + b * (rugged_std - 0.215))
    numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)


m8_1 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_1,
    optim.Adam(0.1),
    ELBO(),
    rugged_std=dd.rugged_std.values,
    log_gdp_std=dd.log_gdp_std.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p8_1 = svi.get_params(state)

Code 8.6

In [6]:
post = m8_1.sample_posterior(random.PRNGKey(1), p8_1, (1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
Out[6]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      1.00      0.01      1.00      0.98      1.02    931.50      1.00
         b      0.00      0.06      0.00     -0.08      0.10   1111.63      1.00
     sigma      0.14      0.01      0.14      0.13      0.15    949.29      1.00

Code 8.7

In [7]:
# make variable to index Africa (0) or not (1)
dd["cid"] = jnp.where(dd.cont_africa.values == 1, 0, 1)

Code 8.8

In [8]:
def model(cid, rugged_std, log_gdp_std=None):
    a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
    b = numpyro.sample("b", dist.Normal(0, 0.3))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a[cid] + b * (rugged_std - 0.215))
    numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)


m8_2 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_2,
    optim.Adam(0.1),
    ELBO(),
    cid=dd.cid.values,
    rugged_std=dd.rugged_std.values,
    log_gdp_std=dd.log_gdp_std.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p8_2 = svi.get_params(state)

Code 8.9

In [9]:
post = m8_1.sample_posterior(random.PRNGKey(2), p8_1, (1000,))
logprob = log_likelihood(
    m8_1.model, post, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values
)
az8_1 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_2.sample_posterior(random.PRNGKey(2), p8_2, (1000,))
logprob = log_likelihood(
    m8_2.model,
    post,
    rugged_std=dd.rugged_std.values,
    cid=dd.cid.values,
    log_gdp_std=dd.log_gdp_std.values,
)
az8_2 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
az.compare({"m8.1": az8_1, "m8.2": az8_2}, ic="waic", scale="deviance")
Out[9]:
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
Out[9]:
rank waic p_waic d_waic weight se dse warning waic_scale
m8.2 0 -252.36 4.15389 0 0.999156 13.0673 0 True deviance
m8.1 1 -188.818 2.65329 63.5418 0.000843778 15.4119 14.9593 False deviance

Code 8.10

In [10]:
post = m8_2.sample_posterior(random.PRNGKey(1), p8_2, (1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
Out[10]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]      0.88      0.02      0.88      0.86      0.90   1049.96      1.00
      a[1]      1.05      0.01      1.05      1.03      1.07    824.00      1.00
         b     -0.05      0.05     -0.05     -0.13      0.02    999.08      1.00
     sigma      0.11      0.01      0.11      0.10      0.12    961.35      1.00

Code 8.11

In [11]:
post = m8_2.sample_posterior(random.PRNGKey(1), p8_2, (1000,))
diff_a1_a2 = post["a"][:, 0] - post["a"][:, 1]
jnp.percentile(diff_a1_a2, q=(5.5, 94.5))
Out[11]:
DeviceArray([-0.19982037, -0.139674  ], dtype=float32)

Code 8.12

In [12]:
rugged_seq = jnp.linspace(start=-1, stop=1.1, num=30)

# compute mu over samples, fixing cid=1
predictive = Predictive(m8_2.model, post, return_sites=["mu"])
mu_NotAfrica = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]

# compute mu over samples, fixing cid=0
mu_Africa = predictive(random.PRNGKey(2), cid=0, rugged_std=rugged_seq)["mu"]

# summarize to means and intervals
mu_NotAfrica_mu = jnp.mean(mu_NotAfrica, 0)
mu_NotAfrica_ci = jnp.percentile(mu_NotAfrica, q=(1.5, 98.5), axis=0)
mu_Africa_mu = jnp.mean(mu_Africa, 0)
mu_Africa_ci = jnp.percentile(mu_Africa, q=(1.5, 98.5), axis=0)

Code 8.13

In [13]:
def model(cid, rugged_std, log_gdp_std=None):
    a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
    b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a[cid] + b[cid] * (rugged_std - 0.215))
    numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)


m8_3 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_3,
    optim.Adam(0.1),
    ELBO(),
    cid=dd.cid.values,
    rugged_std=dd.rugged_std.values,
    log_gdp_std=dd.log_gdp_std.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p8_3 = svi.get_params(state)

Code 8.14

In [14]:
post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, (1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
Out[14]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]      0.89      0.02      0.89      0.86      0.91   1009.20      1.00
      a[1]      1.05      0.01      1.05      1.04      1.07    755.33      1.00
      b[0]      0.13      0.07      0.13      0.01      0.24   1045.06      1.00
      b[1]     -0.15      0.06     -0.14     -0.23     -0.05   1003.36      1.00
     sigma      0.11      0.01      0.11      0.10      0.12    810.01      1.00

Code 8.15

In [15]:
post = m8_1.sample_posterior(random.PRNGKey(2), p8_1, (1000,))
logprob = log_likelihood(
    m8_1.model, post, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values
)
az8_1 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_2.sample_posterior(random.PRNGKey(2), p8_2, (1000,))
logprob = log_likelihood(
    m8_2.model,
    post,
    rugged_std=dd.rugged_std.values,
    cid=dd.cid.values,
    log_gdp_std=dd.log_gdp_std.values,
)
az8_3 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_3.sample_posterior(random.PRNGKey(2), p8_3, (1000,))
logprob = log_likelihood(
    m8_3.model,
    post,
    rugged_std=dd.rugged_std.values,
    cid=dd.cid.values,
    log_gdp_std=dd.log_gdp_std.values,
)
az8_3 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
az.compare({"m8.1": az8_1, "m8.2": az8_2, "m8.3": az8_3}, ic="waic", scale="deviance")
Out[15]:
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
Out[15]:
rank waic p_waic d_waic weight se dse warning waic_scale
m8.3 0 -259.176 5.10349 0 0.830774 13.3606 0 True deviance
m8.2 1 -252.36 4.15389 6.81647 0.169225 14.9721 6.6764 True deviance
m8.1 2 -188.818 2.65329 70.3583 3.07569e-07 14.9125 15.3428 False deviance

Code 8.16

In [16]:
waic_list = az.waic(az8_3, pointwise=True, scale="deviance").waic_i.values
Out[16]:
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details

Code 8.17

In [17]:
# plot non-Africa - cid=1
d_A0 = dd[dd["cid"] == 1]
az.plot_pair(d_A0[["rugged_std", "log_gdp_std"]].to_dict(orient="list"))
plt.gca().set(
    xlim=(-0.01, 1.01),
    xlabel="ruggedness (standardized)",
    ylabel="log GDP (as proportion of mean)",
)
mu = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]
mu_mean = jnp.mean(mu, 0)
mu_ci = jnp.percentile(mu, q=(1.5, 98.5), axis=0)
plt.plot(rugged_seq, mu_mean, "k")
plt.fill_between(rugged_seq, mu_ci[0], mu_ci[1], color="k", alpha=0.2)
plt.title("Non-African nations")
plt.show()
Out[17]:

Code 8.18

In [18]:
rugged_seq = jnp.linspace(start=-0.2, stop=1.2, num=30)
post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, (1000,))
predictive = Predictive(m8_3.model, post, return_sites=["mu"])
muA = predictive(random.PRNGKey(2), cid=0, rugged_std=rugged_seq)["mu"]
muN = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]
delta = muA - muN

Code 8.19

In [19]:
tulips = pd.read_csv("../data/tulips.csv", sep=";")
d = tulips
d.info()
d.head()
Out[19]:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 27 entries, 0 to 26
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   bed     27 non-null     object 
 1   water   27 non-null     int64  
 2   shade   27 non-null     int64  
 3   blooms  27 non-null     float64
dtypes: float64(1), int64(2), object(1)
memory usage: 992.0+ bytes
Out[19]:
bed water shade blooms
0 a 1 1 0.00
1 a 1 2 0.00
2 a 1 3 111.04
3 a 2 1 183.47
4 a 2 2 59.16

Code 8.20

In [20]:
d["blooms_std"] = d.blooms / d.blooms.max()
d["water_cent"] = d.water - d.water.mean()
d["shade_cent"] = d.shade - d.shade.mean()

Code 8.21

In [21]:
a = dist.Normal(0.5, 1).sample(random.PRNGKey(0), (int(1e4),))
jnp.sum((a < 0) | (a > 1)) / a.shape[0]
Out[21]:
DeviceArray(0.6182, dtype=float32)

Code 8.22

In [22]:
a = dist.Normal(0.5, 0.25).sample(random.PRNGKey(0), (int(1e4),))
jnp.sum((a < 0) | (a > 1)) / a.shape[0]
Out[22]:
DeviceArray(0.0471, dtype=float32)

Code 8.23

In [23]:
def model(water_cent, shade_cent, blooms_std=None):
    a = numpyro.sample("a", dist.Normal(0.5, 0.25))
    bw = numpyro.sample("bw", dist.Normal(0, 0.25))
    bs = numpyro.sample("bs", dist.Normal(0, 0.25))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bw * water_cent + bs * shade_cent)
    numpyro.sample("blooms_std", dist.Normal(mu, sigma), obs=blooms_std)


m8_4 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_4,
    optim.Adam(1),
    ELBO(),
    shade_cent=d.shade_cent.values,
    water_cent=d.water_cent.values,
    blooms_std=d.blooms_std.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p8_4 = svi.get_params(state)

Code 8.24

In [24]:
def model(water_cent, shade_cent, blooms_std=None):
    a = numpyro.sample("a", dist.Normal(0.5, 0.25))
    bw = numpyro.sample("bw", dist.Normal(0, 0.25))
    bs = numpyro.sample("bs", dist.Normal(0, 0.25))
    bws = numpyro.sample("bws", dist.Normal(0, 0.25))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bw * water_cent + bs * shade_cent + bws * water_cent * shade_cent
    numpyro.sample("blooms_std", dist.Normal(mu, sigma), obs=blooms_std)


m8_5 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_5,
    optim.Adam(1),
    ELBO(),
    shade_cent=d.shade_cent.values,
    water_cent=d.water_cent.values,
    blooms_std=d.blooms_std.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p8_5 = svi.get_params(state)

Code 8.25

In [25]:
_, axes = plt.subplots(1, 3, figsize=(9, 3), sharey=True)  # 3 plots in 1 row
for ax, s in zip(axes, range(-1, 2)):
    idx = d.shade_cent == s
    ax.scatter(d.water_cent[idx], d.blooms_std[idx])
    ax.set(xlim=(-1.1, 1.1), ylim=(-0.1, 1.1), xlabel="water", ylabel="blooms")
    post = m8_4.sample_posterior(random.PRNGKey(1), p8_4, (1000,))
    mu = Predictive(m8_4.model, post, return_sites=["mu"])(
        random.PRNGKey(2), shade_cent=s, water_cent=jnp.arange(-1, 2)
    )["mu"]
    for i in range(20):
        ax.plot(range(-1, 2), mu[i], "k", alpha=0.3)
Out[25]:

Code 8.26

In [26]:
predictive = Predictive(
    m8_5.model, num_samples=1000, return_sites=["a", "bw", "bs", "bws", "sigma"]
)
prior = predictive(random.PRNGKey(7), water_cent=0, shade_cent=0)

Code 8.27

In [27]:
nettle = pd.read_csv("../data/nettle.csv", sep=";")
d = nettle
d["lang.per.cap"] = d["num.lang"] / d["k.pop"]

Comments

Comments powered by Disqus