Chapter 6. The Haunted DAG & The Causal Terror

In [ ]:
!pip install -q numpyro arviz daft networkx
In [0]:
import collections
import itertools
import os
import warnings

import arviz as az
import daft
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd

import jax.numpy as jnp
from jax import lax, random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
    category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")

Code 6.1

In [1]:
with numpyro.handlers.seed(rng_seed=1914):
    N = 200  # num grant proposals
    p = 0.1  # proportion to select
    # uncorrelated newsworthiness and trustworthiness
    nw = numpyro.sample("nw", dist.Normal().expand([N]))
    tw = numpyro.sample("tw", dist.Normal().expand([N]))
    # select top 10% of combined scores
    s = nw + tw  # total score
    q = jnp.quantile(s, 1 - p)  # top 10% threshold
    selected = jnp.where(s >= q, True, False)
jnp.corrcoef(jnp.stack([tw[selected], nw[selected]], 0))[0, 1]
Out[1]:
DeviceArray(-0.6453402, dtype=float32)

Code 6.2

In [2]:
N = 100  # number of individuals
with numpyro.handlers.seed(rng_seed=909):
    # sim total height of each
    height = numpyro.sample("height", dist.Normal(10, 2).expand([N]))
    # leg as proportion of height
    leg_prop = numpyro.sample("prop", dist.Uniform(0.4, 0.5).expand([N]))
    # sim left leg as proportion + error
    leg_left = leg_prop * height + numpyro.sample(
        "left_error", dist.Normal(0, 0.02).expand([N])
    )
    # sim right leg as proportion + error
    leg_right = leg_prop * height + numpyro.sample(
        "right_error", dist.Normal(0, 0.02).expand([N])
    )
    # combine into data frame
    d = pd.DataFrame({"height": height, "leg_left": leg_left, "leg_right": leg_right})

Code 6.3

In [3]:
def model(leg_left, leg_right, height):
    a = numpyro.sample("a", dist.Normal(10, 100))
    bl = numpyro.sample("bl", dist.Normal(2, 10))
    br = numpyro.sample("br", dist.Normal(2, 10))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bl * leg_left + br * leg_right
    numpyro.sample("height", dist.Normal(mu, sigma), obs=height)


m6_1 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m6_1,
    optim.Adam(0.1),
    Trace_ELBO(),
    leg_left=d.leg_left.values,
    leg_right=d.leg_right.values,
    height=d.height.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p6_1 = svi_result.params
post = m6_1.sample_posterior(random.PRNGKey(1), p6_1, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[3]:
100%|██████████| 2000/2000 [00:01<00:00, 1291.68it/s, init loss: 62894.3672, avg. loss [1901-2000]: 112.9629]
Out[3]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      0.81      0.34      0.80      0.27      1.31   1049.96      1.00
        bl      2.61      2.28      2.63     -1.06      6.26    813.11      1.00
        br     -0.59      2.28     -0.60     -4.41      2.96    805.68      1.00
     sigma      0.67      0.05      0.67      0.60      0.74    968.52      1.00

Code 6.4

In [4]:
az.plot_forest(post, hdi_prob=0.89)
plt.show()
Out[4]:
No description has been provided for this image

Code 6.5

In [5]:
post = m6_1.sample_posterior(random.PRNGKey(1), p6_1, sample_shape=(1000,))
az.plot_pair(post, var_names=["br", "bl"], scatter_kwargs={"alpha": 0.1})
plt.show()
Out[5]:
No description has been provided for this image

Code 6.6

In [6]:
sum_blbr = post["bl"] + post["br"]
az.plot_kde(sum_blbr, label="sum of bl and br")
plt.show()
Out[6]:
No description has been provided for this image

Code 6.7

In [7]:
def model(leg_left, height):
    a = numpyro.sample("a", dist.Normal(10, 100))
    bl = numpyro.sample("bl", dist.Normal(2, 10))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bl * leg_left
    numpyro.sample("height", dist.Normal(mu, sigma), obs=height)


m6_2 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m6_2,
    optim.Adam(1),
    Trace_ELBO(),
    leg_left=d.leg_left.values,
    height=d.height.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_2 = svi_result.params
post = m6_2.sample_posterior(random.PRNGKey(1), p6_2, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[7]:
100%|██████████| 1000/1000 [00:01<00:00, 936.38it/s, init loss: 2486.2402, avg. loss [951-1000]: 108.4584]
Out[7]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      0.83      0.35      0.84      0.25      1.35    931.50      1.00
        bl      2.02      0.08      2.02      1.91      2.15    940.42      1.00
     sigma      0.67      0.05      0.67      0.60      0.75    949.09      1.00

Code 6.8

In [8]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())
d["F"] = d["perc.fat"].pipe(lambda x: (x - x.mean()) / x.std())
d["L"] = d["perc.lactose"].pipe(lambda x: (x - x.mean()) / x.std())

Code 6.9

In [9]:
# kcal.per.g regressed on perc.fat
def model(F, K):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bF = numpyro.sample("bF", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bF * F
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)


m6_3 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_3, optim.Adam(1), Trace_ELBO(), F=d.F.values, K=d.K.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_3 = svi_result.params


# kcal.per.g regressed on perc.lactose
def model(L, K):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bL = numpyro.sample("bL", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bL * L
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)


m6_4 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_4, optim.Adam(1), Trace_ELBO(), L=d.L.values, K=d.K.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_4 = svi_result.params

post = m6_3.sample_posterior(random.PRNGKey(1), p6_3, sample_shape=(1000,))
print_summary(post, 0.89, False)
post = m6_4.sample_posterior(random.PRNGKey(1), p6_4, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[9]:
100%|██████████| 1000/1000 [00:00<00:00, 1426.15it/s, init loss: 198.3563, avg. loss [951-1000]: 20.1324]
100%|██████████| 1000/1000 [00:00<00:00, 1464.26it/s, init loss: 1449.6163, avg. loss [951-1000]: 15.5010]
Out[9]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      0.01      0.08      0.01     -0.13      0.12    931.50      1.00
        bF      0.86      0.09      0.86      0.73      1.01   1111.41      1.00
     sigma      0.46      0.06      0.46      0.37      0.57    940.36      1.00


                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      0.01      0.07      0.01     -0.10      0.11    931.50      1.00
        bL     -0.90      0.07     -0.90     -1.01     -0.78   1111.89      1.00
     sigma      0.39      0.05      0.39      0.31      0.48    957.39      1.00

Code 6.10

In [10]:
def model(F, L, K):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bF = numpyro.sample("bF", dist.Normal(0, 0.5))
    bL = numpyro.sample("bL", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bF * F + bL * L
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)


m6_5 = AutoLaplaceApproximation(model)
svi = SVI(
    model, m6_5, optim.Adam(1), Trace_ELBO(), F=d.F.values, L=d.L.values, K=d.K.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_5 = svi_result.params
post = m6_5.sample_posterior(random.PRNGKey(1), p6_5, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[10]:
100%|██████████| 1000/1000 [00:00<00:00, 1346.20it/s, init loss: 1360.7051, avg. loss [951-1000]: 15.2050]
Out[10]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a     -0.02      0.07     -0.03     -0.13      0.07   1049.96      1.00
        bF      0.25      0.19      0.25     -0.05      0.56    823.80      1.00
        bL     -0.67      0.19     -0.67     -0.99     -0.37    875.48      1.00
     sigma      0.38      0.05      0.38      0.30      0.46    982.83      1.00

Code 6.11

In [11]:
az.plot_pair(d[["kcal.per.g", "perc.fat", "perc.lactose"]].to_dict("list"))
plt.show()
Out[11]:
No description has been provided for this image

Code 6.12

In [12]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk


def sim_coll(i, r=0.9):
    sd = jnp.sqrt((1 - r**2) * jnp.var(d["perc.fat"].values))
    x = dist.Normal(r * d["perc.fat"].values, sd).sample(random.PRNGKey(3 * i))

    def model(perc_fat, kcal_per_g):
        intercept = numpyro.sample("intercept", dist.Normal(0, 10))
        b_perc_flat = numpyro.sample("b_perc.fat", dist.Normal(0, 10))
        b_x = numpyro.sample("b_x", dist.Normal(0, 10))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(2))
        mu = intercept + b_perc_flat * perc_fat + b_x * x
        numpyro.sample("kcal.per.g", dist.Normal(mu, sigma), obs=kcal_per_g)

    m = AutoLaplaceApproximation(model)
    svi = SVI(
        model,
        m,
        optim.Adam(0.01),
        Trace_ELBO(),
        perc_fat=d["perc.fat"].values,
        kcal_per_g=d["kcal.per.g"].values,
    )
    svi_result = svi.run(random.PRNGKey(3 * i + 1), 20000, progress_bar=False)
    params = svi_result.params
    samples = m.sample_posterior(random.PRNGKey(3 * i + 2), params, sample_shape=(1000,))
    vcov = jnp.cov(jnp.stack(list(samples.values()), axis=0))
    stddev = jnp.sqrt(jnp.diag(vcov))  # stddev of parameter
    return dict(zip(samples.keys(), stddev))["b_perc.fat"]


def rep_sim_coll(r=0.9, n=100):
    stddev = lax.map(lambda i: sim_coll(i, r=r), jnp.arange(n))
    return jnp.nanmean(stddev)


r_seq = jnp.arange(start=0, stop=1, step=0.01)
stddev = lax.map(lambda z: rep_sim_coll(r=z, n=100), r_seq)
plt.plot(r_seq, stddev)
plt.xlabel("correlation")
plt.show()
Out[12]:
No description has been provided for this image

Code 6.13

In [13]:
with numpyro.handlers.seed(rng_seed=71):
    # number of plants
    N = 100

    # simulate initial heights
    h0 = numpyro.sample("h0", dist.Normal(10, 2).expand([N]))

    # assign treatments and simulate fungus and growth
    treatment = jnp.repeat(jnp.arange(2), repeats=N // 2)
    fungus = numpyro.sample(
        "fungus", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4))
    )
    h1 = h0 + numpyro.sample("diff", dist.Normal(5 - 3 * fungus))

    # compose a clean data frame
    d = pd.DataFrame({"h0": h0, "h1": h1, "treatment": treatment, "fungus": fungus})
print_summary(dict(zip(d.columns, d.T.values)), 0.89, False)
Out[13]:
                 mean       std    median      5.5%     94.5%     n_eff     r_hat
     fungus      0.31      0.46      0.00      0.00      1.00     18.52      1.17
         h0      9.73      1.95      9.63      7.05     13.33     80.22      0.99
         h1     13.72      2.47     13.60     10.73     18.38     43.44      1.08
  treatment      0.50      0.50      0.50      0.00      1.00      2.64       inf

Code 6.14

In [14]:
sim_p = dist.LogNormal(0, 0.25).sample(random.PRNGKey(0), (int(1e4),))
print_summary({"sim_p": sim_p}, 0.89, False)
Out[14]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
     sim_p      1.04      0.27      1.00      0.63      1.44   9936.32      1.00

Code 6.15

In [15]:
def model(h0, h1):
    p = numpyro.sample("p", dist.LogNormal(0, 0.25))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = h0 * p
    numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)


m6_6 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_6, optim.Adam(1), Trace_ELBO(), h0=d.h0.values, h1=d.h1.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_6 = svi_result.params
post = m6_6.sample_posterior(random.PRNGKey(1), p6_6, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[15]:
100%|██████████| 1000/1000 [00:00<00:00, 1057.45it/s, init loss: 279.8950, avg. loss [951-1000]: 200.2917]
Out[15]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         p      1.39      0.02      1.39      1.36      1.42    994.30      1.00
     sigma      1.84      0.13      1.84      1.65      2.06   1011.70      1.00

Code 6.16

In [16]:
def model(treatment, fungus, h0, h1):
    a = numpyro.sample("a", dist.LogNormal(0, 0.2))
    bt = numpyro.sample("bt", dist.Normal(0, 0.5))
    bf = numpyro.sample("bf", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    p = a + bt * treatment + bf * fungus
    mu = h0 * p
    numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)


m6_7 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m6_7,
    optim.Adam(0.3),
    Trace_ELBO(),
    treatment=d.treatment.values,
    fungus=d.fungus.values,
    h0=d.h0.values,
    h1=d.h1.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_7 = svi_result.params
post = m6_7.sample_posterior(random.PRNGKey(1), p6_7, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[16]:
100%|██████████| 1000/1000 [00:01<00:00, 763.78it/s, init loss: 151456.4062, avg. loss [951-1000]: 164.8744]
Out[16]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      1.47      0.03      1.47      1.43      1.51   1049.04      1.00
        bf     -0.28      0.03     -0.28     -0.33     -0.23    910.93      1.00
        bt      0.01      0.03      0.01     -0.03      0.06   1123.06      1.00
     sigma      1.39      0.10      1.39      1.21      1.54    976.96      1.00

Code 6.17

In [17]:
def model(treatment, h0, h1):
    a = numpyro.sample("a", dist.LogNormal(0, 0.2))
    bt = numpyro.sample("bt", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    p = a + bt * treatment
    mu = h0 * p
    numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)


m6_8 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m6_8,
    optim.Adam(1),
    Trace_ELBO(),
    treatment=d.treatment.values,
    h0=d.h0.values,
    h1=d.h1.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_8 = svi_result.params
post = m6_8.sample_posterior(random.PRNGKey(1), p6_8, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[17]:
100%|██████████| 1000/1000 [00:01<00:00, 855.77it/s, init loss: 87469.1172, avg. loss [951-1000]: 194.5041]
Out[17]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      1.33      0.02      1.33      1.29      1.37    930.82      1.00
        bt      0.13      0.04      0.12      0.08      0.19    880.02      1.00
     sigma      1.73      0.12      1.73      1.55      1.94    948.82      1.00

Code 6.18

In [3]:
plant_dag = nx.DiGraph()
plant_dag.add_edges_from([("H0", "H1"), ("F", "H1"), ("T", "F")])
pgm = daft.PGM()
coordinates = {"H0": (0, 0), "T": (4, 0), "F": (3, 0), "H1": (2, 0)}
for node in plant_dag.nodes:
    pgm.add_node(node, node, *coordinates[node])
for edge in plant_dag.edges:
    pgm.add_edge(*edge)
with plt.rc_context({"figure.constrained_layout.use": False}):
    pgm.render()
Out[3]:
No description has been provided for this image

Code 6.19

In [36]:
conditional_independencies = collections.defaultdict(list)
for edge in itertools.combinations(sorted(plant_dag.nodes), 2):
    remaining = sorted(set(plant_dag.nodes) - set(edge))
    for size in range(len(remaining) + 1):
        for subset in itertools.combinations(remaining, size):
            if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):
                continue
            if nx.d_separated(plant_dag, {edge[0]}, {edge[1]}, set(subset)):
                conditional_independencies[edge].append(set(subset))
                print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else ""))
Out[36]:
F _||_ H0
H0 _||_ T
H1 _||_ T | F

Code 6.20

In [20]:
with numpyro.handlers.seed(rng_seed=71):
    N = 1000
    h0 = numpyro.sample("h0", dist.Normal(10, 2).expand([N]))
    treatment = jnp.repeat(jnp.arange(2), repeats=N // 2)
    M = numpyro.sample("M", dist.Bernoulli(probs=0.5).expand([N]))
    fungus = numpyro.sample(
        "fungus", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4))
    )
    h1 = h0 + numpyro.sample("diff", dist.Normal(5 + 3 * M))
    d2 = pd.DataFrame({"h0": h0, "h1": h1, "treatment": treatment, "fungus": fungus})

Code 6.21

In [21]:
def sim_happiness(seed=1977, N_years=1000, max_age=65, N_births=20, aom=18):
    # age existing individuals & newborns
    A = jnp.repeat(jnp.arange(1, N_years + 1), N_births)
    # sim happiness trait - never changes
    H = jnp.repeat(jnp.linspace(-2, 2, N_births)[None, :], N_years, 0).reshape(-1)
    # not yet married
    M = jnp.zeros(N_years * N_births, dtype=jnp.int32)

    def update_M(i, M):
        # for each person over 17, chance get married
        married = dist.Bernoulli(logits=(H - 4)).sample(random.PRNGKey(seed + i))
        return jnp.where((A >= i) & (M == 0), married, M)

    M = lax.fori_loop(aom, max_age + 1, update_M, M)
    # mortality
    deaths = A > max_age
    A = A[~deaths]
    H = H[~deaths]
    M = M[~deaths]

    d = pd.DataFrame({"age": A, "married": M, "happiness": H})
    return d


d = sim_happiness(seed=1977, N_years=1000)
print_summary(dict(zip(d.columns, d.T.values)), 0.89, False)
Out[21]:
                 mean       std    median      5.5%     94.5%     n_eff     r_hat
        age     33.00     18.77     33.00      1.00     58.00      2.51      2.64
  happiness      0.00      1.21      0.00     -2.00      1.58    338.78      1.00
    married      0.28      0.45      0.00      0.00      1.00     48.04      1.18

Code 6.22

In [22]:
d2 = d[d.age > 17].copy()  # only adults
d2["A"] = (d2.age - 18) / (65 - 18)

Code 6.23

In [23]:
d2["mid"] = d2.married


def model(mid, A, happiness):
    a = numpyro.sample("a", dist.Normal(0, 1).expand([len(set(mid))]))
    bA = numpyro.sample("bA", dist.Normal(0, 2))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a[mid] + bA * A
    numpyro.sample("happiness", dist.Normal(mu, sigma), obs=happiness)


m6_9 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m6_9,
    optim.Adam(1),
    Trace_ELBO(),
    mid=d2.mid.values,
    A=d2.A.values,
    happiness=d2.happiness.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_9 = svi_result.params
post = m6_9.sample_posterior(random.PRNGKey(1), p6_9, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[23]:
100%|██████████| 1000/1000 [00:00<00:00, 1245.44it/s, init loss: 12964.1768, avg. loss [951-1000]: 1355.7842]
Out[23]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]     -0.20      0.06     -0.20     -0.30     -0.10   1049.96      1.00
      a[1]      1.23      0.09      1.23      1.09      1.37    898.97      1.00
        bA     -0.69      0.11     -0.69     -0.88     -0.53   1126.51      1.00
     sigma      1.02      0.02      1.02      0.98      1.05    966.00      1.00

Code 6.24

In [24]:
def model(A, happiness):
    a = numpyro.sample("a", dist.Normal(0, 1))
    bA = numpyro.sample("bA", dist.Normal(0, 2))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bA * A
    numpyro.sample("happiness", dist.Normal(mu, sigma), obs=happiness)


m6_10 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m6_10,
    optim.Adam(1),
    Trace_ELBO(),
    A=d2.A.values,
    happiness=d2.happiness.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_10 = svi_result.params
post = m6_10.sample_posterior(random.PRNGKey(1), p6_10, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[24]:
100%|██████████| 1000/1000 [00:00<00:00, 1244.73it/s, init loss: 19561.3906, avg. loss [951-1000]: 1520.8224]
Out[24]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      0.01      0.08      0.01     -0.12      0.12    931.50      1.00
        bA     -0.01      0.13     -0.01     -0.22      0.21    940.88      1.00
     sigma      1.21      0.03      1.21      1.17      1.26    949.78      1.00

Code 6.25

In [25]:
N = 200  # number of grandparent-parent-child triads
b_GP = 1  # direct effect of G on P
b_GC = 0  # direct effect of G on C
b_PC = 1  # direct effect of P on C
b_U = 2  # direct effect of U on P and C

Code 6.26

In [26]:
with numpyro.handlers.seed(rng_seed=1):
    U = 2 * numpyro.sample("U", dist.Bernoulli(0.5).expand([N])) - 1
    G = numpyro.sample("G", dist.Normal().expand([N]))
    P = numpyro.sample("P", dist.Normal(b_GP * G + b_U * U))
    C = numpyro.sample("C", dist.Normal(b_PC * P + b_GC * G + b_U * U))
    d = pd.DataFrame({"C": C, "P": P, "G": G, "U": U})

Code 6.27

In [27]:
def model(P, G, C):
    a = numpyro.sample("a", dist.Normal(0, 1))
    b_PC = numpyro.sample("b_PC", dist.Normal(0, 1))
    b_GC = numpyro.sample("b_GC", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + b_PC * P + b_GC * G
    numpyro.sample("C", dist.Normal(mu, sigma), obs=C)


m6_11 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m6_11,
    optim.Adam(0.3),
    Trace_ELBO(),
    P=d.P.values,
    G=d.G.values,
    C=d.C.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_11 = svi_result.params
post = m6_11.sample_posterior(random.PRNGKey(1), p6_11, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[27]:
100%|██████████| 1000/1000 [00:01<00:00, 733.78it/s, init loss: 4805.8169, avg. loss [951-1000]: 348.3594]
Out[27]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a     -0.08      0.10     -0.09     -0.24      0.06   1049.96      1.00
      b_GC     -0.71      0.11     -0.71     -0.89     -0.55    813.76      1.00
      b_PC      1.72      0.04      1.72      1.65      1.79    982.64      1.00
     sigma      1.39      0.07      1.39      1.28      1.49    968.54      1.00

Code 6.28

In [28]:
def model(P, G, U, C):
    a = numpyro.sample("a", dist.Normal(0, 1))
    b_PC = numpyro.sample("b_PC", dist.Normal(0, 1))
    b_GC = numpyro.sample("b_GC", dist.Normal(0, 1))
    b_U = numpyro.sample("U", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + b_PC * P + b_GC * G + b_U * U
    numpyro.sample("C", dist.Normal(mu, sigma), obs=C)


m6_12 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m6_12,
    optim.Adam(1),
    Trace_ELBO(),
    P=d.P.values,
    G=d.G.values,
    U=d.U.values,
    C=d.C.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_12 = svi_result.params
post = m6_12.sample_posterior(random.PRNGKey(1), p6_12, sample_shape=(1000,))
print_summary(post, 0.89, False)
Out[28]:
100%|██████████| 1000/1000 [00:01<00:00, 668.45it/s, init loss: 565.4859, avg. loss [951-1000]: 300.9767]
Out[28]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         U      1.87      0.17      1.88      1.59      2.11   1009.20      1.00
         a     -0.06      0.08     -0.05     -0.18      0.07    766.03      1.00
      b_GC      0.01      0.10      0.01     -0.15      0.17   1031.98      1.00
      b_PC      0.99      0.07      0.99      0.88      1.11   1106.62      1.00
     sigma      1.08      0.05      1.08      0.99      1.16    797.58      1.00

Code 6.29

In [29]:
dag_6_1 = nx.DiGraph()
dag_6_1.add_edges_from(
    [("X", "Y"), ("U", "X"), ("A", "U"), ("A", "C"), ("C", "Y"), ("U", "B"), ("C", "B")])
backdoor_paths = [path for path in nx.all_simple_paths(dag_6_1.to_undirected(), "X", "Y")
                  if dag_6_1.has_edge(path[1], "X")]
remaining = sorted(set(dag_6_1.nodes) - {"X", "Y", "U"} - set(nx.descendants(dag_6_1, "X")))
adjustment_sets = []
for size in range(len(remaining) + 1):
    for subset in itertools.combinations(remaining, size):
        subset = set(subset)
        if any(s.issubset(subset) for s in adjustment_sets):
            continue
        need_adjust = True
        for path in backdoor_paths:
            d_separated = False
            for x, z, y in zip(path[:-2], path[1:-1], path[2:]):
                if dag_6_1.has_edge(x, z) and dag_6_1.has_edge(y, z):
                    if set(nx.descendants(dag_6_1, z)) & subset:
                        continue
                    d_separated = z not in subset
                else:
                    d_separated = z in subset
                if d_separated:
                    break
            if not d_separated:
                need_adjust = False
                break
        if need_adjust:
            adjustment_sets.append(subset)
            print(subset)
Out[29]:
{'A'}
{'C'}

Code 6.30

In [30]:
dag_6_2 = nx.DiGraph()
dag_6_2.add_edges_from(
    [("S", "A"), ("A", "D"), ("S", "M"), ("M", "D"), ("S", "W"), ("W", "D"), ("A", "M")])
backdoor_paths = [path for path in nx.all_simple_paths(dag_6_2.to_undirected(), "W", "D")
                  if dag_6_2.has_edge(path[1], "W")]
remaining = sorted(set(dag_6_2.nodes) - {"W", "D"} - set(nx.descendants(dag_6_2, "W")))
adjustment_sets = []
for size in range(len(remaining) + 1):
    for subset in itertools.combinations(remaining, size):
        subset = set(subset)
        if any(s.issubset(subset) for s in adjustment_sets):
            continue
        need_adjust = True
        for path in backdoor_paths:
            d_separated = False
            for x, z, y in zip(path[:-2], path[1:-1], path[2:]):
                if dag_6_2.has_edge(x, z) and dag_6_2.has_edge(y, z):
                    if set(nx.descendants(dag_6_2, z)) & subset:
                        continue
                    d_separated = z not in subset
                else:
                    d_separated = z in subset
                if d_separated:
                    break
            if not d_separated:
                need_adjust = False
                break
        if need_adjust:
            adjustment_sets.append(subset)
            print(subset)
Out[30]:
{'S'}
{'A', 'M'}

Code 6.31

In [31]:
conditional_independencies = collections.defaultdict(list)
for edge in itertools.combinations(sorted(dag_6_2.nodes), 2):
    remaining = sorted(set(dag_6_2.nodes) - set(edge))
    for size in range(len(remaining) + 1):
        for subset in itertools.combinations(remaining, size):
            if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):
                continue
            if nx.d_separated(dag_6_2, {edge[0]}, {edge[1]}, set(subset)):
                conditional_independencies[edge].append(set(subset))
                print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else ""))
Out[31]:
A _||_ W | S
D _||_ S | A M W
M _||_ W | S

Comments

Comments powered by Disqus