annotated_example

In this final part of the series on structured prediction with linear-chain CRFs we will use our implementation from part two to train a model on real data. To learn such a model, we need a dataset with examples consisting of input sentences annotated with POS tags. We will choose the Universal Dependencies dataset (Silveira et al., 2014).

Then all the things we need to implement are:

  • A Vocabulary to convert from strings to numerical values for computational models.

  • A TaggingDataset to convert all our data to Tensors that can be processed by PyTorch.

  • A train() loop to train our CRF and feature-extractor end-to-end on data.

  • A test() loop to test a trained model on new data.

Additionally, we will slightly chance the Encoder and Tagger from part two to incorporate a character-based model.

Imports

Let’s install and import the libraries we need (TorchNLP isn’t part of the default runtime in Google Colab, which I used for the implementation):

!pip install torchnlp pytorch-nlp

import torch.nn as nn
import torch
import torchnlp
from torchnlp.datasets import ud_pos_dataset

from typing import List, Tuple, Iterator
from torch import Tensor
from collections import defaultdict, Counter
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

To make sure that CUDA is in fact available (which is definitely nice and maybe even necessary for training on the universal dependencies dataset), Google Colab offers sessions with a GPU! Select this in the runtime in the top-right corner if you’re coding everything yourself.

The Vocabulary & Dataset

First, we’ll implement the vocabulary, which is a class that reads sentences as lists of strings, and converts them to indices. Something pragmatical for sequence prediction with neural methods is that we often use an <UNK>-token. In our training set, if a word occurs very infrequently, we probably cannot learn meaningful embeddings for it and we can replace the occurrences of that word by <UNK>. This means that we will learn a kind of average embedding for all infrequent words, and we can use this token again at test-time. At test-time there will inevitably be words that don’t occur in the training set, and since we don’t have trained embeddings for those, they will map onto the <UNK>-token.

One other approach to deal with unknown words in the test data is using a character model. It’s unlikely that we won’t encounter a certain character, so when a word is unknown, a character model can represent it more meaningfully. One could imagine that at test time the model encounters and word it hasn’t seen at training time ending with “-ing”, the word model will represent it with an unknown token <UNK>, but the character model might recognize that this is likely a verb. Additionally, if there are misspelled words in the test data, the word vocabulary won’t know them, but the character model might recognize them. Adding a character model makes our method more robust to noise in the data.

Note that adding a character model doesn’t change anything about the way we desribed our implementation in part two of the series. We simply add a learned representation to the representations we already had for the words, but this time character-based. The only change lies in batching, which I’ll discuss below.

Both the code for the Vocabulary and the TaggingDataset below is very straightforward, so if you’re familiar with these kind of methods just skip them and go to the part below where we look at the Universal Dependencies dataset.

class Vocabulary(object):
    """
    Object that maps tokens (e.g., words, characters) to indices to be 
    processed by numerical models.
    """

    def __init__(self, pad_token="<PAD>", unk_token="<UNK>"):
      """
      <PAD> and <UNK> tokens are by construction idxs 0 and 1.
      """
      self.pad_token = pad_token
      self.unk_token = unk_token
      self._idx_to_token = [pad_token, unk_token]
      self._token_to_idx = defaultdict(
          lambda: self._idx_to_token.index(self.pad_token))
      self._token_to_idx[pad_token] = 0
      self._token_to_idx[unk_token] = 1
      self._token_frequencies = Counter()

    def token_to_idx(self, token: str) -> int:
      if token not in self._token_to_idx:
        return self._token_to_idx[self.unk_token]
      else:
        return self._token_to_idx[token]

    def idx_to_token(self, idx: int) -> str:
      return self._idx_to_token[idx]
    
    @property
    def unk_idx(self):
      return self.token_to_idx(self.unk_token)

    @property
    def size(self) -> int:
      return len(self._idx_to_token)
    
    @property
    def pad_idx(self):
      return self.token_to_idx(self.pad_token)

    def add_token_sequence(self, token_sequence: List[str]):
      for token in token_sequence:
        if token not in self._token_to_idx:
          self._token_to_idx[token] = self.size
          self._idx_to_token.append(token)
        self._token_frequencies[token] += 1
    
    def most_common(self, n=10):
      return self._token_frequencies.most_common(n=n)

We will use the above Vocabulary-class three times in the following, once for the input data consisting of words, once for the input data processed character-by-character, and once for the target data consisting of POS tags.

Before we implement the dataset class, let’s have a look at what a batch looks like if we add characters. Everything becomes a bit more complicated in the code, because we want a batched implementation of the forward pass. Recall that in part two of this series, we had a vector of features for each word produced by the bidirectional LSTM. We now also want to add a vector that represents that same word broken up in characters. We denoted by \(\mathbf{\bar{H}} \in \mathbb{R}^{m \times 2d_h}\) the hidden vectors for each input word in the sentence of length \(m\) as produced by the biLSTM. We want to obtain some character-based features that have the exact same size, so we can add them. Let’s denote the number of characters of the longest word in this sentence by \(k\) and define the character input by \(\mathbf{x}_c \in \mathbb{R}^{m \times k \times |C|}\) (changing the definition of the input sequence of words from \(\mathbf{x}\) (in part two) to \(\mathbf{x}_w\)). If we want to add these to the word features in a batched fashion instead of a loop, we need to get a vector that has the same size as the word features that the biLSTM returned. This means we need to have character sequences for the padding tokens in our batch as well. This feels a bit silly, because we will just be adding zeros to zeros, but it’s simply done so we can implement everything in a batched fashion instead of with a loop. It also means we need to use the same hidden size for the bidirectional LSTM we’ll use for the characters. A batch that contains characters will have the following in tuples of examples:

\[\mathbf{x}_w \in \mathbb{R}^{m \times |I|}, \mathbf{x}_c \in \mathbb{R}^{m \times k \times |C|}, \mathbf{y} \in \mathbb{R}^{m \times |S|}\]

Where \(\mathbf{x}_w\) is the input sequence in words, with \(|I|\) the size of the input vocabulary, \(\mathbf{x}_c\) is the input sequence broken up in characters for each word, and \(\mathbf{y}\) the tag sequence with \(|S|\) the number of tags in our dataset. To be able to add the character-based words to the regular words, we pad each sequence in the batch to the maximum \(m\) that occurs in the batch (both the sequence of words, and the sequence of character sequences), and we pad each word to the maximum number of characters \(k\) appearing in the entire batch. See below a new graphical depiction of a batch with batchsize \(B = 2\).

batch_char

Then the next class to implement is the class that holds the TaggingDataset:

class TaggingDataset(object):
  """
  A class to hold data pairs of input words, characters, and target tags.
  """

  def __init__(self, data: List[Tuple[List, List]]):
    self._input_vocabulary = Vocabulary()
    self._char_vocabulary = Vocabulary()
    self._target_vocabulary = Vocabulary()

    # Read the training data and add each example to the vocabularies.
    examples, example_lengths, char_max_lengths = self.read_dataset(
        data, add_to_vocabularies=True)
    self.data = {
        "train": {
            "examples": examples,
            "example_lengths": example_lengths,
            "char_max_lengths": char_max_lengths
        },
        "test": {}. # We will add the test examples later.
    }

  def add_testset(self, examples: List[Dict[str, torch.tensor]],
                  example_lengths: List[int], char_max_lengths: List[int]):
    self.data["test"]["examples"] = examples
    self.data["test"]["example_lengths"] = example_lengths
    self.data["test"]["char_max_lengths"] = char_max_lengths

  def read_dataset(self, input_data: List[Tuple[List, List]],
                   add_to_vocabularies: bool):
    """Convert each example to a tensor and save it's length."""
    examples = []
    example_lengths = []
    char_max_lengths = []
    for input_list, target_list in input_data:
      assert len(input_list) == len(target_list), "Invalid data example."

      # We don't want to add the test examples to the vocabulary.
      if add_to_vocabularies:
        self._input_vocabulary.add_token_sequence(input_list)
        self._target_vocabulary.add_token_sequence(target_list)
      
      # Convert the input sequence to an array of ints.
      input_array = self.sentence_to_array(input_list, vocabulary="input")

      # Convert each word in the sentence into a sequence of ints.
      char_inputs = []
      char_max_length = -1
      for word in input_list:
        char_list = list(word)
        char_length = len(char_list)
        if add_to_vocabularies:
          self._char_vocabulary.add_token_sequence(char_list)
        
        # Keep track of the maximum character length in a sentence for padding.
        if char_length > char_max_length:
          char_max_length = char_length
        char_array = self.sentence_to_array(char_list, vocabulary="char")
        char_inputs.append(torch.tensor(char_array, dtype=torch.long,
                                        device=device).unsqueeze(0))
      char_max_lengths.append(char_max_length)
      target_array = self.sentence_to_array(target_list, vocabulary="target")
      input_tensor = torch.tensor(input_array, dtype=torch.long, device=device)
      target_tensor = torch.tensor(target_array, dtype=torch.long, device=device)
      example_lengths.append(len(input_tensor))
      examples.append({"input_tensor": input_tensor.unsqueeze(0),
                       "char_input_tensor": char_inputs,
                       "target_tensor": target_tensor.unsqueeze(0)})
    return examples, example_lengths, char_max_lengths

  def get_vocabulary(self, vocabulary: str) -> Vocabulary:
    if vocabulary == "input":
      vocab = self._input_vocabulary
    elif vocabulary == "char":
      vocab = self._char_vocabulary
    elif vocabulary == "target":
      vocab = self._target_vocabulary
    else:
      raise ValueError(
          "Specified unknown vocabulary in sentence_to_array: {}".format(
              vocabulary))
    return vocab
  
  def sentence_to_array(self, sentence: List[str], vocabulary: str) -> List[int]:
    """
    Convert each str word in a sentence to an integer from the vocabulary.
    :param sentence: the sentence in words (strings).
    :param vocabulary: whether to use the input or target vocabulary.
    :return: the sentence in integers.
    """
    vocab = self.get_vocabulary(vocabulary)
    sentence_array = []
    for word in sentence:
      sentence_array.append(vocab.token_to_idx(word))
    return sentence_array
  
  def print_stats(self, split="train"):
    print("Number of %s examples in dataset: %d\n" % (split, len(self.data[split]["examples"])))
    print("Input vocabulary size: %d" % self._input_vocabulary.size)
    print("Most common input tokens: ", self._input_vocabulary.most_common(5))
    print("\nChar vocabulary size: %d" % self._char_vocabulary.size)
    print("Most common input tokens: ", self._char_vocabulary.most_common(5))
    print("\nTarget vocabulary size: %d" % self._target_vocabulary.size)
    print("Most common target tokens: ", self._target_vocabulary.most_common(5))
    if len(self.data[split]["examples"]) > 0:
      print("\n%s example: " % split)
      self.print_example(0, split)

  def get_example(self, idx: int, split="train"):
    if idx >= len(self.data[split]["examples"]):
      raise ValueError("Dataset has no example at idx %d for split %s" 
                       % (idx, split))
    input_tensor = self.array_to_sentence(self.data[split]["examples"][idx]["input_tensor"],
                                          "input")
    
    # Convert each word in the sentence into an array of integers per char.
    char_inputs = [self.array_to_sentence(char_input, "char") for char_input in
                   self.data[split]["examples"][idx]["char_input_tensor"]]
    target_tensor = self.array_to_sentence(self.data[split]["examples"][idx]["target_tensor"],
                                           "target")
    return input_tensor, char_inputs, target_tensor

  def print_example(self, idx: int, split="train"):
    input_tensor, char_input_tensor, target_tensor = self.get_example(idx, split=split)
    print(" ".join(target_tensor))
    print(" ".join(input_tensor))
    for char_input in char_input_tensor:
      print(" ".join(char_input), end='')
      print("    ", end='')
    print()

  def array_to_sentence(self, sentence_array: List[int], 
                        vocabulary: str) -> List[str]:
    """
    Translate each integer in a sentence array to the corresponding word.
    :param sentence_array: array with integers representing words from the vocabulary.
    :param vocabulary: whether to use the input or target vocabulary.
    :return: the sentence in words.
    """
    vocab = self.get_vocabulary(vocabulary)
    return [vocab.idx_to_token(token_idx) for token_idx in 
            sentence_array.squeeze(dim=0)]

  def shuffle_train_data(self):
    zipped_data = list(zip(self.data["train"]["examples"], 
                           self.data["train"]["example_lengths"],
                           self.data["train"]["char_max_lengths"]))
    random.shuffle(zipped_data)
    self.data["train"]["examples"], self.data["train"]["example_lengths"], self.data["train"]["char_max_lengths"] = zip(*zipped_data)
    self.data["train"]["examples"] = list(self.data["train"]["examples"])
    self.data["train"]["example_lengths"] = list(self.data["train"]["example_lengths"])
    self.data["train"]["char_max_lengths"] = list(self.data["train"]["char_max_lengths"])

  def print_batch(self, batch_tuple):
    input_tensors, ex_lengths, char_tensors, char_lengths, target_tensors = batch_tuple
    current_char_idx = 0
    for input_tensor, ex_length, char_length, target_tensor in zip(input_tensors, ex_lengths, char_lengths, target_tensors):
      char_tensor = char_tensors[current_char_idx:current_char_idx+ex_length]
      current_char_idx += ex_length
      print("Example length %d" % ex_length)
      print("Input tensor: ", input_tensor)
      input_sentence = self.array_to_sentence(input_tensor, "input")
      print("Input sentence: %s" % " ".join(input_sentence))
      print("Input sentence in chars: ")
      for word, char_length in zip(char_tensor, char_length):
        print("Tensor: ", word)
        print("Word: %s " % self.array_to_sentence(word, "char"))
      print()
      print("Target tensor: ", target_tensor)
      target_sentence = self.array_to_sentence(target_tensor, "target")
      print("Target sentence: %s" % " ".join(target_sentence))

  def get_batch(self, batch_size=2, split="train") -> Tuple[torch.Tensor, List[int],
                                             torch.Tensor]:
    """
    Combines `batch_size` input examples into a batch. Loops over all examples
    in jumps of `batch_size`, and pads everything such that the batch becomes
    of size: 
    inputs: [batch_size, sequence_length, input_vocabulary_size]
    characters: [batch_size, sequence_length, word_length, char_vocabulary_size]
    targets: [batch_size, sequence_length, target_vocabulary_size]
    """
    all_examples = self.data[split]["examples"]
    all_example_lengths = self.data[split]["example_lengths"]
    char_example_lengths = self.data[split]["char_max_lengths"]

    for example_i in range(0, len(all_examples), batch_size):

      # Select the examples, lengths, and max character lengths per sequence,
      examples = all_examples[example_i:example_i + batch_size]
      example_lengths = all_example_lengths[example_i:example_i + batch_size]
      char_max_lengths = char_example_lengths[example_i:example_i + batch_size]

      # Sort them if the batch size is larger than 1.
      if len(set(example_lengths)) > 1:
        examples.sort(reverse=True, key=lambda x: len(x["input_tensor"][0]))
        example_lengths, char_max_lengths = (list(t) for t in 
                                             zip(*sorted(zip(example_lengths, 
                                                             char_max_lengths),
                                                         reverse=True)))
      
      # We need to pad every example sequence to the max length of the batch.
      max_length = np.max(example_lengths)

      # We need to pad each character to the max number of characters of any
      # word in the entire batch.
      max_char_length = np.max(char_max_lengths)

      input_batch = []
      char_lengths_batch = []
      char_batch = []
      target_batch = []
      for example in examples:
          to_pad = max_length - example["input_tensor"].size(1)

          # Loop over the maximum number of words in the batch.
          padded_char_inputs = []
          char_lengths = []
          for i_char_input in range(max_length):

            # If we still have words for this example, get the word.
            if i_char_input < len(example["char_input_tensor"]):
              char_input = example["char_input_tensor"][i_char_input]
            # Else we add a padding word.
            else:
              char_input = torch.zeros(max_char_length, 
                                       dtype=torch.long,
                                       device=device).unsqueeze(0)
            
            # Pad the character input
            to_pad_chars = max_char_length - char_input.size(1)
            char_lengths.append(char_input.size(1))
            padded_char_input = torch.cat([
              char_input,
              torch.zeros(int(to_pad_chars), dtype=torch.long, device=device).unsqueeze(0)], dim=1)
            char_batch.append(padded_char_input)
          char_lengths_batch.append(char_lengths)
          
          # Pad input and target to the maximum sequence length in the batch.
          padded_input = torch.cat([
              example["input_tensor"],
              torch.zeros(int(to_pad), dtype=torch.long, device=device).unsqueeze(0)], dim=1)
          padded_target = torch.cat([
              example["target_tensor"],
              torch.zeros(int(to_pad), dtype=torch.long, device=device).unsqueeze(0)], dim=1)
          input_batch.append(padded_input)
          target_batch.append(padded_target)
      yield (torch.cat(input_batch, dim=0), example_lengths, 
             torch.cat(char_batch, dim=0), char_lengths_batch,
             torch.cat(target_batch, dim=0))

Let’s also adjust the encoder and tagger from part two to process the character sequences:

class Encoder(nn.Module):
  """
  A simple encoder model to encode sentences. Bi-LSTM over word embeddings.
  """
  def __init__(self, vocabulary_size: int, embedding_dim: int,
               char_vocabulary_size: int, char_embedding_dim,
               hidden_dimension: int, char_hidden_dimension: int,
               padding_idx: int):
    super(Encoder, self).__init__()
    # The word embeddings.
    self.embedding = nn.Embedding(num_embeddings=vocabulary_size, 
                                  embedding_dim=embedding_dim,
                                  padding_idx=padding_idx)
    
    # And the character embeddings.
    self.char_embedding = nn.Embedding(num_embeddings=char_vocabulary_size,
                                       embedding_dim=char_embedding_dim,
                                       padding_idx=padding_idx)
    
    # The bi-LSTM.
    self.bi_lstm = nn.LSTM(input_size=embedding_dim, 
                           hidden_size=hidden_dimension, num_layers=1, 
                           bias=True, bidirectional=True, batch_first=True)
    
    # And a bi-LSTM to summarize character-based words, note that char_hidden_dimension must equal hidden_dimension.
    self.char_lstm = nn.LSTM(input_size=char_embedding_dim,
                             hidden_size=char_hidden_dimension,
                             bias=True, bidirectional=True, batch_first=True)
  
  def forward(self, sentence: torch.Tensor, char_inputs: torch.Tensor) -> torch.Tensor:
    """
    :param inputs: tuple with sentence at word level and char level
                      of size [batch_size, sequence_length]
    Returns: tensor of size [batch_size, sequence_length, hidden_size * 2] 
    the hidden states of the biLSTM for each time step.
    """
    # sentence: [batch_size, sequence_length]
    # char_inputs: [batch_size * sequence_length, word_lengths]
    num_words, word_lengths = char_inputs.shape
    batch_size, sequence_length = sentence.shape

    embedded_chars = self.char_embedding(char_inputs)
    # embedded_chars: [batch_size * sequence_length, word_lengths, char_embedding_dim]

    _, (hidden, _) = self.char_lstm(embedded_chars)
    # hidden: [batch_size * sequence_length, char_hidden_dimension]
    embedded_words = hidden.view(batch_size, sequence_length, -1)

    embedded = self.embedding(sentence)
    # embedded: [batch_size, sequence_length, embedding_dimension]
    
    output, (hidden, cell) = self.bi_lstm(embedded)
    # output: [batch_size, sequence_length, hidden_size * 2]
    # hidden: [batch_size, hidden_size * 2]
    # cell: [batch_size, hidden_size * 2]

    output = torch.cat([output, embedded_words], dim=2)
    return output

And the tagger:

class Tagger(nn.Module):
  """
  A POS tagger.
  """
  
  def __init__(self, input_vocabulary: Vocabulary, 
               char_vocabulary: Vocabulary,
               target_vocabulary: Vocabulary,
               embedding_dimension: int, char_embedding_dimension: int,
               hidden_dimension: int, char_hidden_dimension: int,
               crf: bool):
    super(Tagger, self).__init__()
    # The Encoder to extract features from the input sequence.
    self.encoder = Encoder(vocabulary_size=input_vocabulary.size, 
                           char_vocabulary_size=char_vocabulary.size,
                           char_embedding_dim=char_embedding_dimension,
                           embedding_dim=embedding_dimension, 
                           hidden_dimension=hidden_dimension, 
                           char_hidden_dimension=char_hidden_dimension,
                           padding_idx=input_vocabulary.pad_idx)
    
    # The linear projection (with parameters W and b).  
    encoder_output_dim = hidden_dimension * 2 + char_hidden_dimension * 2
    self.encoder_to_tags = nn.Linear(encoder_output_dim, 
                                     target_vocabulary.size)
    
    # The linear-chain CRF.
    self.crf = crf
    if self.crf:
      self.tagger = ChainCRF(num_tags=target_vocabulary.size, 
                           tag_vocabulary=target_vocabulary)
    else:
      self.tagger = torch.nn.Softmax(dim=-1)
      self.loss_fn = torch.nn.CrossEntropyLoss()
  
  def forward(self, input_sequence: torch.Tensor, 
              input_sequence_chars: torch.Tensor,
              target_sequence: torch.Tensor, 
              input_mask: torch.Tensor, input_lengths: torch.Tensor):
    """
    :param input_sequence: input sequence of size 
            [batch_size, sequence_length, input_vocabulary_size]
    :param input_sequence_chars: character-based input sequence of size 
            [batch_size, sequence_length, word_length, char_input_vocab_size]
    :param target_sequence: POS tags target, [batch_size, sequence_length]
    :param input_mask: padding-mask, [batch_size, sequence_length]
    :param input_lengths: lengths of each example in the batch [batch_size]
    Returns: ...
    """
    # input_sequence: [batch_size, sequence_length, input_vocabulary_size]
    # input_sequence_chars: [batch_size, sequence_length, word_length, 
    #                        char_input_vocabulary_size]
    lstm_features = self.encoder(input_sequence, input_sequence_chars)
    # lstm_features: [batch_size, sequence_length, 
    #                 hidden_dimension*2 + char_hidden_dimension*2]
    
    crf_features = self.encoder_to_tags(lstm_features)
    # crf_features: [batch_size, sequence_length, target_vocabulary_size]
    
    if self.crf:
      loss, scores, tag_sequence = self.tagger(input_features=crf_features,
                                              target_tags=target_sequence,
                                              input_mask=input_mask,
                                              input_lengths=input_lengths)
      # loss, score: scalars
      # tag_sequence: [batch_size, sequence_length]
    else:
      batch_size, sequence_length, target_vocabulary_size = crf_features.shape
      scores = self.tagger(crf_features)
      loss = self.loss_fn(scores.reshape(batch_size * sequence_length,
                                         target_vocabulary_size),
                          target_sequence.reshape(batch_size * sequence_length))
      tag_sequence = torch.argmax(scores, dim=-1)
    return loss, scores, tag_sequence

The most important class, the ChainCRF, hasn’t changed from part two!

Now we’ll grab the UD dataset from the awesome TorchNLP library, which has been made incredibly easy:

ud_dataset = ud_pos_dataset(train=True, test=True)

Let’s pass the data we want to our TaggingDataset and look at some stats. Now ud_dataset holds the training and test set of Universal Dependencies, which are both lists of data points, where each data point is a dict with the tokens (AKA the input sequence), the ud_tags (the UPOS tags), and the ptb_tags (the Penn Treebank tags). For our class we need to put these in a list of tuples with the input sequence and target tags. We’ll choose as targets the ud_tags, since there are less classes for those than the Penn Treebank tags, and the task is a bit easier.

training_data_ud = [(example["tokens"], example["ud_tags"]) 
                        for example in ud_dataset[0]]
test_data_ud = [(example["tokens"], example["ud_tags"]) 
                        for example in ud_dataset[1]]
tagging_set = TaggingDataset(data=training_data_ud)
tagging_set.read_testset(test_data_ud)
tagging_set.print_stats()
Number of train examples in dataset: 12543

Input vocabulary size: 19674
Most common input tokens:  [('.', 8640), ('the', 8152), (',', 7021), ('to', 5076), ('and', 4855)]

Char vocabulary size: 110
Most common input tokens:  [('e', 93412), ('t', 67775), ('a', 63699), ('o', 58826), ('n', 53490)]

Target vocabulary size: 19
Most common target tokens:  [('NOUN', 34781), ('PUNCT', 23679), ('VERB', 23081), ('PRON', 18577), ('ADP', 17638)]

train example: 
PROPN PUNCT PROPN PUNCT ADJ NOUN VERB PROPN PROPN PROPN PUNCT PROPN PUNCT DET NOUN ADP DET NOUN ADP DET NOUN ADP PROPN PUNCT ADP DET ADJ NOUN PUNCT
Al - Zaman : American forces killed Shaikh Abdullah al - Ani , the preacher at the mosque in the town of Qaim , near the Syrian border .
A l    -    Z a m a n    :    A m e r i c a n    f o r c e s    k i l l e d    S h a i k h    A b d u l l a h    a l    -    A n i    ,    t h e    p r e a c h e r    a t    t h e    m o s q u e    i n    t h e    t o w n    o f    Q a i m    ,    n e a r    t h e    S y r i a n    b o r d e r    .    


Number of test examples in dataset: 2077

Input vocabulary size: 19674
Most common input tokens:  [('.', 8640), ('the', 8152), (',', 7021), ('to', 5076), ('and', 4855)]

Char vocabulary size: 110
Most common input tokens:  [('e', 93412), ('t', 67775), ('a', 63699), ('o', 58826), ('n', 53490)]

Target vocabulary size: 19
Most common target tokens:  [('NOUN', 34781), ('PUNCT', 23679), ('VERB', 23081), ('PRON', 18577), ('ADP', 17638)]

test example: 
PRON SCONJ PROPN VERB ADP PROPN PUNCT
What if Google <UNK> Into <UNK> ?
W h a t    i f    G o o g l e    M o r p h e d    I n t o    G o o g l e O S    ?    

As we can see above, we have 12543 training examples, 2077 testing examples, an input vocabulary of size 19674, a character vocabulary of size 110, and 19 possible POS tags. The most common input tokens are generally common tokens and their POS tags.

In the printed test example “What if Google <UNK> Into <UNK>?” we encounter two unknown words, which are the words “Morphed” and “GoogleOS”, like we can see from the character representations. If we didn’t have those, the model wouldn’t have any information about these words and could only guess a tag based on the other words in the input sequence, the predicted tag before it if the CRF is used, and the average <UNK> embedding.

Training & Testing

We can now train our bi-LSTM-CRF on the Universal Dependencies training data! We’ll use Adam optimizer with parameters taken from Ma & Hovy. Below, each epoch loops over the entire dataset and puts each batch in the data through our model, calculating the loss and taking a gradient step for the batch.

We’re going to choose a batch size of 100 as opposed to 10 in Ma & Hovy, because we want to train a bit faster and don’t care too much about performance now. We anyway will never achieve the same performance as in Ma & Hovy, because they use many more things that increase performance, most importantly probably a character-based CNN and pre-trained word embeddings. If we were to optimize for performance we would do many more things than discussed here, some of which we’ll briefly discuss below in the section Disclaimers.

def train(data: TaggingDataset, model: ChainCRF, batch_size: int, 
          num_epochs: int):
  """
  :param data: a TaggingDataset filled with training data.
  :param model: an initialized tagger model.
  :param batch_size: a minibatch size.
  :param num_epochs: how many times to go over the entire training data.
  """
  trainable_parameters = [p for p in model.parameters() if p.requires_grad]
  optimizer = torch.optim.Adam(trainable_parameters,
                               lr=1e-3, betas=(0.9, 0.9))
  for epoch in range(num_epochs):
    if (epoch + 1) % 1 == 0:
      print("Epoch %d" % (epoch + 1))
      print("Epoch loss: ", epoch_loss / num_iterations)
    epoch_loss = 0
    num_iterations = 0
    for iteration, (input_sequence, example_lengths, 
                    char_input_sequence, char_example_lengths,
                    target_sequence) in enumerate(data.get_batch(batch_size=batch_size)):
      input_mask = (input_sequence > 0).long()
      input_lengths = torch.tensor(example_lengths, device=device)
      batch_loss, score, output_tags = model(input_sequence=input_sequence, 
                                             input_sequence_chars=char_input_sequence,
                                             target_sequence=target_sequence, 
                                             input_mask=input_mask,
                                             input_lengths=input_lengths)
      loss = torch.mean(batch_loss)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      epoch_loss += loss.item()
      num_iterations += 1

We will also implement the testing loop so below we can immediately test our trained model on the unseen data.

def test(data: TaggingDataset, model: ChainCRF):
  """
  Loops over the test data in `data`, calculates the best scoring
  tag sequence according to the trained `model` for each example,
  and returns the sequences, predictions, and the accuries for
  all examples.
  :param data: An instance of TaggingDataset containing test data.
  :param model: A (trained) ChainCRF.
  
  Returns: a tuple of inputs, targets, predicted targets, accuracies, 
            and mean accuracy.
  """
  input_sequences = []
  target_sequences = []
  decoded_tags = []
  accuracies = []
  total_accs = 0
  n_examples = 0
  for i, (input_sequence, example_lengths, 
          char_input_sequence, char_example_lengths,
          target_sequence) in enumerate(data.get_batch(batch_size=1, split="test")):
      # Save the sequences in string-form instead of numerical form.
      input_sequences.append(data.array_to_sentence(input_sequence, "input"))
      target_sequences.append(data.array_to_sentence(target_sequence, "target"))
      input_mask = (input_sequence > 0).long()
      input_lengths = torch.tensor(example_lengths))
      
      # Get the predicted output sequence of tags.
      batch_loss, score, output_tags = model(input_sequence=input_sequence, 
                                             input_sequence_chars=char_input_sequence,
                                             target_sequence=target_sequence, 
                                             input_mask=input_mask,
                                             input_lengths=input_lengths)
      
      # Calculate the accuracy and save it.
      target_sequence = target_sequence[0]
      accuracy = (target_sequence 
                   == output_tags[0]).long().sum().float() / len(output_tags[0])
      accuracies.append(accuracy.item())
      total_accs += accuracy.item()
      n_examples += 1
      decoded_tags.append(data.array_to_sentence(output_tags, "target"))
  mean_acc = total_accs / n_examples
  return input_sequences, target_sequences, decoded_tags, accuracies, mean_acc

Allright, now let’s initialize our model and train it for 15 epochs. First, let’s train a simple bi-LSTM, without using the CRF.

model = Tagger(input_vocabulary=tagging_set.get_vocabulary("input"),
               target_vocabulary=tagging_set.get_vocabulary("target"),
               char_vocabulary=tagging_set.get_vocabulary("char"),
               embedding_dimension=128,
               char_embedding_dimension=50,
               char_hidden_dimension=50, 
               hidden_dimension=100,
               crf=False)
model.to(device)
train(data=tagging_set, model=model, batch_size=100, num_epochs=15)
Epoch 1
Trained for 126 iterations.
Epoch loss:  2.177231947580973
Epoch 2
Trained for 126 iterations.
Epoch loss:  2.109471258662996
Epoch 3
Trained for 126 iterations.
Epoch loss:  2.1045296589533486
Epoch 4
Trained for 126 iterations.
Epoch loss:  2.087123019354684
Epoch 5
Trained for 126 iterations.
Epoch loss:  2.0774436413295687
Epoch 6
Trained for 126 iterations.
Epoch loss:  2.0735075284564304
Epoch 7
Trained for 126 iterations.
Epoch loss:  2.0646931557428267
Epoch 8
Trained for 126 iterations.
Epoch loss:  2.0587902144780235
Epoch 9
Trained for 126 iterations.
Epoch loss:  2.0560855203204684
Epoch 10
Trained for 126 iterations.
Epoch loss:  2.0547851891744706
Epoch 11
Trained for 126 iterations.
Epoch loss:  2.0541635732802135
Epoch 12
Trained for 126 iterations.
Epoch loss:  2.0540022339139665
Epoch 13
Trained for 126 iterations.
Epoch loss:  2.053600659446111
Epoch 14
Trained for 126 iterations.
Epoch loss:  2.0533236397637262
Epoch 15
Trained for 126 iterations.
Epoch loss:  2.0526692545603193

The training of 15 epochs using a GPU in Google Colab takes about 10 minutes. The loss seems to steadily go down each epoch (although not a lot). We can test our model on the test data of the UD dataset and have a look at the mean accuracy per example. The accuracy for one example is calculated (above in the test loop) as:

\[\text{acc} = \frac{1}{m}\sum_{t=1}^{m} \mathbb{1}(\hat{y}_t, y_t)\]

Where \(\mathbb{1}(\hat{y}_t, y_t)\) denotes the indicator function that equals 1 if the current predicted tag \(\hat{y}_t\) equals the ground-truth target tag \(y_t\) and 0 otherwise. The mean accuracy is then the mean of this metric over the entire testset.

inputs, targets, predictions, accs, mean_acc = test(tagging_set, model)
print(mean_acc)
0.89

So without the CRF, we already get a mean accuracy of 89%. That’s pretty good. Let’s see what we can do with the CRF enabled.

crf_model = Tagger(input_vocabulary=tagging_set.get_vocabulary("input"),
                   target_vocabulary=tagging_set.get_vocabulary("target"),
                   char_vocabulary=tagging_set.get_vocabulary("char"),
                   embedding_dimension=128,
                   char_embedding_dimension=50,
                   char_hidden_dimension=50, 
                   hidden_dimension=100,
                   crf=True)
model.to(device)
train(data=tagging_set, model=crf_model, batch_size=100, num_epochs=15)
Epoch 1
Trained for 126 iterations.
Epoch loss:  9.147202217389667
Epoch 2
Trained for 126 iterations.
Epoch loss:  2.6090604748044695
Epoch 3
Trained for 126 iterations.
Epoch loss:  1.0739414606775557
Epoch 4
Trained for 126 iterations.
Epoch loss:  0.5113307234077227
Epoch 5
Trained for 126 iterations.
Epoch loss:  0.24992894984426953
Epoch 6
Trained for 126 iterations.
Epoch loss:  0.11944707252439998
Epoch 7
Trained for 126 iterations.
Epoch loss:  0.058415038792032095
Epoch 8
Trained for 126 iterations.
Epoch loss:  0.02836258260030595
Epoch 9
Trained for 126 iterations.
Epoch loss:  0.016542019189468453
Epoch 10
Trained for 126 iterations.
Epoch loss:  0.011555857158132963
Epoch 11
Trained for 126 iterations.
Epoch loss:  0.009479762773798217
Epoch 12
Trained for 126 iterations.
Epoch loss:  0.0074993557710614465
Epoch 13
Trained for 126 iterations.
Epoch loss:  0.006215568138508215
Epoch 14
Trained for 126 iterations.
Epoch loss:  0.006054751822606675
Epoch 15
Trained for 126 iterations.
Epoch loss:  0.005618702975981351

The training time is now rather about 20 minutes, but the loss goes down much more during the epochs here. Do note however that we cannot compare the values of the loss between this model and the one without the CRF; they are using completely different loss functions. Let’s look at the accuracy we get now.

crf_inputs, crf_targets, crf_predictions, crf_accs, crf_mean_acc = test(tagging_set, 
                                                                        crf_model)
print(crf_mean_acc)
0.94

That’s a pretty significant improvement! It seems like this test data can benefit from assuming dependencies in the output sequence, which makes sense for the POS tagging task, as motivated before, helping for ambiguous words like “book” for example.

Let’s look some predictions of both models. Below a quick function to print some different examples.

def print_predicted_examples(idx: int):
  print("bi-LSTM outputs:")
  print("      Input sequence: " + ' '.join(inputs[idx]))
  print("     Target sequence: " + ' '.join(targets[idx]))
  print("Predictions sequence: " + ' '.join(predictions[idx]))
  print("Accuracy: ", accs[idx])
  print()

  print("bi-LSTM-CRF outputs:")
  print("      Input sequence: " + ' '.join(crf_inputs[idx]))
  print("     Target sequence: " + ' '.join(crf_targets[idx]))
  print("Predictions sequence: " + ' '.join(crf_predictions[idx]))
  print("Accuracy: ", crf_accs[idx])
  print()

Let’s print examples at index 1 and 2.

bi-LSTM outputs:
      Input sequence: [ via Microsoft Watch from Mary <UNK> <UNK> ]
     Target sequence: PUNCT ADP PROPN PROPN ADP PROPN PROPN PROPN PUNCT
Predictions sequence: PUNCT ADP PROPN VERB ADP PROPN PROPN NOUN PUNCT
Accuracy:  0.86

bi-LSTM-CRF outputs:
      Input sequence: [ via Microsoft Watch from Mary <UNK> <UNK> ]
     Target sequence: PUNCT ADP PROPN PROPN ADP PROPN PROPN PROPN PUNCT
Predictions sequence: PUNCT ADP PROPN VERB ADP PROPN PROPN VERB PUNCT
Accuracy:  0.86

This example shows the benefits of the character model. Even though the sentence contains test words that are unknown because they didn’t occur during training, the character model managed to predict the right tag for one of the unknown words.

bi-LSTM outputs:
      Input sequence: They own blogger , of course .
     Target sequence: PRON VERB PROPN PUNCT ADV ADV PUNCT
Predictions sequence: PRON VERB ADJ PUNCT ADP NOUN PUNCT
Accuracy:  0.57

bi-LSTM-CRF outputs:
      Input sequence: They own blogger , of course .
     Target sequence: PRON VERB PROPN PUNCT ADV ADV PUNCT
Predictions sequence: PRON ADJ VERB PUNCT ADV ADV PUNCT
Accuracy:  0.71

This example shows the power of the CRF. For the last words, “of course”, the bi-LSTM model fails to see the connection between “of” and “course” and between the tag “ADV” following “ADV” in that case, whereas the CRF handles this situation correctly. There are many more examples like this where the CRF properly disambiguates words. Check it out for yourself in the Google Colab!

Conclusion

We’re done! We derived, implemented, and trained a linear-chain CRF, showing that is gets significantly higher test accuracy for a real-world dataset than a simple biLSTM model.

Disclaimers

There are many things that are actually good practice in deep learning, or things that might improve performance, that we didn’t do here. For example, for every epoch we looped over that data in the same order, making it not really SGD. We didn’t optimize at all for hyperparameters and randomly chose some things. Dropout, learning rate decay, pre-trained embedings, etc.

Sources

Natalia Silveira and Timothy Dozat and Marie-Catherine de Marneffe and Samuel Bowman and Miriam Connor and John Bauer and Christopher D. Manning (2014). A Gold Standard Dependency Corpus for English