# Chapter 6. The Haunted DAG & The Causal Terror

In [0]:
import os
import warnings

import arviz as az
import daft
import matplotlib.pyplot as plt
import pandas as pd
from causalgraphicalmodels import CausalGraphicalModel

import jax.numpy as jnp
from jax import lax, random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import ELBO, SVI
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")


#### Code 6.1¶

In [1]:
with numpyro.handlers.seed(rng_seed=1914):
N = 200  # num grant proposals
p = 0.1  # proportion to select
# uncorrelated newsworthiness and trustworthiness
nw = numpyro.sample("nw", dist.Normal().expand([N]))
tw = numpyro.sample("tw", dist.Normal().expand([N]))
# select top 10% of combined scores
s = nw + tw  # total score
q = jnp.quantile(s, 1 - p)  # top 10% threshold
selected = jnp.where(s >= q, True, False)
jnp.corrcoef(jnp.stack([tw[selected], nw[selected]], 0))[0, 1]

Out[1]:
DeviceArray(-0.6453402, dtype=float32)

#### Code 6.2¶

In [2]:
N = 100  # number of individuals
with numpyro.handlers.seed(rng_seed=909):
# sim total height of each
height = numpyro.sample("height", dist.Normal(10, 2).expand([N]))
# leg as proportion of height
leg_prop = numpyro.sample("prop", dist.Uniform(0.4, 0.5).expand([N]))
# sim left leg as proportion + error
leg_left = leg_prop * height + numpyro.sample(
"left_error", dist.Normal(0, 0.02).expand([N])
)
# sim right leg as proportion + error
leg_right = leg_prop * height + numpyro.sample(
"right_error", dist.Normal(0, 0.02).expand([N])
)
# combine into data frame
d = pd.DataFrame({"height": height, "leg_left": leg_left, "leg_right": leg_right})


#### Code 6.3¶

In [3]:
def model(leg_left, leg_right, height):
a = numpyro.sample("a", dist.Normal(10, 100))
bl = numpyro.sample("bl", dist.Normal(2, 10))
br = numpyro.sample("br", dist.Normal(2, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bl * leg_left + br * leg_right
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)

m6_1 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_1,
ELBO(),
leg_left=d.leg_left.values,
leg_right=d.leg_right.values,
height=d.height.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(2000))
p6_1 = svi.get_params(state)
post = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000,))
print_summary(post, 0.89, False)

Out[3]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a      0.79      0.34      0.78      0.25      1.29   1049.96      1.00
bl      1.58      2.26      1.59     -2.07      5.20    813.30      1.00
br      0.45      2.26      0.44     -3.35      3.98    805.83      1.00
sigma      0.66      0.05      0.66      0.59      0.74    969.28      1.00



#### Code 6.4¶

In [4]:
az.plot_forest(post, hdi_prob=0.89)
plt.show()

Out[4]:

#### Code 6.5¶

In [5]:
post = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000,))
az.plot_pair(post, var_names=["br", "bl"], scatter_kwargs={"alpha": 0.1})
plt.show()

Out[5]:

#### Code 6.6¶

In [6]:
sum_blbr = post["bl"] + post["br"]
az.plot_kde(sum_blbr, label="sum of bl and br")
plt.show()

Out[6]:

#### Code 6.7¶

In [7]:
def model(leg_left, height):
a = numpyro.sample("a", dist.Normal(10, 100))
bl = numpyro.sample("bl", dist.Normal(2, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bl * leg_left
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)

m6_2 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_2,
ELBO(),
leg_left=d.leg_left.values,
height=d.height.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p6_2 = svi.get_params(state)
post = m6_2.sample_posterior(random.PRNGKey(1), p6_2, (1000,))
print_summary(post, 0.89, False)

Out[7]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a      0.83      0.35      0.84      0.25      1.35    931.50      1.00
bl      2.02      0.08      2.02      1.91      2.15    940.42      1.00
sigma      0.67      0.05      0.67      0.60      0.75    949.10      1.00



#### Code 6.8¶

In [8]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())
d["F"] = d["perc.fat"].pipe(lambda x: (x - x.mean()) / x.std())
d["L"] = d["perc.lactose"].pipe(lambda x: (x - x.mean()) / x.std())


#### Code 6.9¶

In [9]:
# kcal.per.g regressed on perc.fat
def model(F, K):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bF = numpyro.sample("bF", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bF * F
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m6_3 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_3, optim.Adam(1), ELBO(), F=d.F.values, K=d.K.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p6_3 = svi.get_params(state)

# kcal.per.g regressed on perc.lactose
def model(L, K):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bL = numpyro.sample("bL", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bL * L
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m6_4 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_4, optim.Adam(1), ELBO(), L=d.L.values, K=d.K.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p6_4 = svi.get_params(state)

post = m6_3.sample_posterior(random.PRNGKey(1), p6_3, (1000,))
print_summary(post, 0.89, False)
post = m6_4.sample_posterior(random.PRNGKey(1), p6_4, (1000,))
print_summary(post, 0.89, False)

Out[9]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a      0.01      0.08      0.01     -0.13      0.12    931.50      1.00
bF      0.86      0.09      0.86      0.73      1.01   1111.39      1.00
sigma      0.46      0.06      0.46      0.37      0.57    940.36      1.00

mean       std    median      5.5%     94.5%     n_eff     r_hat
a     -0.09      0.07     -0.09     -0.21      0.02    931.50      1.00
bL     -0.90      0.08     -0.90     -1.03     -0.78   1086.72      1.00
sigma      0.41      0.06      0.40      0.32      0.49    953.70      1.00



#### Code 6.10¶

In [10]:
def model(F, L, K):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bF = numpyro.sample("bF", dist.Normal(0, 0.5))
bL = numpyro.sample("bL", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bF * F + bL * L
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)

m6_5 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_5, optim.Adam(1), ELBO(), F=d.F.values, L=d.L.values, K=d.K.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p6_5 = svi.get_params(state)
post = m6_5.sample_posterior(random.PRNGKey(1), p6_5, (1000,))
print_summary(post, 0.89, False)

Out[10]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a      0.04      0.07      0.03     -0.07      0.13   1049.96      1.00
bF      0.25      0.19      0.25     -0.05      0.56    820.06      1.00
bL     -0.67      0.19     -0.67     -0.98     -0.36    869.27      1.00
sigma      0.39      0.05      0.39      0.31      0.46    931.90      1.00



#### Code 6.11¶

In [11]:
az.plot_pair(d[["kcal.per.g", "perc.fat", "perc.lactose"]].to_dict("list"))
plt.show()

Out[11]:

#### Code 6.12¶

In [12]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk

def sim_coll(i, r=0.9):
sd = jnp.sqrt((1 - r ** 2) * jnp.var(d["perc.fat"].values))
x = dist.Normal(r * d["perc.fat"].values, sd).sample(random.PRNGKey(3 * i))

def model(perc_fat, kcal_per_g):
intercept = numpyro.sample("intercept", dist.Normal(0, 10))
b_perc_flat = numpyro.sample("b_perc.fat", dist.Normal(0, 10))
b_x = numpyro.sample("b_x", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.HalfCauchy(2))
mu = intercept + b_perc_flat * perc_fat + b_x * x
numpyro.sample("kcal.per.g", dist.Normal(mu, sigma), obs=kcal_per_g)

m = AutoLaplaceApproximation(model)
svi = SVI(
model,
m,
ELBO(),
perc_fat=d["perc.fat"].values,
kcal_per_g=d["kcal.per.g"].values,
)
init_state = svi.init(random.PRNGKey(3 * i + 1))
state = lax.fori_loop(0, 20000, lambda i, x: svi.update(x)[0], init_state)
params = svi.get_params(state)
samples = m.sample_posterior(random.PRNGKey(3 * i + 2), params, (1000,))
vcov = jnp.cov(jnp.stack(list(samples.values()), axis=0))
stddev = jnp.sqrt(jnp.diag(vcov))  # stddev of parameter
return dict(zip(samples.keys(), stddev))["b_perc.fat"]

def rep_sim_coll(r=0.9, n=100):
stddev = lax.map(lambda i: sim_coll(i, r=r), jnp.arange(n))
return jnp.nanmean(stddev)

r_seq = jnp.arange(start=0, stop=1, step=0.01)
stddev = lax.map(lambda z: rep_sim_coll(r=z, n=100), r_seq)
plt.plot(r_seq, stddev)
plt.xlabel("correlation")
plt.show()

Out[12]:

#### Code 6.13¶

In [13]:
with numpyro.handlers.seed(rng_seed=71):
# number of plants
N = 100

# simulate initial heights
h0 = numpyro.sample("h0", dist.Normal(10, 2).expand([N]))

# assign treatments and simulate fungus and growth
treatment = jnp.repeat(jnp.arange(2), repeats=N // 2)
fungus = numpyro.sample(
"fungus", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4))
)
h1 = h0 + numpyro.sample("diff", dist.Normal(5 - 3 * fungus))

# compose a clean data frame
d = pd.DataFrame({"h0": h0, "h1": h1, "treatment": treatment, "fungus": fungus})
print_summary(dict(zip(d.columns, d.T.values)), 0.89, False)

Out[13]:
                 mean       std    median      5.5%     94.5%     n_eff     r_hat
fungus      0.31      0.46      0.00      0.00      1.00     18.52      1.17
h0      9.73      1.95      9.63      7.05     13.33     80.22      0.99
h1     13.72      2.47     13.60     10.73     18.38     43.44      1.08
treatment      0.50      0.50      0.50      0.00      1.00      2.64       inf



#### Code 6.14¶

In [14]:
sim_p = dist.LogNormal(0, 0.25).sample(random.PRNGKey(0), (int(1e4),))
print_summary({"sim_p": sim_p}, 0.89, False)

Out[14]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
sim_p      1.04      0.27      1.00      0.63      1.44   9936.32      1.00



#### Code 6.15¶

In [15]:
def model(h0, h1):
p = numpyro.sample("p", dist.LogNormal(0, 0.25))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)

m6_6 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_6, optim.Adam(1), ELBO(), h0=d.h0.values, h1=d.h1.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p6_6 = svi.get_params(state)
post = m6_6.sample_posterior(random.PRNGKey(1), p6_6, (1000,))
print_summary(post, 0.89, False)

Out[15]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
p      6.05      1.04      5.95      4.47      7.68    988.88      1.00
sigma    186.18     13.70    186.00    163.04    205.58   1028.05      1.00



#### Code 6.16¶

In [16]:
def model(treatment, fungus, h0, h1):
a = numpyro.sample("a", dist.LogNormal(0, 0.2))
bt = numpyro.sample("bt", dist.Normal(0, 0.5))
bf = numpyro.sample("bf", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
p = a + bt * treatment + bf * fungus
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)

m6_7 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_7,
ELBO(),
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p6_7 = svi.get_params(state)
post = m6_7.sample_posterior(random.PRNGKey(1), p6_7, (1000,))
print_summary(post, 0.89, False)

Out[16]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a      1.47      0.02      1.47      1.44      1.51   1049.13      1.00
bf     -0.28      0.03     -0.28     -0.33     -0.24    911.14      1.00
bt      0.01      0.03      0.01     -0.03      0.06   1123.14      1.00
sigma      1.25      0.08      1.25      1.11      1.37    982.60      1.00



#### Code 6.17¶

In [17]:
def model(treatment, h0, h1):
a = numpyro.sample("a", dist.LogNormal(0, 0.2))
bt = numpyro.sample("bt", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
p = a + bt * treatment
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)

m6_8 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_8,
ELBO(),
treatment=d.treatment.values,
h0=d.h0.values,
h1=d.h1.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p6_8 = svi.get_params(state)
post = m6_8.sample_posterior(random.PRNGKey(1), p6_8, (1000,))
print_summary(post, 0.89, False)

Out[17]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a      1.33      0.02      1.33      1.29      1.37    930.82      1.00
bt      0.13      0.04      0.12      0.08      0.19    880.02      1.00
sigma      1.73      0.12      1.73      1.55      1.94    948.82      1.00



#### Code 6.18¶

In [18]:
plant_dag = CausalGraphicalModel(
nodes=["H0", "H1", "F", "T"], edges=[("H0", "H1"), ("F", "H1"), ("T", "F")]
)
pgm = daft.PGM()
coordinates = {"H0": (0, 0), "T": (4, 0), "F": (3, 0), "H1": (2, 0)}
for node in plant_dag.dag.nodes:
for edge in plant_dag.dag.edges:
with plt.rc_context({"figure.constrained_layout.use": False}):
pgm.render()

Out[18]:

#### Code 6.19¶

In [19]:
all_independencies = plant_dag.get_all_independence_relationships()
for s in all_independencies:
if all(
t[0] != s[0] or t[1] != s[1] or not t[2].issubset(s[2])
for t in all_independencies
if t != s
):
print(s)

Out[19]:
('F', 'H0', set())
('H0', 'T', set())
('H1', 'T', {'F'})


#### Code 6.20¶

In [20]:
with numpyro.handlers.seed(rng_seed=71):
N = 1000
h0 = numpyro.sample("h0", dist.Normal(10, 2).expand([N]))
treatment = jnp.repeat(jnp.arange(2), repeats=N // 2)
M = numpyro.sample("M", dist.Bernoulli(probs=0.5).expand([N]))
fungus = numpyro.sample(
"fungus", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4))
)
h1 = h0 + numpyro.sample("diff", dist.Normal(5 + 3 * M))
d2 = pd.DataFrame({"h0": h0, "h1": h1, "treatment": treatment, "fungus": fungus})


#### Code 6.21¶

In [21]:
def sim_happiness(seed=1977, N_years=1000, max_age=65, N_births=20, aom=18):
# age existing individuals & newborns
A = jnp.repeat(jnp.arange(1, N_years + 1), N_births)
# sim happiness trait - never changes
H = jnp.repeat(jnp.linspace(-2, 2, N_births)[None, :], N_years, 0).reshape(-1)
# not yet married
M = jnp.zeros(N_years * N_births, dtype=jnp.uint8)

def update_M(i, M):
# for each person over 17, chance get married
married = dist.Bernoulli(logits=(H - 4)).sample(random.PRNGKey(seed + i))
return jnp.where((A >= i) & (M == 0), married, M)

M = lax.fori_loop(aom, max_age + 1, update_M, M)
# mortality
deaths = A > max_age
A = A[~deaths]
H = H[~deaths]
M = M[~deaths]

d = pd.DataFrame({"age": A, "married": M, "happiness": H})
return d

d = sim_happiness(seed=1977, N_years=1000)
print_summary(dict(zip(d.columns, d.T.values)), 0.89, False)

Out[21]:
                 mean       std    median      5.5%     94.5%     n_eff     r_hat
age     33.00     18.77     33.00      1.00     58.00      2.51      2.64
happiness      0.00      1.21      0.00     -2.00      1.58    338.78      1.00
married      0.28      0.45      0.00      0.00      1.00     48.04      1.18



#### Code 6.22¶

In [22]:
d2 = d[d.age > 17].copy()  # only adults
d2["A"] = (d2.age - 18) / (65 - 18)


#### Code 6.23¶

In [23]:
d2["mid"] = d2.married

def model(mid, A, happiness):
a = numpyro.sample("a", dist.Normal(0, 1).expand([len(set(mid))]))
bA = numpyro.sample("bA", dist.Normal(0, 2))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[mid] + bA * A
numpyro.sample("happiness", dist.Normal(mu, sigma), obs=happiness)

m6_9 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_9,
ELBO(),
mid=d2.mid.values,
A=d2.A.values,
happiness=d2.happiness.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p6_9 = svi.get_params(state)
post = m6_9.sample_posterior(random.PRNGKey(1), p6_9, (1000,))
print_summary(post, 0.89, False)

Out[23]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a[0]     -0.20      0.06     -0.20     -0.30     -0.10   1049.96      1.00
a[1]      1.23      0.09      1.23      1.09      1.37    898.97      1.00
bA     -0.69      0.11     -0.69     -0.88     -0.53   1126.51      1.00
sigma      1.02      0.02      1.02      0.98      1.05    966.00      1.00



#### Code 6.24¶

In [24]:
def model(A, happiness):
a = numpyro.sample("a", dist.Normal(0, 1))
bA = numpyro.sample("bA", dist.Normal(0, 2))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bA * A
numpyro.sample("happiness", dist.Normal(mu, sigma), obs=happiness)

m6_10 = AutoLaplaceApproximation(model)
svi = SVI(
model, m6_10, optim.Adam(1), ELBO(), A=d2.A.values, happiness=d2.happiness.values
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p6_10 = svi.get_params(state)
post = m6_10.sample_posterior(random.PRNGKey(1), p6_10, (1000,))
print_summary(post, 0.89, False)

Out[24]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a      0.01      0.08      0.01     -0.12      0.12    931.50      1.00
bA     -0.01      0.13     -0.01     -0.22      0.21    940.88      1.00
sigma      1.21      0.03      1.21      1.17      1.26    949.78      1.00



#### Code 6.25¶

In [25]:
N = 200  # number of grandparent-parent-child triads
b_GP = 1  # direct effect of G on P
b_GC = 0  # direct effect of G on C
b_PC = 1  # direct effect of P on C
b_U = 2  # direct effect of U on P and C


#### Code 6.26¶

In [26]:
with numpyro.handlers.seed(rng_seed=1):
U = 2 * numpyro.sample("U", dist.Bernoulli(0.5).expand([N])) - 1
G = numpyro.sample("G", dist.Normal().expand([N]))
P = numpyro.sample("P", dist.Normal(b_GP * G + b_U * U))
C = numpyro.sample("C", dist.Normal(b_PC * P + b_GC * G + b_U * U))
d = pd.DataFrame({"C": C, "P": P, "G": G, "U": U})


#### Code 6.27¶

In [27]:
def model(P, G, C):
a = numpyro.sample("a", dist.Normal(0, 1))
b_PC = numpyro.sample("b_PC", dist.Normal(0, 1))
b_GC = numpyro.sample("b_GC", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + b_PC * P + b_GC * G
numpyro.sample("C", dist.Normal(mu, sigma), obs=C)

m6_11 = AutoLaplaceApproximation(model)
svi = SVI(
model, m6_11, optim.Adam(0.3), ELBO(), P=d.P.values, G=d.G.values, C=d.C.values
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p6_11 = svi.get_params(state)
post = m6_11.sample_posterior(random.PRNGKey(1), p6_11, (1000,))
print_summary(post, 0.89, False)

Out[27]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
a     -0.08      0.10     -0.09     -0.24      0.06   1049.96      1.00
b_GC     -0.71      0.11     -0.71     -0.89     -0.55    813.76      1.00
b_PC      1.72      0.04      1.72      1.65      1.79    982.64      1.00
sigma      1.39      0.07      1.39      1.28      1.49    968.54      1.00



#### Code 6.28¶

In [28]:
def model(P, G, U, C):
a = numpyro.sample("a", dist.Normal(0, 1))
b_PC = numpyro.sample("b_PC", dist.Normal(0, 1))
b_GC = numpyro.sample("b_GC", dist.Normal(0, 1))
b_U = numpyro.sample("U", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + b_PC * P + b_GC * G + b_U * U
numpyro.sample("C", dist.Normal(mu, sigma), obs=C)

m6_12 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_12,
ELBO(),
P=d.P.values,
G=d.G.values,
U=d.U.values,
C=d.C.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(2000))
p6_12 = svi.get_params(state)
post = m6_12.sample_posterior(random.PRNGKey(1), p6_12, (1000,))
print_summary(post, 0.89, False)

Out[28]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
U      1.89      0.17      1.90      1.61      2.14   1009.20      1.00
a     -0.03      0.08     -0.03     -0.15      0.10    765.90      1.00
b_GC      0.03      0.10      0.04     -0.13      0.20   1032.11      1.00
b_PC      1.02      0.07      1.02      0.90      1.14   1106.88      1.00
sigma      1.09      0.05      1.09      0.99      1.17    832.71      1.00



#### Code 6.29¶

In [29]:
dag_6_1 = CausalGraphicalModel(
nodes=["X", "Y", "C", "U", "B", "A"],
edges=[
("X", "Y"),
("U", "X"),
("A", "U"),
("A", "C"),
("C", "Y"),
("U", "B"),
("C", "B"),
],
)
if all(not t.issubset(s) for t in all_adjustment_sets if t != s):
if s != {"U"}:
print(s)

Out[29]:
frozenset({'C'})
frozenset({'A'})


#### Code 6.30¶

In [30]:
dag_6_2 = CausalGraphicalModel(
nodes=["S", "A", "D", "M", "W"],
edges=[
("S", "A"),
("A", "D"),
("S", "M"),
("M", "D"),
("S", "W"),
("W", "D"),
("A", "M"),
],
)
if all(not t.issubset(s) for t in all_adjustment_sets if t != s):
print(s)

Out[30]:
frozenset({'A', 'M'})
frozenset({'S'})


#### Code 6.31¶

In [31]:
all_independencies = dag_6_2.get_all_independence_relationships()
for s in all_independencies:
if all(
t[0] != s[0] or t[1] != s[1] or not t[2].issubset(s[2])
for t in all_independencies
if t != s
):
print(s)

Out[31]:
('S', 'D', {'W', 'A', 'M'})
('W', 'M', {'S'})
('W', 'A', {'S'})