Chapter 5. The Many Variables & The Spurious Waffles

In [ ]:
!pip install -q numpyro arviz daft networkx
In [0]:
import collections
import itertools
import math
import os

import arviz as az
import daft
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]"arviz-darkgrid")

Code 5.1

In [1]:
# load data and copy
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce

# standardize variables
d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std())
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std())

Code 5.2

In [2]:

Code 5.3

In [3]:
def model(A, D=None):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bA = numpyro.sample("bA", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bA * A)
    numpyro.sample("D", dist.Normal(mu, sigma), obs=D)

m5_1 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_1, optim.Adam(1), Trace_ELBO(), A=d.A.values, D=d.D.values)
svi_result =, 1000)
p5_1 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1196.77it/s, init loss: 2138.6682, avg. loss [951-1000]: 59.4385]

Code 5.4

In [4]:
predictive = Predictive(m5_1.model, num_samples=1000, return_sites=["mu"])
prior_pred = predictive(random.PRNGKey(10), A=jnp.array([-2, 2]))
mu = prior_pred["mu"]
plt.subplot(xlim=(-2, 2), ylim=(-2, 2))
for i in range(20):
    plt.plot([-2, 2], mu[i], "k", alpha=0.4)
No description has been provided for this image

Code 5.5

In [5]:
# compute percentile interval of mean
A_seq = jnp.linspace(start=-3, stop=3.2, num=30)
post = m5_1.sample_posterior(random.PRNGKey(1), p5_1, sample_shape=(1000,))
post_pred = Predictive(m5_1.model, post)(random.PRNGKey(2), A=A_seq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)

# plot it all
az.plot_pair(d[["D", "A"]].to_dict(orient="list"))
plt.plot(A_seq, mu_mean, "k")
plt.fill_between(A_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
No description has been provided for this image

Code 5.6

In [6]:
d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std())

def model(M, D=None):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bM * M
    numpyro.sample("D", dist.Normal(mu, sigma), obs=D)

m5_2 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_2, optim.Adam(1), Trace_ELBO(), M=d.M.values, D=d.D.values)
svi_result =, 1000)
p5_2 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1218.38it/s, init loss: 962.7464, avg. loss [951-1000]: 66.1313]

Code 5.7

In [7]:
dag5_1 = nx.DiGraph()
dag5_1.add_edges_from([("A", "D"), ("A", "M"), ("M", "D")])
pgm = daft.PGM()
coordinates = {"A": (0, 0), "D": (1, 1), "M": (2, 0)}
for node in dag5_1.nodes:
    pgm.add_node(node, node, *coordinates[node])
for edge in dag5_1.edges:
with plt.rc_context({"figure.constrained_layout.use": False}):
No description has been provided for this image

Code 5.8

In [8]:
DMA_dag2 = nx.DiGraph()
DMA_dag2.add_edges_from([("A", "D"), ("A", "M")])
conditional_independencies = collections.defaultdict(list)
for edge in itertools.combinations(sorted(DMA_dag2.nodes), 2):
    remaining = sorted(set(DMA_dag2.nodes) - set(edge))
    for size in range(len(remaining) + 1):
        for subset in itertools.combinations(remaining, size):
            if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):
            if nx.d_separated(DMA_dag2, {edge[0]}, {edge[1]}, set(subset)):
                print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else ""))
D _||_ M | A

Code 5.9

In [9]:
DMA_dag1 = nx.DiGraph()
DMA_dag1.add_edges_from([("A", "D"), ("A", "M"), ("M", "D")])
conditional_independencies = collections.defaultdict(list)
for edge in itertools.combinations(sorted(DMA_dag1.nodes), 2):
    remaining = sorted(set(DMA_dag1.nodes) - set(edge))
    for size in range(len(remaining) + 1):
        for subset in itertools.combinations(remaining, size):
            if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):
            if nx.d_separated(DMA_dag1, {edge[0]}, {edge[1]}, set(subset)):
                print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else ""))

Code 5.10

In [10]:
def model(M, A, D=None):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    bA = numpyro.sample("bA", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bM * M + bA * A)
    numpyro.sample("D", dist.Normal(mu, sigma), obs=D)

m5_3 = AutoLaplaceApproximation(model)
svi = SVI(
    model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values
svi_result =, 1000)
p5_3 = svi_result.params
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, sample_shape=(1000,))
print_summary(post, 0.89, False)
100%|██████████| 1000/1000 [00:00<00:00, 1002.82it/s, init loss: 3201.7393, avg. loss [951-1000]: 59.5721]
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a     -0.00      0.10     -0.01     -0.16      0.14   1049.96      1.00
        bA     -0.61      0.16     -0.61     -0.86     -0.36    822.38      1.00
        bM     -0.06      0.16     -0.06     -0.31      0.19    984.99      1.00
     mu[0]      0.36      0.13      0.37      0.15      0.57    921.63      1.00
     mu[1]      0.32      0.21      0.32     -0.01      0.66    900.77      1.00
     mu[2]      0.12      0.10      0.12     -0.03      0.28    995.35      1.00
     mu[3]      0.76      0.21      0.75      0.43      1.10    911.65      1.00
     mu[4]     -0.35      0.12     -0.35     -0.52     -0.14   1070.56      1.00
     mu[5]      0.12      0.15      0.12     -0.13      0.35    861.01      1.00
     mu[6]     -0.71      0.17     -0.70     -0.95     -0.42   1053.26      1.00
     mu[7]     -0.31      0.20     -0.32     -0.63      0.03    865.14      1.00
     mu[8]     -1.74      0.40     -1.74     -2.28     -1.00    804.96      1.00
     mu[9]     -0.12      0.14     -0.12     -0.37      0.08   1072.06      1.00
    mu[10]      0.04      0.12      0.04     -0.14      0.23    858.83      1.00
    mu[11]     -0.49      0.30     -0.49     -0.91      0.05    872.77      1.00
    mu[12]      1.30      0.27      1.30      0.87      1.72    867.10      1.00
    mu[13]     -0.43      0.13     -0.43     -0.61     -0.21   1125.00      1.00
    mu[14]      0.18      0.11      0.18      0.00      0.35    987.73      1.00
    mu[15]      0.30      0.11      0.30      0.11      0.47    909.55      1.00
    mu[16]      0.48      0.13      0.49      0.27      0.70    875.38      1.00
    mu[17]      0.58      0.15      0.58      0.35      0.82    867.29      1.00
    mu[18]      0.07      0.10      0.06     -0.09      0.21   1002.97      1.00
    mu[19]     -0.07      0.27     -0.07     -0.50      0.34   1018.90      1.00
    mu[20]     -0.58      0.15     -0.58     -0.80     -0.32   1013.12      1.00
    mu[21]     -1.13      0.24     -1.13     -1.47     -0.74    928.11      1.00
    mu[22]     -0.11      0.16     -0.11     -0.38      0.12   1058.69      1.00
    mu[23]     -0.05      0.21     -0.05     -0.38      0.27   1029.22      1.00
    mu[24]      0.13      0.11      0.13     -0.05      0.30   1018.21      1.00
    mu[25]      0.24      0.15      0.24      0.03      0.50    987.36      1.00
    mu[26]      0.20      0.14      0.19     -0.04      0.42   1004.82      1.00
    mu[27]      0.33      0.14      0.32      0.10      0.53    944.68      1.00
    mu[28]     -0.31      0.14     -0.31     -0.53     -0.09   1106.73      1.00
    mu[29]     -0.72      0.19     -0.72     -1.03     -0.42   1099.72      1.00
    mu[30]      0.12      0.10      0.12     -0.03      0.28    992.25      1.00
    mu[31]     -1.09      0.24     -1.10     -1.44     -0.70    891.79      1.00
    mu[32]      0.17      0.10      0.17      0.00      0.33    973.21      1.00
    mu[33]      0.26      0.24      0.26     -0.13      0.63    903.59      1.00
    mu[34]     -0.07      0.15     -0.07     -0.34      0.14   1057.21      1.00
    mu[35]      0.75      0.18      0.75      0.49      1.05    865.16      1.00
    mu[36]      0.04      0.11      0.04     -0.15      0.20   1071.57      1.00
    mu[37]     -0.44      0.17     -0.44     -0.70     -0.17   1096.13      1.00
    mu[38]     -0.97      0.21     -0.97     -1.31     -0.63   1057.69      1.00
    mu[39]     -0.14      0.11     -0.14     -0.32      0.04   1109.30      1.00
    mu[40]      0.22      0.11      0.22      0.04      0.39    962.55      1.00
    mu[41]      0.43      0.16      0.42      0.18      0.69    929.76      1.00
    mu[42]      0.39      0.12      0.40      0.18      0.57    890.80      1.00
    mu[43]      1.19      0.30      1.20      0.72      1.70    930.35      1.00
    mu[44]     -0.36      0.15     -0.35     -0.59     -0.12   1105.88      1.00
    mu[45]     -0.18      0.11     -0.18     -0.34      0.01    937.86      1.00
    mu[46]      0.05      0.10      0.05     -0.12      0.21    891.13      1.00
    mu[47]      0.48      0.13      0.48      0.27      0.70    875.38      1.00
    mu[48]     -0.08      0.14     -0.08     -0.32      0.13   1065.16      1.00
    mu[49]      0.74      0.34      0.73      0.16      1.22    951.13      1.00
     sigma      0.80      0.08      0.79      0.68      0.92    971.25      1.00

Code 5.11

In [11]:
coeftab = {
    "m5.1": m5_1.sample_posterior(random.PRNGKey(1), p5_1, sample_shape=(1, 1000)),
    "m5.2": m5_2.sample_posterior(random.PRNGKey(2), p5_2, sample_shape=(1, 1000)),
    "m5.3": m5_3.sample_posterior(random.PRNGKey(3), p5_3, sample_shape=(1, 1000)),
    var_names=["bA", "bM"],
No description has been provided for this image

Code 5.12

In [12]:
N = 50  # number of simulated States
age = dist.Normal().sample(random.PRNGKey(0), sample_shape=(N,))  # sim A
mar = dist.Normal(age).sample(random.PRNGKey(1))  # sim A -> M
div = dist.Normal(age).sample(random.PRNGKey(2))  # sim A -> D

Code 5.13

In [13]:
def model(A, M=None):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bAM = numpyro.sample("bA", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bAM * A)
    numpyro.sample("M", dist.Normal(mu, sigma), obs=M)

m5_4 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_4, optim.Adam(0.1), Trace_ELBO(), A=d.A.values, M=d.M.values)
svi_result =, 1000)
p5_4 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1218.61it/s, init loss: 2288.6685, avg. loss [951-1000]: 52.6188]

Code 5.14

In [14]:
post = m5_4.sample_posterior(random.PRNGKey(1), p5_4, sample_shape=(1000,))
post_pred = Predictive(m5_4.model, post)(random.PRNGKey(2), A=d.A.values)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_resid = d.M.values - mu_mean

Code 5.15

In [15]:
# call predictive without specifying new data
# so it uses original data
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, sample_shape=(int(1e4),))
post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2), M=d.M.values, A=d.A.values)
mu = post_pred["mu"]

# summarize samples across cases
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)

# simulate observations
# again no new data, so uses original data
D_sim = post_pred["D"]
D_PI = jnp.percentile(D_sim, q=jnp.array([5.5, 94.5]), axis=0)

Code 5.16

In [16]:
ax = plt.subplot(
    ylim=(float(mu_PI.min()), float(mu_PI.max())),
    xlabel="Observed divorce",
    ylabel="Predicted divorce",
plt.plot(d.D, mu_mean, "o")
x = jnp.linspace(mu_PI.min(), mu_PI.max(), 101)
plt.plot(x, x, "--")
for i in range(d.shape[0]):
    plt.plot([d.D[i]] * 2, mu_PI[:, i], "b")
fig = plt.gcf()
No description has been provided for this image

Code 5.17

In [17]:
for i in range(d.shape[0]):
    if d.Loc[i] in ["ID", "UT", "RI", "ME"]:
            d.Loc[i], (d.D[i], mu_mean[i]), xytext=(-25, -5), textcoords="offset pixels"
No description has been provided for this image

Code 5.18

In [18]:
N = 100  # number of cases
# x_real as Gaussian with mean 0 and stddev 1
x_real = dist.Normal().sample(random.PRNGKey(0), (N,))
# x_spur as Gaussian with mean=x_real
x_spur = dist.Normal(x_real).sample(random.PRNGKey(1))
# y as Gaussian with mean=x_real
y = dist.Normal(x_real).sample(random.PRNGKey(2))
# bind all together in data frame
d = pd.DataFrame({"y": y, "x_real": x_real, "x_spur": x_spur})

Code 5.19

In [19]:
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce
d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std())
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std())
d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std())

def model(A, M=None, D=None):
    # A -> M
    aM = numpyro.sample("aM", dist.Normal(0, 0.2))
    bAM = numpyro.sample("bAM", dist.Normal(0, 0.5))
    sigma_M = numpyro.sample("sigma_M", dist.Exponential(1))
    mu_M = aM + bAM * A
    M = numpyro.sample("M", dist.Normal(mu_M, sigma_M), obs=M)
    # A -> D <- M
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    bA = numpyro.sample("bA", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bM * M + bA * A
    numpyro.sample("D", dist.Normal(mu, sigma), obs=D)

m5_3_A = AutoLaplaceApproximation(model)
svi = SVI(
svi_result =, 1000)
p5_3_A = svi_result.params
100%|██████████| 1000/1000 [00:01<00:00, 734.26it/s, init loss: 10480.9580, avg. loss [951-1000]: 112.1909]

Code 5.20

In [20]:
A_seq = jnp.linspace(-2, 2, num=30)

Code 5.21

In [21]:
# prep data
sim_dat = dict(A=A_seq)

# simulate M and then D, using A_seq
post = m5_3_A.sample_posterior(random.PRNGKey(1), p5_3_A, sample_shape=(1000,))
s = Predictive(m5_3_A.model, post)(random.PRNGKey(2), **sim_dat)

Code 5.22

In [22]:
plt.plot(sim_dat["A"], jnp.mean(s["D"], 0))
plt.gca().set(ylim=(-2, 2), xlabel="manipulated A", ylabel="counterfactual D")
    *jnp.percentile(s["D"], q=jnp.array([5.5, 94.5]), axis=0),
plt.title("Total counterfactual effect of A on D")
No description has been provided for this image

Code 5.23

In [23]:
# new data frame, standardized to mean 26.1 and stddev 1.24
sim2_dat = dict(A=(jnp.array([20, 30]) - 26.1) / 1.24)
s2 = Predictive(m5_3_A.model, post, return_sites=["M", "D"])(
    random.PRNGKey(2), **sim2_dat
jnp.mean(s2["D"][:, 1] - s2["D"][:, 0])
DeviceArray(-4.6818223, dtype=float32)

Code 5.24

In [24]:
sim_dat = dict(M=jnp.linspace(-2, 2, num=30), A=0)
s = Predictive(m5_3_A.model, post)(random.PRNGKey(2), **sim_dat)["D"]

plt.plot(sim_dat["M"], jnp.mean(s, 0))
plt.gca().set(ylim=(-2, 2), xlabel="manipulated A", ylabel="counterfactual D")
    *jnp.percentile(s, q=jnp.array([5.5, 94.5]), axis=0),
plt.title("Total counterfactual effect of M on D")
No description has been provided for this image

Code 5.25

In [25]:
A_seq = jnp.linspace(-2, 2, num=30)

Code 5.26

In [26]:
post = m5_3_A.sample_posterior(random.PRNGKey(1), p5_3_A, sample_shape=(1000,))
post = {k: v[..., None] for k, v in post.items()}
M_sim = dist.Normal(post["aM"] + post["bAM"] * A_seq).sample(random.PRNGKey(1))

Code 5.27

In [27]:
D_sim = dist.Normal(post["a"] + post["bA"] * A_seq + post["bM"] * M_sim).sample(

Code 5.28

In [28]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 29 entries, 0 to 28
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   clade           29 non-null     object 
 1   species         29 non-null     object 
 2   kcal.per.g      29 non-null     float64
 3   perc.fat        29 non-null     float64
 4   perc.protein    29 non-null     float64
 5   perc.lactose    29 non-null     float64
 6   mass            29 non-null     float64
 7   neocortex.perc  17 non-null     float64
dtypes: float64(6), object(2)
memory usage: 1.9+ KB
clade species kcal.per.g perc.fat perc.protein perc.lactose mass neocortex.perc
0 Strepsirrhine Eulemur fulvus 0.49 16.60 15.42 67.98 1.95 55.16
1 Strepsirrhine E macaco 0.51 19.27 16.91 63.82 2.09 NaN
2 Strepsirrhine E mongoz 0.46 14.11 16.85 69.04 2.51 NaN
3 Strepsirrhine E rubriventer 0.48 14.91 13.18 71.91 1.62 NaN
4 Strepsirrhine Lemur catta 0.60 27.28 19.50 53.22 2.19 NaN

Code 5.29

In [29]:
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())
d["N"] = d["neocortex.perc"].pipe(lambda x: (x - x.mean()) / x.std())
d["M"] = x: (x - x.mean()) / x.std())

Code 5.30

In [30]:
def model(N, K):
    a = numpyro.sample("a", dist.Normal(0, 1))
    bN = numpyro.sample("bN", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bN * N
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

with numpyro.validation_enabled():
        m5_5_draft = AutoLaplaceApproximation(model)
        svi = SVI(
            model, m5_5_draft, optim.Adam(1), Trace_ELBO(), N=d.N.values, K=d.K.values
        svi_result =, 1000)
        p5_5_draft = svi_result.params
    except ValueError as e:
Normal distribution got invalid loc parameter.

Code 5.31

In [31]:
0     55.16
1       NaN
2       NaN
3       NaN
4       NaN
5     64.54
6     64.54
7     67.64
8       NaN
9     68.85
10    58.85
11    61.69
12    60.32
13      NaN
14      NaN
15    69.97
16      NaN
17    70.41
18      NaN
19    73.40
20      NaN
21    67.53
22      NaN
23    71.26
24    72.60
25      NaN
26    70.24
27    76.30
28    75.49
Name: neocortex.perc, dtype: float64

Code 5.32

In [32]:
dcc = d.iloc[d[["K", "N", "M"]].dropna(how="any", axis=0).index]

Code 5.33

In [33]:
def model(N, K=None):
    a = numpyro.sample("a", dist.Normal(0, 1))
    bN = numpyro.sample("bN", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bN * N)
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m5_5_draft = AutoLaplaceApproximation(model)
svi = SVI(
    model, m5_5_draft, optim.Adam(0.1), Trace_ELBO(), N=dcc.N.values, K=dcc.K.values
svi_result =, 1000)
p5_5_draft = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1444.83it/s, init loss: 411.1621, avg. loss [951-1000]: 26.8758]

Code 5.34

In [34]:
xseq = jnp.array([-2, 2])
prior_pred = Predictive(model, num_samples=1000)(random.PRNGKey(1), N=xseq)
mu = prior_pred["mu"]
plt.subplot(xlim=xseq, ylim=xseq)
for i in range(50):
    plt.plot(xseq, mu[i], "k", alpha=0.3)
No description has been provided for this image

Code 5.35

In [35]:
def model(N, K=None):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bN = numpyro.sample("bN", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bN * N)
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m5_5 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_5, optim.Adam(1), Trace_ELBO(), N=dcc.N.values, K=dcc.K.values)
svi_result =, 1000)
p5_5 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1399.46it/s, init loss: 414.1050, avg. loss [951-1000]: 24.6918]

Code 5.36

In [36]:
post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1000,))
print_summary(post, 0.89, False)
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      0.05      0.16      0.05     -0.21      0.29    931.50      1.00
        bN      0.13      0.23      0.13     -0.21      0.53   1111.88      1.00
     mu[0]     -0.22      0.51     -0.22     -0.99      0.60    911.65      1.00
     mu[1]     -0.02      0.20     -0.01     -0.36      0.27    906.66      1.00
     mu[2]     -0.02      0.20     -0.01     -0.36      0.27    906.66      1.00
     mu[3]      0.05      0.16      0.05     -0.21      0.29    931.28      1.00
     mu[4]      0.08      0.17      0.08     -0.18      0.36    940.50      1.00
     mu[5]     -0.14      0.38     -0.13     -0.74      0.45    889.84      1.00
     mu[6]     -0.08      0.28     -0.07     -0.52      0.36    874.50      1.00
     mu[7]     -0.11      0.33     -0.10     -0.65      0.37    884.63      1.00
     mu[8]      0.10      0.18      0.10     -0.18      0.41    964.72      1.00
     mu[9]      0.11      0.19      0.11     -0.17      0.43    975.68      1.00
    mu[10]      0.17      0.28      0.17     -0.25      0.63   1037.05      1.00
    mu[11]      0.05      0.16      0.05     -0.21      0.29    931.71      1.00
    mu[12]      0.13      0.21      0.13     -0.18      0.49    996.35      1.00
    mu[13]      0.16      0.25      0.15     -0.21      0.58   1024.00      1.00
    mu[14]      0.11      0.19      0.10     -0.17      0.43    971.43      1.00
    mu[15]      0.24      0.37      0.23     -0.39      0.80   1067.60      1.00
    mu[16]      0.22      0.35      0.21     -0.37      0.74   1061.13      1.00
     sigma      1.05      0.18      1.03      0.78      1.35    944.03      1.00

Code 5.37

In [37]:
xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)
post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1000,))
post_pred = Predictive(m5_5.model, post)(random.PRNGKey(2), N=xseq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
az.plot_pair(dcc[["N", "K"]].to_dict(orient="list"))
plt.plot(xseq, mu_mean, "k")
plt.fill_between(xseq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
No description has been provided for this image

Code 5.38

In [38]:
def model(M, K=None):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bM * M)
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m5_6 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_6, optim.Adam(1), Trace_ELBO(), M=dcc.M.values, K=dcc.K.values)
svi_result =, 1000)
p5_6 = svi_result.params
post = m5_6.sample_posterior(random.PRNGKey(1), p5_6, sample_shape=(1000,))
print_summary(post, 0.89, False)
100%|██████████| 1000/1000 [00:00<00:00, 1457.13it/s, init loss: 756.0300, avg. loss [951-1000]: 23.9327]
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      0.06      0.16      0.06     -0.20      0.29    931.50      1.00
        bM     -0.28      0.20     -0.28     -0.61      0.03   1088.44      1.00
     mu[0]      0.18      0.18      0.19     -0.11      0.47    944.38      1.00
     mu[1]      0.02      0.16      0.02     -0.24      0.26    932.90      1.00
     mu[2]      0.02      0.16      0.02     -0.24      0.26    933.51      1.00
     mu[3]      0.14      0.17      0.15     -0.11      0.42    958.01      1.00
     mu[4]      0.36      0.27      0.36     -0.07      0.78    872.95      1.00
     mu[5]      0.65      0.45      0.65     -0.03      1.39    898.89      1.00
     mu[6]      0.42      0.31      0.43     -0.07      0.89    878.58      1.00
     mu[7]      0.49      0.35      0.50     -0.09      1.00    885.51      1.00
     mu[8]      0.22      0.20      0.23     -0.12      0.51    906.05      1.00
     mu[9]      0.10      0.16      0.10     -0.17      0.33    940.69      1.00
    mu[10]     -0.12      0.20     -0.13     -0.40      0.21    989.01      1.00
    mu[11]      0.02      0.16      0.02     -0.24      0.26    933.51      1.00
    mu[12]     -0.30      0.29     -0.30     -0.75      0.18   1050.56      1.00
    mu[13]     -0.43      0.38     -0.44     -1.07      0.12   1073.41      1.00
    mu[14]     -0.32      0.31     -0.33     -0.81      0.16   1055.58      1.00
    mu[15]     -0.29      0.28     -0.29     -0.74      0.16   1047.84      1.00
    mu[16]     -0.37      0.34     -0.38     -0.95      0.12   1064.72      1.00
     sigma      0.99      0.17      0.98      0.72      1.26    957.10      1.00

Code 5.39

In [39]:
def model(N, M, K=None):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bN = numpyro.sample("bN", dist.Normal(0, 0.5))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bN * N + bM * M)
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m5_7 = AutoLaplaceApproximation(model)
svi = SVI(
svi_result =, 1000)
p5_7 = svi_result.params
post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, sample_shape=(1000,))
print_summary(post, 0.89, False)
100%|██████████| 1000/1000 [00:00<00:00, 1324.26it/s, init loss: 136.3944, avg. loss [951-1000]: 21.6292]
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      0.06      0.13      0.06     -0.15      0.26   1049.96      1.00
        bM     -0.68      0.23     -0.68     -1.06     -0.32    837.54      1.00
        bN      0.65      0.25      0.66      0.25      1.06    885.39      1.00
     mu[0]     -0.99      0.47     -1.00     -1.72     -0.22    953.05      1.00
     mu[1]     -0.36      0.20     -0.35     -0.64     -0.03   1076.34      1.00
     mu[2]     -0.37      0.20     -0.36     -0.65     -0.03   1068.21      1.00
     mu[3]      0.28      0.16      0.28      0.03      0.52    967.73      1.00
     mu[4]      0.94      0.33      0.94      0.43      1.44    876.35      1.00
     mu[5]      0.53      0.39      0.53     -0.10      1.11   1015.67      1.00
     mu[6]      0.30      0.26      0.30     -0.10      0.73   1041.23      1.00
     mu[7]      0.30      0.30      0.30     -0.16      0.76   1039.12      1.00
     mu[8]      0.73      0.27      0.73      0.35      1.17    867.10      1.00
     mu[9]      0.48      0.21      0.49      0.16      0.80    872.50      1.00
    mu[10]      0.27      0.23      0.28     -0.06      0.66    895.61      1.00
    mu[11]     -0.04      0.14     -0.04     -0.25      0.17    992.18      1.00
    mu[12]     -0.39      0.25     -0.39     -0.81     -0.00    864.74      1.00
    mu[13]     -0.56      0.32     -0.57     -1.03      0.01    871.33      1.00
    mu[14]     -0.55      0.27     -0.55     -0.98     -0.10    858.97      1.00
    mu[15]      0.20      0.30      0.20     -0.29      0.66    916.65      1.00
    mu[16]     -0.10      0.30     -0.10     -0.59      0.37    895.92      1.00
     sigma      0.77      0.14      0.77      0.55      0.97   1029.58      1.00

Code 5.40

In [40]:
coeftab = {
    "m5.5": m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1, 1000)),
    "m5.6": m5_6.sample_posterior(random.PRNGKey(2), p5_6, sample_shape=(1, 1000)),
    "m5.7": m5_7.sample_posterior(random.PRNGKey(3), p5_7, sample_shape=(1, 1000)),
    var_names=["bM", "bN"],
No description has been provided for this image

Code 5.41

In [41]:
xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)
post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, sample_shape=(1000,))
post_pred = Predictive(m5_7.model, post)(random.PRNGKey(2), M=0, N=xseq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
plt.subplot(xlim=(dcc.M.min(), dcc.M.max()), ylim=(dcc.K.min(), dcc.K.max()))
plt.plot(xseq, mu_mean, "k")
plt.fill_between(xseq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
No description has been provided for this image

Code 5.42

In [42]:
# M -> K <- N
# M -> N
n = 100
M = dist.Normal().sample(random.PRNGKey(0), (n,))
N = dist.Normal(M).sample(random.PRNGKey(1))
K = dist.Normal(N - M).sample(random.PRNGKey(2))
d_sim = pd.DataFrame({"K": K, "N": N, "M": M})

Code 5.43

In [43]:
# M -> K <- N
# N -> M
n = 100
N = dist.Normal().sample(random.PRNGKey(0), (n,))
M = dist.Normal(N).sample(random.PRNGKey(1))
K = dist.Normal(N - M).sample(random.PRNGKey(2))
d_sim2 = pd.DataFrame({"K": K, "N": N, "M": M})

# M -> K <- N
# M <- U -> N
n = 100
N = dist.Normal().sample(random.PRNGKey(3), (n,))
M = dist.Normal(M).sample(random.PRNGKey(4))
K = dist.Normal(N - M).sample(random.PRNGKey(5))
d_sim3 = pd.DataFrame({"K": K, "N": N, "M": M})

Code 5.44

In [44]:
dag5_7 = nx.DiGraph()
dag5_7.add_edges_from([("M", "K"), ("N", "K"), ("M", "N")])
coordinates = {"M": (0, 0.5), "K": (1, 1), "N": (2, 0.5)}
MElist = []
for i in range(2):
    for j in range(2):
        for k in range(2):
            new_dag = nx.DiGraph()
                [edge[::-1] if flip else edge for edge, flip in zip(dag5_7.edges, (i, j, k))]
            if not list(nx.simple_cycles(new_dag)):

Code 5.45

In [45]:
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = Howell1
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 544 entries, 0 to 543
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   height  544 non-null    float64
 1   weight  544 non-null    float64
 2   age     544 non-null    float64
 3   male    544 non-null    int64  
dtypes: float64(3), int64(1)
memory usage: 17.1 KB
height weight age male
0 151.765 47.825606 63.0 1
1 139.700 36.485807 63.0 0
2 136.525 31.864838 65.0 0
3 156.845 53.041915 41.0 1
4 145.415 41.276872 51.0 0

Code 5.46

In [46]:
mu_female = dist.Normal(178, 20).sample(random.PRNGKey(0), (int(1e4),))
diff = dist.Normal(0, 10).sample(random.PRNGKey(1), (int(1e4),))
mu_male = dist.Normal(178, 20).sample(random.PRNGKey(2), (int(1e4),)) + diff
print_summary({"mu_female": mu_female, "mu_male": mu_male}, 0.89, False)
                 mean       std    median      5.5%     94.5%     n_eff     r_hat
  mu_female    178.21     20.22    178.24    147.19    211.84   9943.61      1.00
    mu_male    178.10     22.36    178.51    142.35    213.41  10190.57      1.00

Code 5.47

In [47]:
d["sex"] = jnp.where(d.male.values == 1, 1, 0)
0      1
1      0
2      0
3      1
4      0
539    1
540    1
541    0
542    1
543    1
Name: sex, Length: 544, dtype: int32

Code 5.48

In [48]:
def model(sex, height):
    a = numpyro.sample("a", dist.Normal(178, 20).expand([len(set(sex))]))
    sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
    mu = a[sex]
    numpyro.sample("height", dist.Normal(mu, sigma), obs=height)

m5_8 = AutoLaplaceApproximation(model)
svi = SVI(
    model, m5_8, optim.Adam(1), Trace_ELBO(),, height=d.height.values
svi_result =, 2000)
p5_8 = svi_result.params
post = m5_8.sample_posterior(random.PRNGKey(1), p5_8, sample_shape=(1000,))
print_summary(post, 0.89, False)
100%|██████████| 2000/2000 [00:01<00:00, 1815.39it/s, init loss: 5607.9023, avg. loss [1901-2000]: 2558.3149]
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]    135.02      1.63    135.07    132.32    137.46    931.50      1.00
      a[1]    142.56      1.73    142.54    140.02    145.51   1111.51      1.00
     sigma     27.32      0.84     27.32     26.03     28.71    951.62      1.00

Code 5.49

In [49]:
post = m5_8.sample_posterior(random.PRNGKey(1), p5_8, sample_shape=(1000,))
post["diff_fm"] = post["a"][:, 0] - post["a"][:, 1]
print_summary(post, 0.89, False)
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]    135.02      1.63    135.07    132.32    137.46    931.50      1.00
      a[1]    142.56      1.73    142.54    140.02    145.51   1111.51      1.00
   diff_fm     -7.54      2.38     -7.47    -11.77     -4.32    876.56      1.00
     sigma     27.32      0.84     27.32     26.03     28.71    951.62      1.00

Code 5.50

In [50]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
array(['Strepsirrhine', 'New World Monkey', 'Old World Monkey', 'Ape'],

Code 5.51

In [51]:
d["clade_id"] = d.clade.astype("category")

Code 5.52

In [52]:
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())

def model(clade_id, K):
    a = numpyro.sample("a", dist.Normal(0, 0.5).expand([len(set(clade_id))]))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a[clade_id]
    numpyro.sample("height", dist.Normal(mu, sigma), obs=K)

m5_9 = AutoLaplaceApproximation(model)
svi = SVI(
    model, m5_9, optim.Adam(1), Trace_ELBO(), clade_id=d.clade_id.values, K=d.K.values
svi_result =, 1000)
p5_9 = svi_result.params
post = m5_9.sample_posterior(random.PRNGKey(1), p5_9, sample_shape=(1000,))
labels = ["a[" + str(i) + "]:" + s for i, s in enumerate(sorted(d.clade.unique()))]
az.plot_forest({"a": post["a"][None, ...]}, hdi_prob=0.89)
plt.gca().set(yticklabels=labels[::-1], xlabel="expected kcal (std)")
100%|██████████| 1000/1000 [00:00<00:00, 1407.72it/s, init loss: 94.6847, avg. loss [951-1000]: 35.5646]
No description has been provided for this image

Code 5.53

In [53]:
key = random.PRNGKey(63)
d["house"] = random.choice(key, jnp.repeat(jnp.arange(4), 8), d.shape[:1], False)

Code 5.54

In [54]:
def model(clade_id, house, K):
    a = numpyro.sample("a", dist.Normal(0, 0.5).expand([len(set(clade_id))]))
    h = numpyro.sample("h", dist.Normal(0, 0.5).expand([len(set(house))]))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a[clade_id] + h[house]
    numpyro.sample("height", dist.Normal(mu, sigma), obs=K)

m5_10 = AutoLaplaceApproximation(model)
svi = SVI(
svi_result =, 1000)
p5_10 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1231.03it/s, init loss: 491.4240, avg. loss [951-1000]: 35.8243]


Comments powered by Disqus