Chapter 6. Overfitting, Regularization, and Information Criteria

In [0]:
import pandas as pd
import seaborn as sns
import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
import pyro.ops.stats as stats

from rethinking import (LM, MAP, coef, compare, ensemble, extract_samples,
                        link, precis, replicate)

Code 6.1

In [1]:
sppnames = ["afarensis", "africanus", "habilis", "boisei", "rudolfensis", "ergaster",
            "sapiens"]
brainvolcc = torch.tensor([438., 452, 612, 521, 752, 871, 1350])
masskg = torch.tensor([37.0, 35.5, 34.5, 41.5, 55.5, 61.0, 53.5])
d = pd.DataFrame({"species": sppnames, "brain": brainvolcc, "mass": masskg})

Code 6.2

In [2]:
m6_1 = LM("brain ~ mass", data=d)
m6_1.model = pyro.do(m6_1.model, data={"sigma": 1})
m6_1.run();

Code 6.3

In [3]:
resid = brainvolcc - (coef(m6_1)["Intercept"] + coef(m6_1)["mass"] * masskg)
1 - resid.var() / brainvolcc.var()
Out[3]:
tensor(0.4902)

Code 6.4

In [4]:
m6_2 = LM("brain ~ mass + I(mass ** 2)", data=d).run()

Code 6.5

In [5]:
m6_3 = LM("brain ~ mass + I(mass ** 2) + I(mass ** 3)", data=d).run()
m6_4 = LM("brain ~ mass + I(mass ** 2) + I(mass ** 3) + I(mass ** 4)", data=d).run()
torch.set_default_dtype(torch.double)
m6_5 = LM("brain ~ mass + I(mass ** 2) + I(mass ** 3) + I(mass ** 4) "
          "+ I(mass ** 5)", data=d)
m6_5.model = pyro.do(m6_5.model, data={"sigma": 1})
m6_5.run()
m6_6 = LM("brain ~ mass + I(mass ** 2) + I(mass ** 3) + I(mass ** 4) "
          "+ I(mass ** 5) + I(mass ** 6)",
          data=d)
m6_6.model = pyro.do(m6_6.model, data={"sigma": 1})
m6_6.run()
torch.set_default_dtype(torch.float)

Code 6.6

In [6]:
m6_7 = LM("brain ~ 1", data=d).run()

Code 6.7

d_new = d.drop(index=i)
d_new.index = range(d_new.shape[0])

Code 6.8

In [7]:
sns.scatterplot("mass", "brain", data=d)
for i in range(d.shape[0]):
    d_new = d.drop(index=i)
    d_new.index = range(d_new.shape[0])
    m0 = LM("brain ~ mass", data=d_new)
    m0.model = pyro.do(m0.model, data={"sigma": 1})
    m0.run()
    x = torch.linspace(34, 62, 101)
    sns.lineplot(x, coef(m0)["Intercept"] + coef(m0)["mass"] * x, color="k", alpha=0.5)
Out[7]:

Code 6.9

In [8]:
p = torch.tensor([0.3, 0.7])
-(p * p.log()).sum()
Out[8]:
tensor(0.6109)

Code 6.10

In [9]:
# fit model with lm
m6_1 = LM("brain ~ mass", data=d).run()
m6_1.model = pyro.do(m6_1.model, data={"sigma": 1})
m6_1.run()

# compute deviance by cheating
coefs = coef(m6_1)
resid = brainvolcc - (coefs["Intercept"] + coefs["mass"] * masskg)
sigma = resid.pow(2).mean().sqrt()
coefs["sigma"] = sigma
model = pyro.do(LM("brain ~ mass", data=d, centering=False).model, data=coefs)
(-2) * pyro.poutine.trace(model).get_trace().log_prob_sum()
Out[9]:
tensor(94.9250)

Code 6.11

In [10]:
# standardize the mass before fitting
mass_s = (masskg - masskg.mean()) / masskg.std()

def model(mass, brain):
    a = pyro.param("a", brainvolcc.mean())
    b = pyro.param("b", torch.tensor(0.))
    mu = a + b * mass
    sigma = pyro.param("sigma", brainvolcc.std(), constraints.positive)
    with pyro.plate("plate"):
        pyro.sample("brain", dist.Normal(mu, sigma), obs=brain)

m6_8 = MAP(model).run(mass_s, brainvolcc)

# extract MAP estimates
theta = list(coef(m6_8).values())

# compute deviance
dev = (-2) * dist.Normal(theta[0] + theta[1] * mass_s,
                         theta[2]).log_prob(brainvolcc).sum()
dev
Out[10]:
tensor(94.9536)

Code 6.12

In [11]:
def model(mm, y, b_sigma):
        a = pyro.param("a", torch.tensor(0.))
        Bvec = a.unsqueeze(0)
        k = mm.size(1)
        if k > 1:
            with pyro.plate("plate_b", k - 1):
                b = pyro.sample("b", dist.Normal(0, b_sigma))
            Bvec = torch.cat([Bvec, b])
        mu = mm.matmul(Bvec)
        with pyro.plate("plate"):
            pyro.sample("y", dist.Normal(mu, 1), obs=y)

def sim_train_test(N=20, k=3, rho=[0.15, -0.4], b_sigma=100):
    n_dim = max(k, 3)
    Rho = torch.eye(n_dim)
    Rho[1:len(rho) + 1, 0] = torch.tensor(rho)
    Rho[0, 1:len(rho) + 1] = torch.tensor(rho)

    X_train = dist.MultivariateNormal(torch.zeros(n_dim), Rho).sample(torch.Size([N]))
    mm_train = torch.ones(N, 1)
    if k > 1:
        mm_train = torch.cat([mm_train, X_train[:, 1:k]], dim=1)

    m = MAP(model, start={"b": torch.zeros(k - 1)}, num_samples=10)
    m.run(mm_train, X_train[:, 0], b_sigma)
    coefs = torch.cat([c.reshape(-1) for c in coef(m).values()])
    dev_train = (-2) * dist.Normal(mm_train.matmul(coefs),
                                   1).log_prob(X_train[:, 0]).sum()

    X_test = dist.MultivariateNormal(torch.zeros(n_dim), Rho).sample(torch.Size([N]))
    mm_test = torch.ones(N, 1)
    if k > 1:
        mm_test = torch.cat([mm_test, X_test[:, 1:k]], dim=1)
    dev_test = (-2) * dist.Normal(mm_test.matmul(coefs),
                                  1).log_prob(X_test[:, 0]).sum()
    return torch.stack([dev_train, dev_test])

def dev_fn(N, k):
    print(k)
    r = torch.stack(replicate(int(1e4), sim_train_test, (N, k)), dim=1)
    return torch.stack([r[0].mean(), r[1].mean(), r[0].std(), r[1].std()])

N = 20
kseq = range(1, 6)
dev = torch.stack([dev_fn(N, k) for k in kseq], dim=1)
Out[11]:
1
2
3
4
5

Code 6.13

r = torch.stack(replicate(int(1e4), sim_train_test, (N, k), mc_cores=4), dim=1)

Code 6.14

In [12]:
ax = sns.scatterplot(torch.arange(1, 6), dev[0], s=80, color="b")
ax.set(xlabel="number of parameters", ylabel="deviance", title="N = {}".format(N))
sns.scatterplot(torch.arange(1.1, 6), dev[1], s=80, color="k")
pts_int = (dev[0] - dev[2], dev[0] + dev[2])
pts_out = (dev[1] - dev[3], dev[1] + dev[3])
ax.vlines(torch.arange(1, 6), pts_int[0], pts_int[1], color="b")
ax.vlines(torch.arange(1.1, 6), pts_out[0], pts_out[1], color="k")
ax.annotate("in", (2, dev[0][1]), xytext=(-25, -5),
            textcoords="offset pixels", color="b")
ax.annotate("out", (2.1, dev[1][1]), xytext=(10, -5), textcoords="offset pixels")
ax.annotate("-1SD", (2.1, pts_out[1][1]), xytext=(10, -5),
            textcoords="offset pixels", fontsize=12)
ax.annotate("+1SD", (2.1, pts_out[0][1]), xytext=(10, -5),
            textcoords="offset pixels", fontsize=12);
Out[12]:

Code 6.15

In [13]:
cars = pd.read_csv("../data/cars.csv", sep=",")
speed = torch.tensor(cars["speed"], dtype=torch.float)
distance = torch.tensor(cars["dist"], dtype=torch.float)

def model(speed, distance):
    a = pyro.sample("a", dist.Normal(0, 100))
    b = pyro.sample("b", dist.Normal(0, 10))
    mu = a + b * speed
    sigma = pyro.sample("sigma", dist.Uniform(0, 30))
    with pyro.plate("plate"):
        pyro.sample("dist", dist.Normal(mu, sigma), obs=distance)

m = MAP(model).run(speed, distance)
idx = m._categorical.sample(torch.Size([1000]))
post = {latent: samples[idx] for latent, samples in extract_samples(m).items()}

Code 6.16

In [14]:
n_samples = 1000

def ll_fn(s):
    mu = post["a"][s] + post["b"][s] * speed
    return dist.Normal(mu, post["sigma"][s]).log_prob(distance)

ll = torch.stack([ll_fn(s) for s in range(n_samples)], dim=1)

Code 6.17

In [15]:
n_cases = cars.shape[0]
lppd = ll.logsumexp(1) - torch.tensor(float(n_samples)).log()

Code 6.18

In [16]:
pWAIC = ll.var(1)

Code 6.19

In [17]:
-2 * (lppd.sum() - pWAIC.sum())
Out[17]:
tensor(421.0318)

Code 6.20

In [18]:
waic_vec = -2 * (lppd - pWAIC)
(n_cases * waic_vec.var()).sqrt()
Out[18]:
tensor(14.3545)

Code 6.21

In [19]:
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk.dropna().copy()
d["neocortex"] = d["neocortex.perc"] / 100
d.shape
Out[19]:
(17, 9)

Code 6.22

In [20]:
kcal_per_g = torch.tensor(d["kcal.per.g"], dtype=torch.float)
a_start = kcal_per_g.mean()
sigma_start = kcal_per_g.std().log()

def model(kcal_per_g):
    a = pyro.param("a", a_start)
    log_sigma = pyro.param("log.sigma", sigma_start)
    with pyro.plate("plate"):
        pyro.sample("kcal.per.g", dist.Normal(a, log_sigma.exp()), obs=kcal_per_g)

m6_11 = MAP(model).run(kcal_per_g)

def model(neocortex, kcal_per_g):
    a = pyro.param("a", a_start)
    bn = pyro.param("bn", torch.tensor(0.))
    mu = a + bn * neocortex
    log_sigma = pyro.param("log.sigma", sigma_start)
    with pyro.plate("plate"):
        pyro.sample("kcal.per.g", dist.Normal(mu, log_sigma.exp()), obs=kcal_per_g)

neocortex = torch.tensor(d["neocortex"], dtype=torch.float)
m6_12 = MAP(model).run(neocortex, kcal_per_g)

def model(mass, kcal_per_g):
    a = pyro.param("a", a_start)
    bm = pyro.param("bm", torch.tensor(0.))
    mu = a + bm * mass.log()
    log_sigma = pyro.param("log.sigma", sigma_start)
    with pyro.plate("plate"):
        pyro.sample("kcal.per.g", dist.Normal(mu, log_sigma.exp()), obs=kcal_per_g)

mass = torch.tensor(d["mass"], dtype=torch.float)
m6_13 = MAP(model).run(mass, kcal_per_g)

def model(neocortex, mass, kcal_per_g):
    a = pyro.param("a", a_start)
    bn = pyro.param("bn", torch.tensor(0.))
    bm = pyro.param("bm", torch.tensor(0.))
    mu = a + bn * neocortex + bm * mass.log()
    log_sigma = pyro.param("log.sigma", sigma_start)
    with pyro.plate("plate"):
        pyro.sample("kcal.per.g", dist.Normal(mu, log_sigma.exp()), obs=kcal_per_g)

m6_14 = MAP(model).run(neocortex, mass, kcal_per_g)

Code 6.23

In [21]:
with torch.no_grad():
    print(m6_14.information_criterion())
Out[21]:
OrderedDict([('waic', tensor(-15.1608)), ('p_waic', tensor(4.7740))])

Code 6.24

In [22]:
milk_models = compare({"m6.11": m6_11, "m6.12": m6_12, "m6.13": m6_13, "m6.14": m6_14})
milk_models.round(2)
Out[22]:
WAIC pWAIC dWAIC weight SE dSE
m6.14 -15.16 4.77 0.00 0.93 7.73 0.00
m6.11 -8.22 1.85 6.94 0.03 4.67 7.23
m6.13 -8.01 2.97 7.16 0.03 5.81 5.26
m6.12 -6.45 2.79 8.71 0.01 4.37 7.52

Code 6.25

In [23]:
ax = sns.pointplot(milk_models["WAIC"], milk_models.index, join=False)
ax.errorbar(milk_models["WAIC"], milk_models.index, xerr=milk_models["SE"], fmt="none")
sns.pointplot(milk_models["WAIC"] - 2 * milk_models["pWAIC"],
              milk_models.index, color="k", join=False)
ax.errorbar(milk_models["WAIC"][1:], [0.5, 1.5, 2.5],
            xerr=milk_models["dSE"][1:], fmt="^")
ax.axvline(milk_models["WAIC"][0], c="k", alpha=0.2)
ax.yaxis.grid(True)
Out[23]:

Code 6.26

In [24]:
dWAIC = milk_models.loc["m6.11", "dWAIC"]
dSE = milk_models.loc["m6.11", "dSE"]
diff = torch.empty(int(1e5)).normal_(dWAIC, dSE)
(diff < 0).float().sum() / 1e5
Out[24]:
tensor(0.1719)

Code 6.27

In [25]:
pd.concat([precis(m6_11), precis(m6_12), precis(m6_13), precis(m6_14)],
           keys=["m6.11", "m6.12", "m6.13", "m6.14"]).unstack(level=0)["Mean"]
Out[25]:
m6.11 m6.12 m6.13 m6.14
a 0.66 0.35 0.71 -1.09
bm NaN NaN -0.03 -0.10
bn NaN 0.45 NaN 2.80
log.sigma -1.78 -1.79 -1.85 -2.17

Code 6.28

In [26]:
coeftab = pd.concat([precis(m6_11), precis(m6_12), precis(m6_13), precis(m6_14)],
                    keys=["m6.11", "m6.12", "m6.13", "m6.14"]).unstack(level=0)
mean = coeftab["Mean"].iloc[:, ::-1].stack(dropna=False)
error = mean - coeftab["|0.89"].iloc[:, ::-1].stack(dropna=False)
ax = sns.pointplot(mean, mean.index, join=False, facecolors="none")
ax.errorbar(mean, range(mean.shape[0]), xerr=error, fmt="none", c="k")
ax.set(xlabel="Estimate")
for i in range(1, coeftab.shape[0]):
    ax.axhline(i * coeftab["Mean"].shape[1] - 0.5, c="k", alpha=0.2, linestyle="--")
ax.yaxis.grid(True)
Out[26]:

Code 6.29

In [27]:
# compute counterfactual predictions
# neocortex from 0.5 to 0.8
nc_seq = torch.linspace(start=0.5, end=0.8, steps=30)
d_predict = {
    "kcal_per_g": None,  # empty outcome
    "neocortex": nc_seq,  # sequence of neocortex
    "mass": torch.tensor(4.5).repeat(30)  # average mass
}
pred_m6_14 = link(m6_14, data=d_predict)
mu = pred_m6_14.mean(0)
mu_PI = stats.pi(pred_m6_14, 0.89, dim=0)

# plot it all
fig, ax = sns.mpl.pyplot.subplots()
sns.scatterplot("neocortex", "kcal.per.g", data=d, s=80)
sns.lineplot(nc_seq, mu, color="k", linestyle="--")
ax.lines[-1].set_linestyle("--")
sns.lineplot(nc_seq, mu_PI[0], color="k", linestyle="--")
ax.lines[-1].set_linestyle("--")
sns.lineplot(nc_seq, mu_PI[1], color="k", linestyle="--")
ax.lines[-1].set_linestyle("--")
Out[27]:

Code 6.30

In [28]:
milk_ensemble = ensemble({"m6.11": m6_11, "m6.12": m6_12,
                          "m6.13": m6_13, "m6.14": m6_14}, data=d_predict)
mu = milk_ensemble["link"].mean(0)
mu_PI = stats.pi(milk_ensemble["link"], 0.89, dim=0)
sns.lineplot(nc_seq, mu, color="k", ax=ax)
ax.fill_between(nc_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
fig
Out[28]:

Code 6.31

In [29]:
howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = howell1
d["age"] = (d["age"] - d["age"].mean()) / d["age"].std()
torch.manual_seed(1000)
i = torch.multinomial(torch.ones(d.shape[0]), num_samples=(d.shape[0] // 2)).tolist()
d1 = d.iloc[i]
d2 = d.drop(index=i)