Structured Prediction part two  Implementing a linearchain CRF
In this part of the series of posts on structured prediction with conditional random fields (CRFs) we are going to implement all the ingredients that were discussed in part 1. Recall that we discussed how to model the dependencies among labels in sequence prediction tasks with a linearchain CRF. Now, we will put a CRF on top of a neural network feature extractor and use it for partofspeech (POS) tagging.
Everything below is inspired by this paper by Ma & Hovy (although we’ll train on a smaller dataset for time’s sake and we’ll skip much of the components since it’s not about performance but more about educational value), and the implementation has lots of parts that come from the AllenNLP implementation, so if it’s simply a good implementation you’re looking for, take a look at that one. If you’d like to understand how it works from scratch, keep on reading. Bear with me here, I discuss everything rather in detail and try not to skip over anything, from batching and broadcasting to changing the forwardrecursion of BP for a cleaner implementation, so if you rather want a succint blogpost you might want to choose one of the great other options that are also out there! Alternatively, there are some sections you can skip if they’re clear, like batching and broadcasting. Let’s start!
To learn a model that can annotate examples like the one above with predicted POS tags, we need to extract useful features from the input sequence, which we will do with a bidirectional LSTM (motivated below). Then what we’ll implement in this post is:

An
Encoder
model that holds our feature extractor (the bidirectional LSTM). 
A
ChainCRF
model that implements all the CRF methods, like belief propagation (BP) and Viterbi decoding.
The endtoend model we get by walking through this post is depicted in the image below.
To train this model endtoend we will use the negative loglikelihood (NLL) loss function, which is simply the negated loglikelihood that was given in part 1 of this series. Given some example inputoutput pairs \((\mathbf{x}^{(i)}, \mathbf{y}^{(i)})_{i=1}^N\), the NLL of the entire dataset is:
\[\begin{aligned} \text{NLL} = \log \mathcal{L}(\boldsymbol{\theta}) &=  \sum_{i=1}^{N} \log p(\mathbf{y}^{(i)} \mid \mathbf{x}^{(i)}) \\ &= \sum_{i=1}^{N}\log\left(Z(\mathbf{x}^{(i)})\right)  \sum_{i=1}^{N}\left(\sum_{t=1}^m \boldsymbol{\theta}_1 f(y_t^{(i)}, \mathbf{x}^{(i)}, t) + \sum_{t=1}^{m1} \boldsymbol{\theta}_2 f(y_t^{(i)}, y_{t+1}^{(i)})\right) \\ \end{aligned}\]Instead of maximizing the loglikelihood of our data, we will minimize the negative loglikelihood, which is equivalent.
We can use stochastic gradient descent (SGD) with automatic differentiation in PyTorch, meaning we only need the forwardpart of the BP algorithm.
Recall that the forward recursion allows calculation of the partition function (\(Z(\mathbf{x})\)), which we need for the NLL. The backward recursion allows calculating the
marginals (which are needed for the gradients). PyTorch takes care of the latter calculation for us (<3
PyTorch). Also, since
we’ll be using SGD on minibatches the above sum will go over \(B < N\) examples randomly sampled from the dataset, for batch size \(B\).
Let’s start with feature extraction and defining \(\boldsymbol{\theta}_1\), \(\boldsymbol{\theta}_2\), \(f(y_t, \mathbf{x}, t)\), and \(f(y_t, y_{t+1})\).
Preliminaries
If you want to run the code used in this post yourself, make sure to install PyTorch >= 1.7.0, Python 3.6, and TorchNLP >= 0.5.0.
Feature Extraction
There are some desiderata for our features. The function \(f(y_t, \mathbf{x}, t)\) signifies that we want each tag in the output sequence \(y_t\) to be informed about (i.e., depend on) the entire input sequence \(\mathbf{x}\), but also on the current word at time \(t\). Furthermore, \(f(y_t, y_{t+1})\) tells us that we want each next output tag \(y_{t+1}\) to depend on the previous tag \(y_{t}\). Parametrizing the former part, we take \(\boldsymbol{\theta}_1f(y_t, \mathbf{x}, t)\) to be the output of a bidirectional LSTM projected down to the right dimension with a linear layer:
\[\begin{aligned} \mathbf{\bar{H}} &= \text{biLSTM}(\mathbf{x}) \\ \mathbf{H} &= \mathbf{\bar{H}}\mathbf{W} + \mathbf{b} \end{aligned}\]In the above, \(\mathbf{x} \in \mathbb{R}^{m}\) is our input sequence, \(\mathbf{\bar{H}} \in \mathbb{R}^{m \times 2d_h}\) the hidden vectors for each input word \(x_t\) stacked into a matrix, with \(d_h\) the hidden dimension of the LSTM (doubled because a bidirectional LSTM is essentially two LSTMs processing the input sequence from lefttoright and righttoleft). \(\mathbf{W} \in \mathbb{R}^{2d_h \times S}\) a matrix of parameters that projects the output to the right dimension, namely \(S\) values for each input word \(x_t\). The tth row of \(\mathbf{H} \in \mathbb{R}^{m \times S}\) (let’s define that by \(\mathbf{H}_{t*}\)) then holds the features for the tth word (\(x_t\)) in the input sequence. These values reflect \(\boldsymbol{\theta}_1f(y_t, \mathbf{x}_t, t)\) for each possible \(y_t\), since they depend on the entire input sequence \(\mathbf{x}\), but are specific to the current word at time \(t\).
\[\boldsymbol{\theta}_1f(y_t, \mathbf{x}_t, t) = \mathbf{H}_{t,y_t}\]Using PyTorch, we can code this up in a few lines. As is common in computational models of language, we will assign each
input token a particular index and use dense word embeddings that represent each word in our input vocabulary. We reserve
index 0 for the special token <PAD>
, which we need later when we will batch our examples, grouping together examples
of different length \(m\).
Strictly speaking btw, our biLSTM does not take \(\mathbf{x}\) as input but \(\mathbf{E}_{x_t*}\), where \(\mathbf{E}\) is the matrix of embedding parameters of size \(I \times d_e\) (where \(I\) is the input vocabulary size, or the number of unique input words). So if \(x_t = 3\) (index 3 in our vocabulary, which might be mapping to book for example), \(\mathbf{E}_{x_t*}\) takes out the corresponding embedding of size \(d_e\).
That’s all we need to do to extract features from out input sequences! Later in this post we will define \(\boldsymbol{\theta}_2\)
and \(f(y_t, y_{t+1})\), which are part of the actual CRF. First, we’ll set up our Tagger
module which takes the encoder,
the CRF (to implement later), and outputs the negative loglikelihood and a predicted tag sequence.
The Tagger
Below we can find the Tagger
module. This is basically the class that implements the model as depicted in the image
above. The most interesting part is still missing (namely the ChainCRF
module, but also the Vocabulary
module), but we’ll get to those. The
forward pass of the Tagger
takes an input sequence, a target sequence, and an input mask (we’ll get to what that is
when we discuss batching), puts the input sequence through the encoder and the CRF, and outputs the NLL loss
, the score
(
which is basically the nominator of \(p(\mathbf{y} \mid \mathbf{x})\)), and the tag_sequence
obtained by decoding with viterbi.
If you don’t have a feeling of what the parameters input_mask
and input_lengths
in the module should be, don’t worry,
that will be discussed in the section ‘sequence batching’ below.
OK, now we can finally get to the definition of \(\boldsymbol{\theta}_2\) and \(f(y_t, y_{t+1})\), and the implementation of the linearchain CRF!
Implementing a LinearChain CRF
We want \(f(y_t, y_{t+1})\) to represent the likelihood of some tag \(y_{t+1}\) following \(y_t\) in the sequence, which
can be interpreted as transition likelihood from one tag to another. We assumed the parameters \(\boldsymbol{\theta}_2\) are shared
over time (meaning they are the same for each \(t \in \{1, \dots, m1\}\)), and thus we can simply define a matrix of transition ‘probabilities’ from each tag to another tag.
We define \(\boldsymbol{\theta}_2f(y_t, y_{t+1}) = \mathbf{T}_{y_t,y_{t+1}}\), meaning the \(y_t\)th row (recall that \(y_t\) is an index that represents a tag) and \(y_{t+1}\)th column of a matrix \(\mathbf{T}\).
This matrix \(\mathbf{T}\) will be of size \((S + 2) \times (S + 2)\). The 2 extra tags are the <ROOT>
and the <EOS>
tag.
We need some tag to start of the sequence and to end the sequence, because we want to take into account the probability of a particular tag being the
first tag of a sequence, and the probability of a tag being the last tag. For the former we will use <ROOT>
, and a for the latter we’ll use <EOS>
.
In this part we will implement:

The forwardpass of belief propagation (
ChainCRF.forward_belief_propagation(...)
), calculating the partition function (i.e., the denominator of \(p(\mathbf{y} \mid \mathbf{x})\)). 
Calculating the lognominator of \(p(\mathbf{y} \mid \mathbf{x})\) (
ChainCRF.score_sentence(...)
) 
Decoding to get the target sequence prediction (
ChainCRF.viterbi_decode(...)
)
Below, you’ll find the ChainCRF
class that holds all these methods. The matrix of transition probabilities \(\mathbf{T}\)
is initialized below as log_transitions
(note that \(\mathbf{T}\) are actually the logtransition probabilities because in the CRF equations they are \(\exp(\boldsymbol{\theta}_2f(y_t, y_{t+1})) = \exp\mathbf{T}_{y_t,y_{t+1}}\)), and we hardcode the transition probabilities from any \(y_t\) to the <ROOT>
tag to be 10000 because this should not be possible (and this becomes 0 in the CRF equation: exp(log_transitions)
gives \(\exp10000 \approx 0\)).
We do the same for any transition from <EOS>
to any other tag. The class below implements the methods to calculate the NLL loss, and the total forwardpass of the CRF that returns this loss as well as a predicted tag sequence. In the sections below we will implement the necessary methods for our linearchain CRF, starting with belief propagation.
But first, since we implement all these method in batched versions, let’s briefly go over batching.
Sequence Batching
Processing data points in batches has multiple benefits: averaging the gradient over a minibatch in SGD allows playing with the noise you want while training your model (batch size of 1 gives a maximally noisy gradient, batch size of \(N\) is minimally noisy gradient, namely gradient descent without the stochasticity), but also: batching examples speeds up training. This motivates us to implement all the methods for the CRF in batched versions, allowing parallel processing. We can’t parallelize the timedimension in CRFs unfortunately.
A batch of size 2 would look like this:
For each batch we will additionally keep track of the lengths in the batch. For the image above a list of lengths would be input_lengths = [8, 5]
.
For example, batched inputs to our encoder will be of size [batch_size, sequence_length]
and outputs of size [batch_size, sequence_length, hidden_dim*2]
.
The input mask for the above batch looks like this:
Implementation in logspace, stable ‘‘logsumexping’’ & broadcasting
Before we can finally get into the interesting implementations, we need to talk about two things. Firstly, for calculating the partition function we are going to need to sum a bunch of \(exp(\cdot)\)’s, which might explode. To do this numerically stable, we will use the logsumexptrick. The log here comes from the fact that we will implement everything in logspace. Numerical stability doesn’t fare well with recursive multiplication of small values (i.e., values between 0 and 1) or large values. In logspace, multiplications become summations, which have less of a risk of becoming too small or too large. Recall that the initialization and recursion in the forwardpass of belief propagation are given by the following equations:
\[\begin{aligned} \alpha(1, y^{\prime}_2) &= \sum_{y^{\prime}_1}\psi(y^{\prime}_1, \mathbf{x}, 1) \cdot \psi(y^{\prime}_1, y^{\prime}_{2}) \\ \alpha(t, y^{\prime}_{t+1}) &\leftarrow \sum_{y^{\prime}_{t}}\psi(y^{\prime}_t, \mathbf{x}, t) \cdot \psi(y^{\prime}_t, y^{\prime}_{t+1})\cdot \alpha(t1, y^{\prime}_t) \end{aligned}\]First we convert the alpha initialization to logspace:
\[\begin{aligned} \alpha(1, y^{\prime}_2) &= \sum_{y^{\prime}_1}\psi(y^{\prime}_1, \mathbf{x}, 1) \cdot \psi(y^{\prime}_1, y^{\prime}_{2}) \\ \log \alpha(1, y^{\prime}_2) &= \log \sum_{y^{\prime}_1}\psi(y^{\prime}_1, \mathbf{x}, 1) \cdot \psi(y^{\prime}_1, y^{\prime}_{2}) \\ \end{aligned}\]Then we convert the recursion equation to logspace, we plug in the CRF factors, and we see why we get a logsumexp:
\[\begin{aligned} \log \alpha(t, y^{\prime}_{t+1}) &\leftarrow \log \sum_{y^{\prime}_{t}}\psi(y^{\prime}_t, \mathbf{x}, t) \cdot \psi(y^{\prime}_t, y^{\prime}_{t+1})\cdot \exp\log\alpha(t1, y^{\prime}_t) \\ \log \alpha(t, y^{\prime}_{t+1}) &\leftarrow \log \sum_{y^{\prime}_{t}}\exp\left(\boldsymbol{\theta}_1f(y^{\prime}_t, \mathbf{x}, t) + \boldsymbol{\theta}_2f(y^{\prime}_t, y^{\prime}_{t+1})\right) \cdot \exp\log\alpha(t1, y^{\prime}_t) \\ \log \alpha(t, y^{\prime}_{t+1}) &\leftarrow \underbrace{\log \sum_{y^{\prime}_{t}}\exp}_{\text{logsumexp}}\left(\boldsymbol{\theta}_1f(y^{\prime}_t, \mathbf{x}, t) + \boldsymbol{\theta}_2f(y^{\prime}_t, y^{\prime}_{t+1}) + \log\alpha(t1, y^{\prime}_t)\right) \end{aligned}\]Now what logsumexp does is rewrite this as follows. Let everything inside the \(exp(\cdot)\) above for simplicity be \(q_t\):
\[\begin{aligned} \log \sum_{y^{\prime}_{t}}\exp(q_t) &= \log \sum_{y^{\prime}_{t}}\exp(q_t  c + c) \\ &= \log \sum_{y^{\prime}_{t}}\exp(q_t  c)\exp(c) \\ &= \log \exp(c)\sum_{y^{\prime}_{t}}\exp(q_t  c) \\ &= c + \log\sum_{y^{\prime}_{t}}\exp(q_t  c) \\ \end{aligned}\]If we take the constant \(c\) to be the maximum value that \(q_t\) can take in the sum, the summation becomes stable. See below the code that implements this, adapted from AllenNLP.
Secondly, we need to talk about broadcasting. Broadcasting over dimensions is conceptually the same thing as matrix multiplication,
but then by summing. For example, if we sum a + b = c
where \(a \in \mathbb{R}^{2 \times 1}\) and \(b \in \mathbb{R}^{1 \times 2}\)
the result will be \(c \in \mathbb{R}^{2 \times 2}\):
You can use broadcasting if you have a value that you want to add to every index of a vector, like in the above case 1 is added to 3 and 4 for the top row and 2 to 3 and 4 for the bottom row.
Forward Belief Propagation
In this section we will implement the forwardpass of belief propagation, which we need to calculate part of the NLL loss (namely \(\log\left(Z(\mathbf{x}^{(i)})\right)\)).
We are implementing the forward recursion in a batched version, meaning that the loop over time goes from \(t=1\) to \(m1\) for the largest \(m\) in the batch. Some sequences might already end somewhere earlier in the loop, which is why we will mask out the recursion for those sequences and retain the old \(\alpha\) variables for them. Implementing batched forward belief propagation becomes a lot easier if we instead of using the recursion equation as defined above, use the following recursion:
\[\begin{aligned} \hat{\alpha}(1, y^{\prime}_2) &= \psi(y^{\prime}_2, \mathbf{x}, 2)\sum_{y^{\prime}_1}\psi(y^{\prime}_1, \mathbf{x}, 1) \cdot \psi(y^{\prime}_1, y^{\prime}_{2}) \\ \hat{\alpha}(t, y^{\prime}_{t+1}) &\leftarrow \psi(y^{\prime}_{t+1}, \mathbf{x}, t+1)\sum_{y^{\prime}_{t}}\psi(y^{\prime}_t, y^{\prime}_{t+1})\cdot \hat{\alpha}(t1, y^{\prime}_t) \end{aligned}\]Why this is equivalent becomes apparent when we look at the full partition function we’re trying to compute. We simply move the unary features \(\psi(y_2^{\prime}, \mathbf{x}, 2)\) to the previous recursion.
\[\begin{aligned} Z(\mathbf{x}) &= \sum_{y^{\prime}_m}\psi(y^{\prime}_m, \mathbf{x}, m)\sum_{y^{\prime}_{m1}}\psi(y^{\prime}_{m1}, \mathbf{x}, m1)\psi(y^{\prime}_{m1}, y^{\prime}_{m}) \dots \dots \\ & \quad \quad \quad \dots \sum_{y^{\prime}_2}\psi(y^{\prime}_2, \mathbf{x}, 2) \cdot \psi(y^{\prime}_2, y^{\prime}_{3})\underbrace{\sum_{y^{\prime}_1}\psi(y^{\prime}_1, \mathbf{x}, 1) \cdot \psi(y^{\prime}_1, y^{\prime}_{2})}_{\alpha(1, y^{\prime}_{2})} \\ &= \sum_{y^{\prime}_m}\psi(y^{\prime}_m, \mathbf{x}, m)\sum_{y^{\prime}_{m1}}\psi(y^{\prime}_{m1}, \mathbf{x}, m1)\psi(y^{\prime}_{m1}, y^{\prime}_{m}) \dots \dots \\ & \quad \quad \quad \dots \sum_{y^{\prime}_2} \psi(y^{\prime}_2, y^{\prime}_{3}) \cdot \underbrace{\psi(y^{\prime}_2, \mathbf{x}, 2)\sum_{y^{\prime}_1}\psi(y^{\prime}_1, \mathbf{x}, 1) \cdot \psi(y^{\prime}_1, y^{\prime}_{2})}_{\hat{\alpha}(1, y^{\prime}_{2})} \end{aligned}\]Remember that in the forward recursion of BP we looped until \(m1\), and then for time \(m\) we don’t have transition probabilities to \(m+1\) anymore, so we separately did the following:
\[Z(\mathbf{x}) = \sum_{y^{\prime}_m}\psi(y^{\prime}_m, \mathbf{x}, m) \cdot\alpha(m1, y^{\prime}_m)\]The new recursion results in much easier code because instead of separately keeping track of when each sequence in the batch ends and writing an ifelse statement within the loop over time that calculates the above equations for the sequences that are ending, we simply already incorporate the unary features of \(y_{t+1}^{\prime}\) in each recursive calculation. Then, at the end we only need to sum all new alphas, which we can do outside of the loop because the masking already takes care of keeping the alphas for the ended sequences the same:
\[Z(\mathbf{x}) = \sum_{y^{\prime}_m}\hat{\alpha}(m1, y^{\prime}_m)\]The equations we are going to implement now are these recursions in logspace. In the equation below I’ve annotated every part of the equations with the corresponding variable name their values are a part of in the implementation below.
\[\begin{aligned} \log\hat{\alpha}(1, y^{\prime}_2) &= \log\psi(y^{\prime}_2, \mathbf{x}, 2) + \log\sum_{y^{\prime}_1}\psi(y^{\prime}_1, \mathbf{x}, 1) \cdot \psi(y^{\prime}_1, y^{\prime}_{2}) \\ \log\hat{\alpha}(t, y^{\prime}_{t+1}) &\leftarrow \log\left(\psi(y^{\prime}_{t+1}, \mathbf{x}, t+1)\sum_{y^{\prime}_{t}}\psi(y^{\prime}_t, y^{\prime}_{t+1})\cdot \exp\log\hat{\alpha}(t1, y^{\prime}_t)\right) \\ &\leftarrow \log\left(\sum_{y^{\prime}_{t}}\psi(y^{\prime}_{t+1}, \mathbf{x}, t+1) \cdot \psi(y^{\prime}_t, y^{\prime}_{t+1})\cdot \exp\log\hat{\alpha}(t1, y^{\prime}_t)\right) \\ &\leftarrow \log\left(\sum_{y^{\prime}_{t}}\exp\left(\underbrace{\boldsymbol{\theta}_1f(y^{\prime}_{t+1}, \mathbf{x}, t+1)}_{\text{unary_features}} + \underbrace{\boldsymbol{\theta}_2 f(y^{\prime}_t, y^{\prime}_{t+1})}_{\text{transition_scores}} + \underbrace{\hat{\alpha}(t1, y^{\prime}_t)}_{\text{forward_alpha_e}}\right)\right) \\ \end{aligned}\]We discussed above that \(\boldsymbol{\theta}_1f(y^{\prime}_{t+1}, \mathbf{x}, t+1) = \mathbf{H}_{t+1,y_{t+1}}\) is the t+1th row and \(y_{t+1}\)th column of the projected output of our encoder. These values are collected in a vector called unary_features
for all \(y^{\prime}_{t+1}\) (meaning the vector has size \(S\)). In the above
recursion, these values are the same for all \(y_t^{\prime}\) in the sum (meaning we can use broadcasting in the code!). Then \(\boldsymbol{\theta}_2f(y^{\prime}_t, y^{\prime}_{t+1}) = \mathbf{T}_{y_t,y_{t+1}}\)
is the \(y_t\)th row and \(y_{t+1}\)th column of the matrix of transition probabilities. The vector transition_scores
in the code holds the probabilities from all \(y_t\)’s to a particular \(y_{t+1}\). These are
the same for each example in the batch, which again asks for broadcasting in the implementations.
Now let’s take a look at the code that implements all this. Below the code some clarifications.
The code recurses over the time dimension, from \(1\) to \(m  1\), and calculates the alphas in logspace. The first important line is the following:
next_forward_alpha = unary_features + transition_scores + forward_alpha_e
This line basically calculates the recursion above without the logsumexping:
\[\underbrace{\boldsymbol{\theta}_1f(y^{\prime}_{t+1}, \mathbf{x}, t+1)}_{\text{unary_features}} + \underbrace{\boldsymbol{\theta}_2 f(y^{\prime}_t, y^{\prime}_{t+1})}_{\text{transition_scores}} + \underbrace{\hat{\alpha}(t1, y^{\prime}_t)}_{\text{forward_alpha_e}}\]But then for all possible POS tags \(y^{\prime}_t\) and all next tags \(y^{\prime}_{t+1}\). This is where we use broadcasting. The unary features are the same for all \(y^{\prime}_t\), and the forward alphas are the same for all \(y^{\prime}_{t+1}\). Then if we logsumexp over the first dimension:
logsumexp(next_forward_alpha, 1)
We get our subsequent alpha recursion variables for all \(y^{\prime}_{t+1}\). The following image shows what happens in these two lines of code for a single example in a batch graphically:
The image above has \(\tilde{\alpha}\) which should be \(\hat{\alpha}\) actually but I can’t for the life of me
figure out how to do that in Google drawings >:(
.
Note that we add the transition from each tag to the final <EOS>
tag outside of the loop over time, because these are independent
of time and should happen for every sequence in the batch. Then the whole function returns:
logsumexp(alphas)
which is \(Z(\mathbf{x}) = \sum_{y^{\prime}_m}\hat{\alpha}(m1, y^{\prime}_m)\).
This is possible for all examples in the batch because we make sure within the loop over time to only update the alphas
for the sequences that haven’t ended yet. Whenever the value for a particular timestep of the input_mask
is zero,
the following lines of code make sure that we retain the old alpha for those examples:
Yay, that’s calculating the partition function! Now let’s look at calculating the nominator of the CRF,
or the function ChainCRF.score_sentence(...)
.
Calculating the Nominator
To finish calculating the loss, we just need to calculate the (log)nominator.
\[\sum_{t=1}^m \underbrace{\boldsymbol{\theta}_1 f(y_t^{(i)}, \mathbf{x}^{(i)}, t)}_{\text{unary_features}} + \sum_{t=1}^{m1} \underbrace{\boldsymbol{\theta}_2 f(y_t^{(i)}, y_{t+1}^{(i)})}_{\text{transition_scores}}\]Which is simply a sum over time of the unary features and the transition probabilities for a given input sentence \(\mathbf{x}^{(i)}\)
and target tag sequence \(\mathbf{y}^{(i)}\). In the implementation we loop over all the sequences in the batch in parallel,
get the current tag for each example in the batch from the groundtruth sequence target_tags
,
and calculate the current unary scores and transition scores from the current time to the next.
In the code above we calculate the unary features for the tags at the current time step with the following line:
Here, current_tags
is a vector with the index of the tag at time \(t\) in the sequence for each example in the batch,
and the gather function retrieves this index in the first dimension of the vector unary_features
, which is of size
[batch_size, num_tags]
, meaning this code gathers \(\boldsymbol{\theta}_1 f(y_t^{(i)}, \mathbf{x}^{(i)}, t)\)
for all examples in the batch. Then we add these unary features to the transition features for each current tag to
the next tags at time \(t+1\) (\(\boldsymbol{\theta}_2 f(y_t^{(i)}, y_{t+1}^{(i)})\)), making sure to mask any
unary features for sequences that have ended, and transition scores for sequences that don’t have a tag anymore at time
\(t+1\):
Then outside of the loop over time we still need to gather the transition scores for the last tag of each sequence in the batch
to the <EOS>
token, because we didn’t do this inside the loop. Additionally, we need to add the last unary features for
the longest sequences in the batch. Finally, we return a lognominator score
for each sequence in the batch.
This concludes the calculation of the NLL! So we can start decoding.
Viterbi Decoding
We’ll finally implement decoding with Viterbi. Recall that in structured prediction we ideally want to use an efficient DP algorithm, to get rid of the exponential decoding complexity. As mentioned, decoding here is finding the maximum scoring target sequence according to our model. This amounts to solving the following equation:
\(\mathbf{y}^{\star} = arg\max_{\mathbf{y} \in \mathcal{Y}} p(\mathbf{y} \mid \mathbf{x})\).
With Viterbi, the time complexity of doing this is \(m \cdot S^2\), instead of \(S^m\) if done naively. Viterbi works very similarly to BP, but instead of summing we maximize. The Viterbi implementation is by far the most complex one of all the CRF methods, especially in the batched version that we’ll use. It’s also the longest function as we have to implement both the forward recursion to calculate all the likelihoods of the possible sequences, as well as backtracking by following backpointers to obtain the maximum scoring sequence, so we’ll do it in two steps. That said, if you went through forward belief propagation and understood that part, this is not much different.
The recursion equations for Viterbi decoding are four equations, one for initializing the recursion variables, one for the recursion itself, one for intializing the backpointers and one for recursively finding the backpointer. If you want a detailed explanation of how we get to these equations from the \(arg\max\) above, see part one of this series
\[\begin{aligned} v(1, y^{\prime}_2) &= \max_{y^{\prime}_1}\psi(y^{\prime}_1, \mathbf{x}, 1) \cdot \psi(y^{\prime}_1, y^{\prime}_{2}) \\ v(t, y^{\prime}_{t+1}) &\leftarrow \max_{y^{\prime}_{t}}\psi(y^{\prime}_t, \mathbf{x}, t) \cdot \psi(y^{\prime}_t, y^{\prime}_{t+1})\cdot v(t1, y^{\prime}_t) \\ \overleftarrow{v}(1, y^{\prime}_2) &= arg\max_{y^{\prime}_1}\psi(y^{\prime}_1, \mathbf{x}, 1) \cdot \psi(y^{\prime}_1, y^{\prime}_{2}) \\ \overleftarrow{v}(t, y^{\prime}_{t+1}) &\leftarrow arg\max_{y^{\prime}_{t}}\psi(y^{\prime}_t, \mathbf{x}, t) \cdot \psi(y^{\prime}_t, y^{\prime}_{t+1})\cdot v(t1, y^{\prime}_t) \\ \end{aligned}\]The animation that graphically depict what goes on in Viterbi is the following:
We loop over the sequence, calculating the recursion for each example in the batch and for each tag \(y_t\), and at the same time we keep track for each tag which previous tag is the argument that maximizes the sequence up till that tag. This is what the first loop over time in the code below does. Then, we loop backwards over time and follow the backpointers to find the maximum scoring sequence. The code below implements all this, and below it we’ll go over some more detailed explanations of what’s going, relating the code to the equations. I also write much more detailed comments than usually in the code below for clarification.
Now even though in this code we do need to have an ifelse
statement within the loop over time to
keep track of the best last tags (the final column in the above animation). We will still rewrite the recursion
equations like we did above in forward BP to get rid of the need to separately calculate the final recursive calculation
for which we don’t have transition probabilities.
I’m just going to redefine the \(\overleftarrow{v}\) below to avoid notational clutter. Let’s also write everything in logspace while we’re at it.
Note that since \(exp(\cdot)\) and \(\log(\cdot)\) are monotonically increasing functions and we only care about
the maximizing scores and tags, we can just leave them out of the implementation all together. Therefore, in the second line of
equations below for both \(\hat{v}(1, y^{\prime}_2)\) and \(\hat{v}(t, y^{\prime}_{t+1})\) we’ll leave them out, making those equations
formally incorrect, but it doesn’t matter for finding the maximizing sequence. And finally,
note in the below that we assume \(\psi(y^{\prime}_1, \mathbf{x}, 1)\) of the first tag to be 1
because it’s just always the <ROOT>
and \(\psi(y^{\prime}_1, y^{\prime}_{2})\) to be the transition probabilities from the root tag to
any other tag. This means the initial \(y^{\prime}_1\) that maximizes the equations is simply <ROOT>
.
I annotated the equations with the variables used in the code again, even though in the equations its all for a single example and in the code for a batch of examples. We’re going to do this function in two parts (even though they’re actually a single function), where the first part is the forward recursion given by the above equations and the second part will be backtracking.
Everything in the code above hopefully becomes clear with the equations with annotations given above the code.
Below I took out a piece of code that’s inside of the if
statements in the code above. What happens is that we calculate:
\(\hat{y}^{\prime}_m = arg\max_{y^{\prime}_m}\underbrace{\boldsymbol{\theta}_2 f(y^{\prime}_m, y^{\prime}_{m+1})}_{\text{trans_to_end}} + \underbrace{\hat{v}_{m1}(y^{\prime}_m)}_{\text{forward_ending}}\),
for every example that is ending at that time \(t\), i.e., \(t = m\), and for \(y^{\prime}_{m+1}\) the <EOS>
tag.
We do the same calculation outside of the loop over time for all the sequences that have the maximum sequence length
in the batch. Then at the end of the loop we have all the backpointers and scores initialized, and we can start the backtracking.
In the lists best_last_tags
we appended the final tags that maximized the ending sequences, starting with the shortest
sequences in the batch. When we make sure to sort all the sequences in the batch in descending order,
we can simply reverse best_last_tags
to get the original ordering back. In the code below, we then put these tags in the initialized
matrix best_paths
that should hold the maximizing sequence for each example. This means best_paths
looks like the image below before we loop backwards over time:
(just an example, depending on how long each sequence in the batch is the actual batch might look differently).
Then we loop backwards over time, following the backpointers, and fill the best_paths
for each time in the sequence.
In the above code, at every time step (for the active sequences) we use the previous best tags, starting with
the best last tags, to find the correct backpointer to follow back. Now backpointers
is a matrix of size [batch_size, sequence_length, num_tags]
(where
the batch is still sorted in descending order) and holds at every time \(t\) the backpointers (i.e., ids of previous tags
that maximize the sequence) for every tag. See this for one example in a batch depicted below:
Now best_last_tags
holds the tag to start backtracking with for every example, which might be a verb like depicted above.
Based on this last tag you just need to follow the path backward to find the correct sequence. Selecting the backpointer
is what happens in the following lines (where active
holds the backpointers for the active sequences):
In the end, best_paths
holds the maximizing sequence for every example in the batch, starting at the <ROOT>
tag.
We return the scores and the sequences without the root and we are done!
Summarizing
We’ve implemented all the methods needed for a biLSTMCRF POS tagger! We’ve taken care of batched forward belief propagation and calculating the nominator, together forming our loss function. Then we implemented Viterbi decoding to find the predicted tag sequence with maximum probability. Now, we only need to get some data and train our model on it. Go to part three to train this model on real POS tagging data!
Sources
Xuezhe Ma and Eduard H. Hovy. 2016. Endtoend sequence labeling via bidirectional LSTMCNNCRF. In ACL.
Matt Gardner and Joel Grus and Mark Neumann and Oyvind Tafjord and Pradeep Dasigi and Nelson F. Liu and Matthew Peters and Michael Schmitz and Luke S. Zettlemoyer. 2017 AllenNLP: A Deep Semantic Natural Language Processing Platform.