# Chapter 5. The Many Variables & The Spurious Waffles

In [0]:
import math
import os

import arviz as az
import daft
import matplotlib.pyplot as plt
import pandas as pd
from causalgraphicalmodels import CausalGraphicalModel

import jax.numpy as jnp
from jax import lax, random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import ELBO, SVI, Predictive
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")


#### Code 5.1¶

In [1]:
# load data and copy
d = WaffleDivorce

# standardize variables
d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std())
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std())


#### Code 5.2¶

In [2]:
d.MedianAgeMarriage.std()

Out[2]:
1.2436303013880823

#### Code 5.3¶

In [3]:
def model(A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bA * A)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)

m5_1 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_1, optim.Adam(1), ELBO(), A=d.A.values, D=d.D.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_1 = svi.get_params(state)


#### Code 5.4¶

In [4]:
predictive = Predictive(m5_1.model, num_samples=1000, return_sites=["mu"])
prior_pred = predictive(random.PRNGKey(10), A=jnp.array([-2, 2]))
mu = prior_pred["mu"]
plt.subplot(xlim=(-2, 2), ylim=(-2, 2))
for i in range(20):
plt.plot([-2, 2], mu[i], "k", alpha=0.4)

Out[4]:

#### Code 5.5¶

In [5]:
# compute percentile interval of mean
A_seq = jnp.linspace(start=-3, stop=3.2, num=30)
post = m5_1.sample_posterior(random.PRNGKey(1), p5_1, (1000,))
post_pred = Predictive(m5_1.model, post)(random.PRNGKey(2), A=A_seq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=(5.5, 94.5), axis=0)

# plot it all
az.plot_pair(d[["D", "A"]].to_dict(orient="list"))
plt.plot(A_seq, mu_mean, "k")
plt.fill_between(A_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.show()

Out[5]:

#### Code 5.6¶

In [6]:
d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std())

def model(M, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bM * M
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)

m5_2 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_2, optim.Adam(1), ELBO(), M=d.M.values, D=d.D.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_2 = svi.get_params(state)


#### Code 5.7¶

In [7]:
dag5_1 = CausalGraphicalModel(
nodes=["A", "D", "M"], edges=[("A", "D"), ("A", "M"), ("M", "D")]
)
pgm = daft.PGM()
coordinates = {"A": (0, 0), "D": (1, 1), "M": (2, 0)}
for node in dag5_1.dag.nodes:
for edge in dag5_1.dag.edges:
with plt.rc_context({"figure.constrained_layout.use": False}):
pgm.render()
plt.gca().invert_yaxis()

Out[7]:

#### Code 5.8¶

In [8]:
DMA_dag2 = CausalGraphicalModel(nodes=["A", "D", "M"], edges=[("A", "D"), ("A", "M")])
all_independencies = DMA_dag2.get_all_independence_relationships()
for s in all_independencies:
if all(
t[0] != s[0] or t[1] != s[1] or not t[2].issubset(s[2])
for t in all_independencies
if t != s
):
print(s)

Out[8]:
('D', 'M', {'A'})


#### Code 5.9¶

In [9]:
DMA_dag2 = CausalGraphicalModel(
nodes=["A", "D", "M"], edges=[("A", "D"), ("A", "M"), ("M", "D")]
)
all_independencies = DMA_dag2.get_all_independence_relationships()
for s in all_independencies:
if all(
t[0] != s[0] or t[1] != s[1] or not t[2].issubset(s[2])
for t in all_independencies
if t != s
):
print(s)


#### Code 5.10¶

In [10]:
def model(M, A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bM * M + bA * A)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)

m5_3 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_3, optim.Adam(1), ELBO(), M=d.M.values, A=d.A.values, D=d.D.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_3 = svi.get_params(state)
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (1000,))
print_summary(post, 0.89, False)

Out[10]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a     -0.00      0.10     -0.01     -0.16      0.14   1049.96      1.00
bA     -0.61      0.16     -0.61     -0.86     -0.36    822.38      1.00
bM     -0.06      0.16     -0.06     -0.31      0.19    984.99      1.00
mu[0]      0.36      0.13      0.37      0.15      0.57    921.63      1.00
mu[1]      0.32      0.21      0.32     -0.01      0.66    900.77      1.00
mu[2]      0.12      0.10      0.12     -0.03      0.28    995.35      1.00
mu[3]      0.76      0.21      0.75      0.43      1.10    911.65      1.00
mu[4]     -0.35      0.12     -0.35     -0.52     -0.14   1070.56      1.00
mu[5]      0.12      0.15      0.12     -0.13      0.35    861.01      1.00
mu[6]     -0.71      0.17     -0.70     -0.95     -0.42   1053.27      1.00
mu[7]     -0.31      0.20     -0.32     -0.63      0.03    865.14      1.00
mu[8]     -1.74      0.40     -1.74     -2.28     -1.00    804.96      1.00
mu[9]     -0.12      0.14     -0.12     -0.37      0.08   1072.06      1.00
mu[10]      0.04      0.12      0.04     -0.14      0.23    858.83      1.00
mu[11]     -0.49      0.30     -0.49     -0.91      0.05    872.77      1.00
mu[12]      1.30      0.27      1.30      0.87      1.72    867.10      1.00
mu[13]     -0.43      0.13     -0.43     -0.61     -0.21   1125.00      1.00
mu[14]      0.18      0.11      0.18      0.00      0.35    987.73      1.00
mu[15]      0.30      0.11      0.30      0.11      0.47    909.55      1.00
mu[16]      0.48      0.13      0.49      0.27      0.70    875.38      1.00
mu[17]      0.58      0.15      0.58      0.35      0.82    867.29      1.00
mu[18]      0.07      0.10      0.06     -0.09      0.21   1002.97      1.00
mu[19]     -0.07      0.27     -0.07     -0.50      0.34   1018.90      1.00
mu[20]     -0.58      0.15     -0.58     -0.80     -0.32   1013.12      1.00
mu[21]     -1.13      0.24     -1.13     -1.47     -0.74    928.11      1.00
mu[22]     -0.11      0.16     -0.11     -0.38      0.12   1058.69      1.00
mu[23]     -0.05      0.21     -0.05     -0.38      0.27   1029.22      1.00
mu[24]      0.14      0.11      0.13     -0.05      0.30   1018.21      1.00
mu[25]      0.24      0.15      0.24      0.03      0.50    987.36      1.00
mu[26]      0.20      0.14      0.19     -0.04      0.42   1004.82      1.00
mu[27]      0.33      0.14      0.32      0.10      0.53    944.68      1.00
mu[28]     -0.31      0.14     -0.31     -0.53     -0.09   1106.73      1.00
mu[29]     -0.72      0.19     -0.72     -1.03     -0.42   1099.72      1.00
mu[30]      0.12      0.10      0.12     -0.03      0.28    992.25      1.00
mu[31]     -1.09      0.24     -1.10     -1.44     -0.70    891.79      1.00
mu[32]      0.17      0.10      0.17      0.00      0.33    973.21      1.00
mu[33]      0.26      0.24      0.26     -0.13      0.63    903.59      1.00
mu[34]     -0.07      0.15     -0.07     -0.34      0.14   1057.21      1.00
mu[35]      0.75      0.18      0.75      0.49      1.05    865.16      1.00
mu[36]      0.04      0.11      0.04     -0.15      0.20   1071.57      1.00
mu[37]     -0.44      0.17     -0.44     -0.70     -0.17   1096.13      1.00
mu[38]     -0.97      0.21     -0.97     -1.31     -0.63   1057.69      1.00
mu[39]     -0.14      0.11     -0.14     -0.32      0.04   1109.30      1.00
mu[40]      0.22      0.11      0.22      0.04      0.39    962.55      1.00
mu[41]      0.43      0.16      0.42      0.18      0.69    929.76      1.00
mu[42]      0.39      0.12      0.40      0.18      0.58    890.80      1.00
mu[43]      1.19      0.30      1.20      0.72      1.70    930.35      1.00
mu[44]     -0.36      0.15     -0.35     -0.59     -0.12   1105.88      1.00
mu[45]     -0.18      0.11     -0.18     -0.34      0.01    937.86      1.00
mu[46]      0.05      0.10      0.05     -0.12      0.21    891.13      1.00
mu[47]      0.48      0.13      0.48      0.27      0.70    875.38      1.00
mu[48]     -0.08      0.14     -0.08     -0.32      0.13   1065.16      1.00
mu[49]      0.74      0.34      0.73      0.16      1.22    951.13      1.00
sigma      0.80      0.08      0.79      0.68      0.92    971.23      1.00



#### Code 5.11¶

In [11]:
coeftab = {
"m5.1": m5_1.sample_posterior(random.PRNGKey(1), p5_1, (1, 1000,)),
"m5.2": m5_2.sample_posterior(random.PRNGKey(2), p5_2, (1, 1000,)),
"m5.3": m5_3.sample_posterior(random.PRNGKey(3), p5_3, (1, 1000,)),
}
az.plot_forest(
list(coeftab.values()),
model_names=list(coeftab.keys()),
var_names=["bA", "bM"],
hdi_prob=0.89,
)
plt.show()

Out[11]:

#### Code 5.12¶

In [12]:
N = 50  # number of simulated States
age = dist.Normal().sample(random.PRNGKey(0), sample_shape=(N,))  # sim A
mar = dist.Normal(age).sample(random.PRNGKey(1))  # sim A -> M
div = dist.Normal(age).sample(random.PRNGKey(2))  # sim A -> D


#### Code 5.13¶

In [13]:
def model(A, M=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bAM = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bAM * A)
numpyro.sample("M", dist.Normal(mu, sigma), obs=M)

m5_4 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_4, optim.Adam(0.1), ELBO(), A=d.A.values, M=d.M.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_4 = svi.get_params(state)


#### Code 5.14¶

In [14]:
post = m5_4.sample_posterior(random.PRNGKey(1), p5_4, (1000,))
post_pred = Predictive(m5_4.model, post)(random.PRNGKey(2), A=d.A.values)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_resid = d.M.values - mu_mean


#### Code 5.15¶

In [15]:
# call predictive without specifying new data
# so it uses original data
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (int(1e4),))
post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2), M=d.M.values, A=d.A.values)
mu = post_pred["mu"]

# summarize samples across cases
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=(5.5, 94.5), axis=0)

# simulate observations
# again no new data, so uses original data
D_sim = post_pred["D"]
D_PI = jnp.percentile(D_sim, q=(5.5, 94.5), axis=0)


#### Code 5.17¶

In [16]:
ax = plt.subplot(
ylim=(float(mu_PI.min()), float(mu_PI.max())),
xlabel="Observed divorce",
ylabel="Predicted divorce",
)
plt.plot(d.D, mu_mean, "o")
x = jnp.linspace(mu_PI.min(), mu_PI.max(), 101)
plt.plot(x, x, "--")
for i in range(d.shape[0]):
plt.plot([d.D[i]] * 2, mu_PI[:, i], "b")
fig = plt.gcf()

Out[16]:

#### Code 5.17¶

In [17]:
for i in range(d.shape[0]):
if d.Loc[i] in ["ID", "UT", "RI", "ME"]:
ax.annotate(
d.Loc[i], (d.D[i], mu_mean[i]), xytext=(-25, -5), textcoords="offset pixels"
)
fig

Out[17]:

#### Code 5.18¶

In [18]:
N = 100  # number of cases
# x_real as Gaussian with mean 0 and stddev 1
x_real = dist.Normal().sample(random.PRNGKey(0), (N,))
# x_spur as Gaussian with mean=x_real
x_spur = dist.Normal(x_real).sample(random.PRNGKey(1))
# y as Gaussian with mean=x_real
y = dist.Normal(x_real).sample(random.PRNGKey(2))
# bind all together in data frame
d = pd.DataFrame({"y": y, "x_real": x_real, "x_spur": x_spur})


#### Code 5.19¶

In [19]:
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce
d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std())
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std())
d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std())

def model(A, M=None, D=None):
# A -> M
aM = numpyro.sample("aM", dist.Normal(0, 0.2))
bAM = numpyro.sample("bAM", dist.Normal(0, 0.5))
sigma_M = numpyro.sample("sigma_M", dist.Exponential(1))
mu_M = aM + bAM * A
M = numpyro.sample("M", dist.Normal(mu_M, sigma_M), obs=M)
# A -> D <- M
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bM * M + bA * A
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)

m5_3_A = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_3_A, optim.Adam(0.1), ELBO(), A=d.A.values, M=d.M.values, D=d.D.values
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_3_A = svi.get_params(state)


#### Code 5.20¶

In [20]:
A_seq = jnp.linspace(-2, 2, num=30)


#### Code 5.21¶

In [21]:
# prep data
sim_dat = dict(A=A_seq)

# simulate M and then D, using A_seq
post = m5_3_A.sample_posterior(random.PRNGKey(1), p5_3_A, (1000,))
s = Predictive(m5_3_A.model, post)(random.PRNGKey(2), **sim_dat)


#### Code 5.22¶

In [22]:
plt.plot(sim_dat["A"], jnp.mean(s["D"], 0))
plt.gca().set(ylim=(-2, 2), xlabel="manipulated A", ylabel="counterfactual D")
plt.fill_between(
sim_dat["A"], *jnp.percentile(s["D"], q=(5.5, 94.5), axis=0), color="k", alpha=0.2
)
plt.title("Total counterfactual effect of A on D")
plt.show()

Out[22]:

#### Code 5.23¶

In [23]:
# new data frame, standardized to mean 26.1 and stddev 1.24
sim2_dat = dict(A=(jnp.array([20, 30]) - 26.1) / 1.24)
s2 = Predictive(m5_3_A.model, post, return_sites=["M", "D"])(
random.PRNGKey(2), **sim2_dat
)
jnp.mean(s2["D"][:, 1] - s2["D"][:, 0])

Out[23]:
DeviceArray(-4.588497, dtype=float32)

#### Code 5.24¶

In [24]:
sim_dat = dict(M=jnp.linspace(-2, 2, num=30), A=0)
s = Predictive(m5_3_A.model, post)(random.PRNGKey(2), **sim_dat)["D"]

plt.plot(sim_dat["M"], jnp.mean(s, 0))
plt.gca().set(ylim=(-2, 2), xlabel="manipulated A", ylabel="counterfactual D")
plt.fill_between(
sim_dat["M"], *jnp.percentile(s, q=(5.5, 94.5), axis=0), color="k", alpha=0.2
)
plt.title("Total counterfactual effect of M on D")
plt.show()

Out[24]:

#### Code 5.25¶

In [25]:
A_seq = jnp.linspace(-2, 2, num=30)


#### Code 5.26¶

In [26]:
post = m5_3_A.sample_posterior(random.PRNGKey(1), p5_3_A, (1000,))
post = {k: v[..., None] for k, v in post.items()}
M_sim = dist.Normal(post["aM"] + post["bAM"] * A_seq).sample(random.PRNGKey(1))


#### Code 5.27¶

In [27]:
D_sim = dist.Normal(post["a"] + post["bA"] * A_seq + post["bM"] * M_sim).sample(
random.PRNGKey(1)
)


#### Code 5.28¶

In [28]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d.info()

Out[28]:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 29 entries, 0 to 28
Data columns (total 8 columns):
#   Column          Non-Null Count  Dtype
---  ------          --------------  -----
1   species         29 non-null     object
2   kcal.per.g      29 non-null     float64
3   perc.fat        29 non-null     float64
4   perc.protein    29 non-null     float64
5   perc.lactose    29 non-null     float64
6   mass            29 non-null     float64
7   neocortex.perc  17 non-null     float64
dtypes: float64(6), object(2)
memory usage: 1.9+ KB

Out[28]:
clade species kcal.per.g perc.fat perc.protein perc.lactose mass neocortex.perc
0 Strepsirrhine Eulemur fulvus 0.49 16.60 15.42 67.98 1.95 55.16
1 Strepsirrhine E macaco 0.51 19.27 16.91 63.82 2.09 NaN
2 Strepsirrhine E mongoz 0.46 14.11 16.85 69.04 2.51 NaN
3 Strepsirrhine E rubriventer 0.48 14.91 13.18 71.91 1.62 NaN
4 Strepsirrhine Lemur catta 0.60 27.28 19.50 53.22 2.19 NaN

#### Code 5.29¶

In [29]:
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())
d["N"] = d["neocortex.perc"].pipe(lambda x: (x - x.mean()) / x.std())
d["M"] = d.mass.map(math.log).pipe(lambda x: (x - x.mean()) / x.std())


#### Code 5.30¶

In [30]:
def model(N, K):
a = numpyro.sample("a", dist.Normal(0, 1))
bN = numpyro.sample("bN", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bN * N
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

with numpyro.validation_enabled():
try:
m5_5_draft = AutoLaplaceApproximation(model)
svi = SVI(model, m5_5_draft, optim.Adam(1), ELBO(), N=d.N.values, K=d.K.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_5_draft = svi.get_params(state)
except ValueError as e:
print(str(e))

Out[30]:
The parameter loc has invalid values


#### Code 5.31¶

In [31]:
d["neocortex.perc"]

Out[31]:
0     55.16
1       NaN
2       NaN
3       NaN
4       NaN
5     64.54
6     64.54
7     67.64
8       NaN
9     68.85
10    58.85
11    61.69
12    60.32
13      NaN
14      NaN
15    69.97
16      NaN
17    70.41
18      NaN
19    73.40
20      NaN
21    67.53
22      NaN
23    71.26
24    72.60
25      NaN
26    70.24
27    76.30
28    75.49
Name: neocortex.perc, dtype: float64

#### Code 5.32¶

In [32]:
dcc = d.iloc[d[["K", "N", "M"]].dropna(how="any", axis=0).index]


#### Code 5.33¶

In [33]:
def model(N, K=None):
a = numpyro.sample("a", dist.Normal(0, 1))
bN = numpyro.sample("bN", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bN * N)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m5_5_draft = AutoLaplaceApproximation(model)
svi = SVI(model, m5_5_draft, optim.Adam(0.1), ELBO(), N=dcc.N.values, K=dcc.K.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_5_draft = svi.get_params(state)


#### Code 5.34¶

In [34]:
xseq = jnp.array([-2, 2])
prior_pred = Predictive(model, num_samples=1000)(random.PRNGKey(1), N=xseq)
mu = prior_pred["mu"]
plt.subplot(xlim=xseq, ylim=xseq)
for i in range(50):
plt.plot(xseq, mu[i], "k", alpha=0.3)

Out[34]:

#### Code 5.35¶

In [35]:
def model(N, K=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bN = numpyro.sample("bN", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bN * N)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m5_5 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_5, optim.Adam(1), ELBO(), N=dcc.N.values, K=dcc.K.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_5 = svi.get_params(state)


#### Code 5.36¶

In [36]:
post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, (1000,))
print_summary(post, 0.89, False)

Out[36]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a      0.05      0.16      0.05     -0.21      0.29    931.50      1.00
bN      0.13      0.23      0.13     -0.21      0.53   1111.88      1.00
mu[0]     -0.22      0.51     -0.22     -0.99      0.60    911.65      1.00
mu[1]     -0.02      0.20     -0.01     -0.36      0.27    906.66      1.00
mu[2]     -0.02      0.20     -0.01     -0.36      0.27    906.66      1.00
mu[3]      0.05      0.16      0.05     -0.21      0.29    931.28      1.00
mu[4]      0.08      0.17      0.08     -0.18      0.36    940.50      1.00
mu[5]     -0.14      0.38     -0.13     -0.74      0.45    889.84      1.00
mu[6]     -0.08      0.28     -0.07     -0.52      0.36    874.50      1.00
mu[7]     -0.11      0.33     -0.10     -0.65      0.37    884.63      1.00
mu[8]      0.10      0.18      0.10     -0.18      0.41    964.72      1.00
mu[9]      0.11      0.19      0.11     -0.17      0.43    975.68      1.00
mu[10]      0.17      0.28      0.17     -0.25      0.63   1037.05      1.00
mu[11]      0.05      0.16      0.05     -0.21      0.29    931.71      1.00
mu[12]      0.13      0.21      0.13     -0.18      0.49    996.35      1.00
mu[13]      0.16      0.25      0.15     -0.21      0.58   1024.00      1.00
mu[14]      0.11      0.19      0.10     -0.17      0.43    971.43      1.00
mu[15]      0.24      0.37      0.23     -0.39      0.80   1067.60      1.00
mu[16]      0.22      0.35      0.21     -0.37      0.74   1061.13      1.00
sigma      1.05      0.18      1.03      0.78      1.35    944.03      1.00



#### Code 5.37¶

In [37]:
xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)
post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, (1000,))
post_pred = Predictive(m5_5.model, post)(random.PRNGKey(2), N=xseq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=(5.5, 94.5), axis=0)
az.plot_pair(dcc[["N", "K"]].to_dict(orient="list"))
plt.plot(xseq, mu_mean, "k")
plt.fill_between(xseq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.show()

Out[37]:

#### Code 5.38¶

In [38]:
def model(M, K=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bM * M)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m5_6 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_6, optim.Adam(1), ELBO(), M=dcc.M.values, K=dcc.K.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_6 = svi.get_params(state)
post = m5_6.sample_posterior(random.PRNGKey(1), p5_6, (1000,))
print_summary(post, 0.89, False)

Out[38]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a      0.06      0.16      0.06     -0.20      0.29    931.50      1.00
bM     -0.28      0.20     -0.28     -0.61      0.03   1088.44      1.00
mu[0]      0.18      0.18      0.19     -0.11      0.47    944.38      1.00
mu[1]      0.02      0.16      0.02     -0.24      0.26    932.90      1.00
mu[2]      0.02      0.16      0.02     -0.24      0.26    933.51      1.00
mu[3]      0.14      0.17      0.15     -0.11      0.42    958.01      1.00
mu[4]      0.36      0.27      0.36     -0.07      0.78    872.95      1.00
mu[5]      0.65      0.45      0.65     -0.03      1.39    898.89      1.00
mu[6]      0.42      0.31      0.43     -0.07      0.89    878.58      1.00
mu[7]      0.49      0.35      0.50     -0.09      1.00    885.51      1.00
mu[8]      0.22      0.20      0.23     -0.12      0.51    906.05      1.00
mu[9]      0.10      0.16      0.10     -0.17      0.33    940.69      1.00
mu[10]     -0.12      0.20     -0.13     -0.40      0.21    989.01      1.00
mu[11]      0.02      0.16      0.02     -0.24      0.26    933.51      1.00
mu[12]     -0.30      0.29     -0.30     -0.75      0.18   1050.56      1.00
mu[13]     -0.43      0.38     -0.44     -1.07      0.12   1073.41      1.00
mu[14]     -0.32      0.31     -0.33     -0.81      0.16   1055.58      1.00
mu[15]     -0.29      0.28     -0.29     -0.74      0.16   1047.84      1.00
mu[16]     -0.37      0.34     -0.38     -0.95      0.12   1064.72      1.00
sigma      0.99      0.17      0.98      0.72      1.26    957.10      1.00



#### Code 5.39¶

In [39]:
def model(N, M, K=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bN = numpyro.sample("bN", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bN * N + bM * M)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m5_7 = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_7, optim.Adam(1), ELBO(), N=dcc.N.values, M=dcc.M.values, K=dcc.K.values
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_7 = svi.get_params(state)
post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, (1000,))
print_summary(post, 0.89, False)

Out[39]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a      0.04      0.13      0.03     -0.17      0.24   1049.96      1.00
bM     -0.67      0.23     -0.67     -1.02     -0.28    841.72      1.00
bN      0.66      0.25      0.66      0.26      1.06    884.24      1.00
mu[0]     -1.02      0.46     -1.03     -1.74     -0.27    949.86      1.00
mu[1]     -0.38      0.19     -0.37     -0.65     -0.05   1072.17      1.00
mu[2]     -0.39      0.19     -0.38     -0.67     -0.06   1063.97      1.00
mu[3]      0.25      0.16      0.26      0.01      0.49    968.49      1.00
mu[4]      0.90      0.33      0.90      0.39      1.41    877.16      1.00
mu[5]      0.48      0.38      0.48     -0.14      1.06   1015.82      1.00
mu[6]      0.26      0.26      0.25     -0.13      0.68   1041.49      1.00
mu[7]      0.26      0.30      0.25     -0.20      0.71   1039.34      1.00
mu[8]      0.70      0.26      0.70      0.32      1.14    868.14      1.00
mu[9]      0.46      0.20      0.46      0.15      0.80    873.70      1.00
mu[10]      0.26      0.22      0.26     -0.07      0.64    895.39      1.00
mu[11]     -0.06      0.13     -0.06     -0.27      0.14    992.13      1.00
mu[12]     -0.39      0.24     -0.39     -0.81     -0.03    864.89      1.00
mu[13]     -0.56      0.32     -0.57     -1.01      0.00    871.32      1.00
mu[14]     -0.56      0.27     -0.56     -0.97     -0.11    858.25      1.00
mu[15]      0.19      0.30      0.19     -0.29      0.64    915.24      1.00
mu[16]     -0.10      0.30     -0.10     -0.58      0.36    895.42      1.00
sigma      0.76      0.13      0.75      0.54      0.94   1030.06      1.00



#### Code 5.40¶

In [40]:
coeftab = {
"m5.5": m5_5.sample_posterior(random.PRNGKey(1), p5_5, (1, 1000,)),
"m5.6": m5_6.sample_posterior(random.PRNGKey(2), p5_6, (1, 1000,)),
"m5.7": m5_7.sample_posterior(random.PRNGKey(3), p5_7, (1, 1000,)),
}
az.plot_forest(
list(coeftab.values()),
model_names=list(coeftab.keys()),
var_names=["bM", "bN"],
hdi_prob=0.89,
)
plt.show()

Out[40]:

#### Code 5.41¶

In [41]:
xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)
post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, (1000,))
post_pred = Predictive(m5_7.model, post)(random.PRNGKey(2), M=0, N=xseq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=(5.5, 94.5), axis=0)
plt.subplot(xlim=(dcc.M.min(), dcc.M.max()), ylim=(dcc.K.min(), dcc.K.max()))
plt.plot(xseq, mu_mean, "k")
plt.fill_between(xseq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.show()

Out[41]:

#### Code 5.42¶

In [42]:
# M -> K <- N
# M -> N
n = 100
M = dist.Normal().sample(random.PRNGKey(0), (n,))
N = dist.Normal(M).sample(random.PRNGKey(1))
K = dist.Normal(N - M).sample(random.PRNGKey(2))
d_sim = pd.DataFrame({"K": K, "N": N, "M": M})


#### Code 5.43¶

In [43]:
# M -> K <- N
# N -> M
n = 100
N = dist.Normal().sample(random.PRNGKey(0), (n,))
M = dist.Normal(N).sample(random.PRNGKey(1))
K = dist.Normal(N - M).sample(random.PRNGKey(2))
d_sim2 = pd.DataFrame({"K": K, "N": N, "M": M})

# M -> K <- N
# M <- U -> N
n = 100
N = dist.Normal().sample(random.PRNGKey(3), (n,))
M = dist.Normal(M).sample(random.PRNGKey(4))
K = dist.Normal(N - M).sample(random.PRNGKey(5))
d_sim3 = pd.DataFrame({"K": K, "N": N, "M": M})


#### Code 5.44¶

In [44]:
dag5_7 = CausalGraphicalModel(
nodes=["M", "K", "N"], edges=[("M", "K"), ("N", "K"), ("M", "N")]
)
coordinates = {"M": (0, 0.5), "K": (1, 1), "N": (2, 0.5)}
nodes = list(dag5_7.dag.nodes.keys())
edges = list(dag5_7.dag.edges.keys())
MElist = []
for i in range(2):
for j in range(2):
for k in range(2):
try:
new_dag = CausalGraphicalModel(
nodes=nodes,
edges=[
edges[0] if i == 0 else edges[0][::-1],
edges[1] if j == 0 else edges[1][::-1],
edges[2] if k == 0 else edges[2][::-1],
],
)
MElist.append(new_dag)
except:
pass


#### Code 5.45¶

In [45]:
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = Howell1
d.info()

Out[45]:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 544 entries, 0 to 543
Data columns (total 4 columns):
#   Column  Non-Null Count  Dtype
---  ------  --------------  -----
0   height  544 non-null    float64
1   weight  544 non-null    float64
2   age     544 non-null    float64
3   male    544 non-null    int64
dtypes: float64(3), int64(1)
memory usage: 17.1 KB

Out[45]:
height weight age male
0 151.765 47.825606 63.0 1
1 139.700 36.485807 63.0 0
2 136.525 31.864838 65.0 0
3 156.845 53.041915 41.0 1
4 145.415 41.276872 51.0 0

#### Code 5.46¶

In [46]:
mu_female = dist.Normal(178, 20).sample(random.PRNGKey(0), (int(1e4),))
diff = dist.Normal(0, 10).sample(random.PRNGKey(1), (int(1e4),))
mu_male = dist.Normal(178, 20).sample(random.PRNGKey(2), (int(1e4),)) + diff
print_summary({"mu_female": mu_female, "mu_male": mu_male}, 0.89, False)

Out[46]:
                 mean       std    median      5.5%     94.5%     n_eff     r_hat
mu_female    178.21     20.22    178.24    147.19    211.84   9943.61      1.00
mu_male    178.10     22.36    178.51    142.35    213.41  10190.57      1.00



#### Code 5.47¶

In [47]:
d["sex"] = jnp.where(d.male.values == 1, 1, 0)
d.sex

Out[47]:
0      1
1      0
2      0
3      1
4      0
..
539    1
540    1
541    0
542    1
543    1
Name: sex, Length: 544, dtype: int32

#### Code 5.48¶

In [48]:
def model(sex, height):
a = numpyro.sample("a", dist.Normal(178, 20).expand([len(set(sex))]))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
mu = a[sex]
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)

m5_8 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_8, optim.Adam(1), ELBO(), sex=d.sex.values, height=d.height.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(2000))
p5_8 = svi.get_params(state)
post = m5_8.sample_posterior(random.PRNGKey(1), p5_8, (1000,))
print_summary(post, 0.89, False)

Out[48]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a[0]    135.02      1.63    135.07    132.32    137.46    931.50      1.00
a[1]    142.56      1.73    142.54    140.02    145.51   1111.51      1.00
sigma     27.32      0.84     27.32     26.03     28.71    951.62      1.00



#### Code 5.49¶

In [49]:
post = m5_8.sample_posterior(random.PRNGKey(1), p5_8, (1000,))
post["diff_fm"] = post["a"][:, 0] - post["a"][:, 1]
print_summary(post, 0.89, False)

Out[49]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a[0]    135.02      1.63    135.07    132.32    137.46    931.50      1.00
a[1]    142.56      1.73    142.54    140.02    145.51   1111.51      1.00
diff_fm     -7.54      2.38     -7.47    -11.77     -4.32    876.56      1.00
sigma     27.32      0.84     27.32     26.03     28.71    951.62      1.00



#### Code 5.50¶

In [50]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk

Out[50]:
array(['Strepsirrhine', 'New World Monkey', 'Old World Monkey', 'Ape'],
dtype=object)

#### Code 5.51¶

In [51]:
d["clade_id"] = d.clade.astype("category").cat.codes


#### Code 5.52¶

In [52]:
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())

sigma = numpyro.sample("sigma", dist.Exponential(1))
numpyro.sample("height", dist.Normal(mu, sigma), obs=K)

m5_9 = AutoLaplaceApproximation(model)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p5_9 = svi.get_params(state)
post = m5_9.sample_posterior(random.PRNGKey(1), p5_9, (1000,))
labels = ["a[" + str(i) + "]:" + s for i, s in enumerate(sorted(d.clade.unique()))]
az.plot_forest({"a": post["a"][None, ...]}, hdi_prob=0.89)
plt.gca().set(yticklabels=labels[::-1], xlabel="expected kcal (std)")
plt.show()

Out[52]:

#### Code 5.53¶

In [53]:
key = random.PRNGKey(63)
d["house"] = random.choice(key, jnp.repeat(jnp.arange(4), 8), d.shape[:1], False)


#### Code 5.54¶

In [54]:
def model(clade_id, house, K):
h = numpyro.sample("h", dist.Normal(0, 0.5).expand([len(set(house))]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
numpyro.sample("height", dist.Normal(mu, sigma), obs=K)

m5_10 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m5_10,