Inferences for Deep Gaussian Process models in Pyro

In this tutorial, I want to illustrate how to use Pyro's Gaussian Processes module to create and train some deep Gaussian Process models. For the background on how to use this module, readers can check out some tutorials at http://pyro.ai/examples/.

The first part is a fun example to run HMC with a 2-layer regression GP models while the second part uses SVI to classify digit numbers.

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns; sns.set()
from scipy.cluster.vq import kmeans2

import torch
import torch.nn as nn
from torch.distributions.transforms import AffineTransform
from torchvision import transforms

import pyro
import pyro.contrib.gp as gp
import pyro.distributions as dist
import pyro.infer as infer
import pyro.infer.mcmc as mcmc
from pyro.contrib import autoname
from pyro.contrib.examples.util import get_data_loader

pyro.set_rng_seed(0)

HMC with Heaviside data

Let's create a dataset from Heaviside step function.

In [2]:
N = 25
X = torch.rand(N)
y = (X >= 0.5).float() + torch.randn(N) * 0.01
plt.plot(X.numpy(), y.numpy(), "kx");

We will make a 2-layer regression model. We use the decorator name_count to avoid the conflict of parameter names from two layers.

In [3]:
# mean function is used as in [3]
gpr1 = gp.models.GPRegression(X, None, gp.kernels.RBF(1), noise=torch.tensor(1e-3),
                              mean_function=lambda x: x)
gpr1.kernel.set_prior("variance", dist.Exponential(10))
gpr1.kernel.set_prior("lengthscale", dist.LogNormal(0.0, 1.0))

gpr2 = gp.models.GPRegression(torch.zeros(N), y, gp.kernels.RBF(1), noise=torch.tensor(1e-3))
gpr2.kernel.set_prior("variance", dist.Exponential(1))
gpr2.kernel.set_prior("lengthscale", dist.LogNormal(0.0, 1.0))

@autoname.name_count
def model():
    h_loc, h_var = gpr1.model()
    gpr2.X = pyro.sample("h", dist.Normal(h_loc, h_var.sqrt()))
    gpr2.model()

Now, we run HMC to get 100 samples.

In [4]:
hmc_kernel = mcmc.NUTS(model)
posterior = mcmc.MCMC(hmc_kernel, num_samples=100).run()
Sample: 100%|██████████| 200/200 [01:01<00:00,  4.90it/s, step size=1.20e-01, acc. rate=0.975]

Because sample sites are named automatically by autoname decorators, we use a prototype trace to inspect sites' names. Then we plot the marginal distribution of each latent site.

In [5]:
sites = pyro.poutine.trace(model).get_trace().stochastic_nodes
for name, support in posterior.marginal(sites).support().items():
    if name == "h":
        continue
    sns.distplot(support)
    plt.title(name)
    plt.show()

Let's test if the posterior can predict the Heaviside data. The first step is to make a predictive model.

In [6]:
@autoname.name_count
def predictive(X_new):
    # this sample statement will be replaced by a posterior sample `h`
    h = pyro.sample("h", dist.Normal(torch.zeros(N), 1))
    gpr1.y = h
    gpr2.X = h
    h_new_loc, h_new_var = gpr1(X_new)
    h_new = pyro.sample("h_new", dist.Normal(h_new_loc, h_new_var.sqrt()))
    y_loc, y_var = gpr2(h_new_loc)
    pyro.sample("y", dist.Normal(y_loc, y_var.sqrt()))

We will get predictions from this predictive model by using samples from posterior.

In [7]:
X_test = torch.linspace(-0.5, 1.5, 300)
posterior_predictive = infer.TracePredictive(predictive, posterior, num_samples=100).run(X_test)

# plot 50 predictions
for i in range(50):
    trace = posterior_predictive()  # get a random trace
    y_pred = trace.nodes["y"]["value"].detach()
    plt.plot(X_test.numpy(), y_pred.numpy(), 'r-')
plt.plot(X.numpy(), y.numpy(), "kx");

We can see that the model are able to do its job but it's not good at the change point $x = 0.5$. Maybe more data is needed. To avoid numerical issues when using more data, we might need to use VariationalSparseGP model. Readers who are interested in this direction will find great explanations for many technical points in the reference [1].

SVI with MNIST data

First, we download the MNIST data.

In [8]:
train_loader = get_data_loader(dataset_name='MNIST',
                               data_dir='~/.data',
                               batch_size=1000,
                               is_training_set=True,
                               shuffle=True)
test_loader = get_data_loader(dataset_name='MNIST',
                              data_dir='~/.data',
                              batch_size=1000,
                              is_training_set=False,
                              shuffle=False)
downloading data
download complete.
downloading data
download complete.
In [9]:
X = train_loader.dataset.data.reshape(-1, 784).float() / 255
y = train_loader.dataset.targets

Now, we initialize inducing points for the first layer by using k-mean of X. It is not necessary though, and taking a random subset of X also works.

In [10]:
Xu = torch.from_numpy(kmeans2(X.numpy(), 100, minit='points')[0])
# let's plot one of the inducing points
plt.imshow(Xu[0].reshape(28, 28));

In addition, as mentioned in the section "Further Model Details" of [2], a linear mean function is required. We follow the same approach here.

In [11]:
class LinearT(nn.Module):
    """Linear transform and transpose"""
    def __init__(self, dim_in, dim_out):
        super(LinearT, self).__init__()
        self.linear = nn.Linear(dim_in, dim_out, bias=False)

    def forward(self, x):
        return self.linear(x).t()

# computes the weight for mean function of the first layer;
# it is PCA of X (from 784D to 30D).
_, _, V = np.linalg.svd(X.numpy(), full_matrices=False)
W = torch.from_numpy(V[:30, :])

mean_fn = LinearT(784, 30)
mean_fn.linear.weight.data = W
mean_fn.linear.weight.requires_grad_(False);

Now, we create a deep GP model by stacking 2 variational sparse layers. The first layer includes a mean function (which is defined as above), while the second layer includes a multi-class likelihood. Note that inducing inputs of second layer are initialized by taking the output of mean function on inducing inputs from first layer.

In [12]:
class DeepGP(gp.parameterized.Parameterized):
    def __init__(self, X, y, Xu, mean_fn):
        super(DeepGP, self).__init__()
        self.layer1 = gp.models.VariationalSparseGP(
            X,
            None,
            gp.kernels.RBF(784, variance=torch.tensor(2.), lengthscale=torch.tensor(2.)),
            Xu=Xu,
            likelihood=None,
            mean_function=mean_fn,
            latent_shape=torch.Size([30]))
        # make sure that the input for next layer is batch_size x 30
        h = mean_fn(X).t()
        hu = mean_fn(Xu).t()
        self.layer2 = gp.models.VariationalSparseGP(
            h,
            y,
            gp.kernels.RBF(30, variance=torch.tensor(2.), lengthscale=torch.tensor(2.)),
            Xu=hu,
            likelihood=gp.likelihoods.MultiClass(num_classes=10),
            latent_shape=torch.Size([10]))

    @autoname.name_count
    def model(self, X, y):
        self.layer1.set_data(X, None)
        h_loc, h_var = self.layer1.model()
        # approximate with a Monte Carlo sample (formula 15 of [1])
        h = dist.Normal(h_loc, h_var.sqrt())()
        self.layer2.set_data(h.t(), y)
        self.layer2.model()

    @autoname.name_count
    def guide(self, X, y):
        self.layer1.guide()
        self.layer2.guide()

    # make prediction
    def forward(self, X_new):
        # because prediction is stochastic (due to Monte Carlo sample of hidden layer),
        # we make 100 prediction and take the most common one (as in [4])
        pred = []
        for _ in range(100):
            h_loc, h_var = self.layer1(X_new)
            h = dist.Normal(h_loc, h_var.sqrt())()
            f_loc, f_var = self.layer2(h.t())
            pred.append(f_loc.argmax(dim=0))
        return torch.stack(pred).mode(dim=0)[0]

During early iterations of training process, we want to play more weight on mean function, which is PCA of the input, and reduce the effect of the first layer's kernel. To achieve that, we'll force the inducing outputs of the first layer to be small by setting small initial u_scale_tril.

In [13]:
deepgp = DeepGP(X, y, Xu, mean_fn)
deepgp.layer1.u_scale_tril = deepgp.layer1.u_scale_tril * 1e-5 
deepgp.layer1.set_constraint("u_scale_tril", torch.distributions.constraints.lower_cholesky)
deepgp.cuda()

optimizer = torch.optim.Adam(deepgp.parameters(), lr=0.01)
loss_fn = infer.TraceMeanField_ELBO().differentiable_loss

Now, we make some utitilies to train and test our model, just like other PyTorch models.

In [14]:
def train(train_loader, gpmodule, optimizer, loss_fn, epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        data = data.reshape(-1, 784)
        optimizer.zero_grad()
        loss = loss_fn(gpmodule.model, gpmodule.guide, data, target)
        loss.backward()
        optimizer.step()
        idx = batch_idx + 1
        if idx % 10 == 0:
            print("Train Epoch: {:2d} [{:5d}/{} ({:2.0f}%)]\tLoss: {:.6f}"
                  .format(epoch, idx * len(data), len(train_loader.dataset),
                          100. * idx / len(train_loader), loss))

def test(test_loader, gpmodule):
    correct = 0
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        data = data.reshape(-1, 784)
        pred = gpmodule(data)
        # compare prediction and target to count accuaracy
        correct += pred.eq(target).long().cpu().sum().item()

    print("\nTest set: Accuracy: {}/{} ({:.2f}%)\n"
          .format(correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))

Here I just run 20 steps to illustrate the process.

In [15]:
for i in range(20):
    train(train_loader, deepgp, optimizer, loss_fn, i)
    with torch.no_grad():
        test(test_loader, deepgp)
Train Epoch:  0 [10000/60000 (17%)]	Loss: 215872.984375
Train Epoch:  0 [20000/60000 (33%)]	Loss: 214001.593750
Train Epoch:  0 [30000/60000 (50%)]	Loss: 200020.734375
Train Epoch:  0 [40000/60000 (67%)]	Loss: 191804.921875
Train Epoch:  0 [50000/60000 (83%)]	Loss: 175004.296875
Train Epoch:  0 [60000/60000 (100%)]	Loss: 145725.281250

Test set: Accuracy: 8307/10000 (83.07%)

Train Epoch:  1 [10000/60000 (17%)]	Loss: 96241.750000
Train Epoch:  1 [20000/60000 (33%)]	Loss: 65753.539062
Train Epoch:  1 [30000/60000 (50%)]	Loss: 59164.898438
Train Epoch:  1 [40000/60000 (67%)]	Loss: 53220.820312
Train Epoch:  1 [50000/60000 (83%)]	Loss: 52050.398438
Train Epoch:  1 [60000/60000 (100%)]	Loss: 49506.101562

Test set: Accuracy: 9336/10000 (93.36%)

Train Epoch:  2 [10000/60000 (17%)]	Loss: 46632.328125
Train Epoch:  2 [20000/60000 (33%)]	Loss: 46519.191406
Train Epoch:  2 [30000/60000 (50%)]	Loss: 46531.433594
Train Epoch:  2 [40000/60000 (67%)]	Loss: 46086.132812
Train Epoch:  2 [50000/60000 (83%)]	Loss: 44555.304688
Train Epoch:  2 [60000/60000 (100%)]	Loss: 44280.800781

Test set: Accuracy: 9463/10000 (94.63%)

Train Epoch:  3 [10000/60000 (17%)]	Loss: 42667.562500
Train Epoch:  3 [20000/60000 (33%)]	Loss: 42864.304688
Train Epoch:  3 [30000/60000 (50%)]	Loss: 43375.265625
Train Epoch:  3 [40000/60000 (67%)]	Loss: 41040.968750
Train Epoch:  3 [50000/60000 (83%)]	Loss: 41723.167969
Train Epoch:  3 [60000/60000 (100%)]	Loss: 43360.382812

Test set: Accuracy: 9517/10000 (95.17%)

Train Epoch:  4 [10000/60000 (17%)]	Loss: 40517.832031
Train Epoch:  4 [20000/60000 (33%)]	Loss: 36754.750000
Train Epoch:  4 [30000/60000 (50%)]	Loss: 39611.105469
Train Epoch:  4 [40000/60000 (67%)]	Loss: 38865.871094
Train Epoch:  4 [50000/60000 (83%)]	Loss: 37618.226562
Train Epoch:  4 [60000/60000 (100%)]	Loss: 36474.484375

Test set: Accuracy: 9554/10000 (95.54%)

Train Epoch:  5 [10000/60000 (17%)]	Loss: 37512.335938
Train Epoch:  5 [20000/60000 (33%)]	Loss: 37369.781250
Train Epoch:  5 [30000/60000 (50%)]	Loss: 37403.062500
Train Epoch:  5 [40000/60000 (67%)]	Loss: 35897.414062
Train Epoch:  5 [50000/60000 (83%)]	Loss: 34249.187500
Train Epoch:  5 [60000/60000 (100%)]	Loss: 35301.992188

Test set: Accuracy: 9579/10000 (95.79%)

Train Epoch:  6 [10000/60000 (17%)]	Loss: 34293.789062
Train Epoch:  6 [20000/60000 (33%)]	Loss: 34315.960938
Train Epoch:  6 [30000/60000 (50%)]	Loss: 31821.152344
Train Epoch:  6 [40000/60000 (67%)]	Loss: 31645.339844
Train Epoch:  6 [50000/60000 (83%)]	Loss: 32956.976562
Train Epoch:  6 [60000/60000 (100%)]	Loss: 30657.750000

Test set: Accuracy: 9598/10000 (95.98%)

Train Epoch:  7 [10000/60000 (17%)]	Loss: 31674.531250
Train Epoch:  7 [20000/60000 (33%)]	Loss: 31000.179688
Train Epoch:  7 [30000/60000 (50%)]	Loss: 29016.859375
Train Epoch:  7 [40000/60000 (67%)]	Loss: 31135.843750
Train Epoch:  7 [50000/60000 (83%)]	Loss: 28472.082031
Train Epoch:  7 [60000/60000 (100%)]	Loss: 28538.781250

Test set: Accuracy: 9594/10000 (95.94%)

Train Epoch:  8 [10000/60000 (17%)]	Loss: 28134.074219
Train Epoch:  8 [20000/60000 (33%)]	Loss: 28331.925781
Train Epoch:  8 [30000/60000 (50%)]	Loss: 28295.814453
Train Epoch:  8 [40000/60000 (67%)]	Loss: 28645.269531
Train Epoch:  8 [50000/60000 (83%)]	Loss: 28286.863281
Train Epoch:  8 [60000/60000 (100%)]	Loss: 25048.203125

Test set: Accuracy: 9612/10000 (96.12%)

Train Epoch:  9 [10000/60000 (17%)]	Loss: 25384.703125
Train Epoch:  9 [20000/60000 (33%)]	Loss: 25215.671875
Train Epoch:  9 [30000/60000 (50%)]	Loss: 25843.103516
Train Epoch:  9 [40000/60000 (67%)]	Loss: 25475.539062
Train Epoch:  9 [50000/60000 (83%)]	Loss: 24644.417969
Train Epoch:  9 [60000/60000 (100%)]	Loss: 24063.867188

Test set: Accuracy: 9624/10000 (96.24%)

Train Epoch: 10 [10000/60000 (17%)]	Loss: 24480.371094
Train Epoch: 10 [20000/60000 (33%)]	Loss: 22880.750000
Train Epoch: 10 [30000/60000 (50%)]	Loss: 22695.031250
Train Epoch: 10 [40000/60000 (67%)]	Loss: 23176.921875
Train Epoch: 10 [50000/60000 (83%)]	Loss: 23965.068359
Train Epoch: 10 [60000/60000 (100%)]	Loss: 24928.640625

Test set: Accuracy: 9641/10000 (96.41%)

Train Epoch: 11 [10000/60000 (17%)]	Loss: 21859.009766
Train Epoch: 11 [20000/60000 (33%)]	Loss: 18358.263672
Train Epoch: 11 [30000/60000 (50%)]	Loss: 21347.195312
Train Epoch: 11 [40000/60000 (67%)]	Loss: 20771.326172
Train Epoch: 11 [50000/60000 (83%)]	Loss: 24064.519531
Train Epoch: 11 [60000/60000 (100%)]	Loss: 20937.580078

Test set: Accuracy: 9655/10000 (96.55%)

Train Epoch: 12 [10000/60000 (17%)]	Loss: 19642.025391
Train Epoch: 12 [20000/60000 (33%)]	Loss: 17672.255859
Train Epoch: 12 [30000/60000 (50%)]	Loss: 18563.121094
Train Epoch: 12 [40000/60000 (67%)]	Loss: 18395.488281
Train Epoch: 12 [50000/60000 (83%)]	Loss: 18668.035156
Train Epoch: 12 [60000/60000 (100%)]	Loss: 17270.251953

Test set: Accuracy: 9663/10000 (96.63%)

Train Epoch: 13 [10000/60000 (17%)]	Loss: 18649.554688
Train Epoch: 13 [20000/60000 (33%)]	Loss: 16082.554688
Train Epoch: 13 [30000/60000 (50%)]	Loss: 18488.023438
Train Epoch: 13 [40000/60000 (67%)]	Loss: 16917.660156
Train Epoch: 13 [50000/60000 (83%)]	Loss: 17440.460938
Train Epoch: 13 [60000/60000 (100%)]	Loss: 16976.007812

Test set: Accuracy: 9662/10000 (96.62%)

Train Epoch: 14 [10000/60000 (17%)]	Loss: 17117.199219
Train Epoch: 14 [20000/60000 (33%)]	Loss: 15142.023438
Train Epoch: 14 [30000/60000 (50%)]	Loss: 17457.289062
Train Epoch: 14 [40000/60000 (67%)]	Loss: 15979.667969
Train Epoch: 14 [50000/60000 (83%)]	Loss: 16650.373047
Train Epoch: 14 [60000/60000 (100%)]	Loss: 15610.107422

Test set: Accuracy: 9674/10000 (96.74%)

Train Epoch: 15 [10000/60000 (17%)]	Loss: 16485.027344
Train Epoch: 15 [20000/60000 (33%)]	Loss: 13969.185547
Train Epoch: 15 [30000/60000 (50%)]	Loss: 14948.458984
Train Epoch: 15 [40000/60000 (67%)]	Loss: 14111.014648
Train Epoch: 15 [50000/60000 (83%)]	Loss: 16134.172852
Train Epoch: 15 [60000/60000 (100%)]	Loss: 15030.025391

Test set: Accuracy: 9678/10000 (96.78%)

Train Epoch: 16 [10000/60000 (17%)]	Loss: 15559.684570
Train Epoch: 16 [20000/60000 (33%)]	Loss: 14404.482422
Train Epoch: 16 [30000/60000 (50%)]	Loss: 16205.851562
Train Epoch: 16 [40000/60000 (67%)]	Loss: 15412.806641
Train Epoch: 16 [50000/60000 (83%)]	Loss: 12930.396484
Train Epoch: 16 [60000/60000 (100%)]	Loss: 13448.410156

Test set: Accuracy: 9686/10000 (96.86%)

Train Epoch: 17 [10000/60000 (17%)]	Loss: 13342.847656
Train Epoch: 17 [20000/60000 (33%)]	Loss: 14741.359375
Train Epoch: 17 [30000/60000 (50%)]	Loss: 14400.076172
Train Epoch: 17 [40000/60000 (67%)]	Loss: 15236.247070
Train Epoch: 17 [50000/60000 (83%)]	Loss: 15801.335938
Train Epoch: 17 [60000/60000 (100%)]	Loss: 14962.839844

Test set: Accuracy: 9708/10000 (97.08%)

Train Epoch: 18 [10000/60000 (17%)]	Loss: 14085.380859
Train Epoch: 18 [20000/60000 (33%)]	Loss: 15532.832031
Train Epoch: 18 [30000/60000 (50%)]	Loss: 12432.498047
Train Epoch: 18 [40000/60000 (67%)]	Loss: 14637.081055
Train Epoch: 18 [50000/60000 (83%)]	Loss: 14059.531250
Train Epoch: 18 [60000/60000 (100%)]	Loss: 14864.332031

Test set: Accuracy: 9697/10000 (96.97%)

Train Epoch: 19 [10000/60000 (17%)]	Loss: 14633.511719
Train Epoch: 19 [20000/60000 (33%)]	Loss: 14187.414062
Train Epoch: 19 [30000/60000 (50%)]	Loss: 13159.357422
Train Epoch: 19 [40000/60000 (67%)]	Loss: 13750.495117
Train Epoch: 19 [50000/60000 (83%)]	Loss: 12829.285156
Train Epoch: 19 [60000/60000 (100%)]	Loss: 14987.151367

Test set: Accuracy: 9702/10000 (97.02%)

In [4], the authors run 2-layer Deep GP for more than 300 epochs and achieve 97,94% accuaracy. Despite that stacking many layers can improve performance of Gaussian Processes, it seems to me that following the line of deep kernels is a more reliable approach. Kernels, which are usually underrated, are indeed the core of Gaussian Processes. As demonstrated in Pyro's Deep Kernel Learning example, we can achieve a state-of-the-art result without having to tuning hyperparameters or using many tricks as in the above example (e.g. fixing a linear mean function, reducing the kernel effect of the first layer).

References

[1] MCMC for Variationally Sparse Gaussian Processes arxiv
James Hensman, Alexander G. de G. Matthews, Maurizio Filippone, Zoubin Ghahramani

[2] Doubly Stochastic Variational Inference for Deep Gaussian Processes arxiv
Hugh Salimbeni, Marc Peter Deisenroth

[3] https://github.com/ICL-SML/Doubly-Stochastic-DGP/blob/master/demos/demo_step_function.ipynb

[4] https://github.com/ICL-SML/Doubly-Stochastic-DGP/blob/master/demos/demo_mnist.ipynb

Comments

Comments powered by Disqus