Chapter 8. Markov Chain Monte Carlo
In [0]:
import math
import pandas as pd
import seaborn as sns
import torch
from torch.distributions import transforms
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS
from rethinking import (LM, MAP, coef, extract_samples, glimmer,
link, precis, replicate, sim, vcov)
Code 8.1¶
In [1]:
num_weeks = int(1e5)
positions = torch.tensor(0).repeat(num_weeks)
current = 9
for i in range(num_weeks):
# record current position
positions[i] = current
# flip coin to generate proposal
sample = torch.multinomial(torch.ones(2), num_samples=1)
proposal = current + torch.tensor([-1, 1])[sample].item()
# now make sure he loops around the archipelago
if proposal < 0:
proposal = 9
if proposal > 9:
proposal = 0
# move?
prob_move = (proposal + 1) / (current + 1)
current = proposal if torch.rand(1) < prob_move else current
Code 8.2¶
In [2]:
rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)
dd = d[d["rgdppc_2000"].notnull()]
Code 8.3¶
In [3]:
def model(rugged, cont_africa, log_gdp):
a = pyro.sample("a", dist.Normal(0, 100))
bR = pyro.sample("bR", dist.Normal(0, 10))
bA = pyro.sample("bA", dist.Normal(0, 10))
bAR = pyro.sample("bAR", dist.Normal(0, 10))
mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
sigma = pyro.sample("sigma", dist.Uniform(0, 10))
with pyro.plate("plate"):
pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)
dd.index = range(dd.shape[0])
dd_rugged = torch.tensor(dd["rugged"], dtype=torch.float)
dd_cont_africa = torch.tensor(dd["cont_africa"], dtype=torch.float)
dd_log_gdp = torch.tensor(dd["log_gdp"], dtype=torch.float)
m8_1 = MAP(model).run(dd_rugged, dd_cont_africa, dd_log_gdp)
precis(m8_1)
Out[3]:
Code 8.4¶
In [4]:
dd_trim = dd[["log_gdp", "rugged", "cont_africa"]]
dd_trim.info()
dd_trim.head()
Out[4]:
Out[4]:
Code 8.5¶
In [5]:
def model(rugged, cont_africa, log_gdp):
a = pyro.sample("a", dist.Normal(0, 100))
bR = pyro.sample("bR", dist.Normal(0, 10))
bA = pyro.sample("bA", dist.Normal(0, 10))
bAR = pyro.sample("bAR", dist.Normal(0, 10))
mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
sigma = pyro.sample("sigma", dist.Uniform(0, 10))
with pyro.plate("plate"):
pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)
kernel = NUTS(model)
m8_1stan = MCMC(kernel, num_samples=1000).run(dd_rugged, dd_cont_africa, dd_log_gdp)
Out[5]:
Out[5]:
Code 8.6¶
In [6]:
precis(m8_1stan)
Out[6]:
Code 8.7¶
In [7]:
def model(rugged, cont_africa, log_gdp):
a = pyro.sample("a", dist.Normal(0, 100))
bR = pyro.sample("bR", dist.Normal(0, 10))
bA = pyro.sample("bA", dist.Normal(0, 10))
bAR = pyro.sample("bAR", dist.Normal(0, 10))
mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
sigma = pyro.sample("sigma", dist.Uniform(0, 10))
with pyro.plate("plate"):
pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)
kernel = NUTS(model)
m8_1stan_4chains = MCMC(kernel, num_samples=1000, num_chains=4)
m8_1stan_4chains.run(dd_rugged, dd_cont_africa, dd_log_gdp)
precis(m8_1stan_4chains)
Out[7]:
Out[7]:
Out[7]:
Out[7]:
Out[7]:
Code 8.8¶
In [8]:
post = extract_samples(m8_1stan)
{latent: post[latent][:5] for latent in post}
Out[8]:
Code 8.9¶
In [9]:
sns.pairplot(pd.DataFrame(post), diag_kind="kde",
plot_kws={"edgecolor": "none", "alpha": 0.2});
Out[9]:
Code 8.10¶
In [10]:
post = extract_samples(m8_1stan)
sns.pairplot(pd.DataFrame(post), diag_kind="kde",
plot_kws={"edgecolor": "none", "alpha": 0.2});
Out[10]:
Code 8.12¶
In [11]:
m8_1stan.information_criterion()
Out[11]:
Code 8.13¶
In [12]:
post = extract_samples(m8_1stan)
precis_df = precis(m8_1stan)
for latent in post:
ax = sns.lineplot(range(1000), post[latent])
ax.set(title="{} | n_eff = {}"
.format(latent, int(precis_df.loc[latent, "n_eff"])))
sns.mpl.pyplot.show()
Out[12]:
Out[12]:
Out[12]:
Out[12]:
Out[12]:
Code 8.13¶
In [ ]:
y = torch.tensor([-1., 1])
def model(y):
mu = pyro.sample("alpha", dist.Normal(0, 1000000))
sigma = pyro.sample("sigma", dist.HalfCauchy(200000))
with pyro.plate("plate"):
pyro.sample("y", dist.Normal(mu, sigma), obs=y)
kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["alpha"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_2 = MCMC(kernel, num_samples=50, warmup_steps=100, num_chains=2).run(y)
Code 8.14¶
In [ ]:
precis(m8_2)
Code 8.15¶
In [ ]:
def model(y):
alpha = pyro.sample("alpha", dist.Normal(1, 10))
mu = alpha
sigma = pyro.sample("sigma", dist.HalfCauchy(1))
with pyro.plate("plate"):
pyro.sample("y", dist.Normal(mu, sigma), obs=y)
kernel = NUTS(model)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["alpha"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_3 = MCMC(kernel, num_samples=3000, warmup_steps=1000, num_chains=2).run(y)
precis(m8_3)
Code 8.16¶
In [ ]:
y = dist.Cauchy(0, 5).sample(torch.Size([int(1e4)]))
mu = [sum(y[:i]) / (i + 1) for i in range(y.size(0))]
sns.pointplot(mu)
Code 8.17¶
In [ ]:
y = torch.randn(100)
Code 8.18¶
In [ ]:
def model(y):
a1 = pyro.sample("a1", dist.Normal(0, 1000000))
a2 = pyro.sample("a2", dist.Normal(0, 1000000))
mu = a1 + a2
sigma = pyro.sample("sigma", dist.HalfCauchy(1))
with pyro.plate("plate"):
pyro.sample("y", dist.Normal(mu, sigma), obs=y)
kernel = NUTS(model, jit_compile=True, max_tree_depth=7)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["a1"]["value"] = torch.tensor(0.)
trace.nodes["a2"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_4 = MCMC(kernel, num_samples=3000, warmup_steps=1000, num_chains=2).run(y)
precis(m8_4)
Code 8.19¶
In [ ]:
def model(y):
a1 = pyro.sample("a1", dist.Normal(0, 10))
a2 = pyro.sample("a2", dist.Normal(0, 10))
mu = a1 + a2
sigma = pyro.sample("sigma", dist.HalfCauchy(1))
with pyro.plate("plate"):
pyro.sample("y", dist.Normal(mu, sigma), obs=y)
kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["a1"]["value"] = torch.tensor(0.)
trace.nodes["a2"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_4 = MCMC(kernel, num_samples=3000, warmup_steps=1000, num_chains=2).run(y)
precis(m8_4)
Code 8.20¶
In [ ]:
def model(y):
pyro.sample("a", dist.Normal(0, 1))
pyro.sample("b", dist.HalfCauchy(1))
y = torch.tensor(1.)
kernel = NUTS(model)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["a"]["value"] = torch.tensor(0.)
trace.nodes["b"]["value"] = torch.tensor(0.)
kernel.initial_trace = trace
mp = MCMC(kernel, num_samples=int(1e4), warmup_steps=100).run(y)
precis(mp)
Code 8.21¶
In [ ]:
N = 100 # number of individuals
height = torch.empty(N).normal_(10, 2) # sim total height of each
leg_prop = torch.empty(N).uniform_(0.4, 0.5) # leg as proportion of height
leg_left = leg_prop * height + torch.empty(N).normal_(0, 0.02)
# sim right leg as proportion + error
leg_right = leg_prop * height + torch.empty(N).normal_(0, 0.02)
# sim left leg as proportion + error
# combine into data dict
d = {"height": height, "leg_left": leg_left, "leg_right": leg_right}
Code 8.22¶
In [ ]:
def model(leg_left, leg_right, height):
a = pyro.sample("a", dist.Normal(10, 100))
bl = pyro.sample("bl", dist.Normal(2, 10))
br = pyro.sample("br", dist.Normal(2, 10))
mu = a + bl * leg_left + br * leg_right
sigma = pyro.sample("sigma", dist.HalfCauchy(1))
with pyro.plate("plate"):
pyro.sample("height", dist.Normal(mu, sigma), obs=height)
kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(**d)
trace.nodes["a"]["value"] = torch.tensor(10.)
trace.nodes["bl"]["value"] = torch.tensor(0.)
trace.nodes["br"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m5_8s = MCMC(kernel, num_samples=1000, num_chains=4).run(**d)
Code 8.23¶
In [ ]:
def model(leg_left, leg_right, height):
a = pyro.sample("a", dist.Normal(10, 100))
bl = pyro.sample("bl", dist.Normal(2, 10))
br = pyro.sample("br", dist.TransformedDistribution(
dist.Normal(2, 10), transforms.ExpTransform()))
mu = a + bl * leg_left + br * leg_right
sigma = pyro.sample("sigma", dist.HalfCauchy(1))
with pyro.plate("plate"):
pyro.sample("height", dist.Normal(mu, sigma), obs=height)
kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(**d)
trace.nodes["a"]["value"] = torch.tensor(10.)
trace.nodes["bl"]["value"] = torch.tensor(0.)
trace.nodes["br"]["value"] = transforms.ExpTransform()(torch.tensor(0.))
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m5_8s2 = MCMC(kernel, num_samples=1000, num_chains=4).run(**d)