# Chapter 9. Markov Chain Monte Carlo

In [0]:
```import inspect
import math
import os
import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd

import jax.numpy as jnp
from jax import lax, ops, random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import ELBO, MCMC, NUTS, SVI, init_to_value
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
numpyro.set_host_device_count(4)
```

#### Code 9.1¶

In [1]:
```num_weeks = int(1e5)
positions = jnp.repeat(0, num_weeks)
current = 10

def body_fn(i, val):
positions, current = val
# record current position
positions = ops.index_update(positions, i, current)

# flip coin to generate proposal
bern = dist.Bernoulli(0.5).sample(random.fold_in(random.PRNGKey(0), i))
proposal = current + (bern * 2 - 1)
# now make sure he loops around the archipelago
proposal = jnp.where(proposal < 1, 10, proposal)
proposal = jnp.where(proposal > 10, 1, proposal)

# move?
prob_move = proposal / current
unif = dist.Uniform().sample(random.fold_in(random.PRNGKey(1), i))
current = jnp.where(unif < prob_move, proposal, current)
return positions, current

positions, current = lax.fori_loop(0, num_weeks, body_fn, (positions, current))
```

#### Code 9.2¶

In [2]:
```plt.plot(range(1, 101), positions[:100], "o", mfc="none")
plt.show()
```
Out[2]:

#### Code 9.3¶

In [3]:
```plt.hist(positions, bins=range(1, 12), rwidth=0.1, align="left")
plt.show()
```
Out[3]:

#### Code 9.4¶

In [4]:
```D = 10
T = int(1e3)
Y = dist.MultivariateNormal(jnp.repeat(0, D), jnp.identity(D)).sample(
random.PRNGKey(0), (T,)
)
rad_dist = lambda Y: jnp.sqrt(jnp.sum(Y ** 2))
Rd = lax.map(lambda i: rad_dist(Y[i]), jnp.arange(T))
az.plot_kde(Rd, bw=0.18)
plt.show()
```
Out[4]:

#### Code 9.5¶

In [5]:
```# U needs to return neg-log-probability
def U(q, a=0, b=1, k=0, d=1):
muy = q[0]
mux = q[1]
logprob_y = jnp.sum(dist.Normal(muy, 1).log_prob(y))
logprob_x = jnp.sum(dist.Normal(mux, 1).log_prob(x))
logprob_muy = dist.Normal(a, b).log_prob(muy)
logprob_mux = dist.Normal(k, d).log_prob(mux)
U = logprob_y + logprob_x + logprob_muy + logprob_mux
return -U
```

#### Code 9.6¶

In [6]:
```# gradient function
# need vector of partial derivatives of U with respect to vector q
def U_gradient(q, a=0, b=1, k=0, d=1):
muy = q[0]
mux = q[1]
G1 = jnp.sum(y - muy) + (a - muy) / b ** 2  # dU/dmuy
G2 = jnp.sum(x - mux) + (k - mux) / b ** 2  # dU/dmux
return jnp.stack([-G1, -G2])  # negative bc energy is neg-log-prob

# test data
with numpyro.handlers.seed(rng_seed=7):
y = numpyro.sample("y", dist.Normal().expand([50]))
x = numpyro.sample("x", dist.Normal().expand([50]))
x = (x - jnp.mean(x)) / jnp.std(x)
y = (y - jnp.mean(y)) / jnp.std(y)
```

#### Code 9.7¶

In [7]:
```def HMC2(U, grad_U, epsilon, L, current_q, rng):
q = current_q
# random flick - p is momentum
p = dist.Normal(0, 1).sample(random.fold_in(rng, 0), (q.shape[0],))
current_p = p
# Make a half step for momentum at the beginning
p = p - epsilon * grad_U(q) / 2
# initialize bookkeeping - saves trajectory
qtraj = jnp.full((L + 1, q.shape[0]), jnp.nan)
ptraj = qtraj
qtraj = ops.index_update(qtraj, 0, current_q)
ptraj = ops.index_update(ptraj, 0, p)

# Alternate full steps for position and momentum
for i in range(L):
q = q + epsilon * p  # Full step for the position
# Make a full step for the momentum, except at end of trajectory
if i != (L - 1):
p = p - epsilon * grad_U(q)
ptraj = ops.index_update(ptraj, i + 1, p)
qtraj = ops.index_update(qtraj, i + 1, q)

# Make a half step for momentum at the end
p = p - epsilon * grad_U(q) / 2
ptraj = ops.index_update(ptraj, L, p)
# Negate momentum at end of trajectory to make the proposal symmetric
p = -p
# Evaluate potential and kinetic energies at start and end of trajectory
current_U = U(current_q)
current_K = jnp.sum(current_p ** 2) / 2
proposed_U = U(q)
proposed_K = jnp.sum(p ** 2) / 2
# Accept or reject the state at end of trajectory, returning either
# the position at the end of the trajectory or the initial position
accept = 0
runif = dist.Uniform().sample(random.fold_in(rng, 1))
if runif < jnp.exp(current_U - proposed_U + current_K - proposed_K):
new_q = q  # accept
accept = 1
else:
new_q = current_q  # reject
return {
"q": new_q,
"traj": qtraj,
"ptraj": ptraj,
"accept": accept,
"dH": proposed_U + proposed_K - (current_U + current_K),
}

Q = {}
Q["q"] = jnp.array([-0.1, 0.2])
pr = 0.31
plt.subplot(ylabel="muy", xlabel="mux", xlim=(-pr, pr), ylim=(-pr, pr))
step = 0.03
L = 11  # 0.03/28 for U-turns --- 11 for working example
n_samples = 4
path_col = (0, 0, 0, 0.5)
for r in 0.075 * jnp.arange(2, 6):
plt.scatter(Q["q"][0], Q["q"][1], c="k", marker="x", zorder=4)
for i in range(n_samples):
Q = HMC2(U, U_gradient, step, L, Q["q"], random.fold_in(random.PRNGKey(0), i))
if n_samples < 10:
for j in range(L):
K0 = jnp.sum(Q["ptraj"][j] ** 2) / 2
plt.plot(
Q["traj"][j : j + 2, 0],
Q["traj"][j : j + 2, 1],
c=path_col,
lw=1 + 2 * K0,
)
plt.scatter(Q["traj"][:, 0], Q["traj"][:, 1], c="white", s=5, zorder=3)
# for fancy arrows
dx = Q["traj"][L, 0] - Q["traj"][L - 1, 0]
dy = Q["traj"][L, 1] - Q["traj"][L - 1, 1]
d = jnp.sqrt(dx ** 2 + dy ** 2)
plt.annotate(
"",
(Q["traj"][L - 1, 0], Q["traj"][L - 1, 1]),
(Q["traj"][L, 0], Q["traj"][L, 1]),
arrowprops={"arrowstyle": "<-"},
)
plt.annotate(
str(i + 1),
(Q["traj"][L, 0], Q["traj"][L, 1]),
xytext=(3, 3),
textcoords="offset points",
)
plt.scatter(
Q["traj"][L + 1, 0],
Q["traj"][L + 1, 1],
c=("red" if jnp.abs(Q["dH"]) > 0.1 else "black"),
zorder=4,
)
```
Out[7]:

#### Code 9.8¶

In [8]:
```source_HMC2 = inspect.getsourcelines(HMC2)
print("".join("".join(source_HMC2[0]).split("\n\n")[0]))
```
Out[8]:
```def HMC2(U, grad_U, epsilon, L, current_q, rng):
q = current_q
# random flick - p is momentum
p = dist.Normal(0, 1).sample(random.fold_in(rng, 0), (q.shape[0],))
current_p = p
# Make a half step for momentum at the beginning
p = p - epsilon * grad_U(q) / 2
# initialize bookkeeping - saves trajectory
qtraj = jnp.full((L + 1, q.shape[0]), jnp.nan)
ptraj = qtraj
qtraj = ops.index_update(qtraj, 0, current_q)
ptraj = ops.index_update(ptraj, 0, p)
```

#### Code 9.9¶

In [9]:
```print("".join("".join(source_HMC2[0]).split("\n\n")[1]))
```
Out[9]:
```    # Alternate full steps for position and momentum
for i in range(L):
q = q + epsilon * p  # Full step for the position
# Make a full step for the momentum, except at end of trajectory
if i != (L - 1):
p = p - epsilon * grad_U(q)
ptraj = ops.index_update(ptraj, i + 1, p)
qtraj = ops.index_update(qtraj, i + 1, q)
```

#### Code 9.10¶

In [10]:
```print("".join("".join(source_HMC2[0]).split("\n\n")[2]))
```
Out[10]:
```    # Make a half step for momentum at the end
p = p - epsilon * grad_U(q) / 2
ptraj = ops.index_update(ptraj, L, p)
# Negate momentum at end of trajectory to make the proposal symmetric
p = -p
# Evaluate potential and kinetic energies at start and end of trajectory
current_U = U(current_q)
current_K = jnp.sum(current_p ** 2) / 2
proposed_U = U(q)
proposed_K = jnp.sum(p ** 2) / 2
# Accept or reject the state at end of trajectory, returning either
# the position at the end of the trajectory or the initial position
accept = 0
runif = dist.Uniform().sample(random.fold_in(rng, 1))
if runif < jnp.exp(current_U - proposed_U + current_K - proposed_K):
new_q = q  # accept
accept = 1
else:
new_q = current_q  # reject
return {
"q": new_q,
"traj": qtraj,
"ptraj": ptraj,
"accept": accept,
"dH": proposed_U + proposed_K - (current_U + current_K),
}

```

#### Code 9.11¶

In [11]:
```rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)
dd = d[d["rgdppc_2000"].notnull()].copy()
dd["log_gdp_std"] = dd.log_gdp / dd.log_gdp.mean()
dd["rugged_std"] = dd.rugged / dd.rugged.max()
dd["cid"] = jnp.where(dd.cont_africa.values == 1, 0, 1)
```

#### Code 9.12¶

In [12]:
```def model(cid, rugged_std, log_gdp_std=None):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a[cid] + b[cid] * (rugged_std - 0.215))
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)

m8_3 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_3,
ELBO(),
cid=dd.cid.values,
rugged_std=dd.rugged_std.values,
log_gdp_std=dd.log_gdp_std.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p8_3 = svi.get_params(state)
post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, (1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
```
Out[12]:
```                mean       std    median      5.5%     94.5%     n_eff     r_hat
a[0]      0.89      0.02      0.89      0.86      0.91   1009.20      1.00
a[1]      1.05      0.01      1.05      1.04      1.07    755.33      1.00
b[0]      0.13      0.07      0.13      0.01      0.24   1045.06      1.00
b[1]     -0.15      0.06     -0.14     -0.23     -0.05   1003.36      1.00
sigma      0.11      0.01      0.11      0.10      0.12    810.01      1.00

```

#### Code 9.13¶

In [13]:
```dat_slim = {
"log_gdp_std": dd.log_gdp_std.values,
"rugged_std": dd.rugged_std.values,
"cid": dd.cid.values,
}
{k: v[:5] for k, v in dat_slim.items()}
```
Out[13]:
```{'log_gdp_std': array([0.87971187, 0.9647547 , 1.1662705 , 1.10448536, 0.91490375]),
'rugged_std': array([0.13834247, 0.55256369, 0.12399226, 0.12495969, 0.43340858]),
'cid': array([0, 1, 1, 1, 1], dtype=int32)}```

#### Code 9.14¶

In [14]:
```def model(cid, rugged_std, log_gdp_std):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[cid] + b[cid] * (rugged_std - 0.215)
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)

m9_1 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=1)
m9_1.run(random.PRNGKey(0), **dat_slim)
```
Out[14]:
```sample: 100%|██████████| 1000/1000 [00:06<00:00, 161.87it/s, 3 steps of size 8.42e-01. acc. prob=0.87]
```

#### Code 9.15¶

In [15]:
```m9_1.print_summary(0.89)
```
Out[15]:
```                mean       std    median      5.5%     94.5%     n_eff     r_hat
a[0]      0.89      0.02      0.89      0.86      0.91    427.40      1.00
a[1]      1.05      0.01      1.05      1.04      1.07    515.67      1.00
b[0]      0.13      0.07      0.13      0.02      0.24    654.75      1.00
b[1]     -0.14      0.05     -0.14     -0.23     -0.06    429.49      1.00
sigma      0.11      0.01      0.11      0.10      0.12    534.83      1.00

Number of divergences: 0
```

#### Code 9.16¶

In [16]:
```def model(cid, rugged_std, log_gdp_std):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[cid] + b[cid] * (rugged_std - 0.215)
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)

m9_1 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m9_1.run(random.PRNGKey(0), **dat_slim)
```

#### Code 9.17¶

In [17]:
```print("".join(inspect.getsourcelines(m9_1.sampler.model)[0]))
```
Out[17]:
```def model(cid, rugged_std, log_gdp_std):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[cid] + b[cid] * (rugged_std - 0.215)
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)

```

#### Code 9.18¶

In [18]:
```m9_1.print_summary(0.89)
```
Out[18]:
```                mean       std    median      5.5%     94.5%     n_eff     r_hat
a[0]      0.89      0.02      0.89      0.86      0.91   2213.25      1.00
a[1]      1.05      0.01      1.05      1.03      1.07   2638.31      1.00
b[0]      0.13      0.08      0.13      0.01      0.25   2350.25      1.00
b[1]     -0.14      0.06     -0.14     -0.23     -0.06   2399.93      1.00
sigma      0.11      0.01      0.11      0.10      0.12   2465.49      1.00

Number of divergences: 0
```

#### Code 9.19¶

In [19]:
```az.plot_pair(az.from_numpyro(m9_1))
plt.show()
```
Out[19]: