Chapter 15. Missing Data and Other Opportunities

In [0]:
import math
import os

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd

import jax.numpy as jnp
from jax import ops, random, vmap
from jax.scipy.special import expit

import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import print_summary
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, init_to_value

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
numpyro.set_host_device_count(4)

Code 15.1

In [1]:
# simulate a pancake and return randomly ordered sides
def sim_pancake(seed):
    pancake = dist.Categorical(logits=jnp.ones(3)).sample(random.PRNGKey(2 * seed))
    sides = jnp.array([1, 1, 1, 0, 0, 0]).reshape(3, 2).T[:, pancake]
    return random.permutation(random.PRNGKey(2 * seed + 1), sides)


# sim 10,000 pancakes
pancakes = vmap(sim_pancake, out_axes=1)(jnp.arange(10000))
up = pancakes[0]
down = pancakes[1]

# compute proportion 1/1 (BB) out of all 1/1 and 1/0
num_11_10 = jnp.sum(up == 1)
num_11 = jnp.sum((up == 1) & (down == 1))
num_11 / num_11_10
Out[1]:
DeviceArray(0.6641716, dtype=float32)

Code 15.2

In [2]:
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce

# points
ax = az.plot_pair(
    d[["MedianAgeMarriage", "Divorce"]].to_dict(orient="list"),
    scatter_kwargs=dict(ms=15, mfc="none"),
)
ax.set(ylim=(4, 15), xlabel="Median age marrage", ylabel="Divorce rate")

# standard errors
for i in range(d.shape[0]):
    ci = d.Divorce[i] + jnp.array([-1, 1]) * d["Divorce SE"][i]
    x = d.MedianAgeMarriage[i]
    plt.plot([x, x], ci, "k")
Out[2]:

Code 15.3

In [3]:
dlist = dict(
    D_obs=d.Divorce.pipe(lambda x: (x - x.mean()) / x.std()).values,
    D_sd=d["Divorce SE"].values / d.Divorce.std(),
    M=d.Marriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
    A=d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
    N=d.shape[0],
)


def model(A, M, D_sd, D_obs, N):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bA = numpyro.sample("bA", dist.Normal(0, 0.5))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bA * A + bM * M
    D_true = numpyro.sample("D_true", dist.Normal(mu, sigma))
    numpyro.sample("D_obs", dist.Normal(D_true, D_sd), obs=D_obs)


m15_1 = MCMC(NUTS(model), 500, 500, num_chains=4)
m15_1.run(random.PRNGKey(0), **dlist)

Code 15.4

In [4]:
m15_1.print_summary(0.89)
Out[4]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
 D_true[0]      1.17      0.37      1.17      0.57      1.74   2401.51      1.00
 D_true[1]      0.68      0.57      0.68     -0.22      1.60   3620.74      1.00
 D_true[2]      0.43      0.34      0.42     -0.14      0.96   3700.02      1.00
 D_true[3]      1.42      0.45      1.42      0.68      2.11   3091.26      1.00
 D_true[4]     -0.90      0.13     -0.90     -1.12     -0.70   4414.02      1.00
 D_true[5]      0.66      0.39      0.66      0.05      1.29   3429.74      1.00
 D_true[6]     -1.38      0.34     -1.38     -1.92     -0.85   3758.00      1.00
 D_true[7]     -0.34      0.49     -0.34     -1.06      0.48   3136.96      1.00
 D_true[8]     -1.88      0.59     -1.88     -2.77     -0.91   2435.42      1.00
 D_true[9]     -0.62      0.16     -0.62     -0.88     -0.36   4020.95      1.00
D_true[10]      0.77      0.27      0.78      0.33      1.20   3201.66      1.00
D_true[11]     -0.55      0.48     -0.56     -1.33      0.23   2645.62      1.00
D_true[12]      0.17      0.48      0.17     -0.62      0.89   1486.11      1.00
D_true[13]     -0.87      0.23     -0.87     -1.24     -0.48   3540.33      1.00
D_true[14]      0.56      0.31      0.55      0.12      1.09   3907.98      1.00
D_true[15]      0.29      0.38      0.29     -0.30      0.91   4415.77      1.00
D_true[16]      0.50      0.42      0.49     -0.14      1.15   4748.56      1.00
D_true[17]      1.25      0.34      1.24      0.73      1.82   3141.66      1.00
D_true[18]      0.43      0.38      0.43     -0.23      0.96   3675.14      1.00
D_true[19]      0.41      0.53      0.40     -0.44      1.23   2010.13      1.00
D_true[20]     -0.55      0.32     -0.55     -1.07     -0.06   3516.71      1.00
D_true[21]     -1.10      0.26     -1.10     -1.52     -0.71   3084.19      1.00
D_true[22]     -0.27      0.26     -0.26     -0.68      0.13   4074.25      1.00
D_true[23]     -1.00      0.29     -1.00     -1.45     -0.54   3238.19      1.00
D_true[24]      0.43      0.40      0.42     -0.19      1.07   3781.77      1.00
D_true[25]     -0.03      0.31     -0.03     -0.51      0.47   4161.37      1.00
D_true[26]      0.00      0.51      0.02     -0.84      0.80   3814.40      1.00
D_true[27]     -0.16      0.40     -0.16     -0.84      0.43   4237.18      1.00
D_true[28]     -0.26      0.48     -0.29     -1.04      0.49   3130.68      1.00
D_true[29]     -1.81      0.23     -1.81     -2.15     -1.39   3863.86      1.00
D_true[30]      0.18      0.44      0.19     -0.52      0.89   3790.34      1.00
D_true[31]     -1.66      0.17     -1.66     -1.92     -1.39   3620.41      1.00
D_true[32]      0.12      0.24      0.12     -0.27      0.49   3476.00      1.00
D_true[33]     -0.06      0.52     -0.04     -0.90      0.74   2154.18      1.00
D_true[34]     -0.12      0.22     -0.12     -0.45      0.23   4052.89      1.00
D_true[35]      1.28      0.42      1.27      0.55      1.89   3505.41      1.00
D_true[36]      0.23      0.35      0.23     -0.32      0.80   4023.45      1.00
D_true[37]     -1.02      0.22     -1.01     -1.36     -0.68   4202.55      1.00
D_true[38]     -0.93      0.56     -0.94     -1.93     -0.12   3218.85      1.00
D_true[39]     -0.68      0.32     -0.68     -1.22     -0.19   4395.19      1.00
D_true[40]      0.24      0.54      0.24     -0.61      1.10   3664.71      1.00
D_true[41]      0.75      0.33      0.74      0.19      1.27   3079.24      1.00
D_true[42]      0.19      0.18      0.19     -0.09      0.47   3873.93      1.00
D_true[43]      0.80      0.42      0.82      0.11      1.44   2345.56      1.00
D_true[44]     -0.40      0.51     -0.41     -1.26      0.36   2814.81      1.00
D_true[45]     -0.39      0.24     -0.40     -0.80     -0.02   3852.89      1.00
D_true[46]      0.15      0.31      0.16     -0.35      0.63   3782.98      1.00
D_true[47]      0.57      0.46      0.58     -0.18      1.29   4495.36      1.00
D_true[48]     -0.64      0.27     -0.64     -1.05     -0.21   3630.93      1.00
D_true[49]      0.84      0.62      0.85     -0.13      1.83   2526.49      1.00
         a     -0.05      0.10     -0.05     -0.22      0.09   2354.00      1.00
        bA     -0.62      0.16     -0.62     -0.88     -0.38   2044.70      1.00
        bM      0.05      0.16      0.05     -0.21      0.31   1410.64      1.00
     sigma      0.59      0.11      0.58      0.41      0.76    752.39      1.01

Number of divergences: 0

Code 15.5

In [5]:
dlist = dict(
    D_obs=d.Divorce.pipe(lambda x: (x - x.mean()) / x.std()).values,
    D_sd=d["Divorce SE"].values / d.Divorce.std(),
    M_obs=d.Marriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
    M_sd=d["Marriage SE"].values / d.Marriage.std(),
    A=d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
    N=d.shape[0],
)


def model(A, M_sd, M_obs, D_sd, D_obs, N):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bA = numpyro.sample("bA", dist.Normal(0, 0.5))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    M_est = numpyro.sample("M_est", dist.Normal(0, 1).expand([N]))
    numpyro.sample("M_obs", dist.Normal(M_est, M_sd), obs=M_obs)
    mu = a + bA * A + bM * M_est
    D_est = numpyro.sample("D_est", dist.Normal(mu, sigma))
    numpyro.sample("D_obs", dist.Normal(D_est, D_sd), obs=D_obs)


m15_2 = MCMC(NUTS(model), 500, 500, num_chains=4)
m15_2.run(random.PRNGKey(0), **dlist)

Code 15.6

In [6]:
post = m15_2.get_samples()
D_est = jnp.mean(post["D_est"], 0)
M_est = jnp.mean(post["M_est"], 0)
plt.plot(dlist["M_obs"], dlist["D_obs"], "bo", alpha=0.5)
plt.gca().set(xlabel="marriage rate (std)", ylabel="divorce rate (std)")
plt.plot(M_est, D_est, "ko", mfc="none")
for i in range(d.shape[0]):
    plt.plot([dlist["M_obs"][i], M_est[i]], [dlist["D_obs"][i], D_est[i]], "k-", lw=1)
Out[6]:

Code 15.7

In [7]:
N = 500
A = dist.Normal().sample(random.PRNGKey(0), (N,))
M = dist.Normal(-A).sample(random.PRNGKey(1))
D = dist.Normal(A).sample(random.PRNGKey(2))
A_obs = dist.Normal(A).sample(random.PRNGKey(3))

Code 15.8

In [8]:
N = 100
S = dist.Normal().sample(random.PRNGKey(0), (N,))
H = dist.Binomial(10, expit(S)).sample(random.PRNGKey(1))

Code 15.9

In [9]:
D = dist.Bernoulli(0.5).sample(random.PRNGKey(2))  # dogs completely random
Hm = jnp.where(D == 1, jnp.nan, H)

Code 15.10

In [10]:
D = jnp.where(S > 0, 1, 0)
Hm = jnp.where(D == 1, jnp.nan, H)

Code 15.11

In [11]:
with numpyro.handlers.seed(rng_seed=501):
    N = 1000
    X = numpyro.sample("X", dist.Normal().expand([N]))
    S = numpyro.sample("S", dist.Normal().expand([N]))
    H = numpyro.sample("H", dist.Binomial(10, logits=2 + S - 2 * X))
    D = jnp.where(X > 1, 1, 0)
    Hm = jnp.where(D == 1, jnp.nan, H)

Code 15.12

In [12]:
dat_list = dict(H=H, S=S)


def model(S, H):
    a = numpyro.sample("a", dist.Normal(0, 1))
    bS = numpyro.sample("bS", dist.Normal(0, 0.5))
    logit_p = a + bS * S
    numpyro.sample("H", dist.Binomial(10, logits=logit_p), obs=H)


m15_3 = MCMC(NUTS(model), 500, 500, num_chains=4)
m15_3.run(random.PRNGKey(0), **dat_list)
m15_3.print_summary()
Out[12]:
                mean       std    median      5.0%     95.0%     n_eff     r_hat
         a      1.32      0.03      1.32      1.28      1.36   1268.10      1.00
        bS      0.62      0.03      0.62      0.58      0.66   1078.05      1.00

Number of divergences: 0

Code 15.13

In [13]:
dat_list0 = dict(H=H[D == 0], S=S[D == 0])


def model(S, H):
    a = numpyro.sample("a", dist.Normal(0, 1))
    bS = numpyro.sample("bS", dist.Normal(0, 0.5))
    logit_p = a + bS * S
    numpyro.sample("H", dist.Binomial(10, logits=logit_p), obs=H)


m15_4 = MCMC(NUTS(model), 500, 500, num_chains=4)
m15_4.run(random.PRNGKey(0), **dat_list0)
m15_4.print_summary()
Out[13]:
                mean       std    median      5.0%     95.0%     n_eff     r_hat
         a      1.92      0.04      1.92      1.86      1.97   1007.18      1.00
        bS      0.72      0.03      0.72      0.67      0.78    784.14      1.00

Number of divergences: 0

Code 15.14

In [14]:
D = jnp.where(jnp.abs(X) < 1, 1, 0)

Code 15.15

In [15]:
N = 100
S = dist.Normal().sample(random.PRNGKey(0), (N,))
H = dist.Binomial(10, logits=S).sample(random.PRNGKey(1))
D = jnp.where(H < 5, 1, 0)
Hm = jnp.where(D == 1, jnp.nan, H)

Code 15.16

In [16]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d["neocortex.prop"] = d["neocortex.perc"] / 100
d["logmass"] = d.mass.apply(math.log)

Code 15.17

In [17]:
dat_list = dict(
    K=d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std()).values,
    B=d["neocortex.prop"].pipe(lambda x: (x - x.mean()) / x.std()).values,
    M=d.logmass.pipe(lambda x: (x - x.mean()) / x.std()).values,
)


def model(B, M, K):
    a = numpyro.sample("a", dist.Normal(0, 0.5))
    nu = numpyro.sample("nu", dist.Normal(0, 0.5))
    bB = numpyro.sample("bB", dist.Normal(0, 0.5))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    sigma_B = numpyro.sample("sigma_B", dist.Exponential(1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    B_impute = numpyro.sample(
        "B_impute", dist.Normal(0, 1).expand([int(jnp.isnan(B).sum())]).mask(False)
    )
    B = ops.index_update(B, jnp.nonzero(jnp.isnan(B))[0], B_impute)
    numpyro.sample("B", dist.Normal(nu, sigma_B), obs=B)
    mu = a + bB * B + bM * M
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)


m15_5 = MCMC(NUTS(model), 500, 500, num_chains=4)
m15_5.run(random.PRNGKey(0), **dat_list)

Code 15.18

In [18]:
m15_5.print_summary(0.89)
Out[18]:
                  mean       std    median      5.5%     94.5%     n_eff     r_hat
 B_impute[0]     -0.59      0.93     -0.61     -2.04      0.88   1728.72      1.00
 B_impute[1]     -0.71      0.96     -0.76     -2.36      0.72   1507.54      1.00
 B_impute[2]     -0.74      0.95     -0.77     -2.37      0.72   1925.34      1.00
 B_impute[3]     -0.30      0.92     -0.30     -1.77      1.12   2162.72      1.00
 B_impute[4]      0.45      0.89      0.43     -0.97      1.83   1891.39      1.00
 B_impute[5]     -0.18      0.92     -0.19     -1.59      1.34   2179.08      1.00
 B_impute[6]      0.17      0.89      0.18     -1.07      1.73   2225.82      1.00
 B_impute[7]      0.31      0.86      0.32     -1.10      1.58   2121.79      1.00
 B_impute[8]      0.51      0.87      0.54     -0.76      1.98   2257.69      1.00
 B_impute[9]     -0.44      0.93     -0.45     -1.91      1.06   1837.62      1.00
B_impute[10]     -0.28      0.89     -0.30     -1.63      1.19   2123.46      1.00
B_impute[11]      0.14      0.89      0.16     -1.28      1.56   2139.01      1.00
           a      0.04      0.17      0.04     -0.23      0.30   1876.58      1.00
          bB      0.50      0.24      0.51      0.15      0.89    663.53      1.01
          bM     -0.55      0.20     -0.55     -0.86     -0.23    765.52      1.01
          nu     -0.05      0.21     -0.05     -0.36      0.28   1606.00      1.00
       sigma      0.84      0.14      0.83      0.62      1.05    869.98      1.01
     sigma_B      1.01      0.17      0.99      0.75      1.27   1034.07      1.00

Number of divergences: 0

Code 15.19

In [19]:
obs_idx = d["neocortex.prop"].notnull().values
dat_list_obs = dict(
    K=dat_list["K"][obs_idx], B=dat_list["B"][obs_idx], M=dat_list["M"][obs_idx]
)


def model(B, M, K):
    a = numpyro.sample("a", dist.Normal(0, 0.5))
    nu = numpyro.sample("nu", dist.Normal(0, 0.5))
    bB = numpyro.sample("bB", dist.Normal(0, 0.5))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    sigma_B = numpyro.sample("sigma_B", dist.Exponential(1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    numpyro.sample("B", dist.Normal(nu, sigma_B), obs=B)
    mu = a + bB * B + bM * M
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)


m15_6 = MCMC(NUTS(model), 500, 500, num_chains=4)
m15_6.run(random.PRNGKey(0), **dat_list_obs)
m15_6.print_summary(0.89)
Out[19]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
         a      0.10      0.20      0.09     -0.22      0.43   2589.09      1.00
        bB      0.61      0.28      0.62      0.18      1.03   1375.60      1.00
        bM     -0.64      0.25     -0.65     -1.05     -0.27   1340.98      1.00
        nu      0.00      0.23      0.00     -0.37      0.37   2336.09      1.00
     sigma      0.87      0.18      0.84      0.62      1.15   1389.21      1.00
   sigma_B      1.05      0.19      1.03      0.73      1.31   2136.89      1.00

Number of divergences: 0

Code 15.20

In [20]:
az.plot_forest(
    [az.from_numpyro(m15_5), az.from_numpyro(m15_6)],
    model_names=["m15.5", "m15.6"],
    var_names=["bB", "bM"],
    combined=True,
    hdi_prob=0.89,
)
plt.show()
Out[20]:

Code 15.21

In [21]:
post = m15_5.get_samples()
B_impute_mu = jnp.mean(post["B_impute"], 0)
B_impute_ci = jnp.percentile(post["B_impute"], q=(5.5, 94.5), axis=0)

# B vs K
plt.plot(dat_list["B"], dat_list["K"], "o")
plt.gca().set(xlabel="neocortex percent (std)", ylabel="kcal mild (std)")
miss_idx = pd.isna(dat_list["B"]).nonzero()[0]
Ki = dat_list["K"][miss_idx]
plt.plot(B_impute_mu, Ki, "ko", mfc="none")
for i in range(12):
    plt.plot(B_impute_ci[:, i], jnp.repeat(Ki[i], 2), "k", lw=1)
plt.show()

# M vs B
plt.plot(dat_list["M"], dat_list["B"], "o")
plt.gca().set(xlabel="log body mass (std)", ylabel="neocortex percent (std)")
Mi = dat_list["M"][miss_idx]
plt.plot(Mi, B_impute_mu, "ko", mfc="none")
for i in range(12):
    plt.plot(jnp.repeat(Mi[i], 2), B_impute_ci[:, i], "k", lw=1)
Out[21]:
Out[21]:

Code 15.22

In [22]:
def model(B, M, K):
    # priors
    a = numpyro.sample("a", dist.Normal(0, 0.5))
    muB = numpyro.sample("muB", dist.Normal(0, 0.5))
    muM = numpyro.sample("muM", dist.Normal(0, 0.5))
    bB = numpyro.sample("bB", dist.Normal(0, 0.5))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    Rho_BM = numpyro.sample("Rho_BM", dist.LKJ(2, 2))
    Sigma_BM = numpyro.sample("Sigma_BM", dist.Exponential(1).expand([2]))

    # define B_merge as mix of observed and imputed values
    B_impute = numpyro.sample(
        "B_impute", dist.Normal(0, 1).expand([int(jnp.isnan(B).sum())]).mask(False)
    )
    B_merge = ops.index_update(B, jnp.nonzero(jnp.isnan(B))[0], B_impute)

    # M and B correlation
    MB = jnp.stack([M, B_merge], axis=1)
    cov = jnp.outer(Sigma_BM, Sigma_BM) * Rho_BM
    numpyro.sample("MB", dist.MultivariateNormal(jnp.stack([muM, muB]), cov), obs=MB)

    # K as function of B and M
    mu = a + bB * B_merge + bM * M
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)


m15_7 = MCMC(NUTS(model), 500, 500, num_chains=4)
m15_7.run(random.PRNGKey(0), **dat_list)
post = m15_7.get_samples(group_by_chain=True)
print_summary({k: v for k, v in post.items() if k in ["bM", "bB", "Rho_BM"]})
Out[22]:
                 mean       std    median      5.0%     95.0%     n_eff     r_hat
Rho_BM[0,0]      1.00      0.00      1.00      1.00      1.00       nan       nan
Rho_BM[0,1]      0.61      0.14      0.63      0.39      0.81   1384.32      1.00
Rho_BM[1,0]      0.61      0.14      0.63      0.39      0.81   1384.32      1.00
Rho_BM[1,1]      1.00      0.00      1.00      1.00      1.00   1599.26      1.00
         bB      0.59      0.26      0.61      0.13      0.97    937.14      1.00
         bM     -0.65      0.22     -0.66     -1.01     -0.28   1130.60      1.00

Code 15.23

In [23]:
B_missidx = pd.isna(dat_list["B"]).nonzero()[0]

Code 15.24

In [24]:
Moralizing_gods = pd.read_csv("../data/Moralizing_gods.csv", sep=";")
Moralizing_gods
Out[24]:
polity year population moralizing_gods writing
0 Big Island Hawaii 1000 3.729643 NaN 0
1 Big Island Hawaii 1100 3.729643 NaN 0
2 Big Island Hawaii 1200 3.598340 NaN 0
3 Big Island Hawaii 1300 4.026240 NaN 0
4 Big Island Hawaii 1400 4.311767 NaN 0
... ... ... ... ... ...
859 Yemeni Coastal Plain 1400 6.763083 1.0 1
860 Yemeni Coastal Plain 1500 6.519621 1.0 1
861 Konya Plain 1600 7.447158 1.0 1
862 Yemeni Coastal Plain 1700 3.882606 1.0 1
863 Yemeni Coastal Plain 1800 3.882606 1.0 1

864 rows × 5 columns

Code 15.25

In [25]:
Moralizing_gods.moralizing_gods.value_counts(dropna=False)
Out[25]:
NaN    528
1.0    319
0.0     17
Name: moralizing_gods, dtype: int64

Code 15.26

In [26]:
symbol = Moralizing_gods.moralizing_gods.apply(lambda x: "." if x == 1 else "o")
symbol[Moralizing_gods.moralizing_gods.isna()] = "x"
color = Moralizing_gods.moralizing_gods.apply(lambda x: "k" if pd.isna(x) else "b")
for pch in ["o", ".", "x"]:
    plt.scatter(
        Moralizing_gods.year[symbol == pch],
        Moralizing_gods.population[symbol == pch],
        marker=pch,
        color=color[symbol == pch],
        facecolor="none" if pch == "o" else None,
        lw=1.5,
        alpha=0.7,
    )
plt.gca().set(xlabel="Time (year)", ylabel="Population size")
plt.show()
Out[26]:

Code 15.27

In [27]:
dmg = Moralizing_gods
dmg.astype(str).groupby(["moralizing_gods", "writing"]).size().unstack(fill_value=0)
Out[27]:
writing 0 1
moralizing_gods
0.0 16 1
1.0 9 310
nan 442 86

Code 15.28

In [28]:
dmg = Moralizing_gods
haw = dmg.polity == "Big Island Hawaii"
dmg.loc[haw, ["year", "population", "writing", "moralizing_gods"]].T.round(3)
Out[28]:
0 1 2 3 4 5 6 7 8
year 1000.00 1100.00 1200.000 1300.000 1400.000 1500.000 1600.000 1700.000 1800.000
population 3.73 3.73 3.598 4.026 4.312 4.205 4.374 5.158 4.997
writing 0.00 0.00 0.000 0.000 0.000 0.000 0.000 0.000 0.000
moralizing_gods NaN NaN NaN NaN NaN NaN NaN NaN 1.000

Code 15.29

In [29]:
with numpyro.handlers.seed(rng_seed=9):
    N_houses = 100
    alpha = 5
    beta = -3
    k = 0.5
    r = 0.2
    cat = numpyro.sample("cat", dist.Bernoulli(k).expand([N_houses]))
    notes = numpyro.sample("notes", dist.Poisson(alpha + beta * cat))
    R_C = numpyro.sample("R_C", dist.Bernoulli(r).expand([N_houses]))
    cat_obs = jnp.where(R_C == 1, -9, cat)

Code 15.30

In [30]:
dat = dict(notes=notes, cat=cat_obs, RC=R_C, N=N_houses - 1)


def model(N, RC, cat, notes):
    # priors
    a = numpyro.sample("a", dist.Normal(0, 1))
    b = numpyro.sample("b", dist.Normal(0, 0.5))

    # sneaking cat model
    k = numpyro.sample("k", dist.Beta(2, 2))
    numpyro.sample("cat|RC==0", dist.Bernoulli(k), obs=cat[RC == 0])

    # singing bird model
    # cat NA:
    custom_logprob = jnp.logaddexp(
        jnp.log(k) + dist.Poisson(jnp.exp(a + b)).log_prob(notes[RC == 1]),
        jnp.log(1 - k) + dist.Poisson(jnp.exp(a)).log_prob(notes[RC == 1]),
    )
    numpyro.factor("notes|RC==1", custom_logprob)
    # cat known present/absent:
    lambda_ = jnp.exp(a + b * cat[RC == 0])
    numpyro.sample("notes|RC==0", dist.Poisson(lambda_), obs=notes[RC == 0])


m15_8 = MCMC(NUTS(model), 500, 500, num_chains=4)
m15_8.run(random.PRNGKey(0), **dat)

Code 15.31

In [31]:
def model(N, RC, cat, notes, link=False):
    a = numpyro.sample("a", dist.Normal(0, 1))
    b = numpyro.sample("b", dist.Normal(0, 0.5))

    # sneaking cat model
    k = numpyro.sample("k", dist.Beta(2, 2))
    numpyro.sample("cat|RC==0", dist.Bernoulli(k), obs=cat[RC == 0])

    # singing bird model
    custom_logprob = jnp.logaddexp(
        jnp.log(k) + dist.Poisson(jnp.exp(a + b)).log_prob(notes[RC == 1]),
        jnp.log(1 - k) + dist.Poisson(jnp.exp(a)).log_prob(notes[RC == 1]),
    )
    numpyro.factor("notes|RC==1", custom_logprob)
    lambda_ = jnp.exp(a + b * cat[RC == 0])
    numpyro.sample("notes|RC==0", dist.Poisson(lambda_), obs=notes[RC == 0])

    if link:
        lpC0 = numpyro.deterministic(
            "lpC0", jnp.log(1 - k) + dist.Poisson(jnp.exp(a)).log_prob(notes)
        )
        lpC1 = numpyro.deterministic(
            "lpC0", jnp.log(k) + dist.Poisson(jnp.exp(a + b)).log_prob(notes)
        )
        numpyro.deterministic("PrC1", jnp.exp(lpC1) / (jnp.exp(lpC1) + jnp.exp(lpC0)))


m15_9 = MCMC(NUTS(model), 500, 500, num_chains=4)
m15_9.run(random.PRNGKey(0), **dat)

Code 15.32

In [32]:
with numpyro.handlers.seed(rng_seed=100):
    x = numpyro.sample("x", dist.Normal().expand([10]))
    y = numpyro.sample("y", dist.Normal(x))
    x = jnp.concatenate([x, jnp.array([jnp.nan])])
    y = jnp.concatenate([y, jnp.array([100])])
    d = dict(x=x, y=y)

Code 15.33

In [33]:
Primates301 = pd.read_csv("../data/Primates301.csv", sep=";")
d = Primates301
cc = d.dropna(subset=["brain", "body"]).index
B = d.brain[cc]
M = d.body[cc]
B = B.values / max(B)
M = M.values / max(M)

Code 15.34

In [34]:
Bse = B * 0.1
Mse = M * 0.1

Code 15.35

In [35]:
dat_list = dict(B=B, M=M)


def model(M, B):
    a = numpyro.sample("a", dist.Normal(0, 1))
    b = numpyro.sample("b", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + b * jnp.log(M)
    numpyro.sample("B", dist.LogNormal(mu, sigma), obs=B)


m15H4 = MCMC(NUTS(model), 500, 500)
m15H4.run(random.PRNGKey(0), **dat_list)
Out[35]:
sample: 100%|██████████| 1000/1000 [00:12<00:00, 80.62it/s, 7 steps of size 2.96e-01. acc. prob=0.93] 

Code 15.36

In [36]:
start = dict(M_true=dat_list["M"], B_true=dat_list["B"])
init_strategy = init_to_value(values=start)

Code 15.37

In [37]:
Primates301 = pd.read_csv("../data/Primates301.csv", sep=";")
d = Primates301
d.isna().sum()
Out[37]:
name                     0
genus                    0
species                  0
subspecies             267
spp_id                   0
genus_id                 0
social_learning         98
research_effort        115
brain                  117
body                    63
group_size             114
gestation              161
weaning                185
longevity              181
sex_maturity           194
maternal_investment    197
dtype: int64

Code 15.38

In [38]:
cc = d.dropna(subset=["body"]).index
M = d.body[cc]
M = M.values / max(M)
B = d.brain[cc]
B = B.values / B.max(skipna=True)

Code 15.39

In [39]:
start = dict(B_impute=jnp.repeat(0.5, 56))
init_strategy = init_to_value(values=start)

Comments

Comments powered by Disqus