Chapter 4. Linear Models

In [0]:
import pandas as pd
import seaborn as sns
import torch

import pyro
import pyro.distributions as dist
import pyro.ops.stats as stats

from rethinking import MAP, coef, extract_samples, link, precis, sim, vcov

Code 4.1

In [1]:
pos = torch.empty(1000, 16).uniform_(-1, 1).sum(1)

Code 4.2

In [2]:
(1 + torch.empty(12).uniform_(0, 0.1)).prod()
Out[2]:
tensor(1.8004)

Code 4.3

In [3]:
growth = (1 + torch.empty(10000, 12).uniform_(0, 0.1)).prod(1)
sns.distplot(growth, hist=False)
ax = sns.lineplot(growth, dist.Normal(growth.mean(),
                                      growth.std()).log_prob(growth).exp())
ax.lines[1].set_linestyle("--")
Out[3]:

Code 4.4

In [4]:
big = (1 + torch.empty(10000, 12).uniform_(0, 0.5)).prod(1)
small = (1 + torch.empty(10000, 12).uniform_(0, 0.01)).prod(1)

Code 4.5

In [5]:
log_big = (1 + torch.empty(10000, 12).uniform_(0, 0.5)).prod(1).log()

Code 4.6

In [6]:
w, n = 6., 9
p_grid = torch.linspace(start=0, end=1, steps=1000)
posterior = (dist.Binomial(n, p_grid).log_prob(torch.tensor(w)).exp()
             * dist.Uniform(0, 1).log_prob(p_grid).exp())
posterior = posterior / posterior.sum()

Code 4.7

In [7]:
howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = howell1

Code 4.8

In [8]:
d.info()
d.head()
Out[8]:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 544 entries, 0 to 543
Data columns (total 4 columns):
height    544 non-null float64
weight    544 non-null float64
age       544 non-null float64
male      544 non-null int64
dtypes: float64(3), int64(1)
memory usage: 17.1 KB
Out[8]:
height weight age male
0 151.765 47.825606 63.0 1
1 139.700 36.485807 63.0 0
2 136.525 31.864838 65.0 0
3 156.845 53.041915 41.0 1
4 145.415 41.276872 51.0 0

Code 4.9

In [9]:
d["height"].head()
Out[9]:
0    151.765
1    139.700
2    136.525
3    156.845
4    145.415
Name: height, dtype: float64

Code 4.10

In [10]:
d2 = d[d["age"] >= 18]
d2_height = torch.tensor(d2["height"], dtype=torch.float)

Code 4.11

In [11]:
x = torch.linspace(100, 250, 101)
sns.lineplot(x, dist.Normal(178, 20).log_prob(x).exp());
Out[11]:

Code 4.12

In [12]:
x = torch.linspace(-10, 60, 101)
sns.lineplot(x, dist.Uniform(0, 50, validate_args=False).log_prob(x).exp());
Out[12]:

Code 4.13

In [13]:
sample_mu = torch.empty(int(1e4)).normal_(178, 20)
sample_sigma = torch.empty(int(1e4)).uniform_(0, 50)
prior_h = dist.Normal(sample_mu, sample_sigma).sample()
sns.distplot(prior_h);
Out[13]:

Code 4.14

In [14]:
mu_list = torch.linspace(start=140, end=160, steps=200)
sigma_list = torch.linspace(start=4, end=9, steps=200)
post = {"mu": mu_list.expand(200, 200).reshape(-1),
        "sigma": sigma_list.expand(200, 200).t().reshape(-1)}
post_LL = dist.Normal(post["mu"],
                      post["sigma"]).log_prob(d2_height.unsqueeze(1)).sum(0)
post_prod = (post_LL + dist.Normal(178, 20).log_prob(post["mu"])
             + dist.Uniform(0, 50).log_prob(post["sigma"]))
post_prob = (post_prod - max(post_prod)).exp()

Code 4.15

In [15]:
_, ax = sns.mpl.pyplot.subplots()
ax.contour(post["mu"].reshape(200, 200), post["sigma"].reshape(200, 200),
           post_prob.reshape(200, 200));
Out[15]:

Code 4.16

In [16]:
_, ax = sns.mpl.pyplot.subplots()
ax.imshow(post_prob.reshape(200, 200), origin="lower",
          extent=(140, 160, 4, 9), aspect="auto")
ax.grid(False)
Out[16]:

Code 4.17

In [17]:
sample_rows = torch.multinomial(input=post_prob, num_samples=int(1e4),
                                replacement=True)
sample_mu = post["mu"][sample_rows]
sample_sigma = post["sigma"][sample_rows]

Code 4.18

In [18]:
ax = sns.scatterplot(sample_mu, sample_sigma, s=64, alpha=0.1, edgecolor="none")
ax.set(xlabel="sample.mu", ylabel="sample.sigma");
Out[18]:

Code 4.19

In [19]:
sns.distplot(sample_mu)
sns.mpl.pyplot.show()
sns.distplot(sample_sigma);
Out[19]:
Out[19]:

Code 4.20

In [20]:
print(stats.hpdi(sample_mu, 0.89))
print(stats.hpdi(sample_sigma, 0.89))
Out[20]:
tensor([153.8694, 155.1759])
tensor([7.3166, 8.2462])

Code 4.21

In [21]:
d3 = stats.resample(d2_height, num_samples=20)

Code 4.22

In [22]:
mu_list = torch.linspace(start=150, end=170, steps=200)
sigma_list = torch.linspace(start=4, end=20, steps=200)
post2 = {"mu": mu_list.expand(200, 200).reshape(-1),
         "sigma": sigma_list.expand(200, 200).t().reshape(-1)}
post2_LL = dist.Normal(post2["mu"], post2["sigma"]).log_prob(d3.unsqueeze(1)).sum(0)
post2_prod = (post2_LL + dist.Normal(178, 20).log_prob(post2["mu"])
              + dist.Uniform(0, 50).log_prob(post2["sigma"]))
post2_prob = (post2_prod - max(post2_prod)).exp()
sample2_rows = torch.multinomial(input=post2_prob, num_samples=int(1e4),
                                 replacement=True)
sample2_mu = post2["mu"][sample2_rows]
sample2_sigma = post2["sigma"][sample2_rows]
ax = sns.scatterplot(sample2_mu, sample2_sigma, s=80, alpha=0.1, edgecolor="none")
ax.set(xlabel="mu", ylabel="sigma");
Out[22]:

Code 4.23

In [23]:
sns.distplot(sample2_sigma, hist=False)
ax = sns.lineplot(sample2_sigma, dist.Normal(sample2_sigma.mean(), sample2_sigma.std())
                  .log_prob(sample2_sigma).exp())
ax.lines[1].set_linestyle("--")
Out[23]:

Code 4.24

In [24]:
howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = howell1
d2 = d[d["age"] >= 18]

Code 4.25

In [25]:
def flist(height):
    mu = pyro.sample("mu", dist.Normal(178, 20))
    sigma = pyro.sample("sigma", dist.Uniform(0, 50))
    with pyro.plate("plate"):
        pyro.sample("height", dist.Normal(mu, sigma), obs=height)

Code 4.26

In [26]:
d2_height = torch.tensor(d2["height"], dtype=torch.float)
m4_1 = MAP(flist).run(d2_height)

Code 4.27

In [27]:
precis(m4_1)
Out[27]:
Mean StdDev |0.89 0.89|
mu 154.62 0.42 153.99 155.31
sigma 7.74 0.29 7.27 8.20

Code 4.28

In [28]:
start = {"mu": d2_height.mean(), "sigma": d2_height.std()}

Code 4.29

In [29]:
def model(height):
    mu = pyro.sample("mu", dist.Normal(178, 0.1))
    sigma = pyro.sample("sigma", dist.Uniform(0, 50))
    with pyro.plate("plate"):
        pyro.sample("height", dist.Normal(mu, sigma), obs=height)

m4_2 = MAP(model).run(d2_height)
precis(m4_2)
Out[29]:
Mean StdDev |0.89 0.89|
mu 177.87 0.10 177.71 178.03
sigma 24.55 0.92 23.10 26.03

Code 4.30

In [30]:
vcov(m4_1)
Out[30]:
tensor([[ 0.1730, -0.0003],
        [-0.0003,  0.0854]])

Code 4.31

In [31]:
print(vcov(m4_1).diag())
cov = vcov(m4_1)
print(cov / cov.diag().ger(cov.diag()).sqrt())
Out[31]:
tensor([0.1730, 0.0854])
tensor([[ 1.0000, -0.0025],
        [-0.0025,  1.0000]])

Code 4.32

In [32]:
post = extract_samples(m4_1)
{latent: post[latent][:5] for latent in post}
Out[32]:
{'mu': tensor([154.3505, 153.8552, 154.8622, 154.0932, 154.5948]),
 'sigma': tensor([8.0512, 8.2835, 7.6042, 7.9720, 7.8379])}

Code 4.33

In [33]:
precis(post)
Out[33]:
Mean StdDev |0.89 0.89|
mu 154.62 0.42 153.99 155.31
sigma 7.74 0.29 7.27 8.20

Code 4.34

In [34]:
post = dist.MultivariateNormal(torch.stack(list(coef(m4_1).values())),
                               vcov(m4_1)).sample(torch.Size([int(1e4)]))

Code 4.35

In [35]:
def model(height):
    mu = pyro.sample("mu", dist.Normal(178, 20))
    log_sigma = pyro.sample("log_sigma", dist.Normal(2, 10))
    with pyro.plate("plate"):
        pyro.sample("height", dist.Normal(mu, log_sigma.exp()), obs=height)

m4_1_logsigma = MAP(model).run(d2_height)

Code 4.36

In [36]:
post = extract_samples(m4_1_logsigma)
sigma = post["log_sigma"].exp()

Code 4.37

In [37]:
sns.scatterplot("weight", "height", data=d2);
Out[37]:

Code 4.38

In [38]:
# load data again, since it's a long way back
howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = howell1
d2 = d[d["age"] >= 18]

# fit model
def model(weight, height):
    a = pyro.sample("a", dist.Normal(178, 100))
    b = pyro.sample("b", dist.Normal(0, 10))
    mu = a + b * weight
    sigma = pyro.sample("sigma", dist.Uniform(0, 50))
    with pyro.plate("plate"):
        pyro.sample("height", dist.Normal(mu, sigma), obs=height)

d2_weight = torch.tensor(d2["weight"], dtype=torch.float)
d2_height = torch.tensor(d2["height"], dtype=torch.float)
m4_3 = MAP(model).run(d2_weight, d2_height)

Code 4.39

In [39]:
def model(weight, height):
    a = pyro.sample("a", dist.Normal(178, 100))
    b = pyro.sample("b", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.Uniform(0, 50))
    with pyro.plate("plate"):
        pyro.sample("height", dist.Normal(a + b * weight, sigma), obs=height)

m4_3 = MAP(model).run(d2_weight, d2_height)

Code 4.40

In [40]:
precis(m4_3)
Out[40]:
Mean StdDev |0.89 0.89|
a 113.89 1.90 110.98 117.06
b 0.90 0.04 0.84 0.97
sigma 5.08 0.19 4.77 5.39

Code 4.41

In [41]:
precis(m4_3, corr=True)
Out[41]:
Mean StdDev |0.89 0.89| a b sigma
a 113.89 1.90 110.98 117.06 1.00 -0.99 0.01
b 0.90 0.04 0.84 0.97 -0.99 1.00 -0.01
sigma 5.08 0.19 4.77 5.39 0.01 -0.01 1.00

Code 4.42

In [42]:
d2_weight_c = d2_weight - d2_weight.mean()

Code 4.43

In [43]:
m4_4 = MAP(model).run(d2_weight_c, d2_height)

Code 4.44

In [44]:
precis(m4_4, corr=True)
Out[44]:
Mean StdDev |0.89 0.89| a b sigma
a 154.60 0.27 154.18 155.04 1.00 0.01 -0.01
b 0.91 0.04 0.84 0.97 0.01 1.00 0.02
sigma 5.08 0.19 4.77 5.39 -0.01 0.02 1.00

Code 4.45

In [45]:
sns.scatterplot("weight", "height", data=d2)
x = torch.linspace(30, 65, 101)
sns.lineplot(x, (coef(m4_3)["a"] + coef(m4_3)["b"] * x), color="k");
Out[45]:

Code 4.46

In [46]:
post = extract_samples(m4_3)

Code 4.47

In [47]:
{latent: post[latent][:5].detach() for latent in post}
Out[47]:
{'a': tensor([113.5913, 115.2334, 114.4788, 118.5937, 115.5136]),
 'b': tensor([0.9214, 0.8755, 0.8890, 0.7998, 0.8732]),
 'sigma': tensor([5.0328, 5.1344, 5.0411, 4.9590, 5.1898])}

Code 4.48

In [48]:
N = 10
dN = {"weight": d2_weight[:N], "height": d2_height[:N]}
mN = MAP(model).run(**dN)

Code 4.49

In [49]:
# extract 20 samples from the posterior
idx = mN._categorical.sample(torch.Size([20]))
post = {latent: samples[idx] for latent, samples in extract_samples(mN).items()}

# display raw data and sample size
ax = sns.scatterplot("weight", "height", data=dN)
ax.set(xlabel="weight", ylabel="height", title="N = {}".format(N))

# plot the lines, with transparency
x = torch.linspace(30, 65, 101)
for i in range(20):
    sns.lineplot(x, post["a"][i] + post["b"][i] * x, color="k", alpha=0.3)
Out[49]:

Code 4.50

In [50]:
post = extract_samples(m4_3)
mu_at_50 = post["a"] + post["b"] * 50

Code 4.51

In [51]:
ax = sns.distplot(mu_at_50)
ax.set(xlabel="mu|weight=50", ylabel="Density");
Out[51]:

Code 4.52

In [52]:
stats.hpdi(mu_at_50, prob=0.89)
Out[52]:
tensor([158.6436, 159.7386])

Code 4.53

In [53]:
mu = link(m4_3)
mu.shape, mu[:5, 0]
Out[53]:
(torch.Size([1000, 352]),
 tensor([157.2706, 156.9633, 157.1407, 157.3087, 157.4005]))

Code 4.54

In [54]:
# define sequence of weights to compute predictions for
# these values will be on the horizontal axis
weight_seq = torch.arange(start=25., end=71, step=1)

# use link to compute mu
# for each sample from posterior
# and for each weight in weight_seq
mu = link(m4_3, data={"weight": weight_seq})
mu.shape, mu[:5, 0]
Out[54]:
(torch.Size([1000, 46]),
 tensor([135.8355, 135.6814, 136.9525, 137.5477, 137.2216]))

Code 4.55

In [55]:
# use visible=False to hide raw data
sns.scatterplot("weight", "height", data=d2, visible=False)

# loop over samples and plot each mu value
for i in range(100):
    sns.scatterplot(weight_seq, mu[i], color="royalblue", alpha=0.1)
Out[55]:

Code 4.56

In [56]:
# summarize the distribution of mu
mu_mean = mu.mean(0)
mu_HPDI = stats.hpdi(mu, prob=0.89, dim=0)

Code 4.57

In [57]:
# plot raw data
# fading out points to make line and interval more visible
sns.scatterplot("weight", "height", data=d2, alpha=0.5)

# plot the MAP line, aka the mean mu for each weight
ax = sns.lineplot(weight_seq, mu_mean, color="k")

# plot a shaded region for 89% HPDI
ax.fill_between(weight_seq, mu_HPDI[0], mu_HPDI[1], color="k", alpha=0.2);
Out[57]:

Code 4.58

In [58]:
post = extract_samples(m4_3)
mu_link = lambda weight: post["a"].unsqueeze(1) + post["b"].unsqueeze(1) * weight
weight_seq = torch.arange(start=25., end=71, step=1)
mu = mu_link(weight_seq)
mu_mean = mu.mean(0)
mu_HPDI = stats.hpdi(mu, prob=0.89, dim=0)

Code 4.59

In [59]:
sim_height = sim(m4_3, data={"weight": weight_seq})
sim_height.shape, sim_height[:5, 0]
Out[59]:
(torch.Size([1000, 46]),
 tensor([137.2475, 137.3180, 127.3701, 139.6853, 132.9194]))

Code 4.60

In [60]:
height_PI = stats.pi(sim_height, prob=0.89, dim=0)

Code 4.61

In [61]:
# plot raw data
sns.scatterplot("weight", "height", data=d2, alpha=0.5)

# draw MAP line
ax = sns.lineplot(weight_seq, mu_mean, color="k")

# draw HPDI region for line
ax.fill_between(weight_seq, mu_HPDI[0], mu_HPDI[1], color="k", alpha=0.15)

# draw PI region for simulated heights
ax.fill_between(weight_seq, height_PI[0], height_PI[1], color="k", alpha=0.15);
Out[61]:

Code 4.62

In [62]:
sim_height = sim(m4_3, data={"weight": weight_seq}, n=int(1e4))
height_PI = stats.pi(sim_height, prob=0.89, dim=0)

Code 4.63

In [63]:
def sim_fn(weight):
    mean = post["a"].unsqueeze(1) + post["b"].unsqueeze(1) * weight
    sd = post["sigma"].unsqueeze(1)
    return dist.Normal(loc=mean, scale=sd).sample()

post = extract_samples(m4_3)
weight_seq = torch.arange(start=25., end=71, step=1)
sim_height = sim_fn(weight_seq)
height_PI = stats.pi(sim_height, prob=0.89, dim=0)

Code 4.64

In [64]:
howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = howell1
d.info()
d.head()
Out[64]:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 544 entries, 0 to 543
Data columns (total 4 columns):
height    544 non-null float64
weight    544 non-null float64
age       544 non-null float64
male      544 non-null int64
dtypes: float64(3), int64(1)
memory usage: 17.1 KB
Out[64]:
height weight age male
0 151.765 47.825606 63.0 1
1 139.700 36.485807 63.0 0
2 136.525 31.864838 65.0 0
3 156.845 53.041915 41.0 1
4 145.415 41.276872 51.0 0

Code 4.65

In [65]:
weight = torch.tensor(d["weight"], dtype=torch.float)
weight_s = (weight - weight.mean()) / weight.std()

Code 4.66

In [66]:
def model(weight, weight2, height):
    a = pyro.sample("a", dist.Normal(178, 100))
    b1 = pyro.sample("b1", dist.Normal(0, 10))
    b2 = pyro.sample("b2", dist.Normal(0, 10))
    mu = a + b1 * weight + b2 * weight2
    sigma = pyro.sample("sigma", dist.Uniform(0, 50))
    with pyro.plate("plate"):
        pyro.sample("height", dist.Normal(mu, sigma), obs=height)

weight_s2 = weight_s ** 2
height = torch.tensor(d["height"], dtype=torch.float)
m4_5 = MAP(model).run(weight_s, weight_s2, height)

Code 4.67

In [67]:
precis(m4_5)
Out[67]:
Mean StdDev |0.89 0.89|
a 146.65 0.37 146.07 147.24
b1 21.41 0.29 20.93 21.86
b2 -8.41 0.29 -8.87 -7.96
sigma 5.75 0.17 5.47 6.03

Code 4.68

In [68]:
weight_seq = torch.linspace(start=-2.2, end=2, steps=30)
pred_data = {"weight": weight_seq, "weight2": weight_seq ** 2}
mu = link(m4_5, data=pred_data)
mu_mean = mu.mean(0)
mu_PI = stats.pi(mu, prob=0.89, dim=0)
sim_height = sim(m4_5, data=pred_data)
height_PI = stats.pi(sim_height, prob=0.89, dim=0)

Code 4.69

In [69]:
ax = sns.scatterplot(weight_s, height, alpha=0.5)
ax.set(xlabel="weight.s", ylabel="height")
sns.lineplot(weight_seq, mu_mean, color="k")
ax.fill_between(weight_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
ax.fill_between(weight_seq, height_PI[0], height_PI[1], color="k", alpha=0.2);
Out[69]:

Code 4.70

In [70]:
def model(weight, weight2, weight3, height):
    a = pyro.sample("a", dist.Normal(178, 100))
    b1 = pyro.sample("b1", dist.Normal(0, 10))
    b2 = pyro.sample("b2", dist.Normal(0, 10))
    b3 = pyro.sample("b3", dist.Normal(0, 10))
    mu = a + b1 * weight + b2 * weight2 + b3 * weight3
    sigma = pyro.sample("sigma", dist.Uniform(0, 50))
    with pyro.plate("plate"):
        pyro.sample("height", dist.Normal(mu, sigma), obs=height)

weight_s3 = weight_s ** 3
m4_6 = MAP(model).run(weight_s, weight_s2, weight_s3, height)

Code 4.71

In [71]:
fig, ax = sns.mpl.pyplot.subplots()
sns.scatterplot(weight_s, height, alpha=0.5)
ax.set(xlabel="weight", ylabel="height", xticks=[]);
Out[71]:

Code 4.72

In [72]:
at = torch.tensor([-2, -1, 0, 1, 2])
labels = at * weight.std() + weight.mean()
ax.set_xticks(at)
ax.set_xticklabels([round(label.item(), 1) for label in labels])
fig
Out[72]:

Code 4.73

In [73]:
sns.scatterplot("weight", "height", data=howell1, alpha=0.4);
Out[73]: