Chapter 15. Missing Data and Other Opportunities
In [ ]:
!pip install -q numpyro arviz
In [0]:
import math
import os
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax.numpy as jnp
from jax import ops, random, vmap
from jax.scipy.special import expit
import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import print_summary
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, init_to_value
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
Code 15.1¶
In [1]:
# simulate a pancake and return randomly ordered sides
def sim_pancake(seed):
pancake = dist.Categorical(logits=jnp.ones(3)).sample(random.PRNGKey(2 * seed))
sides = jnp.array([1, 1, 1, 0, 0, 0]).reshape(3, 2).T[:, pancake]
return random.permutation(random.PRNGKey(2 * seed + 1), sides)
# sim 10,000 pancakes
pancakes = vmap(sim_pancake, out_axes=1)(jnp.arange(10000))
up = pancakes[0]
down = pancakes[1]
# compute proportion 1/1 (BB) out of all 1/1 and 1/0
num_11_10 = jnp.sum(up == 1)
num_11 = jnp.sum((up == 1) & (down == 1))
num_11 / num_11_10
Out[1]:
Code 15.2¶
In [2]:
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce
# points
ax = az.plot_pair(
d[["MedianAgeMarriage", "Divorce"]].to_dict(orient="list"),
scatter_kwargs=dict(s=15, facecolors="none"),
)
ax.set(ylim=(4, 15), xlabel="Median age marrage", ylabel="Divorce rate")
# standard errors
for i in range(d.shape[0]):
ci = d.Divorce[i] + jnp.array([-1, 1]) * d["Divorce SE"][i]
x = d.MedianAgeMarriage[i]
plt.plot([x, x], ci, "k")
Out[2]:
Code 15.3¶
In [3]:
dlist = dict(
D_obs=d.Divorce.pipe(lambda x: (x - x.mean()) / x.std()).values,
D_sd=d["Divorce SE"].values / d.Divorce.std(),
M=d.Marriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
A=d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
N=d.shape[0],
)
def model(A, M, D_sd, D_obs, N):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bA * A + bM * M
D_true = numpyro.sample("D_true", dist.Normal(mu, sigma))
numpyro.sample("D_obs", dist.Normal(D_true, D_sd), obs=D_obs)
m15_1 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_1.run(random.PRNGKey(0), **dlist)
Out[3]:
Out[3]:
Out[3]:
Out[3]:
Code 15.4¶
In [4]:
m15_1.print_summary(0.89)
Out[4]:
Code 15.5¶
In [5]:
dlist = dict(
D_obs=d.Divorce.pipe(lambda x: (x - x.mean()) / x.std()).values,
D_sd=d["Divorce SE"].values / d.Divorce.std(),
M_obs=d.Marriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
M_sd=d["Marriage SE"].values / d.Marriage.std(),
A=d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
N=d.shape[0],
)
def model(A, M_sd, M_obs, D_sd, D_obs, N):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
M_est = numpyro.sample("M_est", dist.Normal(0, 1).expand([N]))
numpyro.sample("M_obs", dist.Normal(M_est, M_sd), obs=M_obs)
mu = a + bA * A + bM * M_est
D_est = numpyro.sample("D_est", dist.Normal(mu, sigma))
numpyro.sample("D_obs", dist.Normal(D_est, D_sd), obs=D_obs)
m15_2 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_2.run(random.PRNGKey(0), **dlist)
Out[5]:
Out[5]:
Out[5]:
Out[5]:
Code 15.6¶
In [6]:
post = m15_2.get_samples()
D_est = jnp.mean(post["D_est"], 0)
M_est = jnp.mean(post["M_est"], 0)
plt.plot(dlist["M_obs"], dlist["D_obs"], "bo", alpha=0.5)
plt.gca().set(xlabel="marriage rate (std)", ylabel="divorce rate (std)")
plt.plot(M_est, D_est, "ko", mfc="none")
for i in range(d.shape[0]):
plt.plot([dlist["M_obs"][i], M_est[i]], [dlist["D_obs"][i], D_est[i]], "k-", lw=1)
Out[6]:
Code 15.7¶
In [7]:
N = 500
A = dist.Normal().sample(random.PRNGKey(0), (N,))
M = dist.Normal(-A).sample(random.PRNGKey(1))
D = dist.Normal(A).sample(random.PRNGKey(2))
A_obs = dist.Normal(A).sample(random.PRNGKey(3))
Code 15.8¶
In [8]:
N = 100
S = dist.Normal().sample(random.PRNGKey(0), (N,))
H = dist.Binomial(10, expit(S)).sample(random.PRNGKey(1))
Code 15.9¶
In [9]:
D = dist.Bernoulli(0.5).sample(random.PRNGKey(2), (N,)) # dogs completely random
Hm = jnp.where(D == 1, jnp.nan, H)
Code 15.10¶
In [10]:
D = jnp.where(S > 0, 1, 0)
Hm = jnp.where(D == 1, jnp.nan, H)
Code 15.11¶
In [11]:
with numpyro.handlers.seed(rng_seed=501):
N = 1000
X = numpyro.sample("X", dist.Normal().expand([N]))
S = numpyro.sample("S", dist.Normal().expand([N]))
H = numpyro.sample("H", dist.Binomial(10, logits=2 + S - 2 * X))
D = jnp.where(X > 1, 1, 0)
Hm = jnp.where(D == 1, jnp.nan, H)
Code 15.12¶
In [12]:
dat_list = dict(H=H, S=S)
def model(S, H):
a = numpyro.sample("a", dist.Normal(0, 1))
bS = numpyro.sample("bS", dist.Normal(0, 0.5))
logit_p = a + bS * S
numpyro.sample("H", dist.Binomial(10, logits=logit_p), obs=H)
m15_3 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_3.run(random.PRNGKey(0), **dat_list)
m15_3.print_summary()
Out[12]:
Out[12]:
Out[12]:
Out[12]:
Out[12]:
Code 15.13¶
In [13]:
dat_list0 = dict(H=H[D == 0], S=S[D == 0])
def model(S, H):
a = numpyro.sample("a", dist.Normal(0, 1))
bS = numpyro.sample("bS", dist.Normal(0, 0.5))
logit_p = a + bS * S
numpyro.sample("H", dist.Binomial(10, logits=logit_p), obs=H)
m15_4 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_4.run(random.PRNGKey(0), **dat_list0)
m15_4.print_summary()
Out[13]:
Out[13]:
Out[13]:
Out[13]:
Out[13]:
Code 15.14¶
In [14]:
D = jnp.where(jnp.abs(X) < 1, 1, 0)
Code 15.15¶
In [15]:
N = 100
S = dist.Normal().sample(random.PRNGKey(0), (N,))
H = dist.Binomial(10, logits=S).sample(random.PRNGKey(1))
D = jnp.where(H < 5, 1, 0)
Hm = jnp.where(D == 1, jnp.nan, H)
Code 15.16¶
In [16]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d["neocortex.prop"] = d["neocortex.perc"] / 100
d["logmass"] = d.mass.apply(math.log)
Code 15.17¶
In [17]:
dat_list = dict(
K=d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std()).values,
B=d["neocortex.prop"].pipe(lambda x: (x - x.mean()) / x.std()).values,
M=d.logmass.pipe(lambda x: (x - x.mean()) / x.std()).values,
)
def model(B, M, K):
a = numpyro.sample("a", dist.Normal(0, 0.5))
nu = numpyro.sample("nu", dist.Normal(0, 0.5))
bB = numpyro.sample("bB", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma_B = numpyro.sample("sigma_B", dist.Exponential(1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
B_impute = numpyro.sample(
"B_impute", dist.Normal(0, 1).expand([int(np.isnan(B).sum())]).mask(False)
)
B = jnp.asarray(B).at[np.nonzero(np.isnan(B))[0]].set(B_impute)
numpyro.sample("B", dist.Normal(nu, sigma_B), obs=B)
mu = a + bB * B + bM * M
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m15_5 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_5.run(random.PRNGKey(0), **dat_list)
Out[17]:
Out[17]:
Out[17]:
Out[17]:
Code 15.18¶
In [18]:
m15_5.print_summary(0.89)
Out[18]:
Code 15.19¶
In [19]:
obs_idx = d["neocortex.prop"].notnull().values
dat_list_obs = dict(
K=dat_list["K"][obs_idx], B=dat_list["B"][obs_idx], M=dat_list["M"][obs_idx]
)
def model(B, M, K):
a = numpyro.sample("a", dist.Normal(0, 0.5))
nu = numpyro.sample("nu", dist.Normal(0, 0.5))
bB = numpyro.sample("bB", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma_B = numpyro.sample("sigma_B", dist.Exponential(1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
numpyro.sample("B", dist.Normal(nu, sigma_B), obs=B)
mu = a + bB * B + bM * M
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m15_6 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_6.run(random.PRNGKey(0), **dat_list_obs)
m15_6.print_summary(0.89)
Out[19]:
Out[19]:
Out[19]:
Out[19]:
Out[19]:
Code 15.20¶
In [20]:
az.plot_forest(
[az.from_numpyro(m15_5), az.from_numpyro(m15_6)],
model_names=["m15.5", "m15.6"],
var_names=["bB", "bM"],
combined=True,
hdi_prob=0.89,
)
plt.show()
Out[20]:
Code 15.21¶
In [21]:
post = m15_5.get_samples()
B_impute_mu = jnp.mean(post["B_impute"], 0)
B_impute_ci = jnp.percentile(post["B_impute"], q=jnp.array([5.5, 94.5]), axis=0)
# B vs K
plt.plot(dat_list["B"], dat_list["K"], "o")
plt.gca().set(xlabel="neocortex percent (std)", ylabel="kcal mild (std)")
miss_idx = pd.isna(dat_list["B"]).nonzero()[0]
Ki = dat_list["K"][miss_idx]
plt.plot(B_impute_mu, Ki, "ko", mfc="none")
for i in range(12):
plt.plot(B_impute_ci[:, i], jnp.repeat(Ki[i], 2), "k", lw=1)
plt.show()
# M vs B
plt.plot(dat_list["M"], dat_list["B"], "o")
plt.gca().set(xlabel="log body mass (std)", ylabel="neocortex percent (std)")
Mi = dat_list["M"][miss_idx]
plt.plot(Mi, B_impute_mu, "ko", mfc="none")
for i in range(12):
plt.plot(jnp.repeat(Mi[i], 2), B_impute_ci[:, i], "k", lw=1)
Out[21]:
Out[21]:
Code 15.22¶
In [22]:
def model(B, M, K):
# priors
a = numpyro.sample("a", dist.Normal(0, 0.5))
muB = numpyro.sample("muB", dist.Normal(0, 0.5))
muM = numpyro.sample("muM", dist.Normal(0, 0.5))
bB = numpyro.sample("bB", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
Rho_BM = numpyro.sample("Rho_BM", dist.LKJ(2, 2))
Sigma_BM = numpyro.sample("Sigma_BM", dist.Exponential(1).expand([2]))
# define B_merge as mix of observed and imputed values
B_impute = numpyro.sample(
"B_impute", dist.Normal(0, 1).expand([int(np.isnan(B).sum())]).mask(False)
)
B_merge = jnp.asarray(B).at[np.nonzero(np.isnan(B))[0]].set(B_impute)
# M and B correlation
MB = jnp.stack([M, B_merge], axis=1)
cov = jnp.outer(Sigma_BM, Sigma_BM) * Rho_BM
numpyro.sample("MB", dist.MultivariateNormal(jnp.stack([muM, muB]), cov), obs=MB)
# K as function of B and M
mu = a + bB * B_merge + bM * M
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m15_7 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_7.run(random.PRNGKey(0), **dat_list)
post = m15_7.get_samples(group_by_chain=True)
print_summary({k: v for k, v in post.items() if k in ["bM", "bB", "Rho_BM"]})
Out[22]:
Out[22]:
Out[22]:
Out[22]:
Out[22]:
Code 15.23¶
In [23]:
B_missidx = pd.isna(dat_list["B"]).nonzero()[0]
Code 15.24¶
In [24]:
Moralizing_gods = pd.read_csv("../data/Moralizing_gods.csv", sep=";")
Moralizing_gods
Out[24]:
Code 15.25¶
In [25]:
Moralizing_gods.moralizing_gods.value_counts(dropna=False)
Out[25]:
Code 15.26¶
In [26]:
symbol = Moralizing_gods.moralizing_gods.apply(lambda x: "." if x == 1 else "o")
symbol[Moralizing_gods.moralizing_gods.isna()] = "x"
color = Moralizing_gods.moralizing_gods.apply(lambda x: "k" if pd.isna(x) else "b")
for pch in ["o", ".", "x"]:
plt.scatter(
Moralizing_gods.year[symbol == pch],
Moralizing_gods.population[symbol == pch],
marker=pch,
color=color[symbol == pch],
facecolor="none" if pch == "o" else None,
lw=1.5,
alpha=0.7,
)
plt.gca().set(xlabel="Time (year)", ylabel="Population size")
plt.show()
Out[26]:
Code 15.27¶
In [27]:
dmg = Moralizing_gods
dmg.astype(str).groupby(["moralizing_gods", "writing"]).size().unstack(fill_value=0)
Out[27]:
Code 15.28¶
In [28]:
dmg = Moralizing_gods
haw = dmg.polity == "Big Island Hawaii"
dmg.loc[haw, ["year", "population", "writing", "moralizing_gods"]].T.round(3)
Out[28]:
Code 15.29¶
In [29]:
with numpyro.handlers.seed(rng_seed=9):
N_houses = 100
alpha = 5
beta = -3
k = 0.5
r = 0.2
cat = numpyro.sample("cat", dist.Bernoulli(k).expand([N_houses]))
notes = numpyro.sample("notes", dist.Poisson(alpha + beta * cat))
R_C = numpyro.sample("R_C", dist.Bernoulli(r).expand([N_houses]))
cat_obs = jnp.where(R_C == 1, -9, cat)
Code 15.30¶
In [30]:
dat = dict(notes=notes, cat=np.asarray(cat_obs), RC=np.asarray(R_C), N=N_houses - 1)
def model(N, RC, cat, notes):
# priors
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(0, 0.5))
# sneaking cat model
k = numpyro.sample("k", dist.Beta(2, 2))
numpyro.sample("cat|RC==0", dist.Bernoulli(k), obs=cat[RC == 0])
# singing bird model
# cat NA:
custom_logprob = jnp.logaddexp(
jnp.log(k) + dist.Poisson(jnp.exp(a + b)).log_prob(notes[RC == 1]),
jnp.log(1 - k) + dist.Poisson(jnp.exp(a)).log_prob(notes[RC == 1]),
)
numpyro.factor("notes|RC==1", custom_logprob)
# cat known present/absent:
lambda_ = jnp.exp(a + b * cat[RC == 0])
numpyro.sample("notes|RC==0", dist.Poisson(lambda_), obs=notes[RC == 0])
m15_8 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_8.run(random.PRNGKey(0), **dat)
Out[30]:
Out[30]:
Out[30]:
Out[30]:
Code 15.31¶
In [31]:
def model(N, RC, cat, notes, link=False):
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(0, 0.5))
# sneaking cat model
k = numpyro.sample("k", dist.Beta(2, 2))
numpyro.sample("cat|RC==0", dist.Bernoulli(k), obs=cat[RC == 0])
# singing bird model
custom_logprob = jnp.logaddexp(
jnp.log(k) + dist.Poisson(jnp.exp(a + b)).log_prob(notes[RC == 1]),
jnp.log(1 - k) + dist.Poisson(jnp.exp(a)).log_prob(notes[RC == 1]),
)
numpyro.factor("notes|RC==1", custom_logprob)
lambda_ = jnp.exp(a + b * cat[RC == 0])
numpyro.sample("notes|RC==0", dist.Poisson(lambda_), obs=notes[RC == 0])
if link:
lpC0 = numpyro.deterministic(
"lpC0", jnp.log(1 - k) + dist.Poisson(jnp.exp(a)).log_prob(notes)
)
lpC1 = numpyro.deterministic(
"lpC1", jnp.log(k) + dist.Poisson(jnp.exp(a + b)).log_prob(notes)
)
numpyro.deterministic("PrC1", jnp.exp(lpC1) / (jnp.exp(lpC1) + jnp.exp(lpC0)))
m15_9 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_9.run(random.PRNGKey(0), **dat)
Out[31]:
Out[31]:
Out[31]:
Out[31]:
Code 15.32¶
In [32]:
with numpyro.handlers.seed(rng_seed=100):
x = numpyro.sample("x", dist.Normal().expand([10]))
y = numpyro.sample("y", dist.Normal(x))
x = jnp.concatenate([x, jnp.array([jnp.nan])])
y = jnp.concatenate([y, jnp.array([100])])
d = dict(x=x, y=y)
Code 15.33¶
In [33]:
Primates301 = pd.read_csv("../data/Primates301.csv", sep=";")
d = Primates301
cc = d.dropna(subset=["brain", "body"]).index
B = d.brain[cc]
M = d.body[cc]
B = B.values / max(B)
M = M.values / max(M)
Code 15.34¶
In [34]:
Bse = B * 0.1
Mse = M * 0.1
Code 15.35¶
In [35]:
dat_list = dict(B=B, M=M)
def model(M, B):
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + b * jnp.log(M)
numpyro.sample("B", dist.LogNormal(mu, sigma), obs=B)
m15H4 = MCMC(NUTS(model), num_warmup=500, num_samples=500)
m15H4.run(random.PRNGKey(0), **dat_list)
Out[35]:
Code 15.36¶
In [36]:
start = dict(M_true=dat_list["M"], B_true=dat_list["B"])
init_strategy = init_to_value(values=start)
Code 15.37¶
In [37]:
Primates301 = pd.read_csv("../data/Primates301.csv", sep=";")
d = Primates301
d.isna().sum()
Out[37]:
Code 15.38¶
In [38]:
cc = d.dropna(subset=["body"]).index
M = d.body[cc]
M = M.values / max(M)
B = d.brain[cc]
B = B.values / B.max(skipna=True)
Code 15.39¶
In [39]:
start = dict(B_impute=jnp.repeat(0.5, 56))
init_strategy = init_to_value(values=start)
Comments
Comments powered by Disqus