Chapter 8. Conditional Manatees

In [ ]:
!pip install -q numpyro arviz
In [0]:
import math
import os
import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO, log_likelihood
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
    category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")

Code 8.1

In [1]:
rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged

# make log version of outcome
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)

# extract countries with GDP data
dd = d[d["rgdppc_2000"].notnull()].copy()

# rescale variables
dd["log_gdp_std"] = dd.log_gdp / dd.log_gdp.mean()
dd["rugged_std"] = dd.rugged / dd.rugged.max()

Code 8.2

In [2]:
def model(rugged_std, log_gdp_std=None):
    a = numpyro.sample("a", dist.Normal(1, 1))
    b = numpyro.sample("b", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + b * (rugged_std - 0.215))
    numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)


m8_1 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_1,
    optim.Adam(0.1),
    Trace_ELBO(),
    rugged_std=dd.rugged_std.values,
    log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_1 = svi_result.params
Out[2]:
100%|██████████| 1000/1000 [00:01<00:00, 944.10it/s, init loss: 810.4496, avg. loss [951-1000]: -91.4252]

Code 8.3

In [3]:
predictive = Predictive(m8_1.model, num_samples=1000, return_sites=["a", "b", "sigma"])
prior = predictive(random.PRNGKey(7), rugged_std=0)

# set up the plot dimensions
plt.subplot(xlim=(0, 1), ylim=(0.5, 1.5), xlabel="ruggedness", ylabel="log GDP")
plt.gca().axhline(dd.log_gdp_std.min(), ls="--")
plt.gca().axhline(dd.log_gdp_std.max(), ls="--")

# draw 50 lines from the prior
rugged_seq = jnp.linspace(-0.1, 1.1, num=30)
mu = Predictive(m8_1.model, prior, return_sites=["mu"])(
    random.PRNGKey(7), rugged_std=rugged_seq
)["mu"]
for i in range(50):
    plt.plot(rugged_seq, mu[i], "k", alpha=0.3)
Out[3]:
No description has been provided for this image

Code 8.4

In [4]:
jnp.sum(jnp.abs(prior["b"]) > 0.6) / prior["b"].shape[0]
Out[4]:
DeviceArray(0.564, dtype=float32)

Code 8.5

In [5]:
def model(rugged_std, log_gdp_std=None):
    a = numpyro.sample("a", dist.Normal(1, 0.1))
    b = numpyro.sample("b", dist.Normal(0, 0.3))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + b * (rugged_std - 0.215))
    numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)


m8_1 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_1,
    optim.Adam(0.1),
    Trace_ELBO(),
    rugged_std=dd.rugged_std.values,
    log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_1 = svi_result.params
Out[5]:
100%|██████████| 1000/1000 [00:01<00:00, 950.36it/s, init loss: 852.5239, avg. loss [951-1000]: -94.8615]

Code 8.6

In [6]:
post = m8_1.sample_posterior(random.PRNGKey(1), p8_1, sample_shape=(1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
Out[6]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      1.00      0.01      1.00      0.98      1.02    931.50      1.00
         b      0.00      0.06      0.00     -0.08      0.10   1111.63      1.00
     sigma      0.14      0.01      0.14      0.13      0.15    949.29      1.00

Code 8.7

In [7]:
# make variable to index Africa (0) or not (1)
dd["cid"] = jnp.where(dd.cont_africa.values == 1, 0, 1)

Code 8.8

In [8]:
def model(cid, rugged_std, log_gdp_std=None):
    a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
    b = numpyro.sample("b", dist.Normal(0, 0.3))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a[cid] + b * (rugged_std - 0.215))
    numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)


m8_2 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_2,
    optim.Adam(0.1),
    Trace_ELBO(),
    cid=dd.cid.values,
    rugged_std=dd.rugged_std.values,
    log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_2 = svi_result.params
Out[8]:
100%|██████████| 1000/1000 [00:00<00:00, 1061.70it/s, init loss: 1785.1527, avg. loss [951-1000]: -127.6097]

Code 8.9

In [9]:
post = m8_1.sample_posterior(random.PRNGKey(2), p8_1, sample_shape=(1000,))
logprob = log_likelihood(
    m8_1.model, post, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values
)
az8_1 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_2.sample_posterior(random.PRNGKey(2), p8_2, sample_shape=(1000,))
logprob = log_likelihood(
    m8_2.model,
    post,
    rugged_std=dd.rugged_std.values,
    cid=dd.cid.values,
    log_gdp_std=dd.log_gdp_std.values,
)
az8_2 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
az.compare({"m8.1": az8_1, "m8.2": az8_2}, ic="waic", scale="deviance")
Out[9]:
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
Out[9]:
rank waic p_waic d_waic weight se dse warning waic_scale
m8.2 0 -252.36 4.15389 0 0.9998 13.2547 0 True deviance
m8.1 1 -188.818 2.65329 63.5418 0.000200227 14.8249 14.9592 False deviance

Code 8.10

In [10]:
post = m8_2.sample_posterior(random.PRNGKey(1), p8_2, sample_shape=(1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
Out[10]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]      0.88      0.02      0.88      0.86      0.90   1049.96      1.00
      a[1]      1.05      0.01      1.05      1.03      1.07    824.00      1.00
         b     -0.05      0.05     -0.05     -0.13      0.02    999.08      1.00
     sigma      0.11      0.01      0.11      0.10      0.12    961.35      1.00

Code 8.11

In [11]:
post = m8_2.sample_posterior(random.PRNGKey(1), p8_2, sample_shape=(1000,))
diff_a1_a2 = post["a"][:, 0] - post["a"][:, 1]
jnp.percentile(diff_a1_a2, q=jnp.array([5.5, 94.5]))
Out[11]:
DeviceArray([-0.19981882, -0.13967244], dtype=float32)

Code 8.12

In [12]:
rugged_seq = jnp.linspace(start=-1, stop=1.1, num=30)

# compute mu over samples, fixing cid=1
post.pop("mu")
predictive = Predictive(m8_2.model, post, return_sites=["mu"])
mu_NotAfrica = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]

# compute mu over samples, fixing cid=0
mu_Africa = predictive(random.PRNGKey(2), cid=0, rugged_std=rugged_seq)["mu"]

# summarize to means and intervals
mu_NotAfrica_mu = jnp.mean(mu_NotAfrica, 0)
mu_NotAfrica_ci = jnp.percentile(mu_NotAfrica, q=jnp.array([1.5, 98.5]), axis=0)
mu_Africa_mu = jnp.mean(mu_Africa, 0)
mu_Africa_ci = jnp.percentile(mu_Africa, q=jnp.array([1.5, 98.5]), axis=0)

Code 8.13

In [13]:
def model(cid, rugged_std, log_gdp_std=None):
    a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
    b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a[cid] + b[cid] * (rugged_std - 0.215))
    numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)


m8_3 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_3,
    optim.Adam(0.1),
    Trace_ELBO(),
    cid=dd.cid.values,
    rugged_std=dd.rugged_std.values,
    log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_3 = svi_result.params
Out[13]:
100%|██████████| 1000/1000 [00:00<00:00, 1080.73it/s, init loss: 1670.7773, avg. loss [951-1000]: -132.0968]

Code 8.14

In [14]:
post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, sample_shape=(1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
Out[14]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]      0.89      0.02      0.89      0.86      0.91   1009.20      1.00
      a[1]      1.05      0.01      1.05      1.04      1.07    755.33      1.00
      b[0]      0.13      0.07      0.13      0.01      0.24   1045.06      1.00
      b[1]     -0.15      0.06     -0.14     -0.23     -0.05   1003.36      1.00
     sigma      0.11      0.01      0.11      0.10      0.12    810.01      1.00

Code 8.15

In [15]:
post = m8_1.sample_posterior(random.PRNGKey(2), p8_1, sample_shape=(1000,))
logprob = log_likelihood(
    m8_1.model, post, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values
)
az8_1 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_2.sample_posterior(random.PRNGKey(2), p8_2, sample_shape=(1000,))
logprob = log_likelihood(
    m8_2.model,
    post,
    rugged_std=dd.rugged_std.values,
    cid=dd.cid.values,
    log_gdp_std=dd.log_gdp_std.values,
)
az8_3 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_3.sample_posterior(random.PRNGKey(2), p8_3, sample_shape=(1000,))
logprob = log_likelihood(
    m8_3.model,
    post,
    rugged_std=dd.rugged_std.values,
    cid=dd.cid.values,
    log_gdp_std=dd.log_gdp_std.values,
)
az8_3 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
az.compare({"m8.1": az8_1, "m8.2": az8_2, "m8.3": az8_3}, ic="waic", scale="deviance")
Out[15]:
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
Out[15]:
rank waic p_waic d_waic weight se dse warning waic_scale
m8.3 0 -259.176 5.10348 0 0.824888 13.4328 0 True deviance
m8.2 1 -252.36 4.15389 6.81647 0.175111 14.6901 6.67691 True deviance
m8.1 2 -188.818 2.65329 70.3582 4.90447e-08 14.6588 15.3423 False deviance

Code 8.16

In [16]:
waic_list = az.waic(az8_3, pointwise=True, scale="deviance").waic_i.values
Out[16]:
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details

Code 8.17

In [17]:
# plot non-Africa - cid=1
d_A0 = dd[dd["cid"] == 1]
az.plot_pair(d_A0[["rugged_std", "log_gdp_std"]].to_dict(orient="list"))
plt.gca().set(
    xlim=(-0.01, 1.01),
    xlabel="ruggedness (standardized)",
    ylabel="log GDP (as proportion of mean)",
)
mu = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]
mu_mean = jnp.mean(mu, 0)
mu_ci = jnp.percentile(mu, q=jnp.array([1.5, 98.5]), axis=0)
plt.plot(rugged_seq, mu_mean, "k")
plt.fill_between(rugged_seq, mu_ci[0], mu_ci[1], color="k", alpha=0.2)
plt.title("Non-African nations")
plt.show()
Out[17]:
No description has been provided for this image

Code 8.18

In [18]:
rugged_seq = jnp.linspace(start=-0.2, stop=1.2, num=30)
post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, sample_shape=(1000,))
post.pop("mu")
predictive = Predictive(m8_3.model, post, return_sites=["mu"])
muA = predictive(random.PRNGKey(2), cid=0, rugged_std=rugged_seq)["mu"]
muN = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]
delta = muA - muN

Code 8.19

In [19]:
tulips = pd.read_csv("../data/tulips.csv", sep=";")
d = tulips
d.info()
d.head()
Out[19]:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 27 entries, 0 to 26
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   bed     27 non-null     object 
 1   water   27 non-null     int64  
 2   shade   27 non-null     int64  
 3   blooms  27 non-null     float64
dtypes: float64(1), int64(2), object(1)
memory usage: 992.0+ bytes
Out[19]:
bed water shade blooms
0 a 1 1 0.00
1 a 1 2 0.00
2 a 1 3 111.04
3 a 2 1 183.47
4 a 2 2 59.16

Code 8.20

In [20]:
d["blooms_std"] = d.blooms / d.blooms.max()
d["water_cent"] = d.water - d.water.mean()
d["shade_cent"] = d.shade - d.shade.mean()

Code 8.21

In [21]:
a = dist.Normal(0.5, 1).sample(random.PRNGKey(0), (int(1e4),))
jnp.sum((a < 0) | (a > 1)) / a.shape[0]
Out[21]:
DeviceArray(0.6182, dtype=float32)

Code 8.22

In [22]:
a = dist.Normal(0.5, 0.25).sample(random.PRNGKey(0), (int(1e4),))
jnp.sum((a < 0) | (a > 1)) / a.shape[0]
Out[22]:
DeviceArray(0.0471, dtype=float32)

Code 8.23

In [23]:
def model(water_cent, shade_cent, blooms_std=None):
    a = numpyro.sample("a", dist.Normal(0.5, 0.25))
    bw = numpyro.sample("bw", dist.Normal(0, 0.25))
    bs = numpyro.sample("bs", dist.Normal(0, 0.25))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bw * water_cent + bs * shade_cent)
    numpyro.sample("blooms_std", dist.Normal(mu, sigma), obs=blooms_std)


m8_4 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_4,
    optim.Adam(1),
    Trace_ELBO(),
    shade_cent=d.shade_cent.values,
    water_cent=d.water_cent.values,
    blooms_std=d.blooms_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_4 = svi_result.params
Out[23]:
100%|██████████| 1000/1000 [00:00<00:00, 1306.10it/s, init loss: 753.9799, avg. loss [951-1000]: -9.9633]

Code 8.24

In [24]:
def model(water_cent, shade_cent, blooms_std=None):
    a = numpyro.sample("a", dist.Normal(0.5, 0.25))
    bw = numpyro.sample("bw", dist.Normal(0, 0.25))
    bs = numpyro.sample("bs", dist.Normal(0, 0.25))
    bws = numpyro.sample("bws", dist.Normal(0, 0.25))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bw * water_cent + bs * shade_cent + bws * water_cent * shade_cent
    numpyro.sample("blooms_std", dist.Normal(mu, sigma), obs=blooms_std)


m8_5 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m8_5,
    optim.Adam(1),
    Trace_ELBO(),
    shade_cent=d.shade_cent.values,
    water_cent=d.water_cent.values,
    blooms_std=d.blooms_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_5 = svi_result.params
Out[24]:
100%|██████████| 1000/1000 [00:00<00:00, 1205.86it/s, init loss: 133.8938, avg. loss [951-1000]: -16.3747]

Code 8.25

In [25]:
_, axes = plt.subplots(1, 3, figsize=(9, 3), sharey=True)  # 3 plots in 1 row
for ax, s in zip(axes, range(-1, 2)):
    idx = d.shade_cent == s
    ax.scatter(d.water_cent[idx], d.blooms_std[idx])
    ax.set(xlim=(-1.1, 1.1), ylim=(-0.1, 1.1), xlabel="water", ylabel="blooms")
    post = m8_4.sample_posterior(random.PRNGKey(1), p8_4, sample_shape=(1000,))
    post.pop("mu")
    mu = Predictive(m8_4.model, post, return_sites=["mu"])(
        random.PRNGKey(2), shade_cent=s, water_cent=jnp.arange(-1, 2)
    )["mu"]
    for i in range(20):
        ax.plot(range(-1, 2), mu[i], "k", alpha=0.3)
Out[25]:
No description has been provided for this image

Code 8.26

In [26]:
predictive = Predictive(
    m8_5.model, num_samples=1000, return_sites=["a", "bw", "bs", "bws", "sigma"]
)
prior = predictive(random.PRNGKey(7), water_cent=0, shade_cent=0)

Code 8.27

In [27]:
nettle = pd.read_csv("../data/nettle.csv", sep=";")
d = nettle
d["lang.per.cap"] = d["num.lang"] / d["k.pop"]

Comments

Comments powered by Disqus