rethinking.py (Source)

import inspect
import os
import re
import warnings

import pandas as pd
import seaborn as sns
import torch
import torch.multiprocessing as mp
from torch.distributions import transform_to, constraints

import pyro
import pyro.distributions as dist
import pyro.ops.stats as stats
import pyro.poutine as poutine
from pyro.contrib.autoguide import AutoLaplaceApproximation
from pyro.infer import TracePosterior, TracePredictive, Trace_ELBO
from pyro.infer.mcmc import MCMC
from pyro.ops.welford import WelfordCovariance

os.environ["CUDA_VISIBLE_DEVICES"] = ""
warnings.simplefilter("ignore", FutureWarning)

mp.set_sharing_strategy("file_system")
sns.set(font_scale=1.25, rc={"figure.figsize": (8, 6)})

pyro.enable_validation()
pyro.set_rng_seed(0)


class MAP(TracePosterior):
    def __init__(self, model, num_samples=10000, start={}):
        super(MAP, self).__init__()
        self.model = model
        self.num_samples = num_samples
        self.start = start

    def _traces(self, *args, **kwargs):
        pyro.clear_param_store()

        # find good initial trace
        model_trace = poutine.trace(self.model).get_trace(*args, **kwargs)
        best_log_prob = model_trace.log_prob_sum()
        for i in range(10):
            trace = poutine.trace(self.model).get_trace(*args, **kwargs)
            log_prob = trace.log_prob_sum()
            if log_prob > best_log_prob:
                best_log_prob = log_prob
                model_trace = trace

        # lift model
        model_trace = poutine.util.prune_subsample_sites(model_trace)
        prior, unpacked = {}, {}
        param_constraints = pyro.get_param_store().get_state()["constraints"]
        for name, node in model_trace.nodes.items():
            if node["type"] == "param":
                if param_constraints[name] is constraints.positive:
                    prior[name] = dist.HalfCauchy(200)
                else:
                    prior[name] = dist.Normal(0, 1000)
                unpacked[name] = pyro.param(name).unconstrained().clone().detach()
            elif name in self.start:
                unpacked[name] = self.start[name]
            elif node["type"] == "sample" and not node["is_observed"]:
                unpacked[name] = transform_to(node["fn"].support).inv(node["value"])
        lifted_model = poutine.lift(self.model, prior)

        # define guide
        packed = torch.cat([v.clone().detach().reshape(-1) for v in unpacked.values()])
        pyro.param("auto_loc", packed)
        delta_guide = AutoLaplaceApproximation(lifted_model)

        # train guide
        loc_param = pyro.param("auto_loc").unconstrained()
        optimizer = torch.optim.LBFGS((loc_param,), lr=0.1, max_iter=500, tolerance_grad=1e-3)
        loss_fn = Trace_ELBO().differentiable_loss

        def closure():
            optimizer.zero_grad()
            loss = loss_fn(lifted_model, delta_guide, *args, **kwargs)
            loss.backward()
            return loss

        optimizer.step(closure)
        guide = delta_guide.laplace_approximation(*args, **kwargs)

        # get posterior
        for i in range(self.num_samples):
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_poutine = poutine.trace(poutine.replay(lifted_model, trace=guide_trace))
            yield model_poutine.get_trace(*args, **kwargs), 1.0

    def run(self, *args, **kwargs):
        with warnings.catch_warnings():
            warnings.simplefilter("error")
            for i in range(10):
                try:
                    return super(MAP, self).run(*args, **kwargs)
                except Exception as e:
                    last_error = e
        raise last_error


def _formula_to_predictors(formula, data):
    dtype = torch.get_default_dtype()
    y_name, expr_str = formula.split(" ~ ")
    y_node = {"name": y_name, "value": torch.tensor(data[y_name], dtype=dtype)}
    y_node["mean"] = y_node["value"].mean()

    fit_intercept = True
    predictors = {"Intercept": False}
    col_to_num = dict(zip(data.columns, range(data.shape[1])))
    expr_list = expr_str.split(" + ")
    for expr in expr_list:
        if expr == "0":
            fit_intercept = False
        elif expr.startswith("I"):
            org_expr = expr
            for col in col_to_num:
                expr = expr.replace(col, "c{}".format(col_to_num[col]))
            eval_expr = expr.lstrip("I")
            eval_map = {"c{}".format(i): data.iloc[:, i] for i in range(data.shape[1])}
            predictors[org_expr] = torch.tensor(eval(eval_expr, eval_map), dtype=dtype)
        elif expr.startswith("C"):
            cat_col = expr[2:-1]
            for cat in data[cat_col].unique():
                predictors["C(d){}".format(cat)] = torch.tensor(data[cat_col] == cat, dtype=dtype)
        elif expr in data.columns:
            predictors[expr] = torch.tensor(data[expr], dtype=dtype)

    if fit_intercept:
        predictors["Intercept"] = True
    return y_node, predictors


class LM(MAP):
    def __init__(self, formula, data, num_samples=10000, start={}, centering=True):
        self.formula = formula
        self.y_node, self.predictors = _formula_to_predictors(formula, data)
        self._predictor_means = {name: predictor.mean() for name, predictor
                                 in self.predictors.items() if name != "Intercept"}
        self.centering = centering
        super(LM, self).__init__(self.model, num_samples, start)

    def model(self, data=None):
        if data is None:
            y_node, predictors = self.y_node, self.predictors.copy()
        else:
            y_node, predictors = _formula_to_predictors(self.formula, data)
        fit_intercept = predictors.pop("Intercept")

        mu = 0
        if fit_intercept:
            mu = mu + pyro.sample("Intercept", dist.Normal(y_node["mean"], 10))

        for name, predictor in predictors.items():
            coef = pyro.sample(name, dist.Normal(0, 10))
            if fit_intercept and self.centering:
                # use "centering trick"
                predictor = predictor - self._predictor_means[name]
            mu = mu + coef * predictor
        sigma = pyro.sample("sigma", dist.HalfCauchy(2))
        with pyro.plate("plate"):
            return pyro.sample(y_node["name"], dist.Normal(mu, sigma), obs=y_node["value"])

    def _get_centering_constant(self, coefs):
        center = torch.tensor(0.)
        for name, predictor_mean in self._predictor_means.items():
            center = center + coefs[name] * predictor_mean
        return center


def glimmer(formula, data):
    y_node, predictors = _formula_to_predictors(formula, data)
    fit_intercept = predictors.pop("Intercept")
    print("def model({}):".format(", ".join(predictors.keys()) + ", {}".format(y_node["name"])))
    mu_str = "    mu = "
    if fit_intercept:
        print("    intercept = pyro.sample('Intercept', dist.Normal(0, 10))")
        mu_str += "intercept + "
    for predictor in predictors:
        coef = predictor.replace("**", "_POW_").replace("*", "_MUL_").replace(" ", "")
        coef = re.sub("\W", "_", coef).strip("_")
        print("    b_{} = pyro.sample('{}', dist.Normal(0, 10))".format(coef, predictor))
        mu_str += "b_{} * {}".format(coef, predictor)
    print(mu_str)
    print("    sigma = pyro.sample('sigma', dist.HalfCauchy(2))")
    print("    with pyro.plate('plate'):")
    print("        return pyro.sample('{}', dist.Normal(mu, sigma), obs={})"
          .format(y_node["name"], y_node["name"]))


def extract_samples(posterior):
    nodes = poutine.util.prune_subsample_sites(posterior.exec_traces[0]).stochastic_nodes
    node_supports = posterior.marginal(nodes).support(flatten=True)
    return {latent: samples.detach() for latent, samples in node_supports.items()}


def coef(posterior):
    mean = {}
    node_supports = extract_samples(posterior)
    for node, support in node_supports.items():
        mean[node] = support.mean(dim=0)
    # correct `intercept` due to "centering trick"
    if isinstance(posterior, LM) and "Intercept" in mean and posterior.centering:
        center = posterior._get_centering_constant(mean)
        mean["Intercept"] = mean["Intercept"] - center
    return mean


def vcov(posterior):
    node_supports = extract_samples(posterior)
    packed_support = torch.cat([support.reshape(support.size(0), -1)
                                for support in node_supports.values()], dim=1)
    cov_scheme = WelfordCovariance(diagonal=False)
    for sample in packed_support:
        cov_scheme.update(sample)
    return cov_scheme.get_covariance(regularize=False)


def precis(posterior, corr=False, digits=2):
    if isinstance(posterior, TracePosterior):
        node_supports = extract_samples(posterior)
    else:
        node_supports = posterior
    df = pd.DataFrame(columns=["Mean", "StdDev", "|0.89", "0.89|"])
    for node, support in node_supports.items():
        if support.dim() == 1:
            hpdi = stats.hpdi(support, prob=0.89)
            df.loc[node] = [support.mean().item(), support.std().item(),
                            hpdi[0].item(), hpdi[1].item()]
        else:
            support = support.reshape(support.size(0), -1)
            mean = support.mean(0)
            std = support.std(0)
            hpdi = stats.hpdi(support, prob=0.89)
            for i in range(mean.size(0)):
                df.loc["{}[{}]".format(node, i)] = [mean[i].item(), std[i].item(),
                                                    hpdi[0, i].item(), hpdi[1, i].item()]
    # correct `intercept` due to "centering trick"
    if isinstance(posterior, LM) and "Intercept" in df.index and posterior.centering:
        center = posterior._get_centering_constant(df["Mean"].to_dict()).item()
        df.loc["Intercept", ["Mean", "|0.89", "0.89|"]] -= center

    if corr:
        cov = vcov(posterior)
        corr = cov / cov.diag().ger(cov.diag()).sqrt()
        for i, node in enumerate(df.index):
            df[node] = corr[:, i]

    if isinstance(posterior, MCMC):
        diagnostics = posterior.marginal(df.index.tolist()).diagnostics()
        df = pd.concat([df, pd.DataFrame(diagnostics).T.astype(float)], axis=1)

    return df.round(digits)


def link(posterior, data=None, n=1000):
    obs_node = posterior.exec_traces[0].observation_nodes[-1]
    mu = []
    if data is None:
        for i in range(n):
            idx = posterior._categorical.sample().item()
            trace = posterior.exec_traces[idx]
            mu.append(trace.nodes[obs_node]["fn"].mean)
    else:
        data = {name: data[name] if name in data else None
                for name in inspect.signature(posterior.model).parameters}
        predictive = TracePredictive(poutine.lift(posterior.model, dist.Normal(0, 1)),
                                     posterior, n).run(**data)
        for trace in predictive.exec_traces:
            mu.append(trace.nodes[obs_node]["fn"].mean)
    return torch.stack(mu).detach()


def sim(posterior, data=None, n=1000):
    obs_node = posterior.exec_traces[0].observation_nodes[-1]
    obs = []
    if data is None:
        for i in range(n):
            idx = posterior._categorical.sample().item()
            trace = posterior.exec_traces[idx]
            obs.append(trace.nodes[obs_node]["fn"].sample())
    else:
        data = {name: data[name] if name in data else None
                for name in inspect.signature(posterior.model).parameters}
        predictive = TracePredictive(poutine.lift(posterior.model, dist.Normal(0, 1)),
                                     posterior, n).run(**data)
        for trace in predictive.exec_traces:
            obs.append(trace.nodes[obs_node]["value"])
    return torch.stack(obs).detach()


def compare(posteriors):
    post_ics = {}
    with torch.no_grad():
        for name in posteriors:
            post_ics[name] = posteriors[name].information_criterion(pointwise=True)
    n_cases = post_ics[name]["waic"].size(0)
    WAIC = {name: post_ics[name]["waic"].sum() for name in posteriors}
    pWAIC = {name: post_ics[name]["p_waic"].sum() for name in posteriors}
    SE = {name: (n_cases * post_ics[name]["waic"].var()).sqrt() for name in posteriors}
    table = pd.DataFrame({"WAIC": WAIC, "pWAIC": pWAIC}).sort_values(by="WAIC")
    table["dWAIC"] = table["WAIC"] - table.iloc[0, 0]
    table["weight"] = torch.nn.functional.softmax(-1/2 * torch.tensor(table["dWAIC"]), dim=0)
    table["SE"] = pd.Series(SE)
    dSE = []
    for i in range(table.shape[0]):
        WAIC0 = post_ics[table.index[0]]["waic"]
        WAICi = post_ics[table.index[i]]["waic"]
        dSE.append((n_cases * (WAICi - WAIC0).var()).sqrt())
    table["dSE"] = dSE
    return table.astype(float)


def ensemble(posteriors, data):
    weighted_num = (compare(posteriors)["weight"] * 1000).astype(int)
    weighted_num.iloc[-1] -= (sum(weighted_num) - 1000)
    links = []
    sims = []
    for name in weighted_num.index:
        num_samples = weighted_num[name]
        links.append(link(posteriors[name], data, num_samples).reshape(num_samples, -1))
        sims.append(sim(posteriors[name], data, num_samples).reshape(num_samples, -1))
    num_data = max(l.size(1) for l in links)
    links = [l.expand(-1, num_data) for l in links]
    sims = [s.expand(-1, num_data) for s in sims]
    return {"link": torch.cat(links), "sim": torch.cat(sims)}


def _worker(n, fn, fn_args, child_info=None):
    if child_info is not None:
        idx, event, queue = child_info
        pyro.set_rng_seed(idx)
    result = []
    for i in range(n):
        item = fn(*fn_args)
        result.append(item)
        queue.put((idx, item))
        event.wait()
        event.clear()
    return result


def replicate(n, fn, fn_args, mc_cores=None):
    mc_cores = mp.cpu_count() - 1 if mc_cores is None else mc_cores
    queue = mp.Queue()
    events = [mp.Event() for i in range(mc_cores)]
    processes = []
    for i in range(mc_cores):
        n_i = n // mc_cores + (i < n % mc_cores)
        child_info = (i, events[i], queue)
        p = mp.Process(target=_worker, args=(n_i, fn, fn_args, child_info), daemon=True)
        p.start()
        processes.append(p)

    result = []
    for i in range(n):
        idx, item = queue.get()
        result.append(item)
        events[idx].set()

    for i in range(mc_cores):
        processes[i].join()
    return result