Chapter 11. God Spiked the Integers

In [0]:
import math
import os
import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd

import jax.numpy as jnp
from jax import lax, nn, ops, random, vmap
from jax.scipy.special import expit

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import ELBO, MCMC, NUTS, SVI, Predictive, log_likelihood
from numpyro.infer.autoguide import AutoLaplaceApproximation

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
    category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_host_device_count(4)

Code 11.1

In [1]:
chimpanzees = pd.read_csv("../data/chimpanzees.csv", sep=";")
d = chimpanzees

Code 11.2

In [2]:
d["treatment"] = d.prosoc_left + 2 * d.condition

Code 11.3

In [3]:
d.reset_index().groupby(["condition", "prosoc_left", "treatment"]).count()["index"]
Out[3]:
condition  prosoc_left  treatment
0          0            0            126
           1            1            126
1          0            2            126
           1            3            126
Name: index, dtype: int64

Code 11.4

In [4]:
def model(pulled_left=None):
    a = numpyro.sample("a", dist.Normal(0, 10))
    logit_p = a
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m11_1 = AutoLaplaceApproximation(model)
svi = SVI(model, m11_1, optim.Adam(1), ELBO(), pulled_left=d.pulled_left.values)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p11_1 = svi.get_params(state)

Code 11.5

In [5]:
prior = Predictive(m11_1.model, num_samples=10000)(random.PRNGKey(1999))

Code 11.6

In [6]:
p = expit(prior["a"])
az.plot_kde(p, bw=0.3)
plt.show()
Out[6]:

Code 11.7

In [7]:
def model(treatment, pulled_left=None):
    a = numpyro.sample("a", dist.Normal(0, 1.5))
    b = numpyro.sample("b", dist.Normal(0, 10).expand([4]))
    logit_p = a + b[treatment]
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m11_2 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m11_2,
    optim.Adam(1),
    ELBO(),
    treatment=d.treatment.values,
    pulled_left=d.pulled_left.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p11_2 = svi.get_params(state)
prior = Predictive(model, num_samples=int(1e4))(
    random.PRNGKey(1999), treatment=0, pulled_left=0
)
p = vmap(lambda k: expit(prior["a"] + prior["b"][:, k]), 0, 1)(jnp.arange(4))

Code 11.8

In [8]:
az.plot_kde(jnp.abs(p[:, 0] - p[:, 1]), bw=0.3)
plt.show()
Out[8]:

Code 11.9

In [9]:
def model(treatment, pulled_left=None):
    a = numpyro.sample("a", dist.Normal(0, 1.5))
    b = numpyro.sample("b", dist.Normal(0, 0.5).expand([4]))
    logit_p = a + b[treatment]
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m11_3 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m11_3,
    optim.Adam(1),
    ELBO(),
    treatment=d.treatment.values,
    pulled_left=d.pulled_left.values,
)
init_state = svi.init(random.PRNGKey(0))
state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(1000))
p11_3 = svi.get_params(state)
prior = Predictive(model, num_samples=int(1e4))(
    random.PRNGKey(1999), treatment=0, pulled_left=0
)
p = vmap(lambda k: expit(prior["a"] + prior["b"][:, k]), 0, 1)(jnp.arange(4))
jnp.mean(jnp.abs(p[:, 0] - p[:, 1]))
Out[9]:
DeviceArray(0.09770478, dtype=float32)

Code 11.10

In [10]:
# trimmed data list
dat_list = {
    "pulled_left": d.pulled_left.values,
    "actor": d.actor.values - 1,
    "treatment": d.treatment.values,
}

Code 11.11

In [11]:
def model(actor, treatment, pulled_left=None, link=False):
    a = numpyro.sample("a", dist.Normal(0, 1.5).expand([7]))
    b = numpyro.sample("b", dist.Normal(0, 0.5).expand([4]))
    logit_p = a[actor] + b[treatment]
    if link:
        numpyro.deterministic("p", expit(logit_p))
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m11_4 = MCMC(NUTS(model), 500, 500, num_chains=4)
m11_4.run(random.PRNGKey(0), **dat_list)
m11_4.print_summary(0.89)
Out[11]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]     -0.47      0.34     -0.47     -1.01      0.08    560.43      1.01
      a[1]      3.88      0.75      3.84      2.76      5.05   1871.58      1.00
      a[2]     -0.76      0.34     -0.76     -1.31     -0.23    559.21      1.01
      a[3]     -0.76      0.34     -0.76     -1.31     -0.26    618.35      1.01
      a[4]     -0.47      0.34     -0.46     -1.01      0.09    554.01      1.01
      a[5]      0.47      0.35      0.46     -0.09      1.01    561.33      1.01
      a[6]      1.95      0.44      1.95      1.32      2.72    921.75      1.00
      b[0]     -0.03      0.30     -0.03     -0.48      0.48    504.19      1.01
      b[1]      0.49      0.29      0.48      0.01      0.96    505.45      1.01
      b[2]     -0.37      0.30     -0.37     -0.85      0.08    471.35      1.01
      b[3]      0.39      0.29      0.38     -0.05      0.87    461.72      1.01

Number of divergences: 0

Code 11.12

In [12]:
post = m11_4.get_samples(group_by_chain=True)
p_left = expit(post["a"])
az.plot_forest({"p_left": p_left}, combined=True, hdi_prob=0.89)
plt.gca().set(xlim=(-0.01, 1.01))
plt.show()
Out[12]:

Code 11.13

In [13]:
labs = ["R/N", "L/N", "R/P", "L/P"]
az.plot_forest(
    m11_4.get_samples(group_by_chain=True), combined=True, var_names="b", hdi_prob=0.89,
)
plt.gca().set_yticklabels(labs[::-1])
plt.show()
Out[13]:

Code 11.14

In [14]:
diffs = {
    "db13": post["b"][..., 0] - post["b"][..., 2],
    "db24": post["b"][..., 1] - post["b"][..., 3],
}
az.plot_forest(diffs, combined=True)
plt.show()
Out[14]:

Code 11.15

In [15]:
pl = d.groupby(["actor", "treatment"])["pulled_left"].mean().unstack()
pl.iloc[0, :]
Out[15]:
treatment
0    0.333333
1    0.500000
2    0.277778
3    0.555556
Name: 1, dtype: float64

Code 11.16

In [16]:
ax = plt.subplot(
    xlim=(0.5, 28.5),
    ylim=(0, 1.05),
    xlabel="",
    ylabel="proportion left lever",
    xticks=[],
)
plt.yticks(ticks=[0, 0.5, 1], labels=[0, 0.5, 1])
ax.axhline(0.5, c="k", lw=1, ls="--")
for j in range(1, 8):
    ax.axvline((j - 1) * 4 + 4.5, c="k", lw=0.5)
for j in range(1, 8):
    ax.annotate(
        "actor {}".format(j),
        ((j - 1) * 4 + 2.5, 1.1),
        ha="center",
        va="center",
        annotation_clip=False,
    )
for j in [1] + list(range(3, 8)):
    ax.plot((j - 1) * 4 + jnp.array([1, 3]), pl.loc[j, [0, 2]], "b")
    ax.plot((j - 1) * 4 + jnp.array([2, 4]), pl.loc[j, [1, 3]], "b")
x = jnp.arange(1, 29).reshape(7, 4)
ax.scatter(
    x[:, [0, 1]].reshape(-1),
    pl.values[:, [0, 1]].reshape(-1),
    edgecolor="b",
    facecolor="w",
    zorder=3,
)
ax.scatter(
    x[:, [2, 3]].reshape(-1), pl.values[:, [2, 3]].reshape(-1), marker=".", c="b", s=80
)
yoff = 0.01
ax.annotate("R/N", (1, pl.loc[1, 0] - yoff), ha="center", va="top")
ax.annotate("L/N", (2, pl.loc[1, 1] + yoff), ha="center", va="bottom")
ax.annotate("R/P", (3, pl.loc[1, 2] - yoff), ha="center", va="top")
ax.annotate("L/P", (4, pl.loc[1, 3] + yoff), ha="center", va="bottom")
ax.set_title("observed proportions\n")
plt.show()
Out[16]:

Code 11.17

In [17]:
dat = {"actor": jnp.repeat(jnp.arange(7), 4), "treatment": jnp.tile(jnp.arange(4), 7)}
pred = Predictive(m11_4.sampler.model, m11_4.get_samples(), return_sites=["p"])
p_post = pred(random.PRNGKey(1), link=True, **dat)["p"]
p_mu = jnp.mean(p_post, 0)
p_ci = jnp.percentile(p_post, q=(4.5, 95.5), axis=0)

Code 11.18

In [18]:
d["side"] = d.prosoc_left  # right 0, left 1
d["cond"] = d.condition  # no partner 0, partner 1

Code 11.19

In [19]:
dat_list2 = {
    "pulled_left": d.pulled_left.values,
    "actor": d.actor.values - 1,
    "side": d.side.values,
    "cond": d.cond.values,
}


def model(actor, side, cond, pulled_left=None):
    a = numpyro.sample("a", dist.Normal(0, 1.5).expand([7]))
    bs = numpyro.sample("bs", dist.Normal(0, 0.5).expand([2]))
    bc = numpyro.sample("bc", dist.Normal(0, 0.5).expand([2]))
    logit_p = a[actor] + bs[side] + bc[cond]
    numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)


m11_5 = MCMC(NUTS(model), 500, 500, num_chains=4)
m11_5.run(random.PRNGKey(0), **dat_list2)

Code 11.20

In [20]:
az.compare(
    {"m11.5": az.from_numpyro(m11_5), "m11.4": az.from_numpyro(m11_4)},
    ic="loo",
    scale="deviance",
)
Out[20]:
rank loo p_loo d_loo weight se dse warning loo_scale
m11.5 0 530.972 7.81572 0 0.626632 19.6065 0 False deviance
m11.4 1 532.119 8.40718 1.14721 0.373368 19.4751 1.37715 False deviance

Code 11.21

In [21]:
post = m11_4.get_samples()
post["log_lik"] = log_likelihood(m11_4.sampler.model, post, **dat_list)["pulled_left"]
{k: v.shape for k, v in post.items()}
Out[21]:
{'a': (2000, 7), 'b': (2000, 4), 'log_lik': (2000, 504)}

Code 11.22

In [22]:
def m11_4_pe_code(params, log_lik=False):
    a_logprob = jnp.sum(dist.Normal(0, 1.5).log_prob(params["a"]))
    b_logprob = jnp.sum(dist.Normal(0, 0.5).log_prob(params["b"]))
    logit_p = params["a"][dat_list["actor"]] + params["b"][dat_list["treatment"]]
    pulled_left_logprob = dist.Binomial(logits=logit_p).log_prob(
        dat_list["pulled_left"]
    )
    if log_lik:
        return pulled_left_logprob
    return -(a_logprob + b_logprob + jnp.sum(pulled_left_logprob))


m11_4_pe = MCMC(NUTS(potential_fn=m11_4_pe_code), 1000, 1000, num_chains=4)
init_params = {"a": jnp.zeros((4, 7)), "b": jnp.zeros((4, 4))}
m11_4_pe.run(random.PRNGKey(0), init_params=init_params)
log_lik = vmap(lambda p: m11_4_pe_code(p, log_lik=True))(m11_4_pe.get_samples())
m11_4_pe_az = az.from_numpyro(m11_4_pe)
m11_4_pe_az.sample_stats["log_likelihood"] = (
    ("chain", "draw", "log_lik"),
    jnp.reshape(log_lik, (4, 1000, -1)),
)
az.compare(
    {"m11.4_pe": m11_4_pe_az, "m11.4": az.from_numpyro(m11_4)},
    ic="waic",
    scale="deviance",
)
Out[22]:
rank waic p_waic d_waic weight se dse warning waic_scale
m11.4_pe 0 532.05 8.37874 0 0.507703 18.7561 0 False deviance
m11.4 1 532.113 8.40404 0.0631368 0.492297 18.8106 0.144271 False deviance

Code 11.23

In [23]:
post = m11_4.get_samples()
jnp.mean(jnp.exp(post["b"][:, 3] - post["b"][:, 1]))
Out[23]:
DeviceArray(0.9352549, dtype=float32)

Code 11.24

In [24]:
chimpanzees = pd.read_csv("../data/chimpanzees.csv", sep=";")
d = chimpanzees
d["treatment"] = d.prosoc_left + 2 * d.condition
d["side"] = d.prosoc_left  # right 0, left 1
d["cond"] = d.condition  # no partner 0, partner 1
d_aggregated = (
    d.groupby(["treatment", "actor", "side", "cond"])["pulled_left"].sum().reset_index()
)
d_aggregated.rename(columns={"pulled_left": "left_pulls"}, inplace=True)

Code 11.25

In [25]:
dat = dict(zip(d_aggregated.columns, d_aggregated.values.T))


def model(actor, treatment, left_pulls):
    a = numpyro.sample("a", dist.Normal(0, 1.5).expand([7]))
    b = numpyro.sample("b", dist.Normal(0, 0.5).expand([4]))
    logit_p = a[actor] + b[treatment]
    numpyro.sample("left_pulls", dist.Binomial(18, logits=logit_p), obs=left_pulls)


m11_6 = MCMC(NUTS(model), 500, 500, num_chains=4)
m11_6.run(
    random.PRNGKey(0),
    actor=dat["actor"] - 1,
    treatment=dat["treatment"],
    left_pulls=dat["left_pulls"],
)

Code 11.26

In [26]:
try:
    az.compare(
        {"m11.6": az.from_numpyro(m11_6), "m11.4": az.from_numpyro(m11_4)},
        ic="loo",
        scale="deviance",
    )
except Exception as e:
    warnings.warn("\n{}: {}".format(type(e).__name__, e))
Out[26]:
UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.7 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
UserWarning: 
ValueError: The number of observations should be the same across all models

Code 11.27

In [27]:
# deviance of aggregated 6-in-9
print(-2 * dist.Binomial(9, 0.2).log_prob(6))
# deviance of dis-aggregated
print(
    -2 * jnp.sum(dist.Bernoulli(0.2).log_prob(jnp.array([1, 1, 1, 1, 1, 1, 0, 0, 0])))
)
Out[27]:
11.790477
20.652117

Code 11.28

In [28]:
UCBadmit = pd.read_csv("../data/UCBadmit.csv", sep=";")
d = UCBadmit

Code 11.29

In [29]:
dat_list = dict(
    admit=d.admit.values,
    applications=d.applications.values,
    gid=(d["applicant.gender"] != "male").astype(int).values,
)


def model(gid, applications, admit=None):
    a = numpyro.sample("a", dist.Normal(0, 1.5).expand([2]))
    logit_p = a[gid]
    numpyro.sample("admit", dist.Binomial(applications, logits=logit_p), obs=admit)


m11_7 = MCMC(NUTS(model), 500, 500, num_chains=4)
m11_7.run(random.PRNGKey(0), **dat_list)
m11_7.print_summary(0.89)
Out[29]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]     -0.22      0.04     -0.22     -0.28     -0.16   1726.89      1.00
      a[1]     -0.83      0.05     -0.83     -0.90     -0.75   1791.29      1.00

Number of divergences: 0

Code 11.30

In [30]:
post = m11_7.get_samples()
diff_a = post["a"][:, 0] - post["a"][:, 1]
diff_p = expit(post["a"][:, 0]) - expit(post["a"][:, 1])
print_summary({"diff_a": diff_a, "diff_p": diff_p}, 0.89, False)
Out[30]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
    diff_a      0.61      0.06      0.61      0.51      0.71   1852.99      1.00
    diff_p      0.14      0.01      0.14      0.12      0.16   1851.59      1.00

Code 11.31

In [31]:
post = m11_7.get_samples()
admit_pred = Predictive(m11_7.sampler.model, post)(
    random.PRNGKey(2), gid=dat_list["gid"], applications=dat_list["applications"]
)["admit"]
admit_rate = admit_pred / d.applications.values
plt.errorbar(
    range(1, 13),
    jnp.mean(admit_rate, 0),
    jnp.std(admit_rate, 0) / 2,
    fmt="o",
    c="k",
    mfc="none",
    ms=7,
    elinewidth=1,
)
plt.plot(range(1, 13), jnp.percentile(admit_rate, 5.5, 0), "k+")
plt.plot(range(1, 13), jnp.percentile(admit_rate, 94.5, 0), "k+")
# draw lines connecting points from same dept
for i in range(1, 7):
    x = 1 + 2 * (i - 1)
    y1 = d.admit.iloc[x - 1] / d.applications.iloc[x - 1]
    y2 = d.admit.iloc[x] / d.applications.iloc[x]
    plt.plot((x, x + 1), (y1, y2), "bo-")
    plt.annotate(
        d.dept.iloc[x], (x + 0.5, (y1 + y2) / 2 + 0.05), ha="center", color="royalblue"
    )
plt.gca().set(ylim=(0, 1), xticks=range(1, 13), ylabel="admit", xlabel="case")
plt.show()
Out[31]:

Code 11.32

In [32]:
dat_list["dept_id"] = jnp.repeat(jnp.arange(6), 2)


def model(gid, dept_id, applications, admit=None):
    a = numpyro.sample("a", dist.Normal(0, 1.5).expand([2]))
    delta = numpyro.sample("delta", dist.Normal(0, 1.5).expand([6]))
    logit_p = a[gid] + delta[dept_id]
    numpyro.sample("admit", dist.Binomial(applications, logits=logit_p), obs=admit)


m11_8 = MCMC(NUTS(model), 2000, 2000, num_chains=4)
m11_8.run(random.PRNGKey(0), **dat_list)
m11_8.print_summary(0.89)
Out[32]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
      a[0]     -0.51      0.53     -0.49     -1.33      0.35    550.03      1.01
      a[1]     -0.41      0.53     -0.39     -1.24      0.44    552.84      1.01
  delta[0]      1.09      0.53      1.07      0.26      1.95    554.06      1.01
  delta[1]      1.05      0.53      1.03      0.19      1.89    562.82      1.01
  delta[2]     -0.17      0.53     -0.19     -1.01      0.67    555.05      1.01
  delta[3]     -0.20      0.53     -0.22     -1.05      0.63    558.08      1.01
  delta[4]     -0.64      0.53     -0.66     -1.50      0.20    566.64      1.01
  delta[5]     -2.20      0.55     -2.21     -3.06     -1.32    574.37      1.01

Number of divergences: 0

Code 11.33

In [33]:
post = m11_8.get_samples()
diff_a = post["a"][:, 0] - post["a"][:, 1]
diff_p = expit(post["a"][:, 0]) - expit(post["a"][:, 1])
print_summary({"diff_a": diff_a, "diff_p": diff_p}, 0.89, False)
Out[33]:
                mean       std    median      5.5%     94.5%     n_eff     r_hat
    diff_a     -0.10      0.08     -0.10     -0.23      0.03   8764.59      1.00
    diff_p     -0.02      0.02     -0.02     -0.05      0.01   7600.90      1.00

Code 11.34

In [34]:
pg = jnp.stack(
    list(
        map(
            lambda k: jnp.divide(
                d.applications[dat_list["dept_id"].copy() == k].values,
                d.applications[dat_list["dept_id"].copy() == k].sum(),
            ),
            range(6),
        )
    ),
    axis=0,
).T
pg = pd.DataFrame(pg, index=["male", "female"], columns=d.dept.unique())
pg.round(2)
Out[34]:
A B C D E F
male 0.88 0.96 0.35 0.53 0.33 0.52
female 0.12 0.04 0.65 0.47 0.67 0.48

Code 11.35

In [35]:
y = dist.Binomial(1000, 1 / 1000).sample(random.PRNGKey(0), (int(1e5),))
jnp.mean(y), jnp.var(y)
Out[35]:
(DeviceArray(0.99553, dtype=float32), DeviceArray(1.0001298, dtype=float32))

Code 11.36

In [36]:
Kline = pd.read_csv("../data/Kline.csv", sep=";")
d = Kline
d
Out[36]:
culture population contact total_tools mean_TU
0 Malekula 1100 low 13 3.2
1 Tikopia 1500 low 22 4.7
2 Santa Cruz 3600 low 24 4.0
3 Yap 4791 high 43 5.0
4 Lau Fiji 7400 high 33 5.0
5 Trobriand 8000 high 19 4.0
6 Chuuk 9200 high 40 3.8
7 Manus 13000 low 28 6.6
8 Tonga 17500 high 55 5.4
9 Hawaii 275000 low 71 6.6

Code 11.37

In [37]:
d["P"] = d.population.apply(math.log).pipe(lambda x: (x - x.mean()) / x.std())
d["contact_id"] = (d.contact == "high").astype(int)

Code 11.38

In [38]:
x = jnp.linspace(0, 100, 200)
plt.plot(x, jnp.exp(dist.LogNormal(0, 10).log_prob(x)))
plt.show()
Out[38]:

Code 11.39

In [39]:
a = dist.Normal(0, 10).sample(random.PRNGKey(0), (int(1e4),))
lambda_ = jnp.exp(a)
jnp.mean(lambda_)
Out[39]:
DeviceArray(1.172584e+12, dtype=float32)

Code 11.40

In [40]:
x = jnp.linspace(0, 100, 200)
plt.plot(x, jnp.exp(dist.LogNormal(3, 0.5).log_prob(x)))
plt.show()
Out[40]:

Code 11.41

In [41]:
N = 100
a = dist.Normal(3, 0.5).sample(random.PRNGKey(0), (N,))
b = dist.Normal(0, 10).sample(random.PRNGKey(1), (N,))
plt.subplot(xlim=(-2, 2), ylim=(0, 100))
x = jnp.linspace(-2, 2, 100)
for i in range(N):
    plt.plot(x, jnp.exp(a[i] + b[i] * x), c="k", alpha=0.5)
Out[41]:

Code 11.42

In [42]:
with numpyro.handlers.seed(rng_seed=10):
    N = 100
    a = numpyro.sample("a", dist.Normal(3, 0.5).expand([N]))
    b = numpyro.sample("a", dist.Normal(0, 0.2).expand([N]))
    plt.subplot(xlim=(-2, 2), ylim=(0, 100))
    x = jnp.linspace(-2, 2, 100)
    for i in range(N):
        plt.plot(x, jnp.exp(a[i] + b[i] * x), c="k", alpha=0.5)
Out[42]:

Code 11.43

In [43]:
x_seq = jnp.linspace(jnp.log(100), jnp.log(200000), num=100)
lambda_ = vmap(lambda x: jnp.exp(a + b * x), out_axes=1)(x_seq)
plt.subplot(
    xlim=(jnp.min(x_seq).item(), jnp.max(x_seq).item()),
    ylim=(0, 500),
    xlabel="log population",
    ylabel="total tools",
)
for i in range(N):
    plt.plot(x_seq, lambda_[i], c="k", alpha=0.5)
Out[43]:

Code 11.44

In [44]:
plt.subplot(
    xlim=(jnp.min(jnp.exp(x_seq)).item(), jnp.max(jnp.exp(x_seq)).item()),
    ylim=(0, 500),
    xlabel="population",
    ylabel="total tools",
)
for i in range(N):
    plt.plot(jnp.exp(x_seq), lambda_[i], c="k", alpha=0.5)
Out[44]: