Chapter 9. Markov Chain Monte Carlo
In [ ]:
!pip install -q numpyro arviz
In [0]:
import inspect
import math
import os
import warnings
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax.numpy as jnp
from jax import lax, random
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_value
from numpyro.infer.autoguide import AutoLaplaceApproximation
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
Code 9.1¶
In [1]:
num_weeks = int(1e5)
positions = jnp.repeat(0, num_weeks)
current = 10
def body_fn(i, val):
positions, current = val
# record current position
positions = positions.at[i].set(current)
# flip coin to generate proposal
bern = dist.Bernoulli(0.5).sample(random.fold_in(random.PRNGKey(0), i))
proposal = current + (bern * 2 - 1)
# now make sure he loops around the archipelago
proposal = jnp.where(proposal < 1, 10, proposal)
proposal = jnp.where(proposal > 10, 1, proposal)
# move?
prob_move = proposal / current
unif = dist.Uniform().sample(random.fold_in(random.PRNGKey(1), i))
current = jnp.where(unif < prob_move, proposal, current)
return positions, current
positions, current = lax.fori_loop(0, num_weeks, body_fn, (positions, current))
Code 9.2¶
In [2]:
plt.plot(range(1, 101), positions[:100], "o", mfc="none")
plt.show()
Out[2]:
Code 9.3¶
In [3]:
plt.hist(np.asarray(positions), bins=range(1, 12), rwidth=0.1, align="left")
plt.show()
Out[3]:
Code 9.4¶
In [4]:
D = 10
T = int(1e3)
Y = dist.MultivariateNormal(jnp.repeat(0, D), jnp.identity(D)).sample(
random.PRNGKey(0), (T,)
)
rad_dist = lambda Y: jnp.sqrt(jnp.sum(Y**2))
Rd = lax.map(lambda i: rad_dist(Y[i]), jnp.arange(T))
az.plot_kde(Rd, bw=0.18)
plt.show()
Out[4]:
Code 9.5¶
In [5]:
# U needs to return neg-log-probability
def U(q, a=0, b=1, k=0, d=1):
muy = q[0]
mux = q[1]
logprob_y = jnp.sum(dist.Normal(muy, 1).log_prob(y))
logprob_x = jnp.sum(dist.Normal(mux, 1).log_prob(x))
logprob_muy = dist.Normal(a, b).log_prob(muy)
logprob_mux = dist.Normal(k, d).log_prob(mux)
U = logprob_y + logprob_x + logprob_muy + logprob_mux
return -U
Code 9.6¶
In [6]:
# gradient function
# need vector of partial derivatives of U with respect to vector q
def U_gradient(q, a=0, b=1, k=0, d=1):
muy = q[0]
mux = q[1]
G1 = jnp.sum(y - muy) + (a - muy) / b**2 # dU/dmuy
G2 = jnp.sum(x - mux) + (k - mux) / b**2 # dU/dmux
return jnp.stack([-G1, -G2]) # negative bc energy is neg-log-prob
# test data
with numpyro.handlers.seed(rng_seed=7):
y = numpyro.sample("y", dist.Normal().expand([50]))
x = numpyro.sample("x", dist.Normal().expand([50]))
x = (x - jnp.mean(x)) / jnp.std(x)
y = (y - jnp.mean(y)) / jnp.std(y)
Code 9.7¶
In [7]:
def HMC2(U, grad_U, epsilon, L, current_q, rng):
q = current_q
# random flick - p is momentum
p = dist.Normal(0, 1).sample(random.fold_in(rng, 0), (q.shape[0],))
current_p = p
# Make a half step for momentum at the beginning
p = p - epsilon * grad_U(q) / 2
# initialize bookkeeping - saves trajectory
qtraj = jnp.full((L + 1, q.shape[0]), jnp.nan)
ptraj = qtraj
qtraj = qtraj.at[0].set(current_q)
ptraj = ptraj.at[0].set(p)
# Alternate full steps for position and momentum
for i in range(L):
q = q + epsilon * p # Full step for the position
# Make a full step for the momentum, except at end of trajectory
if i != (L - 1):
p = p - epsilon * grad_U(q)
ptraj = ptraj.at[i + 1].set(p)
qtraj = qtraj.at[i + 1].set(q)
# Make a half step for momentum at the end
p = p - epsilon * grad_U(q) / 2
ptraj = ptraj.at[L].set(p)
# Negate momentum at end of trajectory to make the proposal symmetric
p = -p
# Evaluate potential and kinetic energies at start and end of trajectory
current_U = U(current_q)
current_K = jnp.sum(current_p**2) / 2
proposed_U = U(q)
proposed_K = jnp.sum(p**2) / 2
# Accept or reject the state at end of trajectory, returning either
# the position at the end of the trajectory or the initial position
accept = 0
runif = dist.Uniform().sample(random.fold_in(rng, 1))
if runif < jnp.exp(current_U - proposed_U + current_K - proposed_K):
new_q = q # accept
accept = 1
else:
new_q = current_q # reject
return {
"q": new_q,
"traj": qtraj,
"ptraj": ptraj,
"accept": accept,
"dH": proposed_U + proposed_K - (current_U + current_K),
}
Q = {}
Q["q"] = jnp.array([-0.1, 0.2])
pr = 0.31
plt.subplot(ylabel="muy", xlabel="mux", xlim=(-pr, pr), ylim=(-pr, pr))
step = 0.03
L = 11 # 0.03/28 for U-turns --- 11 for working example
n_samples = 4
path_col = (0, 0, 0, 0.5)
for r in 0.075 * jnp.arange(2, 6):
plt.gca().add_artist(plt.Circle((0, 0), r, alpha=0.2, fill=False))
plt.scatter(Q["q"][0], Q["q"][1], c="k", marker="x", zorder=4)
for i in range(n_samples):
Q = HMC2(U, U_gradient, step, L, Q["q"], random.fold_in(random.PRNGKey(0), i))
if n_samples < 10:
for j in range(L):
K0 = jnp.sum(Q["ptraj"][j] ** 2) / 2
plt.plot(
Q["traj"][j : j + 2, 0],
Q["traj"][j : j + 2, 1],
c=path_col,
lw=1 + 2 * K0,
)
plt.scatter(Q["traj"][:, 0], Q["traj"][:, 1], c="white", s=5, zorder=3)
# for fancy arrows
dx = Q["traj"][L, 0] - Q["traj"][L - 1, 0]
dy = Q["traj"][L, 1] - Q["traj"][L - 1, 1]
d = jnp.sqrt(dx**2 + dy**2)
plt.annotate(
"",
(Q["traj"][L - 1, 0], Q["traj"][L - 1, 1]),
(Q["traj"][L, 0], Q["traj"][L, 1]),
arrowprops={"arrowstyle": "<-"},
)
plt.annotate(
str(i + 1),
(Q["traj"][L, 0], Q["traj"][L, 1]),
xytext=(3, 3),
textcoords="offset points",
)
plt.scatter(
Q["traj"][L + 1, 0],
Q["traj"][L + 1, 1],
c=("red" if jnp.abs(Q["dH"]) > 0.1 else "black"),
zorder=4,
)
Out[7]:
Code 9.8¶
In [8]:
source_HMC2 = inspect.getsourcelines(HMC2)
print("".join("".join(source_HMC2[0]).split("\n\n")[0]))
Out[8]:
Code 9.9¶
In [9]:
print("".join("".join(source_HMC2[0]).split("\n\n")[1]))
Out[9]:
Code 9.10¶
In [10]:
print("".join("".join(source_HMC2[0]).split("\n\n")[2]))
Out[10]:
Code 9.11¶
In [11]:
rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)
dd = d[d["rgdppc_2000"].notnull()].copy()
dd["log_gdp_std"] = dd.log_gdp / dd.log_gdp.mean()
dd["rugged_std"] = dd.rugged / dd.rugged.max()
dd["cid"] = jnp.where(dd.cont_africa.values == 1, 0, 1)
Code 9.12¶
In [12]:
def model(cid, rugged_std, log_gdp_std=None):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a[cid] + b[cid] * (rugged_std - 0.215))
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m8_3 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_3,
optim.Adam(0.1),
Trace_ELBO(),
cid=dd.cid.values,
rugged_std=dd.rugged_std.values,
log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_3 = svi_result.params
post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, sample_shape=(1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
Out[12]:
Out[12]:
Code 9.13¶
In [13]:
dat_slim = {
"log_gdp_std": dd.log_gdp_std.values,
"rugged_std": dd.rugged_std.values,
"cid": dd.cid.values,
}
{k: v[:5] for k, v in dat_slim.items()}
Out[13]:
Code 9.14¶
In [14]:
def model(cid, rugged_std, log_gdp_std):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[cid] + b[cid] * (rugged_std - 0.215)
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m9_1 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=1)
m9_1.run(random.PRNGKey(0), **dat_slim)
Out[14]:
Code 9.15¶
In [15]:
m9_1.print_summary(0.89)
Out[15]:
Code 9.16¶
In [16]:
def model(cid, rugged_std, log_gdp_std):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[cid] + b[cid] * (rugged_std - 0.215)
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m9_1 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m9_1.run(random.PRNGKey(0), **dat_slim)
Out[16]:
Out[16]:
Out[16]:
Out[16]:
Code 9.17¶
In [17]:
print("".join(inspect.getsourcelines(m9_1.sampler.model)[0]))
Out[17]:
Code 9.18¶
In [18]:
m9_1.print_summary(0.89)
Out[18]:
Code 9.19¶
In [19]:
az.plot_pair(az.from_numpyro(m9_1))
plt.show()
Out[19]:
Code 9.20¶
In [20]:
az.plot_trace(az.from_numpyro(m9_1))
plt.show()
Out[20]:
Code 9.21¶
In [21]:
az.plot_rank(az.from_numpyro(m9_1))
plt.show()
Out[21]:
Code 9.22¶
In [22]:
y = jnp.array([-1, 1])
def model(y):
alpha = numpyro.sample("alpha", dist.Normal(0, 1000))
sigma = numpyro.sample("sigma", dist.Exponential(0.0001))
mu = alpha
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
m9_2 = MCMC(
NUTS(model, target_accept_prob=0.95), num_warmup=500, num_samples=500, num_chains=3
)
m9_2.run(random.PRNGKey(11), y=y)
Out[22]:
Out[22]:
Out[22]:
Code 9.23¶
In [23]:
m9_2.print_summary(0.89)
Out[23]:
Code 9.24¶
In [24]:
y = jnp.array([-1, 1])
def model(y):
alpha = numpyro.sample("alpha", dist.Normal(1, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = alpha
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
m9_3 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=3)
m9_3.run(random.PRNGKey(11), y=y)
m9_3.print_summary(0.89)
Out[24]:
Out[24]:
Out[24]:
Out[24]:
Code 9.25¶
In [25]:
y = dist.Normal(loc=0, scale=1).sample(random.PRNGKey(41), (100,))
Code 9.26¶
In [26]:
def model(y):
a1 = numpyro.sample("a1", dist.Normal(0, 1000))
a2 = numpyro.sample("a2", dist.Normal(0, 1000))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a1 + a2
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
m9_4 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=3)
m9_4.run(random.PRNGKey(384), extra_fields=["num_steps"], y=y)
m9_4.print_summary()
print(
"There were {} transitions that exceeded the maximum treedepth.".format(
(m9_4.get_extra_fields()["num_steps"] + 1 == 2**10).sum()
)
)
Out[26]:
Out[26]:
Out[26]:
Out[26]:
Code 9.27¶
In [27]:
def model(y):
a1 = numpyro.sample("a1", dist.Normal(0, 10))
a2 = numpyro.sample("a2", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a1 + a2
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
m9_5 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=3)
m9_5.run(random.PRNGKey(0), y=y)
m9_5.print_summary(0.89)
Out[27]:
Out[27]:
Out[27]:
Out[27]:
Code 9.28¶
In [28]:
def model(y):
numpyro.sample("a", dist.Normal(0, 1))
numpyro.sample("b", dist.Cauchy(0, 1))
kernel = NUTS(model, init_strategy=init_to_value(values={"a": 0.0, "b": 0.0}))
mp = MCMC(kernel, num_warmup=100, num_samples=9900)
mp.run(random.PRNGKey(0), y=1)
Out[28]:
Code 9.29¶
In [29]:
N = 100 # number of individuals
# sim total height of each
height = dist.Normal(10, 2).sample(random.PRNGKey(0), (N,))
# leg as proportion of height
leg_prop = dist.Uniform(0.4, 0.5).sample(random.PRNGKey(1), (N,))
# sim left leg as proportion + error
leg_left = leg_prop * height + dist.Normal(0, 0.02).sample(random.PRNGKey(2), (N,))
# sim right leg as proportion + error
leg_right = leg_prop * height + dist.Normal(0, 0.02).sample(random.PRNGKey(3), (N,))
# combine into data frame
d = pd.DataFrame({"height": height, "leg_left": leg_left, "leg_right": leg_right})
def model(leg_left, leg_right, height):
a = numpyro.sample("a", dist.Normal(10, 100))
bl = numpyro.sample("bl", dist.Normal(2, 10))
br = numpyro.sample("br", dist.Normal(2, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bl * leg_left + br * leg_right
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
kernel = NUTS(
model,
init_strategy=init_to_value(values={"a": 10.0, "bl": 0.0, "br": 0.1, "sigma": 1.0}),
)
m5_8s = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=4)
m5_8s.run(random.PRNGKey(0), **dict(zip(d.columns, d.T.values)))
Out[29]:
Out[29]:
Out[29]:
Out[29]:
Code 9.30¶
In [30]:
def model(leg_left, leg_right, height):
a = numpyro.sample("a", dist.Normal(10, 100))
bl = numpyro.sample("bl", dist.Normal(2, 10))
br = numpyro.sample("br", dist.TruncatedNormal(2, 10, low=0))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bl * leg_left + br * leg_right
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
kernel = NUTS(
model,
init_strategy=init_to_value(values={"a": 10.0, "bl": 0.0, "br": 0.1, "sigma": 1.0}),
)
m5_8s2 = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=4)
m5_8s2.run(random.PRNGKey(0), **dict(zip(d.columns, d.T.values)))
Out[30]:
Out[30]:
Out[30]:
Out[30]:
Comments
Comments powered by Disqus