|
import inspect
|
|
import os
|
|
import re
|
|
import warnings
|
|
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
from torch.distributions import transform_to, constraints
|
|
|
|
import pyro
|
|
import pyro.distributions as dist
|
|
import pyro.ops.stats as stats
|
|
import pyro.poutine as poutine
|
|
from pyro.contrib.autoguide import AutoLaplaceApproximation
|
|
from pyro.infer import TracePosterior, TracePredictive, Trace_ELBO
|
|
from pyro.infer.mcmc import MCMC
|
|
from pyro.ops.welford import WelfordCovariance
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
warnings.simplefilter("ignore", FutureWarning)
|
|
|
|
mp.set_sharing_strategy("file_system")
|
|
sns.set(font_scale=1.25, rc={"figure.figsize": (8, 6)})
|
|
|
|
pyro.enable_validation()
|
|
pyro.set_rng_seed(0)
|
|
|
|
|
|
class MAP(TracePosterior):
|
|
def __init__(self, model, num_samples=10000, start={}):
|
|
super(MAP, self).__init__()
|
|
self.model = model
|
|
self.num_samples = num_samples
|
|
self.start = start
|
|
|
|
def _traces(self, *args, **kwargs):
|
|
pyro.clear_param_store()
|
|
|
|
# find good initial trace
|
|
model_trace = poutine.trace(self.model).get_trace(*args, **kwargs)
|
|
best_log_prob = model_trace.log_prob_sum()
|
|
for i in range(10):
|
|
trace = poutine.trace(self.model).get_trace(*args, **kwargs)
|
|
log_prob = trace.log_prob_sum()
|
|
if log_prob > best_log_prob:
|
|
best_log_prob = log_prob
|
|
model_trace = trace
|
|
|
|
# lift model
|
|
model_trace = poutine.util.prune_subsample_sites(model_trace)
|
|
prior, unpacked = {}, {}
|
|
param_constraints = pyro.get_param_store().get_state()["constraints"]
|
|
for name, node in model_trace.nodes.items():
|
|
if node["type"] == "param":
|
|
if param_constraints[name] is constraints.positive:
|
|
prior[name] = dist.HalfCauchy(200)
|
|
else:
|
|
prior[name] = dist.Normal(0, 1000)
|
|
unpacked[name] = pyro.param(name).unconstrained().clone().detach()
|
|
elif name in self.start:
|
|
unpacked[name] = self.start[name]
|
|
elif node["type"] == "sample" and not node["is_observed"]:
|
|
unpacked[name] = transform_to(node["fn"].support).inv(node["value"])
|
|
lifted_model = poutine.lift(self.model, prior)
|
|
|
|
# define guide
|
|
packed = torch.cat([v.clone().detach().reshape(-1) for v in unpacked.values()])
|
|
pyro.param("auto_loc", packed)
|
|
delta_guide = AutoLaplaceApproximation(lifted_model)
|
|
|
|
# train guide
|
|
loc_param = pyro.param("auto_loc").unconstrained()
|
|
optimizer = torch.optim.LBFGS((loc_param,), lr=0.1, max_iter=500, tolerance_grad=1e-3)
|
|
loss_fn = Trace_ELBO().differentiable_loss
|
|
|
|
def closure():
|
|
optimizer.zero_grad()
|
|
loss = loss_fn(lifted_model, delta_guide, *args, **kwargs)
|
|
loss.backward()
|
|
return loss
|
|
|
|
optimizer.step(closure)
|
|
guide = delta_guide.laplace_approximation(*args, **kwargs)
|
|
|
|
# get posterior
|
|
for i in range(self.num_samples):
|
|
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
|
|
model_poutine = poutine.trace(poutine.replay(lifted_model, trace=guide_trace))
|
|
yield model_poutine.get_trace(*args, **kwargs), 1.0
|
|
|
|
def run(self, *args, **kwargs):
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("error")
|
|
for i in range(10):
|
|
try:
|
|
return super(MAP, self).run(*args, **kwargs)
|
|
except Exception as e:
|
|
last_error = e
|
|
raise last_error
|
|
|
|
|
|
def _formula_to_predictors(formula, data):
|
|
dtype = torch.get_default_dtype()
|
|
y_name, expr_str = formula.split(" ~ ")
|
|
y_node = {"name": y_name, "value": torch.tensor(data[y_name], dtype=dtype)}
|
|
y_node["mean"] = y_node["value"].mean()
|
|
|
|
fit_intercept = True
|
|
predictors = {"Intercept": False}
|
|
col_to_num = dict(zip(data.columns, range(data.shape[1])))
|
|
expr_list = expr_str.split(" + ")
|
|
for expr in expr_list:
|
|
if expr == "0":
|
|
fit_intercept = False
|
|
elif expr.startswith("I"):
|
|
org_expr = expr
|
|
for col in col_to_num:
|
|
expr = expr.replace(col, "c{}".format(col_to_num[col]))
|
|
eval_expr = expr.lstrip("I")
|
|
eval_map = {"c{}".format(i): data.iloc[:, i] for i in range(data.shape[1])}
|
|
predictors[org_expr] = torch.tensor(eval(eval_expr, eval_map), dtype=dtype)
|
|
elif expr.startswith("C"):
|
|
cat_col = expr[2:-1]
|
|
for cat in data[cat_col].unique():
|
|
predictors["C(d){}".format(cat)] = torch.tensor(data[cat_col] == cat, dtype=dtype)
|
|
elif expr in data.columns:
|
|
predictors[expr] = torch.tensor(data[expr], dtype=dtype)
|
|
|
|
if fit_intercept:
|
|
predictors["Intercept"] = True
|
|
return y_node, predictors
|
|
|
|
|
|
class LM(MAP):
|
|
def __init__(self, formula, data, num_samples=10000, start={}, centering=True):
|
|
self.formula = formula
|
|
self.y_node, self.predictors = _formula_to_predictors(formula, data)
|
|
self._predictor_means = {name: predictor.mean() for name, predictor
|
|
in self.predictors.items() if name != "Intercept"}
|
|
self.centering = centering
|
|
super(LM, self).__init__(self.model, num_samples, start)
|
|
|
|
def model(self, data=None):
|
|
if data is None:
|
|
y_node, predictors = self.y_node, self.predictors.copy()
|
|
else:
|
|
y_node, predictors = _formula_to_predictors(self.formula, data)
|
|
fit_intercept = predictors.pop("Intercept")
|
|
|
|
mu = 0
|
|
if fit_intercept:
|
|
mu = mu + pyro.sample("Intercept", dist.Normal(y_node["mean"], 10))
|
|
|
|
for name, predictor in predictors.items():
|
|
coef = pyro.sample(name, dist.Normal(0, 10))
|
|
if fit_intercept and self.centering:
|
|
# use "centering trick"
|
|
predictor = predictor - self._predictor_means[name]
|
|
mu = mu + coef * predictor
|
|
sigma = pyro.sample("sigma", dist.HalfCauchy(2))
|
|
with pyro.plate("plate"):
|
|
return pyro.sample(y_node["name"], dist.Normal(mu, sigma), obs=y_node["value"])
|
|
|
|
def _get_centering_constant(self, coefs):
|
|
center = torch.tensor(0.)
|
|
for name, predictor_mean in self._predictor_means.items():
|
|
center = center + coefs[name] * predictor_mean
|
|
return center
|
|
|
|
|
|
def glimmer(formula, data):
|
|
y_node, predictors = _formula_to_predictors(formula, data)
|
|
fit_intercept = predictors.pop("Intercept")
|
|
print("def model({}):".format(", ".join(predictors.keys()) + ", {}".format(y_node["name"])))
|
|
mu_str = " mu = "
|
|
if fit_intercept:
|
|
print(" intercept = pyro.sample('Intercept', dist.Normal(0, 10))")
|
|
mu_str += "intercept + "
|
|
for predictor in predictors:
|
|
coef = predictor.replace("**", "_POW_").replace("*", "_MUL_").replace(" ", "")
|
|
coef = re.sub("\W", "_", coef).strip("_")
|
|
print(" b_{} = pyro.sample('{}', dist.Normal(0, 10))".format(coef, predictor))
|
|
mu_str += "b_{} * {}".format(coef, predictor)
|
|
print(mu_str)
|
|
print(" sigma = pyro.sample('sigma', dist.HalfCauchy(2))")
|
|
print(" with pyro.plate('plate'):")
|
|
print(" return pyro.sample('{}', dist.Normal(mu, sigma), obs={})"
|
|
.format(y_node["name"], y_node["name"]))
|
|
|
|
|
|
def extract_samples(posterior):
|
|
nodes = poutine.util.prune_subsample_sites(posterior.exec_traces[0]).stochastic_nodes
|
|
node_supports = posterior.marginal(nodes).support(flatten=True)
|
|
return {latent: samples.detach() for latent, samples in node_supports.items()}
|
|
|
|
|
|
def coef(posterior):
|
|
mean = {}
|
|
node_supports = extract_samples(posterior)
|
|
for node, support in node_supports.items():
|
|
mean[node] = support.mean(dim=0)
|
|
# correct `intercept` due to "centering trick"
|
|
if isinstance(posterior, LM) and "Intercept" in mean and posterior.centering:
|
|
center = posterior._get_centering_constant(mean)
|
|
mean["Intercept"] = mean["Intercept"] - center
|
|
return mean
|
|
|
|
|
|
def vcov(posterior):
|
|
node_supports = extract_samples(posterior)
|
|
packed_support = torch.cat([support.reshape(support.size(0), -1)
|
|
for support in node_supports.values()], dim=1)
|
|
cov_scheme = WelfordCovariance(diagonal=False)
|
|
for sample in packed_support:
|
|
cov_scheme.update(sample)
|
|
return cov_scheme.get_covariance(regularize=False)
|
|
|
|
|
|
def precis(posterior, corr=False, digits=2):
|
|
if isinstance(posterior, TracePosterior):
|
|
node_supports = extract_samples(posterior)
|
|
else:
|
|
node_supports = posterior
|
|
df = pd.DataFrame(columns=["Mean", "StdDev", "|0.89", "0.89|"])
|
|
for node, support in node_supports.items():
|
|
if support.dim() == 1:
|
|
hpdi = stats.hpdi(support, prob=0.89)
|
|
df.loc[node] = [support.mean().item(), support.std().item(),
|
|
hpdi[0].item(), hpdi[1].item()]
|
|
else:
|
|
support = support.reshape(support.size(0), -1)
|
|
mean = support.mean(0)
|
|
std = support.std(0)
|
|
hpdi = stats.hpdi(support, prob=0.89)
|
|
for i in range(mean.size(0)):
|
|
df.loc["{}[{}]".format(node, i)] = [mean[i].item(), std[i].item(),
|
|
hpdi[0, i].item(), hpdi[1, i].item()]
|
|
# correct `intercept` due to "centering trick"
|
|
if isinstance(posterior, LM) and "Intercept" in df.index and posterior.centering:
|
|
center = posterior._get_centering_constant(df["Mean"].to_dict()).item()
|
|
df.loc["Intercept", ["Mean", "|0.89", "0.89|"]] -= center
|
|
|
|
if corr:
|
|
cov = vcov(posterior)
|
|
corr = cov / cov.diag().ger(cov.diag()).sqrt()
|
|
for i, node in enumerate(df.index):
|
|
df[node] = corr[:, i]
|
|
|
|
if isinstance(posterior, MCMC):
|
|
diagnostics = posterior.marginal(df.index.tolist()).diagnostics()
|
|
df = pd.concat([df, pd.DataFrame(diagnostics).T.astype(float)], axis=1)
|
|
|
|
return df.round(digits)
|
|
|
|
|
|
def link(posterior, data=None, n=1000):
|
|
obs_node = posterior.exec_traces[0].observation_nodes[-1]
|
|
mu = []
|
|
if data is None:
|
|
for i in range(n):
|
|
idx = posterior._categorical.sample().item()
|
|
trace = posterior.exec_traces[idx]
|
|
mu.append(trace.nodes[obs_node]["fn"].mean)
|
|
else:
|
|
data = {name: data[name] if name in data else None
|
|
for name in inspect.signature(posterior.model).parameters}
|
|
predictive = TracePredictive(poutine.lift(posterior.model, dist.Normal(0, 1)),
|
|
posterior, n).run(**data)
|
|
for trace in predictive.exec_traces:
|
|
mu.append(trace.nodes[obs_node]["fn"].mean)
|
|
return torch.stack(mu).detach()
|
|
|
|
|
|
def sim(posterior, data=None, n=1000):
|
|
obs_node = posterior.exec_traces[0].observation_nodes[-1]
|
|
obs = []
|
|
if data is None:
|
|
for i in range(n):
|
|
idx = posterior._categorical.sample().item()
|
|
trace = posterior.exec_traces[idx]
|
|
obs.append(trace.nodes[obs_node]["fn"].sample())
|
|
else:
|
|
data = {name: data[name] if name in data else None
|
|
for name in inspect.signature(posterior.model).parameters}
|
|
predictive = TracePredictive(poutine.lift(posterior.model, dist.Normal(0, 1)),
|
|
posterior, n).run(**data)
|
|
for trace in predictive.exec_traces:
|
|
obs.append(trace.nodes[obs_node]["value"])
|
|
return torch.stack(obs).detach()
|
|
|
|
|
|
def compare(posteriors):
|
|
post_ics = {}
|
|
with torch.no_grad():
|
|
for name in posteriors:
|
|
post_ics[name] = posteriors[name].information_criterion(pointwise=True)
|
|
n_cases = post_ics[name]["waic"].size(0)
|
|
WAIC = {name: post_ics[name]["waic"].sum() for name in posteriors}
|
|
pWAIC = {name: post_ics[name]["p_waic"].sum() for name in posteriors}
|
|
SE = {name: (n_cases * post_ics[name]["waic"].var()).sqrt() for name in posteriors}
|
|
table = pd.DataFrame({"WAIC": WAIC, "pWAIC": pWAIC}).sort_values(by="WAIC")
|
|
table["dWAIC"] = table["WAIC"] - table.iloc[0, 0]
|
|
table["weight"] = torch.nn.functional.softmax(-1/2 * torch.tensor(table["dWAIC"]), dim=0)
|
|
table["SE"] = pd.Series(SE)
|
|
dSE = []
|
|
for i in range(table.shape[0]):
|
|
WAIC0 = post_ics[table.index[0]]["waic"]
|
|
WAICi = post_ics[table.index[i]]["waic"]
|
|
dSE.append((n_cases * (WAICi - WAIC0).var()).sqrt())
|
|
table["dSE"] = dSE
|
|
return table.astype(float)
|
|
|
|
|
|
def ensemble(posteriors, data):
|
|
weighted_num = (compare(posteriors)["weight"] * 1000).astype(int)
|
|
weighted_num.iloc[-1] -= (sum(weighted_num) - 1000)
|
|
links = []
|
|
sims = []
|
|
for name in weighted_num.index:
|
|
num_samples = weighted_num[name]
|
|
links.append(link(posteriors[name], data, num_samples).reshape(num_samples, -1))
|
|
sims.append(sim(posteriors[name], data, num_samples).reshape(num_samples, -1))
|
|
num_data = max(l.size(1) for l in links)
|
|
links = [l.expand(-1, num_data) for l in links]
|
|
sims = [s.expand(-1, num_data) for s in sims]
|
|
return {"link": torch.cat(links), "sim": torch.cat(sims)}
|
|
|
|
|
|
def _worker(n, fn, fn_args, child_info=None):
|
|
if child_info is not None:
|
|
idx, event, queue = child_info
|
|
pyro.set_rng_seed(idx)
|
|
result = []
|
|
for i in range(n):
|
|
item = fn(*fn_args)
|
|
result.append(item)
|
|
queue.put((idx, item))
|
|
event.wait()
|
|
event.clear()
|
|
return result
|
|
|
|
|
|
def replicate(n, fn, fn_args, mc_cores=None):
|
|
mc_cores = mp.cpu_count() - 1 if mc_cores is None else mc_cores
|
|
queue = mp.Queue()
|
|
events = [mp.Event() for i in range(mc_cores)]
|
|
processes = []
|
|
for i in range(mc_cores):
|
|
n_i = n // mc_cores + (i < n % mc_cores)
|
|
child_info = (i, events[i], queue)
|
|
p = mp.Process(target=_worker, args=(n_i, fn, fn_args, child_info), daemon=True)
|
|
p.start()
|
|
processes.append(p)
|
|
|
|
result = []
|
|
for i in range(n):
|
|
idx, item = queue.get()
|
|
result.append(item)
|
|
events[idx].set()
|
|
|
|
for i in range(mc_cores):
|
|
processes[i].join()
|
|
return result
|