Chapter 8. Markov Chain Monte Carlo

In [0]:
import math
import pandas as pd
import seaborn as sns
import torch
from torch.distributions import transforms

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS

from rethinking import (LM, MAP, coef, extract_samples, glimmer,
                        link, precis, replicate, sim, vcov)

Code 8.1

In [1]:
num_weeks = int(1e5)
positions = torch.tensor(0).repeat(num_weeks)
current = 9
for i in range(num_weeks):
    # record current position
    positions[i] = current

    # flip coin to generate proposal
    sample = torch.multinomial(torch.ones(2), num_samples=1)
    proposal = current + torch.tensor([-1, 1])[sample].item()
    # now make sure he loops around the archipelago
    if proposal < 0:
        proposal = 9
    if proposal > 9:
        proposal = 0
    
    # move?
    prob_move = (proposal + 1) / (current + 1)
    current = proposal if torch.rand(1) < prob_move else current

Code 8.2

In [2]:
rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)
dd = d[d["rgdppc_2000"].notnull()]

Code 8.3

In [3]:
def model(rugged, cont_africa, log_gdp):
    a = pyro.sample("a", dist.Normal(0, 100))
    bR = pyro.sample("bR", dist.Normal(0, 10))
    bA = pyro.sample("bA", dist.Normal(0, 10))
    bAR = pyro.sample("bAR", dist.Normal(0, 10))
    mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
    sigma = pyro.sample("sigma", dist.Uniform(0, 10))
    with pyro.plate("plate"):
        pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)

dd.index = range(dd.shape[0])
dd_rugged = torch.tensor(dd["rugged"], dtype=torch.float)
dd_cont_africa = torch.tensor(dd["cont_africa"], dtype=torch.float)
dd_log_gdp = torch.tensor(dd["log_gdp"], dtype=torch.float)
m8_1 = MAP(model).run(dd_rugged, dd_cont_africa, dd_log_gdp)
precis(m8_1)
Out[3]:
Mean StdDev |0.89 0.89|
a 9.23 0.14 9.01 9.46
bR -0.20 0.08 -0.33 -0.09
bA -1.95 0.22 -2.30 -1.59
bAR 0.40 0.13 0.20 0.61
sigma 0.94 0.05 0.85 1.01

Code 8.4

In [4]:
dd_trim = dd[["log_gdp", "rugged", "cont_africa"]]
dd_trim.info()
dd_trim.head()
Out[4]:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 170 entries, 0 to 169
Data columns (total 3 columns):
log_gdp        170 non-null float64
rugged         170 non-null float64
cont_africa    170 non-null int64
dtypes: float64(2), int64(1)
memory usage: 4.1 KB
Out[4]:
log_gdp rugged cont_africa
0 7.492609 0.858 1
1 8.216929 3.427 0
2 9.933263 0.769 0
3 9.407032 0.775 0
4 7.792343 2.688 0

Code 8.5

In [5]:
def model(rugged, cont_africa, log_gdp):
    a = pyro.sample("a", dist.Normal(0, 100))
    bR = pyro.sample("bR", dist.Normal(0, 10))
    bA = pyro.sample("bA", dist.Normal(0, 10))
    bAR = pyro.sample("bAR", dist.Normal(0, 10))
    mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
    sigma = pyro.sample("sigma", dist.Uniform(0, 10))
    with pyro.plate("plate"):
        pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)

kernel = NUTS(model)
m8_1stan = MCMC(kernel, num_samples=1000).run(dd_rugged, dd_cont_africa, dd_log_gdp)
Out[5]:
Out[5]:

Code 8.6

In [6]:
precis(m8_1stan)
Out[6]:
Mean StdDev |0.89 0.89| n_eff r_hat
a 9.21 0.14 8.99 9.44 231.13 1.0
bR -0.20 0.08 -0.33 -0.07 237.89 1.0
bA -1.93 0.23 -2.27 -1.54 239.02 1.0
bAR 0.38 0.13 0.17 0.56 283.33 1.0
sigma 0.95 0.05 0.86 1.03 468.30 1.0

Code 8.7

In [7]:
def model(rugged, cont_africa, log_gdp):
    a = pyro.sample("a", dist.Normal(0, 100))
    bR = pyro.sample("bR", dist.Normal(0, 10))
    bA = pyro.sample("bA", dist.Normal(0, 10))
    bAR = pyro.sample("bAR", dist.Normal(0, 10))
    mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
    sigma = pyro.sample("sigma", dist.Uniform(0, 10))
    with pyro.plate("plate"):
        pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)

kernel = NUTS(model)
m8_1stan_4chains = MCMC(kernel, num_samples=1000, num_chains=4)
m8_1stan_4chains.run(dd_rugged, dd_cont_africa, dd_log_gdp)
precis(m8_1stan_4chains)
Out[7]:
Out[7]:
Out[7]:
Out[7]:
Out[7]:
Mean StdDev |0.89 0.89| n_eff r_hat
a 9.22 0.15 9.01 9.47 1063.48 1.0
bR -0.20 0.08 -0.34 -0.08 1080.07 1.0
bA -1.95 0.24 -2.35 -1.60 1060.95 1.0
bAR 0.39 0.14 0.17 0.61 1179.82 1.0
sigma 0.95 0.05 0.87 1.03 1778.89 1.0

Code 8.8

In [8]:
post = extract_samples(m8_1stan)
{latent: post[latent][:5] for latent in post}
Out[8]:
{'a': tensor([9.2374, 9.3048, 9.1149, 9.0362, 9.3176]),
 'bR': tensor([-0.2022, -0.3170, -0.1933, -0.1671, -0.1185]),
 'bA': tensor([-2.1425, -2.0199, -1.9134, -1.8863, -2.1787]),
 'bAR': tensor([0.4066, 0.4874, 0.3268, 0.4041, 0.3728]),
 'sigma': tensor([0.8741, 0.9995, 1.0940, 1.0770, 0.9178])}

Code 8.9

In [9]:
sns.pairplot(pd.DataFrame(post), diag_kind="kde",
             plot_kws={"edgecolor": "none", "alpha": 0.2});
Out[9]:

Code 8.10

In [10]:
post = extract_samples(m8_1stan)
sns.pairplot(pd.DataFrame(post), diag_kind="kde",
             plot_kws={"edgecolor": "none", "alpha": 0.2});
Out[10]:

Code 8.12

In [11]:
m8_1stan.information_criterion()
Out[11]:
OrderedDict([('waic', tensor(469.6760)), ('p_waic', tensor(5.3495))])

Code 8.13

In [12]:
post = extract_samples(m8_1stan)
precis_df = precis(m8_1stan)
for latent in post:
    ax = sns.lineplot(range(1000), post[latent])
    ax.set(title="{}  |  n_eff = {}"
           .format(latent, int(precis_df.loc[latent, "n_eff"])))
    sns.mpl.pyplot.show()
Out[12]:
Out[12]:
Out[12]:
Out[12]:
Out[12]:

Code 8.13

In [ ]:
y = torch.tensor([-1., 1])

def model(y):
    mu = pyro.sample("alpha", dist.Normal(0, 1000000))
    sigma = pyro.sample("sigma", dist.HalfCauchy(200000))
    with pyro.plate("plate"):
        pyro.sample("y", dist.Normal(mu, sigma), obs=y)

kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["alpha"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_2 = MCMC(kernel, num_samples=50, warmup_steps=100, num_chains=2).run(y)

Code 8.14

In [ ]:
precis(m8_2)

Code 8.15

In [ ]:
def model(y):
    alpha = pyro.sample("alpha", dist.Normal(1, 10))
    mu = alpha
    sigma = pyro.sample("sigma", dist.HalfCauchy(1))
    with pyro.plate("plate"):
        pyro.sample("y", dist.Normal(mu, sigma), obs=y)

kernel = NUTS(model)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["alpha"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_3 = MCMC(kernel, num_samples=3000, warmup_steps=1000, num_chains=2).run(y)
precis(m8_3)

Code 8.16

In [ ]:
y = dist.Cauchy(0, 5).sample(torch.Size([int(1e4)]))
mu = [sum(y[:i]) / (i + 1) for i in range(y.size(0))]
sns.pointplot(mu)

Code 8.17

In [ ]:
y = torch.randn(100)

Code 8.18

In [ ]:
def model(y):
    a1 = pyro.sample("a1", dist.Normal(0, 1000000))
    a2 = pyro.sample("a2", dist.Normal(0, 1000000))
    mu = a1 + a2
    sigma = pyro.sample("sigma", dist.HalfCauchy(1))
    with pyro.plate("plate"):
        pyro.sample("y", dist.Normal(mu, sigma), obs=y)

kernel = NUTS(model, jit_compile=True, max_tree_depth=7)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["a1"]["value"] = torch.tensor(0.)
trace.nodes["a2"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_4 = MCMC(kernel, num_samples=3000, warmup_steps=1000, num_chains=2).run(y)
precis(m8_4)

Code 8.19

In [ ]:
def model(y):
    a1 = pyro.sample("a1", dist.Normal(0, 10))
    a2 = pyro.sample("a2", dist.Normal(0, 10))
    mu = a1 + a2
    sigma = pyro.sample("sigma", dist.HalfCauchy(1))
    with pyro.plate("plate"):
        pyro.sample("y", dist.Normal(mu, sigma), obs=y)

kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["a1"]["value"] = torch.tensor(0.)
trace.nodes["a2"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_4 = MCMC(kernel, num_samples=3000, warmup_steps=1000, num_chains=2).run(y)
precis(m8_4)

Code 8.20

In [ ]:
def model(y):
    pyro.sample("a", dist.Normal(0, 1))
    pyro.sample("b", dist.HalfCauchy(1))

y = torch.tensor(1.)
kernel = NUTS(model)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["a"]["value"] = torch.tensor(0.)
trace.nodes["b"]["value"] = torch.tensor(0.)
kernel.initial_trace = trace
mp = MCMC(kernel, num_samples=int(1e4), warmup_steps=100).run(y)
precis(mp)

Code 8.21

In [ ]:
N = 100  # number of individuals
height = torch.empty(N).normal_(10, 2)  # sim total height of each
leg_prop = torch.empty(N).uniform_(0.4, 0.5)  # leg as proportion of height
leg_left = leg_prop * height + torch.empty(N).normal_(0, 0.02)
# sim right leg as proportion + error
leg_right = leg_prop * height + torch.empty(N).normal_(0, 0.02)
# sim left leg as proportion + error
# combine into data dict
d = {"height": height, "leg_left": leg_left, "leg_right": leg_right}

Code 8.22

In [ ]:
def model(leg_left, leg_right, height):
    a = pyro.sample("a", dist.Normal(10, 100))
    bl = pyro.sample("bl", dist.Normal(2, 10))
    br = pyro.sample("br", dist.Normal(2, 10))
    mu = a + bl * leg_left + br * leg_right
    sigma = pyro.sample("sigma", dist.HalfCauchy(1))
    with pyro.plate("plate"):
        pyro.sample("height", dist.Normal(mu, sigma), obs=height)

kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(**d)
trace.nodes["a"]["value"] = torch.tensor(10.)
trace.nodes["bl"]["value"] = torch.tensor(0.)
trace.nodes["br"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m5_8s = MCMC(kernel, num_samples=1000, num_chains=4).run(**d)

Code 8.23

In [ ]:
def model(leg_left, leg_right, height):
    a = pyro.sample("a", dist.Normal(10, 100))
    bl = pyro.sample("bl", dist.Normal(2, 10))
    br = pyro.sample("br", dist.TransformedDistribution(
        dist.Normal(2, 10), transforms.ExpTransform()))
    mu = a + bl * leg_left + br * leg_right
    sigma = pyro.sample("sigma", dist.HalfCauchy(1))
    with pyro.plate("plate"):
        pyro.sample("height", dist.Normal(mu, sigma), obs=height)

kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(**d)
trace.nodes["a"]["value"] = torch.tensor(10.)
trace.nodes["bl"]["value"] = torch.tensor(0.)
trace.nodes["br"]["value"] = transforms.ExpTransform()(torch.tensor(0.))
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m5_8s2 = MCMC(kernel, num_samples=1000, num_chains=4).run(**d)