Chapter 5. The Many Variables & The Spurious Waffles
In [ ]:
!pip install -q numpyro arviz daft networkx
In [0]:
import collections
import itertools
import math
import os
import arviz as az
import daft
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
Code 5.1¶
In [1]:
# load data and copy
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce
# standardize variables
d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std())
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std())
Code 5.2¶
In [2]:
d.MedianAgeMarriage.std()
Out[2]:
Code 5.3¶
In [3]:
def model(A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bA * A)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_1 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_1, optim.Adam(1), Trace_ELBO(), A=d.A.values, D=d.D.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_1 = svi_result.params
Out[3]:
Code 5.4¶
In [4]:
predictive = Predictive(m5_1.model, num_samples=1000, return_sites=["mu"])
prior_pred = predictive(random.PRNGKey(10), A=jnp.array([-2, 2]))
mu = prior_pred["mu"]
plt.subplot(xlim=(-2, 2), ylim=(-2, 2))
for i in range(20):
plt.plot([-2, 2], mu[i], "k", alpha=0.4)
Out[4]:
Code 5.5¶
In [5]:
# compute percentile interval of mean
A_seq = jnp.linspace(start=-3, stop=3.2, num=30)
post = m5_1.sample_posterior(random.PRNGKey(1), p5_1, sample_shape=(1000,))
post.pop("mu")
post_pred = Predictive(m5_1.model, post)(random.PRNGKey(2), A=A_seq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
# plot it all
az.plot_pair(d[["D", "A"]].to_dict(orient="list"))
plt.plot(A_seq, mu_mean, "k")
plt.fill_between(A_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.show()
Out[5]:
Code 5.6¶
In [6]:
d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std())
def model(M, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bM * M
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_2 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_2, optim.Adam(1), Trace_ELBO(), M=d.M.values, D=d.D.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_2 = svi_result.params
Out[6]:
Code 5.7¶
In [7]:
dag5_1 = nx.DiGraph()
dag5_1.add_edges_from([("A", "D"), ("A", "M"), ("M", "D")])
pgm = daft.PGM()
coordinates = {"A": (0, 0), "D": (1, 1), "M": (2, 0)}
for node in dag5_1.nodes:
pgm.add_node(node, node, *coordinates[node])
for edge in dag5_1.edges:
pgm.add_edge(*edge)
with plt.rc_context({"figure.constrained_layout.use": False}):
pgm.render()
plt.gca().invert_yaxis()
plt.show()
Out[7]:
Code 5.8¶
In [8]:
DMA_dag2 = nx.DiGraph()
DMA_dag2.add_edges_from([("A", "D"), ("A", "M")])
conditional_independencies = collections.defaultdict(list)
for edge in itertools.combinations(sorted(DMA_dag2.nodes), 2):
remaining = sorted(set(DMA_dag2.nodes) - set(edge))
for size in range(len(remaining) + 1):
for subset in itertools.combinations(remaining, size):
if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):
continue
if nx.d_separated(DMA_dag2, {edge[0]}, {edge[1]}, set(subset)):
conditional_independencies[edge].append(set(subset))
print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else ""))
Out[8]:
Code 5.9¶
In [9]:
DMA_dag1 = nx.DiGraph()
DMA_dag1.add_edges_from([("A", "D"), ("A", "M"), ("M", "D")])
conditional_independencies = collections.defaultdict(list)
for edge in itertools.combinations(sorted(DMA_dag1.nodes), 2):
remaining = sorted(set(DMA_dag1.nodes) - set(edge))
for size in range(len(remaining) + 1):
for subset in itertools.combinations(remaining, size):
if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):
continue
if nx.d_separated(DMA_dag1, {edge[0]}, {edge[1]}, set(subset)):
conditional_independencies[edge].append(set(subset))
print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else ""))
Code 5.10¶
In [10]:
def model(M, A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bM * M + bA * A)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_3 = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_3 = svi_result.params
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[10]:
Out[10]:
Code 5.11¶
In [11]:
coeftab = {
"m5.1": m5_1.sample_posterior(random.PRNGKey(1), p5_1, sample_shape=(1, 1000)),
"m5.2": m5_2.sample_posterior(random.PRNGKey(2), p5_2, sample_shape=(1, 1000)),
"m5.3": m5_3.sample_posterior(random.PRNGKey(3), p5_3, sample_shape=(1, 1000)),
}
az.plot_forest(
list(coeftab.values()),
model_names=list(coeftab.keys()),
var_names=["bA", "bM"],
hdi_prob=0.89,
)
plt.show()
Out[11]:
Code 5.12¶
In [12]:
N = 50 # number of simulated States
age = dist.Normal().sample(random.PRNGKey(0), sample_shape=(N,)) # sim A
mar = dist.Normal(age).sample(random.PRNGKey(1)) # sim A -> M
div = dist.Normal(age).sample(random.PRNGKey(2)) # sim A -> D
Code 5.13¶
In [13]:
def model(A, M=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bAM = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bAM * A)
numpyro.sample("M", dist.Normal(mu, sigma), obs=M)
m5_4 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_4, optim.Adam(0.1), Trace_ELBO(), A=d.A.values, M=d.M.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_4 = svi_result.params
Out[13]:
Code 5.14¶
In [14]:
post = m5_4.sample_posterior(random.PRNGKey(1), p5_4, sample_shape=(1000,))
post.pop("mu")
post_pred = Predictive(m5_4.model, post)(random.PRNGKey(2), A=d.A.values)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_resid = d.M.values - mu_mean
Code 5.15¶
In [15]:
# call predictive without specifying new data
# so it uses original data
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, sample_shape=(int(1e4),))
post.pop("mu")
post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2), M=d.M.values, A=d.A.values)
mu = post_pred["mu"]
# summarize samples across cases
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
# simulate observations
# again no new data, so uses original data
D_sim = post_pred["D"]
D_PI = jnp.percentile(D_sim, q=jnp.array([5.5, 94.5]), axis=0)
Code 5.16¶
In [16]:
ax = plt.subplot(
ylim=(float(mu_PI.min()), float(mu_PI.max())),
xlabel="Observed divorce",
ylabel="Predicted divorce",
)
plt.plot(d.D, mu_mean, "o")
x = jnp.linspace(mu_PI.min(), mu_PI.max(), 101)
plt.plot(x, x, "--")
for i in range(d.shape[0]):
plt.plot([d.D[i]] * 2, mu_PI[:, i], "b")
fig = plt.gcf()
Out[16]:
Code 5.17¶
In [17]:
for i in range(d.shape[0]):
if d.Loc[i] in ["ID", "UT", "RI", "ME"]:
ax.annotate(
d.Loc[i], (d.D[i], mu_mean[i]), xytext=(-25, -5), textcoords="offset pixels"
)
fig
Out[17]:
Code 5.18¶
In [18]:
N = 100 # number of cases
# x_real as Gaussian with mean 0 and stddev 1
x_real = dist.Normal().sample(random.PRNGKey(0), (N,))
# x_spur as Gaussian with mean=x_real
x_spur = dist.Normal(x_real).sample(random.PRNGKey(1))
# y as Gaussian with mean=x_real
y = dist.Normal(x_real).sample(random.PRNGKey(2))
# bind all together in data frame
d = pd.DataFrame({"y": y, "x_real": x_real, "x_spur": x_spur})
Code 5.19¶
In [19]:
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce
d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std())
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std())
d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std())
def model(A, M=None, D=None):
# A -> M
aM = numpyro.sample("aM", dist.Normal(0, 0.2))
bAM = numpyro.sample("bAM", dist.Normal(0, 0.5))
sigma_M = numpyro.sample("sigma_M", dist.Exponential(1))
mu_M = aM + bAM * A
M = numpyro.sample("M", dist.Normal(mu_M, sigma_M), obs=M)
# A -> D <- M
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bM * M + bA * A
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_3_A = AutoLaplaceApproximation(model)
svi = SVI(
model,
m5_3_A,
optim.Adam(0.1),
Trace_ELBO(),
A=d.A.values,
M=d.M.values,
D=d.D.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_3_A = svi_result.params
Out[19]:
Code 5.20¶
In [20]:
A_seq = jnp.linspace(-2, 2, num=30)
Code 5.21¶
In [21]:
# prep data
sim_dat = dict(A=A_seq)
# simulate M and then D, using A_seq
post = m5_3_A.sample_posterior(random.PRNGKey(1), p5_3_A, sample_shape=(1000,))
s = Predictive(m5_3_A.model, post)(random.PRNGKey(2), **sim_dat)
Code 5.22¶
In [22]:
plt.plot(sim_dat["A"], jnp.mean(s["D"], 0))
plt.gca().set(ylim=(-2, 2), xlabel="manipulated A", ylabel="counterfactual D")
plt.fill_between(
sim_dat["A"],
*jnp.percentile(s["D"], q=jnp.array([5.5, 94.5]), axis=0),
color="k",
alpha=0.2
)
plt.title("Total counterfactual effect of A on D")
plt.show()
Out[22]:
Code 5.23¶
In [23]:
# new data frame, standardized to mean 26.1 and stddev 1.24
sim2_dat = dict(A=(jnp.array([20, 30]) - 26.1) / 1.24)
s2 = Predictive(m5_3_A.model, post, return_sites=["M", "D"])(
random.PRNGKey(2), **sim2_dat
)
jnp.mean(s2["D"][:, 1] - s2["D"][:, 0])
Out[23]:
Code 5.24¶
In [24]:
sim_dat = dict(M=jnp.linspace(-2, 2, num=30), A=0)
s = Predictive(m5_3_A.model, post)(random.PRNGKey(2), **sim_dat)["D"]
plt.plot(sim_dat["M"], jnp.mean(s, 0))
plt.gca().set(ylim=(-2, 2), xlabel="manipulated A", ylabel="counterfactual D")
plt.fill_between(
sim_dat["M"],
*jnp.percentile(s, q=jnp.array([5.5, 94.5]), axis=0),
color="k",
alpha=0.2
)
plt.title("Total counterfactual effect of M on D")
plt.show()
Out[24]:
Code 5.25¶
In [25]:
A_seq = jnp.linspace(-2, 2, num=30)
Code 5.26¶
In [26]:
post = m5_3_A.sample_posterior(random.PRNGKey(1), p5_3_A, sample_shape=(1000,))
post = {k: v[..., None] for k, v in post.items()}
M_sim = dist.Normal(post["aM"] + post["bAM"] * A_seq).sample(random.PRNGKey(1))
Code 5.27¶
In [27]:
D_sim = dist.Normal(post["a"] + post["bA"] * A_seq + post["bM"] * M_sim).sample(
random.PRNGKey(1)
)
Code 5.28¶
In [28]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d.info()
d.head()
Out[28]:
Out[28]:
Code 5.29¶
In [29]:
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())
d["N"] = d["neocortex.perc"].pipe(lambda x: (x - x.mean()) / x.std())
d["M"] = d.mass.map(math.log).pipe(lambda x: (x - x.mean()) / x.std())
Code 5.30¶
In [30]:
def model(N, K):
a = numpyro.sample("a", dist.Normal(0, 1))
bN = numpyro.sample("bN", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bN * N
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
with numpyro.validation_enabled():
try:
m5_5_draft = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_5_draft, optim.Adam(1), Trace_ELBO(), N=d.N.values, K=d.K.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_5_draft = svi_result.params
except ValueError as e:
print(str(e))
Out[30]:
Code 5.31¶
In [31]:
d["neocortex.perc"]
Out[31]:
Code 5.32¶
In [32]:
dcc = d.iloc[d[["K", "N", "M"]].dropna(how="any", axis=0).index]
Code 5.33¶
In [33]:
def model(N, K=None):
a = numpyro.sample("a", dist.Normal(0, 1))
bN = numpyro.sample("bN", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bN * N)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m5_5_draft = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_5_draft, optim.Adam(0.1), Trace_ELBO(), N=dcc.N.values, K=dcc.K.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_5_draft = svi_result.params
Out[33]:
Code 5.34¶
In [34]:
xseq = jnp.array([-2, 2])
prior_pred = Predictive(model, num_samples=1000)(random.PRNGKey(1), N=xseq)
mu = prior_pred["mu"]
plt.subplot(xlim=xseq, ylim=xseq)
for i in range(50):
plt.plot(xseq, mu[i], "k", alpha=0.3)
Out[34]:
Code 5.35¶
In [35]:
def model(N, K=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bN = numpyro.sample("bN", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bN * N)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m5_5 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_5, optim.Adam(1), Trace_ELBO(), N=dcc.N.values, K=dcc.K.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_5 = svi_result.params
Out[35]:
Code 5.36¶
In [36]:
post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[36]:
Code 5.37¶
In [37]:
xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)
post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1000,))
post.pop("mu")
post_pred = Predictive(m5_5.model, post)(random.PRNGKey(2), N=xseq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
az.plot_pair(dcc[["N", "K"]].to_dict(orient="list"))
plt.plot(xseq, mu_mean, "k")
plt.fill_between(xseq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.show()
Out[37]:
Code 5.38¶
In [38]:
def model(M, K=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bM * M)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m5_6 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_6, optim.Adam(1), Trace_ELBO(), M=dcc.M.values, K=dcc.K.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_6 = svi_result.params
post = m5_6.sample_posterior(random.PRNGKey(1), p5_6, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[38]:
Out[38]:
Code 5.39¶
In [39]:
def model(N, M, K=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bN = numpyro.sample("bN", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bN * N + bM * M)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m5_7 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m5_7,
optim.Adam(1),
Trace_ELBO(),
N=dcc.N.values,
M=dcc.M.values,
K=dcc.K.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_7 = svi_result.params
post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[39]:
Out[39]:
Code 5.40¶
In [40]:
coeftab = {
"m5.5": m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1, 1000)),
"m5.6": m5_6.sample_posterior(random.PRNGKey(2), p5_6, sample_shape=(1, 1000)),
"m5.7": m5_7.sample_posterior(random.PRNGKey(3), p5_7, sample_shape=(1, 1000)),
}
az.plot_forest(
list(coeftab.values()),
model_names=list(coeftab.keys()),
var_names=["bM", "bN"],
hdi_prob=0.89,
)
plt.show()
Out[40]:
Code 5.41¶
In [41]:
xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)
post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, sample_shape=(1000,))
post.pop("mu")
post_pred = Predictive(m5_7.model, post)(random.PRNGKey(2), M=0, N=xseq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
plt.subplot(xlim=(dcc.M.min(), dcc.M.max()), ylim=(dcc.K.min(), dcc.K.max()))
plt.plot(xseq, mu_mean, "k")
plt.fill_between(xseq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.show()
Out[41]:
Code 5.42¶
In [42]:
# M -> K <- N
# M -> N
n = 100
M = dist.Normal().sample(random.PRNGKey(0), (n,))
N = dist.Normal(M).sample(random.PRNGKey(1))
K = dist.Normal(N - M).sample(random.PRNGKey(2))
d_sim = pd.DataFrame({"K": K, "N": N, "M": M})
Code 5.43¶
In [43]:
# M -> K <- N
# N -> M
n = 100
N = dist.Normal().sample(random.PRNGKey(0), (n,))
M = dist.Normal(N).sample(random.PRNGKey(1))
K = dist.Normal(N - M).sample(random.PRNGKey(2))
d_sim2 = pd.DataFrame({"K": K, "N": N, "M": M})
# M -> K <- N
# M <- U -> N
n = 100
N = dist.Normal().sample(random.PRNGKey(3), (n,))
M = dist.Normal(M).sample(random.PRNGKey(4))
K = dist.Normal(N - M).sample(random.PRNGKey(5))
d_sim3 = pd.DataFrame({"K": K, "N": N, "M": M})
Code 5.44¶
In [44]:
dag5_7 = nx.DiGraph()
dag5_7.add_edges_from([("M", "K"), ("N", "K"), ("M", "N")])
coordinates = {"M": (0, 0.5), "K": (1, 1), "N": (2, 0.5)}
MElist = []
for i in range(2):
for j in range(2):
for k in range(2):
new_dag = nx.DiGraph()
new_dag.add_edges_from(
[edge[::-1] if flip else edge for edge, flip in zip(dag5_7.edges, (i, j, k))]
)
if not list(nx.simple_cycles(new_dag)):
MElist.append(new_dag)
Code 5.45¶
In [45]:
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = Howell1
d.info()
d.head()
Out[45]:
Out[45]:
Code 5.46¶
In [46]:
mu_female = dist.Normal(178, 20).sample(random.PRNGKey(0), (int(1e4),))
diff = dist.Normal(0, 10).sample(random.PRNGKey(1), (int(1e4),))
mu_male = dist.Normal(178, 20).sample(random.PRNGKey(2), (int(1e4),)) + diff
print_summary({"mu_female": mu_female, "mu_male": mu_male}, 0.89, False)
Out[46]:
Code 5.47¶
In [47]:
d["sex"] = jnp.where(d.male.values == 1, 1, 0)
d.sex
Out[47]:
Code 5.48¶
In [48]:
def model(sex, height):
a = numpyro.sample("a", dist.Normal(178, 20).expand([len(set(sex))]))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
mu = a[sex]
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
m5_8 = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_8, optim.Adam(1), Trace_ELBO(), sex=d.sex.values, height=d.height.values
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p5_8 = svi_result.params
post = m5_8.sample_posterior(random.PRNGKey(1), p5_8, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[48]:
Out[48]:
Code 5.49¶
In [49]:
post = m5_8.sample_posterior(random.PRNGKey(1), p5_8, sample_shape=(1000,))
post["diff_fm"] = post["a"][:, 0] - post["a"][:, 1]
print_summary(post, 0.89, False)
Out[49]:
Code 5.50¶
In [50]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d.clade.unique()
Out[50]:
Code 5.51¶
In [51]:
d["clade_id"] = d.clade.astype("category").cat.codes
Code 5.52¶
In [52]:
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())
def model(clade_id, K):
a = numpyro.sample("a", dist.Normal(0, 0.5).expand([len(set(clade_id))]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[clade_id]
numpyro.sample("height", dist.Normal(mu, sigma), obs=K)
m5_9 = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_9, optim.Adam(1), Trace_ELBO(), clade_id=d.clade_id.values, K=d.K.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_9 = svi_result.params
post = m5_9.sample_posterior(random.PRNGKey(1), p5_9, sample_shape=(1000,))
labels = ["a[" + str(i) + "]:" + s for i, s in enumerate(sorted(d.clade.unique()))]
az.plot_forest({"a": post["a"][None, ...]}, hdi_prob=0.89)
plt.gca().set(yticklabels=labels[::-1], xlabel="expected kcal (std)")
plt.show()
Out[52]:
Out[52]:
Code 5.53¶
In [53]:
key = random.PRNGKey(63)
d["house"] = random.choice(key, jnp.repeat(jnp.arange(4), 8), d.shape[:1], False)
Code 5.54¶
In [54]:
def model(clade_id, house, K):
a = numpyro.sample("a", dist.Normal(0, 0.5).expand([len(set(clade_id))]))
h = numpyro.sample("h", dist.Normal(0, 0.5).expand([len(set(house))]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[clade_id] + h[house]
numpyro.sample("height", dist.Normal(mu, sigma), obs=K)
m5_10 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m5_10,
optim.Adam(1),
Trace_ELBO(),
clade_id=d.clade_id.values,
house=d.house.values,
K=d.K.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_10 = svi_result.params
Out[54]:
Comments
Comments powered by Disqus