%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set(palette="bright")
import torch
import warnings; warnings.simplefilter("ignore", FutureWarning)
# this post assumes a Pyro version in dev branch (dated 2019-01-01):
# pip install git+https://github.com/uber/pyro@4e42613
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS
1) pyro.set_rng_seed(
Sampling Hidden Markov Model with Pyro
To understand the multimodal phenomenon of unsupervised hidden Markov models (HMM) when reading some discussions in PyMC discourse, I decide to reimplement in Pyro various models from Stan. The main reference which we’ll use is Stan User’s Guide.
As in Stan user’s guide, we use the notation categories
for latent states and words
for observations. The following data information is taken from Stan’s example-models repository.
= 3
num_categories = 10
num_words = 100
num_supervised_data = 600
num_data
= torch.empty(num_categories).fill_(1.)
transition_prior = torch.empty(num_words).fill_(0.1)
emission_prior
= dist.Dirichlet(transition_prior).sample(torch.Size([num_categories]))
transition_prob = dist.Dirichlet(emission_prior).sample(torch.Size([num_categories])) emission_prob
We need to generate data randomly from the above transition probability and emission probability. In addition, we will generate an initial category from the equilibrium distribution of its Markov chain.
def equilibrium(mc_matrix):
= mc_matrix.size(0)
n return (torch.eye(n) - mc_matrix.t() + 1).inverse().matmul(torch.ones(n))
= equilibrium(transition_prob)
start_prob
# simulate data
= [], []
categories, words for t in range(num_data):
if t == 0 or t == num_supervised_data:
= dist.Categorical(start_prob).sample()
category else:
= dist.Categorical(transition_prob[category]).sample()
category = dist.Categorical(emission_prob[category]).sample()
word
categories.append(category)
words.append(word)= torch.stack(categories), torch.stack(words)
categories, words
# split into supervised data and unsupervised data
= categories[:num_supervised_data]
supervised_categories = words[:num_supervised_data]
supervised_words = words[num_supervised_data:] unsupervised_words
To observe the posterior, which are samples drawn from a Markov chain Monte Carlo sampling, we’ll make a convenient plotting function.
def plot_posterior(mcmc):
# get `transition_prob` samples from posterior
= mcmc.get_samples()["transition_prob"]
trace_transition_prob
=(10, 6))
plt.figure(figsizefor i in range(num_categories):
for j in range(num_categories):
=False, kde_kws={"lw": 2},
sns.distplot(trace_transition_prob[:, i, j], hist="transition_prob[{}, {}], true value = {:.2f}"
labelformat(i, j, transition_prob[i, j]))
."Probability", fontsize=13)
plt.xlabel("Frequency", fontsize=13)
plt.ylabel("Transition probability posterior", fontsize=15) plt.title(
Supervised HMM
When we know all hidden states (categories), we can use a supervised HMM model. Implementing it in Pyro is quite straightforward (to get familiar with Pyro, please checkout its tutorial page).
def supervised_hmm(categories, words):
with pyro.plate("prob_plate", num_categories):
= pyro.sample("transition_prob", dist.Dirichlet(transition_prior))
transition_prob = pyro.sample("emission_prob", dist.Dirichlet(emission_prior))
emission_prob
= categories[0] # start with first category
category for t in range(len(words)):
if t > 0:
= pyro.sample("category_{}".format(t), dist.Categorical(transition_prob[category]),
category =categories[t])
obs"word_{}".format(t), dist.Categorical(emission_prob[category]), obs=words[t]) pyro.sample(
# enable jit_compile to improve the sampling speed
= NUTS(supervised_hmm, jit_compile=True, ignore_jit_warnings=True)
nuts_kernel = MCMC(nuts_kernel, num_samples=100)
mcmc # we run MCMC to get posterior
mcmc.run(supervised_categories, supervised_words)# after that, we plot the posterior
plot_posterior(mcmc)
Sample: 100%|██████████| 200/200 [00:54, 3.70it/s, step size=5.18e-02, acc. prob=0.963]
We can see that MCMC gives a good posterior in this supervised context. Let’s see how things change for an unsupervised model.
Unsupervised HMM
In this case, we don’t know yet which categories generate observed words. These hidden states (categories) are discrete latent variables. Although Pyro supports maginalizing out discrete latent variables, we won’t use that technique here because it is slow for HMM. We instead will use the forward algorithm to reduce time complexity.
def forward_log_prob(prev_log_prob, curr_word, transition_log_prob, emission_log_prob):
= emission_log_prob[:, curr_word] + transition_log_prob + prev_log_prob.unsqueeze(dim=1)
log_prob return log_prob.logsumexp(dim=0)
def unsupervised_hmm(words):
with pyro.plate("prob_plate", num_categories):
= pyro.sample("transition_prob", dist.Dirichlet(transition_prior))
transition_prob = pyro.sample("emission_prob", dist.Dirichlet(emission_prior))
emission_prob
= transition_prob.log()
transition_log_prob = emission_prob.log()
emission_log_prob = emission_log_prob[:, words[0]]
log_prob for t in range(1, len(words)):
= forward_log_prob(log_prob, words[t], transition_log_prob, emission_log_prob)
log_prob = log_prob.logsumexp(dim=0).exp()
prob # a trick to inject an additional log_prob into model's log_prob
"forward_prob", dist.Bernoulli(prob), obs=torch.tensor(1.)) pyro.sample(
= NUTS(unsupervised_hmm, jit_compile=True, ignore_jit_warnings=True)
nuts_kernel = MCMC(nuts_kernel, num_samples=100)
mcmc
mcmc.run(unsupervised_words) plot_posterior(mcmc)
Sample: 100%|██████████| 200/200 [05:24, 1.62s/it, step size=1.52e-01, acc. prob=0.860]
We can see that the posterior distributions highly spread over the interval \([0, 1]\) (though they seem to favor the first half). This posterior will not be useful for making further predictions.
Semi-supervised HMM
To fix the above issue, we will use supervised data for inference.
def semisupervised_hmm(supervised_categories, supervised_words, unsupervised_words):
with pyro.plate("prob_plate", num_categories):
= pyro.sample("transition_prob", dist.Dirichlet(transition_prior))
transition_prob = pyro.sample("emission_prob", dist.Dirichlet(emission_prior))
emission_prob
= supervised_categories[0]
category for t in range(len(supervised_words)):
if t > 0:
= pyro.sample("category_{}".format(t), dist.Categorical(transition_prob[category]),
category =supervised_categories[t])
obs"word_{}".format(t), dist.Categorical(emission_prob[category]),
pyro.sample(=supervised_words[t])
obs
= transition_prob.log()
transition_log_prob = emission_prob.log()
emission_log_prob = emission_log_prob[:, unsupervised_words[0]]
log_prob for t in range(1, len(unsupervised_words)):
= forward_log_prob(log_prob, unsupervised_words[t],
log_prob
transition_log_prob, emission_log_prob)= log_prob.logsumexp(dim=0).exp()
prob "forward_prob", dist.Bernoulli(prob), obs=torch.tensor(1.)) pyro.sample(
= NUTS(semisupervised_hmm, jit_compile=True, ignore_jit_warnings=True)
nuts_kernel = MCMC(nuts_kernel, num_samples=100)
mcmc
mcmc.run(supervised_categories, supervised_words, unsupervised_words) plot_posterior(mcmc)
Sample: 100%|██████████| 200/200 [10:20, 3.10s/it, step size=9.60e-02, acc. prob=0.857]
The posterior is much better now. Which means that the additional information from supervised data has helped a lot!
Some takeaways
- When we don’t have much labeled data, consider using semi-supervised learning.
- Using additional algorithms (which include the forward algorithm in this case) can significantly improve the speed of our models.
For a variational inference approach to HMM, please check out this excellent example in Pyro tutorial page.