Chapter 7. Ulysses’ Compass
In [ ]:
!pip install -q numpyro arviz
In [0]:
import os
import warnings
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import lax, random, vmap
from jax.scipy.special import logsumexp
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO, init_to_value, log_likelihood
from numpyro.infer.autoguide import AutoLaplaceApproximation
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
Code 7.1¶
In [1]:
sppnames = [
"afarensis",
"africanus",
"habilis",
"boisei",
"rudolfensis",
"ergaster",
"sapiens",
]
brainvolcc = jnp.array([438, 452, 612, 521, 752, 871, 1350])
masskg = jnp.array([37.0, 35.5, 34.5, 41.5, 55.5, 61.0, 53.5])
d = pd.DataFrame({"species": sppnames, "brain": brainvolcc, "mass": masskg})
Code 7.2¶
In [2]:
d["mass_std"] = (d.mass - d.mass.mean()) / d.mass.std()
d["brain_std"] = d.brain / d.brain.max()
Code 7.3¶
In [3]:
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic("mu", a + b * mass_std)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_1 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m7_1,
optim.Adam(0.3),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p7_1 = svi_result.params
Out[3]:
Code 7.4¶
In [4]:
def model(mass_std, brain_std):
intercept = numpyro.sample("intercept", dist.Normal(0, 10))
b_mass_std = numpyro.sample("b_mass_std", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.HalfCauchy(2))
mu = intercept + b_mass_std * mass_std
numpyro.sample("brain_std", dist.Normal(mu, sigma), obs=brain_std)
m7_1_OLS = AutoLaplaceApproximation(model)
svi = SVI(
model,
m7_1_OLS,
optim=optim.Adam(0.01),
loss=Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p7_1_OLS = svi_result.params
post = m7_1_OLS.sample_posterior(random.PRNGKey(1), p7_1_OLS, sample_shape=(1000,))
Out[4]:
Code 7.5¶
In [5]:
post = m7_1.sample_posterior(random.PRNGKey(12), p7_1, sample_shape=(1000,))
s = Predictive(m7_1.model, post)(random.PRNGKey(2), d.mass_std.values)
r = jnp.mean(s["brain_std"], 0) - d.brain_std.values
resid_var = jnp.var(r, ddof=1)
outcome_var = jnp.var(d.brain_std.values, ddof=1)
1 - resid_var / outcome_var
Out[5]:
Code 7.6¶
In [6]:
def R2_is_bad(quap_fit):
quap, params = quap_fit
post = quap.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
s = Predictive(quap.model, post)(random.PRNGKey(2), d.mass_std.values)
r = jnp.mean(s["brain_std"], 0) - d.brain_std.values
return 1 - jnp.var(r, ddof=1) / jnp.var(d.brain_std.values, ddof=1)
Code 7.7¶
In [7]:
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10).expand([2]))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic("mu", a + b[0] * mass_std + b[1] * mass_std**2)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_2 = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.repeat(0.0, 2)})
)
svi = SVI(
model,
m7_2,
optim.Adam(0.3),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p7_2 = svi_result.params
Out[7]:
Code 7.8¶
In [8]:
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10).expand([3]))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic(
"mu", a + b[0] * mass_std + b[1] * mass_std**2 + b[2] * mass_std**3
)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_3 = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.repeat(0.0, 3)})
)
svi = SVI(
model,
m7_3,
optim.Adam(0.01),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p7_3 = svi_result.params
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10).expand([4]))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic(
"mu", a + jnp.sum(b * jnp.power(mass_std[..., None], jnp.arange(1, 5)), -1)
)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_4 = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.repeat(0.0, 4)})
)
svi = SVI(
model,
m7_4,
optim.Adam(0.01),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p7_4 = svi_result.params
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10).expand([5]))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic(
"mu", a + jnp.sum(b * jnp.power(mass_std[..., None], jnp.arange(1, 6)), -1)
)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_5 = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.repeat(0.0, 5)})
)
svi = SVI(
model,
m7_5,
optim.Adam(0.01),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p7_5 = svi_result.params
Out[8]:
Code 7.9¶
In [9]:
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10).expand([6]))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic(
"mu", a + jnp.sum(b * jnp.power(mass_std[..., None], jnp.arange(1, 7)), -1)
)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_6 = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.repeat(0.0, 6)})
)
svi = SVI(
model,
m7_6,
optim.Adam(0.003),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 5000)
p7_6 = svi_result.params
Out[9]:
Code 7.10¶
In [10]:
post = m7_1.sample_posterior(random.PRNGKey(1), p7_1, sample_shape=(1000,))
post.pop("mu")
mass_seq = jnp.linspace(d.mass_std.min(), d.mass_std.max(), num=100)
l = Predictive(m7_1.model, post, return_sites=["mu"])(
random.PRNGKey(2), mass_std=mass_seq
)["mu"]
mu = jnp.mean(l, 0)
ci = jnp.percentile(l, jnp.array([5.5, 94.5]), 0)
az.plot_pair(d[["mass_std", "brain_std"]].to_dict("list"))
plt.plot(mass_seq, mu, "k")
plt.fill_between(mass_seq, ci[0], ci[1], color="k", alpha=0.2)
plt.title("m7.1: R^2 = {:0.2f}".format(R2_is_bad((m7_1, p7_1)).item()))
plt.show()
Out[10]:
Code 7.11¶
In [11]:
i = 1
d_minus_i = d.drop(i)
Code 7.12¶
In [12]:
p = jnp.array([0.3, 0.7])
-jnp.sum(p * jnp.log(p))
Out[12]:
Code 7.13¶
In [13]:
def lppd_fn(seed, quad, params, num_samples=1000):
post = quad.sample_posterior(random.PRNGKey(1), params, sample_shape=(num_samples,))
logprob = log_likelihood(quad.model, post, d.mass_std.values, d.brain_std.values)
logprob = logprob["brain_std"]
return logsumexp(logprob, 0) - jnp.log(logprob.shape[0])
lppd_fn(random.PRNGKey(1), m7_1, p7_1, int(1e4))
Out[13]:
Code 7.14¶
In [14]:
post = m7_1.sample_posterior(random.PRNGKey(1), p7_1, sample_shape=(int(1e4),))
logprob = log_likelihood(m7_1.model, post, d.mass_std.values, d.brain_std.values)
logprob = logprob["brain_std"]
n = logprob.shape[1]
ns = logprob.shape[0]
f = lambda i: logsumexp(logprob[:, i]) - jnp.log(ns)
lppd = vmap(f)(jnp.arange(n))
lppd
Out[14]:
Code 7.15¶
In [15]:
[
jnp.sum(lppd_fn(random.PRNGKey(1), m[0], m[1])).item()
for m in (
(m7_1, p7_1),
(m7_2, p7_2),
(m7_3, p7_3),
(m7_4, p7_4),
(m7_5, p7_5),
(m7_6, p7_6),
)
]
Out[15]:
Out[15]:
Code 7.16¶
In [16]:
def model(mm, y, b_sigma):
a = numpyro.param("a", jnp.array([0.0]))
Bvec = a
k = mm.shape[1]
if k > 1:
b = numpyro.sample("b", dist.Normal(0, b_sigma).expand([k - 1]))
Bvec = jnp.concatenate([Bvec, b])
mu = jnp.matmul(mm, Bvec)
numpyro.sample("y", dist.Normal(mu, 1), obs=y)
def sim_train_test(i, N=20, k=3, rho=[0.15, -0.4], b_sigma=100):
n_dim = max(k, 3)
Rho = jnp.identity(n_dim)
Rho = Rho.at[1 : len(rho) + 1, 0].set(jnp.array(rho))
Rho = Rho.at[0, 1 : len(rho) + 1].set(jnp.array(rho))
X_train = dist.MultivariateNormal(jnp.zeros(n_dim), Rho).sample(
random.fold_in(random.PRNGKey(0), i), (N,)
)
mm_train = jnp.ones((N, 1))
if k > 1:
mm_train = jnp.concatenate([mm_train, X_train[:, 1:k]], axis=1)
if k > 1:
m = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.zeros(k - 1)})
)
else:
m = lambda mm, y, b_sigma: None
svi = SVI(
model,
m,
optim.Adam(0.3),
Trace_ELBO(),
mm=mm_train,
y=X_train[:, 0],
b_sigma=b_sigma,
)
svi_result = svi.run(random.fold_in(random.PRNGKey(1), i), 1000, progress_bar=False)
params = svi_result.params
coefs = params["a"]
if k > 1:
coefs = jnp.concatenate([coefs, m.median(params)["b"]])
logprob = dist.Normal(jnp.matmul(mm_train, coefs)).log_prob(X_train[:, 0])
dev_train = (-2) * jnp.sum(logprob)
X_test = dist.MultivariateNormal(jnp.zeros(n_dim), Rho).sample(
random.fold_in(random.PRNGKey(2), i), (N,)
)
mm_test = jnp.ones((N, 1))
if k > 1:
mm_test = jnp.concatenate([mm_test, X_test[:, 1:k]], axis=1)
logprob = dist.Normal(jnp.matmul(mm_test, coefs)).log_prob(X_test[:, 0])
dev_test = (-2) * jnp.sum(logprob)
return jnp.stack([dev_train, dev_test])
def dev_fn(N, k):
print(k)
r = lax.map(lambda i: sim_train_test(i, N, k), jnp.arange((int(1e4))))
return jnp.concatenate([jnp.mean(r, 0), jnp.std(r, 0)])
N = 20
kseq = range(1, 6)
dev = jnp.stack([dev_fn(N, k) for k in kseq], axis=1)
Out[16]:
Code 7.17¶
In [17]:
def dev_fn(N, k):
print(k)
r = vmap(lambda i: sim_train_test(i, N, k))(jnp.arange((int(1e4))))
return jnp.concatenate([jnp.mean(r, 0), jnp.std(r, 0)])
Code 7.18¶
In [18]:
plt.subplot(
ylim=(jnp.min(dev[0]).item() - 5, jnp.max(dev[0]).item() + 12),
xlim=(0.9, 5.2),
xlabel="number of parameters",
ylabel="deviance",
)
plt.title("N = {}".format(N))
plt.scatter(jnp.arange(1, 6), dev[0], s=80, color="b")
plt.scatter(jnp.arange(1.1, 6), dev[1], s=80, color="k")
pts_int = (dev[0] - dev[2], dev[0] + dev[2])
pts_out = (dev[1] - dev[3], dev[1] + dev[3])
plt.vlines(jnp.arange(1, 6), pts_int[0], pts_int[1], color="b")
plt.vlines(jnp.arange(1.1, 6), pts_out[0], pts_out[1], color="k")
plt.annotate(
"in", (2, dev[0][1]), xytext=(-25, -5), textcoords="offset pixels", color="b"
)
plt.annotate("out", (2.1, dev[1][1]), xytext=(10, -5), textcoords="offset pixels")
plt.annotate(
"+1SD",
(2.1, pts_out[1][1]),
xytext=(10, -5),
textcoords="offset pixels",
fontsize=12,
)
plt.annotate(
"-1SD",
(2.1, pts_out[0][1]),
xytext=(10, -5),
textcoords="offset pixels",
fontsize=12,
)
plt.show()
Out[18]:
Code 7.19¶
In [19]:
cars = pd.read_csv("../data/cars.csv", sep=",")
def model(speed, cars_dist):
a = numpyro.sample("a", dist.Normal(0, 100))
b = numpyro.sample("b", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + b * speed
numpyro.sample("dist", dist.Normal(mu, sigma), obs=cars_dist)
m = AutoLaplaceApproximation(model)
svi = SVI(
model,
m,
optim.Adam(1),
Trace_ELBO(),
speed=cars.speed.values,
cars_dist=cars.dist.values,
)
svi_result = svi.run(random.PRNGKey(0), 5000)
params = svi_result.params
post = m.sample_posterior(random.PRNGKey(94), params, sample_shape=(1000,))
Out[19]:
Code 7.20¶
In [20]:
n_samples = 1000
def logprob_fn(s):
mu = post["a"][s] + post["b"][s] * cars.speed.values
return dist.Normal(mu, post["sigma"][s]).log_prob(cars.dist.values)
logprob = vmap(logprob_fn, out_axes=1)(jnp.arange(n_samples))
Code 7.21¶
In [21]:
n_cases = cars.shape[0]
lppd = logsumexp(logprob, 1) - jnp.log(n_samples)
Code 7.22¶
In [22]:
pWAIC = jnp.var(logprob, 1)
Code 7.23¶
In [23]:
-2 * (jnp.sum(lppd) - jnp.sum(pWAIC))
Out[23]:
Code 7.24¶
In [24]:
waic_vec = -2 * (lppd - pWAIC)
jnp.sqrt(n_cases * jnp.var(waic_vec))
Out[24]:
Code 7.25¶
In [25]:
with numpyro.handlers.seed(rng_seed=71):
# number of plants
N = 100
# simulate initial heights
h0 = numpyro.sample("h0", dist.Normal(10, 2).expand([N]))
# assign treatments and simulate fungus and growth
treatment = jnp.repeat(jnp.arange(2), repeats=N // 2)
fungus = numpyro.sample(
"fungus", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4))
)
h1 = h0 + numpyro.sample("diff", dist.Normal(5 - 3 * fungus))
# compose a clean data frame
d = pd.DataFrame({"h0": h0, "h1": h1, "treatment": treatment, "fungus": fungus})
def model(h0, h1):
p = numpyro.sample("p", dist.LogNormal(0, 0.25))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)
m6_6 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_6, optim.Adam(0.1), Trace_ELBO(), h0=d.h0.values, h1=d.h1.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_6 = svi_result.params
def model(treatment, fungus, h0, h1):
a = numpyro.sample("a", dist.LogNormal(0, 0.2))
bt = numpyro.sample("bt", dist.Normal(0, 0.5))
bf = numpyro.sample("bf", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
p = a + bt * treatment + bf * fungus
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)
m6_7 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_7,
optim.Adam(0.3),
Trace_ELBO(),
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_7 = svi_result.params
def model(treatment, h0, h1):
a = numpyro.sample("a", dist.LogNormal(0, 0.2))
bt = numpyro.sample("bt", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
p = a + bt * treatment
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)
m6_8 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_8,
optim.Adam(1),
Trace_ELBO(),
treatment=d.treatment.values,
h0=d.h0.values,
h1=d.h1.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_8 = svi_result.params
post = m6_7.sample_posterior(random.PRNGKey(11), p6_7, sample_shape=(1000,))
logprob = log_likelihood(
m6_7.model,
post,
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
az6_7 = az.from_dict(sample_stats={"log_likelihood": logprob["h1"][None, ...]})
az.waic(az6_7, scale="deviance")
Out[25]:
Out[25]:
Code 7.26¶
In [26]:
post = m6_6.sample_posterior(random.PRNGKey(77), p6_6, sample_shape=(1000,))
logprob = log_likelihood(m6_6.model, post, h0=d.h0.values, h1=d.h1.values)
az6_6 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
post = m6_7.sample_posterior(random.PRNGKey(77), p6_7, sample_shape=(1000,))
logprob = log_likelihood(
m6_7.model,
post,
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
az6_7 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
post = m6_8.sample_posterior(random.PRNGKey(77), p6_8, sample_shape=(1000,))
logprob = log_likelihood(
m6_8.model, post, treatment=d.treatment.values, h0=d.h0.values, h1=d.h1.values
)
az6_8 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
az.compare({"m6.6": az6_6, "m6.7": az6_7, "m6.8": az6_8}, ic="waic", scale="deviance")
Out[26]:
Out[26]:
Code 7.27¶
In [27]:
post = m6_7.sample_posterior(random.PRNGKey(91), p6_7, sample_shape=(1000,))
logprob = log_likelihood(
m6_7.model,
post,
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
az6_7 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_7 = az.waic(az6_7, pointwise=True, scale="deviance")
post = m6_8.sample_posterior(random.PRNGKey(91), p6_8, sample_shape=(1000,))
logprob = log_likelihood(
m6_8.model, post, treatment=d.treatment.values, h0=d.h0.values, h1=d.h1.values
)
az6_8 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_8 = az.waic(az6_8, pointwise=True, scale="deviance")
n = waic_m6_7.n_data_points
diff_m6_7_m6_8 = waic_m6_7.waic_i.values - waic_m6_8.waic_i.values
jnp.sqrt(n * jnp.var(diff_m6_7_m6_8))
Out[27]:
Out[27]:
Code 7.28¶
In [28]:
40.0 + jnp.array([-1, 1]) * 10.4 * 2.6
Out[28]:
Code 7.29¶
In [29]:
compare = az.compare(
{"m6.6": az6_6, "m6.7": az6_7, "m6.8": az6_8}, ic="waic", scale="deviance"
)
az.plot_compare(compare)
plt.show()
Out[29]:
Out[29]:
Code 7.30¶
In [30]:
post = m6_6.sample_posterior(random.PRNGKey(92), p6_6, sample_shape=(1000,))
logprob = log_likelihood(m6_6.model, post, h0=d.h0.values, h1=d.h1.values)
az6_6 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_6 = az.waic(az6_6, pointwise=True, scale="deviance")
diff_m6_6_m6_8 = waic_m6_6.waic_i.values - waic_m6_8.waic_i.values
jnp.sqrt(n * jnp.var(diff_m6_6_m6_8))
Out[30]:
Code 7.31¶
In [31]:
post = m6_6.sample_posterior(random.PRNGKey(93), p6_6, sample_shape=(1000,))
logprob = log_likelihood(m6_6.model, post, h0=d.h0.values, h1=d.h1.values)
az6_6 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_6 = az.waic(az6_6, pointwise=True, scale="deviance")
post = m6_7.sample_posterior(random.PRNGKey(93), p6_7, sample_shape=(1000,))
logprob = log_likelihood(
m6_7.model,
post,
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
az6_7 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_7 = az.waic(az6_7, pointwise=True, scale="deviance")
post = m6_8.sample_posterior(random.PRNGKey(93), p6_8, sample_shape=(1000,))
logprob = log_likelihood(
m6_8.model, post, treatment=d.treatment.values, h0=d.h0.values, h1=d.h1.values
)
az6_8 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_8 = az.waic(az6_8, pointwise=True, scale="deviance")
dSE = lambda waic1, waic2: jnp.sqrt(
n * jnp.var(waic1.waic_i.values - waic2.waic_i.values)
)
data = {"m6.6": waic_m6_6, "m6.7": waic_m6_7, "m6.8": waic_m6_8}
pd.DataFrame(
{
row: {col: dSE(row_val, col_val) for col, col_val in data.items()}
for row, row_val in data.items()
}
)
Out[31]:
Out[31]:
Code 7.32¶
In [32]:
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce
d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std())
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std())
d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std())
def model(A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bA * A)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_1 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_1, optim.Adam(1), Trace_ELBO(), A=d.A.values, D=d.D.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_1 = svi_result.params
def model(M, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bM * M
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_2 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_2, optim.Adam(1), Trace_ELBO(), M=d.M.values, D=d.D.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_2 = svi_result.params
def model(M, A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bM * M + bA * A)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_3 = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_3 = svi_result.params
Out[32]:
Code 7.33¶
In [33]:
post = m5_1.sample_posterior(random.PRNGKey(24071847), p5_1, sample_shape=(1000,))
logprob = log_likelihood(m5_1.model, post, A=d.A.values, D=d.D.values)["D"]
az5_1 = az.from_dict(
posterior={k: v[None, ...] for k, v in post.items()},
log_likelihood={"D": logprob[None, ...]},
)
post = m5_2.sample_posterior(random.PRNGKey(24071847), p5_2, sample_shape=(1000,))
logprob = log_likelihood(m5_2.model, post, M=d.M.values, D=d.D.values)["D"]
az5_2 = az.from_dict(
posterior={k: v[None, ...] for k, v in post.items()},
log_likelihood={"D": logprob[None, ...]},
)
post = m5_3.sample_posterior(random.PRNGKey(24071847), p5_3, sample_shape=(1000,))
logprob = log_likelihood(m5_3.model, post, A=d.A.values, M=d.M.values, D=d.D.values)[
"D"
]
az5_3 = az.from_dict(
posterior={k: v[None, ...] for k, v in post.items()},
log_likelihood={"D": logprob[None, ...]},
)
az.compare({"m5.1": az5_1, "m5.2": az5_2, "m5.3": az5_3}, ic="waic", scale="deviance")
Out[33]:
Out[33]:
Code 7.34¶
In [34]:
PSIS_m5_3 = az.loo(az5_3, pointwise=True, scale="deviance")
WAIC_m5_3 = az.waic(az5_3, pointwise=True, scale="deviance")
penalty = az5_3.log_likelihood.stack(sample=("chain", "draw")).var(dim="sample")
plt.plot(PSIS_m5_3.pareto_k.values, penalty.D.values, "o", mfc="none")
plt.gca().set(xlabel="PSIS Pareto k", ylabel="WAIC penalty")
plt.show()
Out[34]:
Out[34]:
Code 7.35¶
In [35]:
def model(M, A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bM * M + bA * A
numpyro.sample("D", dist.StudentT(2, mu, sigma), obs=D)
m5_3t = AutoLaplaceApproximation(model)
svi = SVI(
model,
m5_3t,
optim.Adam(0.3),
Trace_ELBO(),
M=d.M.values,
A=d.A.values,
D=d.D.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_3t = svi_result.params
Out[35]:
Comments