# Chapter 4. Geocentric Models

In [0]:
```import os
import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
from scipy.interpolate import BSpline
from scipy.stats import gaussian_kde

import jax.numpy as jnp
from jax import lax, random, vmap

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import hpdi, print_summary
from numpyro.infer import ELBO, SVI, Predictive, init_to_value
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")
```

#### Code 4.1¶

In [1]:
```pos = jnp.sum(dist.Uniform(-1, 1).sample(random.PRNGKey(0), (1000, 16)), -1)
```

#### Code 4.2¶

In [2]:
```jnp.prod(1 + dist.Uniform(0, 0.1).sample(random.PRNGKey(0), (12,)))
```
Out[2]:
`DeviceArray(1.7294353, dtype=float32)`

#### Code 4.3¶

In [3]:
```growth = jnp.prod(1 + dist.Uniform(0, 0.1).sample(random.PRNGKey(0), (1000, 12)), -1)
az.plot_density({"growth": growth}, hdi_prob=1)
x = jnp.sort(growth)
plt.plot(x, jnp.exp(dist.Normal(jnp.mean(x), jnp.std(x)).log_prob(x)), "--")
plt.show()
```
Out[3]:

#### Code 4.4¶

In [4]:
```big = jnp.prod(1 + dist.Uniform(0, 0.5).sample(random.PRNGKey(0), (1000, 12)), -1)
small = jnp.prod(1 + dist.Uniform(0, 0.01).sample(random.PRNGKey(0), (1000, 12)), -1)
```

#### Code 4.5¶

In [5]:
```log_big = jnp.log(
jnp.prod(1 + dist.Uniform(0, 0.5).sample(random.PRNGKey(0), (1000, 12)), -1)
)
```

#### Code 4.6¶

In [6]:
```w = 6
n = 9
p_grid = jnp.linspace(start=0, stop=1, num=100)
prob_binom = jnp.exp(dist.Binomial(n, p_grid).log_prob(w))
posterior = prob_binom * jnp.exp(dist.Uniform(0, 1).log_prob(p_grid))
posterior = posterior / jnp.sum(posterior)
```

#### Code 4.7¶

In [7]:
```Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = Howell1
```

#### Code 4.8¶

In [8]:
```d.info()
```
Out[8]:
```<class 'pandas.core.frame.DataFrame'>
RangeIndex: 544 entries, 0 to 543
Data columns (total 4 columns):
#   Column  Non-Null Count  Dtype
---  ------  --------------  -----
0   height  544 non-null    float64
1   weight  544 non-null    float64
2   age     544 non-null    float64
3   male    544 non-null    int64
dtypes: float64(3), int64(1)
memory usage: 17.1 KB
```
Out[8]:
height weight age male
0 151.765 47.825606 63.0 1
1 139.700 36.485807 63.0 0
2 136.525 31.864838 65.0 0
3 156.845 53.041915 41.0 1
4 145.415 41.276872 51.0 0

#### Code 4.9¶

In [9]:
```print_summary(dict(zip(d.columns, d.T.values)), 0.89, False)
```
Out[9]:
```                mean       std    median      5.5%     94.5%     n_eff     r_hat
age     29.34     20.75     27.00      0.00     57.00    186.38      1.03
height    138.26     27.60    148.59     90.81    170.18    218.68      1.06
male      0.47      0.50      0.00      0.00      1.00    670.75      1.00
weight     35.61     14.72     40.06     11.37     55.71    305.62      1.05

```

#### Code 4.10¶

In [10]:
```d.height
```
Out[10]:
```0      151.765
1      139.700
2      136.525
3      156.845
4      145.415
...
539    145.415
540    162.560
541    156.210
542     71.120
543    158.750
Name: height, Length: 544, dtype: float64```

#### Code 4.11¶

In [11]:
```d2 = d[d.age >= 18]
```

#### Code 4.12¶

In [12]:
```x = jnp.linspace(100, 250, 101)
plt.plot(x, jnp.exp(dist.Normal(178, 20).log_prob(x)))
plt.show()
```
Out[12]:

#### Code 4.13¶

In [13]:
```x = jnp.linspace(-10, 60, 101)
plt.plot(x, jnp.exp(dist.Uniform(0, 50, validate_args=True).log_prob(x)))
plt.show()
```
Out[13]:
```UserWarning: Out-of-support values provided to log prob method. The value argument should be within the support.
```
Out[13]:

#### Code 4.14¶

In [14]:
```sample_mu = dist.Normal(178, 20).sample(random.PRNGKey(0), (int(1e4),))
sample_sigma = dist.Uniform(0, 50).sample(random.PRNGKey(1), (int(1e4),))
prior_h = dist.Normal(sample_mu, sample_sigma).sample(random.PRNGKey(2))
az.plot_kde(prior_h)
plt.show()
```
Out[14]:

#### Code 4.15¶

In [15]:
```sample_mu = dist.Normal(178, 100).sample(random.PRNGKey(0), (int(1e4),))
prior_h = dist.Normal(sample_mu, sample_sigma).sample(random.PRNGKey(2))
az.plot_kde(prior_h)
plt.show()
```
Out[15]:

#### Code 4.16¶

In [16]:
```mu_list = jnp.linspace(start=150, stop=160, num=100)
sigma_list = jnp.linspace(start=7, stop=9, num=100)
mesh = jnp.meshgrid(mu_list, sigma_list)
post = {"mu": mesh[0].reshape(-1), "sigma": mesh[1].reshape(-1)}
post["LL"] = vmap(
lambda mu, sigma: jnp.sum(dist.Normal(mu, sigma).log_prob(d2.height.values))
)(post["mu"], post["sigma"])
logprob_mu = dist.Normal(178, 20).log_prob(post["mu"])
logprob_sigma = dist.Uniform(0, 50).log_prob(post["sigma"])
post["prob"] = post["LL"] + logprob_mu + logprob_sigma
post["prob"] = jnp.exp(post["prob"] - jnp.max(post["prob"]))
```

#### Code 4.17¶

In [17]:
```plt.contour(
post["mu"].reshape(100, 100),
post["sigma"].reshape(100, 100),
post["prob"].reshape(100, 100),
)
plt.show()
```
Out[17]:

#### Code 4.18¶

In [18]:
```plt.imshow(
post["prob"].reshape(100, 100),
origin="lower",
extent=(150, 160, 7, 9),
aspect="auto",
)
plt.show()
```
Out[18]:

#### Code 4.19¶

In [19]:
```prob = post["prob"] / jnp.sum(post["prob"])
sample_rows = dist.Categorical(probs=prob).sample(random.PRNGKey(0), (int(1e4),))
sample_mu = post["mu"][sample_rows]
sample_sigma = post["sigma"][sample_rows]
```

#### Code 4.20¶

In [20]:
```plt.scatter(sample_mu, sample_sigma, s=64, alpha=0.1, edgecolor="none")
plt.show()
```
Out[20]:

#### Code 4.19¶

In [21]:
```az.plot_kde(sample_mu)
plt.show()
az.plot_kde(sample_sigma)
plt.show()
```
Out[21]:
Out[21]:

#### Code 4.22¶

In [22]:
```print(hpdi(sample_mu, 0.89))
print(hpdi(sample_sigma, 0.89))
```
Out[22]:
```[153.93939 155.15152]
[7.3232327 8.252525 ]
```

#### Code 4.23¶

In [23]:
```d3 = d2.height.sample(n=20)
```

#### Code 4.24¶

In [24]:
```mu_list = jnp.linspace(start=150, stop=170, num=200)
sigma_list = jnp.linspace(start=4, stop=20, num=200)
mesh = jnp.meshgrid(mu_list, sigma_list)
post2 = {"mu": mesh[0].reshape(-1), "sigma": mesh[1].reshape(-1)}
post2["LL"] = vmap(
lambda mu, sigma: jnp.sum(dist.Normal(mu, sigma).log_prob(d3.values))
)(post2["mu"], post2["sigma"])
logprob_mu = dist.Normal(178, 20).log_prob(post2["mu"])
logprob_sigma = dist.Uniform(0, 50).log_prob(post2["sigma"])
post2["prob"] = post2["LL"] + logprob_mu + logprob_sigma
post2["prob"] = jnp.exp(post2["prob"] - jnp.max(post2["prob"]))
prob = post2["prob"] / jnp.sum(post2["prob"])
sample2_rows = dist.Categorical(probs=prob).sample(random.PRNGKey(0), (int(1e4),))
sample2_mu = post2["mu"][sample2_rows]
sample2_sigma = post2["sigma"][sample2_rows]
plt.scatter(sample2_mu, sample2_sigma, s=64, alpha=0.1, edgecolor="none")
plt.show()
```
Out[24]:

#### Code 4.25¶

In [25]:
```az.plot_kde(sample2_sigma)
x = jnp.sort(sample2_sigma)
plt.plot(x, jnp.exp(dist.Normal(jnp.mean(x), jnp.std(x)).log_prob(x)), "--")
plt.show()
```
Out[25]:

#### Code 4.26¶

In [26]:
```Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = Howell1
d2 = d[d["age"] >= 18]
```

#### Code 4.27¶

In [27]:
```def flist(height):
mu = numpyro.sample("mu", dist.Normal(178, 20))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
```

#### Code 4.28¶

In [28]:
```m4_1 = AutoLaplaceApproximation(flist)
svi = SVI(flist, m4_1, optim.Adam(1), ELBO(), height=d2.height.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(2000))
p4_1 = svi.get_params(state)
```

Note: `lax.scan` is similar to `lax.fori_loop` but it also collects loss values at every steps. Instead of using `lax.scan`, we can use the native Python loop after jitting the SVI update method:

```from jax import jit

update_fn = jit(svi.update)
state = init_state
loss = []
for i in range(1000):
state, loss_i = update_fn(state)
loss.append(loss_i)
```

Using `lax.scan` is a bit faster though.

#### Code 4.29¶

In [29]:
```samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, (1000,))
print_summary(samples, 0.89, False)
```
Out[29]:
```                mean       std    median      5.5%     94.5%     n_eff     r_hat
mu    154.60      0.40    154.60    154.00    155.28    995.06      1.00
sigma      7.76      0.30      7.76      7.33      8.26   1007.15      1.00

```

#### Code 4.30¶

In [30]:
```start = {"mu": d2.height.mean(), "sigma": d2.height.std()}
m4_1 = AutoLaplaceApproximation(flist, init_strategy=init_to_value(values=start))
svi = SVI(flist, m4_1, optim.Adam(0.1), ELBO(), height=d2.height.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(2000))
p4_1 = svi.get_params(state)
```

#### Code 4.31¶

In [31]:
```def model(height):
mu = numpyro.sample("mu", dist.Normal(178, 0.1))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)

m4_2 = AutoLaplaceApproximation(model)
svi = SVI(model, m4_2, optim.Adam(1), ELBO(), height=d2.height.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(2000))
p4_2 = svi.get_params(state)
samples = m4_2.sample_posterior(random.PRNGKey(1), p4_2, (1000,))
print_summary(samples, 0.89, False)
```
Out[31]:
```                mean       std    median      5.5%     94.5%     n_eff     r_hat
mu    177.86      0.10    177.86    177.72    178.03    995.05      1.00
sigma     24.57      0.94     24.60     23.01     25.96   1012.88      1.00

```

#### Code 4.32¶

In [32]:
```samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, (1000,))
vcov = jnp.cov(jnp.stack(list(samples.values()), axis=0))
vcov
```
Out[32]:
```DeviceArray([[0.1624806 , 0.00178086],
[0.00178086, 0.08732025]], dtype=float32)```

#### Code 4.33¶

In [33]:
```print(jnp.diagonal(vcov))
print(vcov / jnp.sqrt(jnp.outer(jnp.diagonal(vcov), jnp.diagonal(vcov))))
```
Out[33]:
```[0.1624806  0.08732025]
[[1.       0.014951]
[0.014951 1.      ]]
```

#### Code 4.34¶

In [34]:
```post = m4_1.sample_posterior(random.PRNGKey(1), p4_1, (int(1e4),))
{latent: list(post[latent][:6]) for latent in post}
```
Out[34]:
```{'mu': [154.24677, 154.48788, 154.9816, 154.2149, 155.49384, 154.82945],
'sigma': [7.559363, 7.3059254, 7.279783, 7.810844, 7.9050875, 7.97789]}```

#### Code 4.35¶

In [35]:
```print_summary(post, 0.89, False)
```
Out[35]:
```                mean       std    median      5.5%     94.5%     n_eff     r_hat
mu    154.61      0.41    154.61    153.94    155.25   9927.01      1.00
sigma      7.75      0.29      7.74      7.28      8.22   9502.46      1.00

```

#### Code 4.36¶

In [36]:
```samples_flat = jnp.stack(list(post.values()))
mu, sigma = jnp.mean(samples_flat), jnp.cov(samples_flat)
post = dist.MultivariateNormal(mu, sigma).sample(random.PRNGKey(0), (int(1e4),))
```

#### Code 4.37¶

In [37]:
```az.plot_pair(d2[["weight", "height"]].to_dict(orient="list"))
plt.show()
```
Out[37]:

#### Code 4.38¶

In [38]:
```with numpyro.handlers.seed(rng_seed=2971):
N = 100  # 100 lines
a = numpyro.sample("a", dist.Normal(178, 20).expand([N]))
b = numpyro.sample("b", dist.Normal(0, 10).expand([N]))
```

#### Code 4.39¶

In [39]:
```plt.subplot(
xlim=(d2.weight.min(), d2.weight.max()),
ylim=(-100, 400),
xlabel="weight",
ylabel="height",
)
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("b ~ Normal(0, 10)")
xbar = d2.weight.mean()
x = jnp.linspace(d2.weight.min(), d2.weight.max(), 101)
for i in range(N):
plt.plot(x, a[i] + b[i] * (x - xbar), "k", alpha=0.2)
plt.show()
```
Out[39]: