Chapter 6. The Haunted DAG & The Causal Terror
In [ ]:
!pip install -q numpyro arviz daft networkx
In [0]:
import collections
import itertools
import os
import warnings
import arviz as az
import daft
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import jax.numpy as jnp
from jax import lax, random
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
Code 6.1¶
In [1]:
with numpyro.handlers.seed(rng_seed=1914):
N = 200 # num grant proposals
p = 0.1 # proportion to select
# uncorrelated newsworthiness and trustworthiness
nw = numpyro.sample("nw", dist.Normal().expand([N]))
tw = numpyro.sample("tw", dist.Normal().expand([N]))
# select top 10% of combined scores
s = nw + tw # total score
q = jnp.quantile(s, 1 - p) # top 10% threshold
selected = jnp.where(s >= q, True, False)
jnp.corrcoef(jnp.stack([tw[selected], nw[selected]], 0))[0, 1]
Out[1]:
Code 6.2¶
In [2]:
N = 100 # number of individuals
with numpyro.handlers.seed(rng_seed=909):
# sim total height of each
height = numpyro.sample("height", dist.Normal(10, 2).expand([N]))
# leg as proportion of height
leg_prop = numpyro.sample("prop", dist.Uniform(0.4, 0.5).expand([N]))
# sim left leg as proportion + error
leg_left = leg_prop * height + numpyro.sample(
"left_error", dist.Normal(0, 0.02).expand([N])
)
# sim right leg as proportion + error
leg_right = leg_prop * height + numpyro.sample(
"right_error", dist.Normal(0, 0.02).expand([N])
)
# combine into data frame
d = pd.DataFrame({"height": height, "leg_left": leg_left, "leg_right": leg_right})
Code 6.3¶
In [3]:
def model(leg_left, leg_right, height):
a = numpyro.sample("a", dist.Normal(10, 100))
bl = numpyro.sample("bl", dist.Normal(2, 10))
br = numpyro.sample("br", dist.Normal(2, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bl * leg_left + br * leg_right
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
m6_1 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_1,
optim.Adam(0.1),
Trace_ELBO(),
leg_left=d.leg_left.values,
leg_right=d.leg_right.values,
height=d.height.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p6_1 = svi_result.params
post = m6_1.sample_posterior(random.PRNGKey(1), p6_1, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[3]:
Out[3]:
Code 6.4¶
In [4]:
az.plot_forest(post, hdi_prob=0.89)
plt.show()
Out[4]:
Code 6.5¶
In [5]:
post = m6_1.sample_posterior(random.PRNGKey(1), p6_1, sample_shape=(1000,))
az.plot_pair(post, var_names=["br", "bl"], scatter_kwargs={"alpha": 0.1})
plt.show()
Out[5]:
Code 6.6¶
In [6]:
sum_blbr = post["bl"] + post["br"]
az.plot_kde(sum_blbr, label="sum of bl and br")
plt.show()
Out[6]:
Code 6.7¶
In [7]:
def model(leg_left, height):
a = numpyro.sample("a", dist.Normal(10, 100))
bl = numpyro.sample("bl", dist.Normal(2, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bl * leg_left
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
m6_2 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_2,
optim.Adam(1),
Trace_ELBO(),
leg_left=d.leg_left.values,
height=d.height.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_2 = svi_result.params
post = m6_2.sample_posterior(random.PRNGKey(1), p6_2, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[7]:
Out[7]:
Code 6.8¶
In [8]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())
d["F"] = d["perc.fat"].pipe(lambda x: (x - x.mean()) / x.std())
d["L"] = d["perc.lactose"].pipe(lambda x: (x - x.mean()) / x.std())
Code 6.9¶
In [9]:
# kcal.per.g regressed on perc.fat
def model(F, K):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bF = numpyro.sample("bF", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bF * F
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m6_3 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_3, optim.Adam(1), Trace_ELBO(), F=d.F.values, K=d.K.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_3 = svi_result.params
# kcal.per.g regressed on perc.lactose
def model(L, K):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bL = numpyro.sample("bL", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bL * L
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m6_4 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_4, optim.Adam(1), Trace_ELBO(), L=d.L.values, K=d.K.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_4 = svi_result.params
post = m6_3.sample_posterior(random.PRNGKey(1), p6_3, sample_shape=(1000,))
print_summary(post, 0.89, False)
post = m6_4.sample_posterior(random.PRNGKey(1), p6_4, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[9]:
Out[9]:
Code 6.10¶
In [10]:
def model(F, L, K):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bF = numpyro.sample("bF", dist.Normal(0, 0.5))
bL = numpyro.sample("bL", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bF * F + bL * L
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m6_5 = AutoLaplaceApproximation(model)
svi = SVI(
model, m6_5, optim.Adam(1), Trace_ELBO(), F=d.F.values, L=d.L.values, K=d.K.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_5 = svi_result.params
post = m6_5.sample_posterior(random.PRNGKey(1), p6_5, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[10]:
Out[10]:
Code 6.11¶
In [11]:
az.plot_pair(d[["kcal.per.g", "perc.fat", "perc.lactose"]].to_dict("list"))
plt.show()
Out[11]:
Code 6.12¶
In [12]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
def sim_coll(i, r=0.9):
sd = jnp.sqrt((1 - r**2) * jnp.var(d["perc.fat"].values))
x = dist.Normal(r * d["perc.fat"].values, sd).sample(random.PRNGKey(3 * i))
def model(perc_fat, kcal_per_g):
intercept = numpyro.sample("intercept", dist.Normal(0, 10))
b_perc_flat = numpyro.sample("b_perc.fat", dist.Normal(0, 10))
b_x = numpyro.sample("b_x", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.HalfCauchy(2))
mu = intercept + b_perc_flat * perc_fat + b_x * x
numpyro.sample("kcal.per.g", dist.Normal(mu, sigma), obs=kcal_per_g)
m = AutoLaplaceApproximation(model)
svi = SVI(
model,
m,
optim.Adam(0.01),
Trace_ELBO(),
perc_fat=d["perc.fat"].values,
kcal_per_g=d["kcal.per.g"].values,
)
svi_result = svi.run(random.PRNGKey(3 * i + 1), 20000, progress_bar=False)
params = svi_result.params
samples = m.sample_posterior(random.PRNGKey(3 * i + 2), params, sample_shape=(1000,))
vcov = jnp.cov(jnp.stack(list(samples.values()), axis=0))
stddev = jnp.sqrt(jnp.diag(vcov)) # stddev of parameter
return dict(zip(samples.keys(), stddev))["b_perc.fat"]
def rep_sim_coll(r=0.9, n=100):
stddev = lax.map(lambda i: sim_coll(i, r=r), jnp.arange(n))
return jnp.nanmean(stddev)
r_seq = jnp.arange(start=0, stop=1, step=0.01)
stddev = lax.map(lambda z: rep_sim_coll(r=z, n=100), r_seq)
plt.plot(r_seq, stddev)
plt.xlabel("correlation")
plt.show()
Out[12]:
Code 6.13¶
In [13]:
with numpyro.handlers.seed(rng_seed=71):
# number of plants
N = 100
# simulate initial heights
h0 = numpyro.sample("h0", dist.Normal(10, 2).expand([N]))
# assign treatments and simulate fungus and growth
treatment = jnp.repeat(jnp.arange(2), repeats=N // 2)
fungus = numpyro.sample(
"fungus", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4))
)
h1 = h0 + numpyro.sample("diff", dist.Normal(5 - 3 * fungus))
# compose a clean data frame
d = pd.DataFrame({"h0": h0, "h1": h1, "treatment": treatment, "fungus": fungus})
print_summary(dict(zip(d.columns, d.T.values)), 0.89, False)
Out[13]:
Code 6.14¶
In [14]:
sim_p = dist.LogNormal(0, 0.25).sample(random.PRNGKey(0), (int(1e4),))
print_summary({"sim_p": sim_p}, 0.89, False)
Out[14]:
Code 6.15¶
In [15]:
def model(h0, h1):
p = numpyro.sample("p", dist.LogNormal(0, 0.25))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)
m6_6 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_6, optim.Adam(1), Trace_ELBO(), h0=d.h0.values, h1=d.h1.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_6 = svi_result.params
post = m6_6.sample_posterior(random.PRNGKey(1), p6_6, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[15]:
Out[15]:
Code 6.16¶
In [16]:
def model(treatment, fungus, h0, h1):
a = numpyro.sample("a", dist.LogNormal(0, 0.2))
bt = numpyro.sample("bt", dist.Normal(0, 0.5))
bf = numpyro.sample("bf", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
p = a + bt * treatment + bf * fungus
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)
m6_7 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_7,
optim.Adam(0.3),
Trace_ELBO(),
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_7 = svi_result.params
post = m6_7.sample_posterior(random.PRNGKey(1), p6_7, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[16]:
Out[16]:
Code 6.17¶
In [17]:
def model(treatment, h0, h1):
a = numpyro.sample("a", dist.LogNormal(0, 0.2))
bt = numpyro.sample("bt", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
p = a + bt * treatment
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)
m6_8 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_8,
optim.Adam(1),
Trace_ELBO(),
treatment=d.treatment.values,
h0=d.h0.values,
h1=d.h1.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_8 = svi_result.params
post = m6_8.sample_posterior(random.PRNGKey(1), p6_8, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[17]:
Out[17]:
Code 6.18¶
In [3]:
plant_dag = nx.DiGraph()
plant_dag.add_edges_from([("H0", "H1"), ("F", "H1"), ("T", "F")])
pgm = daft.PGM()
coordinates = {"H0": (0, 0), "T": (4, 0), "F": (3, 0), "H1": (2, 0)}
for node in plant_dag.nodes:
pgm.add_node(node, node, *coordinates[node])
for edge in plant_dag.edges:
pgm.add_edge(*edge)
with plt.rc_context({"figure.constrained_layout.use": False}):
pgm.render()
Out[3]:
Code 6.19¶
In [36]:
conditional_independencies = collections.defaultdict(list)
for edge in itertools.combinations(sorted(plant_dag.nodes), 2):
remaining = sorted(set(plant_dag.nodes) - set(edge))
for size in range(len(remaining) + 1):
for subset in itertools.combinations(remaining, size):
if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):
continue
if nx.d_separated(plant_dag, {edge[0]}, {edge[1]}, set(subset)):
conditional_independencies[edge].append(set(subset))
print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else ""))
Out[36]:
Code 6.20¶
In [20]:
with numpyro.handlers.seed(rng_seed=71):
N = 1000
h0 = numpyro.sample("h0", dist.Normal(10, 2).expand([N]))
treatment = jnp.repeat(jnp.arange(2), repeats=N // 2)
M = numpyro.sample("M", dist.Bernoulli(probs=0.5).expand([N]))
fungus = numpyro.sample(
"fungus", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4))
)
h1 = h0 + numpyro.sample("diff", dist.Normal(5 + 3 * M))
d2 = pd.DataFrame({"h0": h0, "h1": h1, "treatment": treatment, "fungus": fungus})
Code 6.21¶
In [21]:
def sim_happiness(seed=1977, N_years=1000, max_age=65, N_births=20, aom=18):
# age existing individuals & newborns
A = jnp.repeat(jnp.arange(1, N_years + 1), N_births)
# sim happiness trait - never changes
H = jnp.repeat(jnp.linspace(-2, 2, N_births)[None, :], N_years, 0).reshape(-1)
# not yet married
M = jnp.zeros(N_years * N_births, dtype=jnp.int32)
def update_M(i, M):
# for each person over 17, chance get married
married = dist.Bernoulli(logits=(H - 4)).sample(random.PRNGKey(seed + i))
return jnp.where((A >= i) & (M == 0), married, M)
M = lax.fori_loop(aom, max_age + 1, update_M, M)
# mortality
deaths = A > max_age
A = A[~deaths]
H = H[~deaths]
M = M[~deaths]
d = pd.DataFrame({"age": A, "married": M, "happiness": H})
return d
d = sim_happiness(seed=1977, N_years=1000)
print_summary(dict(zip(d.columns, d.T.values)), 0.89, False)
Out[21]:
Code 6.22¶
In [22]:
d2 = d[d.age > 17].copy() # only adults
d2["A"] = (d2.age - 18) / (65 - 18)
Code 6.23¶
In [23]:
d2["mid"] = d2.married
def model(mid, A, happiness):
a = numpyro.sample("a", dist.Normal(0, 1).expand([len(set(mid))]))
bA = numpyro.sample("bA", dist.Normal(0, 2))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[mid] + bA * A
numpyro.sample("happiness", dist.Normal(mu, sigma), obs=happiness)
m6_9 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_9,
optim.Adam(1),
Trace_ELBO(),
mid=d2.mid.values,
A=d2.A.values,
happiness=d2.happiness.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_9 = svi_result.params
post = m6_9.sample_posterior(random.PRNGKey(1), p6_9, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[23]:
Out[23]:
Code 6.24¶
In [24]:
def model(A, happiness):
a = numpyro.sample("a", dist.Normal(0, 1))
bA = numpyro.sample("bA", dist.Normal(0, 2))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bA * A
numpyro.sample("happiness", dist.Normal(mu, sigma), obs=happiness)
m6_10 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_10,
optim.Adam(1),
Trace_ELBO(),
A=d2.A.values,
happiness=d2.happiness.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_10 = svi_result.params
post = m6_10.sample_posterior(random.PRNGKey(1), p6_10, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[24]:
Out[24]:
Code 6.25¶
In [25]:
N = 200 # number of grandparent-parent-child triads
b_GP = 1 # direct effect of G on P
b_GC = 0 # direct effect of G on C
b_PC = 1 # direct effect of P on C
b_U = 2 # direct effect of U on P and C
Code 6.26¶
In [26]:
with numpyro.handlers.seed(rng_seed=1):
U = 2 * numpyro.sample("U", dist.Bernoulli(0.5).expand([N])) - 1
G = numpyro.sample("G", dist.Normal().expand([N]))
P = numpyro.sample("P", dist.Normal(b_GP * G + b_U * U))
C = numpyro.sample("C", dist.Normal(b_PC * P + b_GC * G + b_U * U))
d = pd.DataFrame({"C": C, "P": P, "G": G, "U": U})
Code 6.27¶
In [27]:
def model(P, G, C):
a = numpyro.sample("a", dist.Normal(0, 1))
b_PC = numpyro.sample("b_PC", dist.Normal(0, 1))
b_GC = numpyro.sample("b_GC", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + b_PC * P + b_GC * G
numpyro.sample("C", dist.Normal(mu, sigma), obs=C)
m6_11 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_11,
optim.Adam(0.3),
Trace_ELBO(),
P=d.P.values,
G=d.G.values,
C=d.C.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_11 = svi_result.params
post = m6_11.sample_posterior(random.PRNGKey(1), p6_11, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[27]:
Out[27]:
Code 6.28¶
In [28]:
def model(P, G, U, C):
a = numpyro.sample("a", dist.Normal(0, 1))
b_PC = numpyro.sample("b_PC", dist.Normal(0, 1))
b_GC = numpyro.sample("b_GC", dist.Normal(0, 1))
b_U = numpyro.sample("U", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + b_PC * P + b_GC * G + b_U * U
numpyro.sample("C", dist.Normal(mu, sigma), obs=C)
m6_12 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_12,
optim.Adam(1),
Trace_ELBO(),
P=d.P.values,
G=d.G.values,
U=d.U.values,
C=d.C.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_12 = svi_result.params
post = m6_12.sample_posterior(random.PRNGKey(1), p6_12, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[28]:
Out[28]:
Code 6.29¶
In [29]:
dag_6_1 = nx.DiGraph()
dag_6_1.add_edges_from(
[("X", "Y"), ("U", "X"), ("A", "U"), ("A", "C"), ("C", "Y"), ("U", "B"), ("C", "B")])
backdoor_paths = [path for path in nx.all_simple_paths(dag_6_1.to_undirected(), "X", "Y")
if dag_6_1.has_edge(path[1], "X")]
remaining = sorted(set(dag_6_1.nodes) - {"X", "Y", "U"} - set(nx.descendants(dag_6_1, "X")))
adjustment_sets = []
for size in range(len(remaining) + 1):
for subset in itertools.combinations(remaining, size):
subset = set(subset)
if any(s.issubset(subset) for s in adjustment_sets):
continue
need_adjust = True
for path in backdoor_paths:
d_separated = False
for x, z, y in zip(path[:-2], path[1:-1], path[2:]):
if dag_6_1.has_edge(x, z) and dag_6_1.has_edge(y, z):
if set(nx.descendants(dag_6_1, z)) & subset:
continue
d_separated = z not in subset
else:
d_separated = z in subset
if d_separated:
break
if not d_separated:
need_adjust = False
break
if need_adjust:
adjustment_sets.append(subset)
print(subset)
Out[29]:
Code 6.30¶
In [30]:
dag_6_2 = nx.DiGraph()
dag_6_2.add_edges_from(
[("S", "A"), ("A", "D"), ("S", "M"), ("M", "D"), ("S", "W"), ("W", "D"), ("A", "M")])
backdoor_paths = [path for path in nx.all_simple_paths(dag_6_2.to_undirected(), "W", "D")
if dag_6_2.has_edge(path[1], "W")]
remaining = sorted(set(dag_6_2.nodes) - {"W", "D"} - set(nx.descendants(dag_6_2, "W")))
adjustment_sets = []
for size in range(len(remaining) + 1):
for subset in itertools.combinations(remaining, size):
subset = set(subset)
if any(s.issubset(subset) for s in adjustment_sets):
continue
need_adjust = True
for path in backdoor_paths:
d_separated = False
for x, z, y in zip(path[:-2], path[1:-1], path[2:]):
if dag_6_2.has_edge(x, z) and dag_6_2.has_edge(y, z):
if set(nx.descendants(dag_6_2, z)) & subset:
continue
d_separated = z not in subset
else:
d_separated = z in subset
if d_separated:
break
if not d_separated:
need_adjust = False
break
if need_adjust:
adjustment_sets.append(subset)
print(subset)
Out[30]:
Code 6.31¶
In [31]:
conditional_independencies = collections.defaultdict(list)
for edge in itertools.combinations(sorted(dag_6_2.nodes), 2):
remaining = sorted(set(dag_6_2.nodes) - set(edge))
for size in range(len(remaining) + 1):
for subset in itertools.combinations(remaining, size):
if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):
continue
if nx.d_separated(dag_6_2, {edge[0]}, {edge[1]}, set(subset)):
conditional_independencies[edge].append(set(subset))
print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else ""))
Out[31]:
Comments
Comments powered by Disqus