Preface
In [ ]:
!pip install -q numpyro arviz
In [0]:
import os
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import SVI, Trace_ELBO
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
Code 0.1¶
In [1]:
print("All models are wrong, but some are useful.")
Out[1]:
Code 0.2¶
In [2]:
x = jnp.arange(1, 3)
x = x * 10
x = jnp.log(x)
x = jnp.sum(x)
x = jnp.exp(x)
x
Out[2]:
Code 0.3¶
In [3]:
print(jnp.log(0.01**200))
print(200 * jnp.log(0.01))
Out[3]:
Code 0.4¶
In [4]:
# Load the data:
# car braking distances in feet paired with speeds in km/h
# see cars.info() for details
cars = pd.read_csv("../data/cars.csv", index_col=0)
# fit a linear regression of distance on speed
def model(speed, dist_):
mu = numpyro.param("a", 0.0) + numpyro.param("b", 1.0) * speed
numpyro.sample("dist", dist.Normal(mu, 1), obs=dist_)
svi = SVI(
model,
lambda speed, dist_: None,
optim=optim.Adam(1),
loss=Trace_ELBO(),
speed=cars.speed.values,
dist_=cars.dist.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
params = svi_result.params
# estimated coefficients from the model
print(params)
# plot residuals against speed
resid = cars.dist - (params["a"] + params["b"] * cars.speed.values)
az.plot_pair({"speed": cars.speed, "resid": resid})
plt.show()
Out[4]:
Out[4]:
Out[4]:
Code 0.5¶
pip install numpyro arviz daft networkx
Comments
Comments powered by Disqus