General utility code for building PyTorch-based agents in ParlAI.

Contains the following main utilities: - TorchAgent class which serves as a useful parent class for other model agents - Batch namedtuple which is the input type of the main abstract methods of

the TorchAgent class
  • Output namedtuple which is the expected output type of the main abstract methods of the TorchAgent class
  • Beam class which provides some generic beam functionality for classes to use

See below for documentation on each specific tool.

class parlai.core.torch_agent.Batch(text_vec, text_lengths, label_vec, label_lengths, labels, valid_indices, candidates, candidate_vecs, image, memory_vecs)

Return self as a plain tuple. Used by copy and pickle.

static __new__(_cls, text_vec=None, text_lengths=None, label_vec=None, label_lengths=None, labels=None, valid_indices=None, candidates=None, candidate_vecs=None, image=None, memory_vecs=None)

Create new instance of Batch(text_vec, text_lengths, label_vec, label_lengths, labels, valid_indices, candidates, candidate_vecs, image, memory_vecs)


Return a nicely formatted representation string


Alias for field number 7


Alias for field number 6


Alias for field number 8


Alias for field number 3


Alias for field number 2


Alias for field number 4


Alias for field number 9


Alias for field number 1


Alias for field number 0


Alias for field number 5

class parlai.core.torch_agent.Output(text, text_candidates)

Return self as a plain tuple. Used by copy and pickle.

static __new__(_cls, text=None, text_candidates=None)

Create new instance of Output(text, text_candidates)


Return a nicely formatted representation string


Alias for field number 0


Alias for field number 1

class parlai.core.torch_agent.TorchAgent(opt, shared=None)

A provided base agent for any model that wants to use Torch.

Exists to make it easier to implement a new agent. Not necessary, but reduces duplicated code.

Many methods are intended to be either used as is when the default is acceptable, or to be overriden and called with super(), with the extra functionality added to the initial result. See the method comment for recommended behavior.

This agent serves as a common framework for all ParlAI models which want to use PyTorch.

static dictionary_class()

Return the dictionary class that this agent expects to use.

Can be overriden if a more complex dictionary is required.

classmethod add_cmdline_args(argparser)

Add the default commandline args we expect most agents to want.

__init__(opt, shared=None)

Initialize agent.


Use the metrics to decide when to adjust LR schedule.

This uses the loss as the validation metric if present, if not this function does nothing. Note that the model must be reporting loss for this to work. Override this to override the behavior.


Share fields from parent as well as useful objects in this class.

Subclasses will likely want to share their model as well.

vectorize(obs, add_start=True, add_end=True, truncate=None, split_lines=False)

Make vectors out of observation fields and store in the observation.

In particular, the ‘text’ and ‘labels’/’eval_labels’ fields are processed and a new field is added to the observation with the suffix ‘_vec’.

If you want to use additional fields on your subclass, you can override this function, call super().vectorize(...) to process the text and labels, and then process the other fields in your subclass.

  • obs – Single observation from observe function.
  • add_start – default True, adds the start token to each label.
  • add_end – default True, adds the end token to each label.
  • truncate – default None, if set truncates all vectors to the specified length. Note that this truncates to the rightmost for inputs and the leftmost for labels and, when applicable, candidates.
  • split_lines – If set, returns list of vectors instead of a single vector for input text, one for each substring after splitting on newlines.
batchify(obs_batch, sort=False, is_valid=<function TorchAgent.<lambda>>)

Create a batch of valid observations from an unchecked batch.

A valid observation is one that passes the lambda provided to the function, which defaults to checking if the preprocessed ‘text_vec’ field is present which would have been set by this agent’s ‘vectorize’ function.

Returns a namedtuple Batch. See original definition above for in-depth explanation of each field.

If you want to include additonal fields in the batch, you can subclass this function and return your own “Batch” namedtuple: copy the Batch namedtuple at the top of this class, and then add whatever additional fields that you want to be able to access. You can then call super().batchify(...) to set up the original fields and then set up the additional fields in your subclass and return that batch instead.

  • obs_batch – List of vectorized observations
  • sort – Default False, orders the observations by length of vectors. Set to true when using torch.nn.utils.rnn.pack_padded_sequence. Uses the text vectors if available, otherwise uses the label vectors if available.
  • is_valid – Function that checks if ‘text_vec’ is in the observation, determines if an observation is valid
match_batch(batch_reply, valid_inds, output=None)

Match sub-batch of predictions to the original batch indices.

Batches may be only partially filled (i.e when completing the remainder at the end of the validation or test set), or we may want to sort by e.g the length of the input sequences if using pack_padded_sequence.

This matches rows back with their original row in the batch for calculating metrics like accuracy.

If output is None (model choosing not to provide any predictions), we will just return the batch of replies.

Otherwise, output should be a parlai.core.torch_agent.Output object. This is a namedtuple, which can provide text predictions and/or text_candidates predictions. If you would like to map additional fields into the batch_reply, you can override this method as well as providing your own namedtuple with additional fields.

  • batch_reply – Full-batchsize list of message dictionaries to put responses into.
  • valid_inds – Original indices of the predictions.
  • output – Output namedtuple which contains sub-batchsize list of text outputs from model. May be None (default) if model chooses not to answer. This method will check for text and text_candidates fields.
get_dialog_history(observation, reply=None, add_person_tokens=False, add_p1_after_newln=False)

Retrieve dialog history and add current observations to it.

  • observation – current observation
  • reply – past utterance from the model to add to the history, such as the past label or response generated by the model.
  • add_person_tokens – add tokens identifying each speaking before utterances in the text & history.
  • add_p1_after_newln – add the other speaker token before the last newline in the input instead of at the beginning of the input. this is useful for tasks that include some kind of context before the actual utterance (e.g. squad, babi, personachat).

observation with text replaced with full dialog


Retrieve the last reply from the model.

If available, we use the true label instead of the model’s prediction.

By default, batch_act stores the batch of replies and this method will extract the reply of the current instance from the batch.

Parameters:use_label – default true, use the label when available instead of the model’s generated response.

Process incoming message in preparation for producing a response.

This includes remembering the past history of the conversation.


Save model parameters to path (or default to model_file arg).

Override this method for more specific saving.


Return opt and model states.

Override this method for more specific loading.


Clear internal states.


Call batch_act with the singleton batch.


Process a batch of observations (batchsize list of message dicts).

These observations have been preprocessed by the observe method.

Subclasses can override this for special functionality, but if the default behaviors are fine then just override the train_step and eval_step methods instead. The former is called when labels are present in the observations batch; otherwise, the latter is called.


Process one batch with training labels.


Process one batch but do not train on it.

class parlai.core.torch_agent.Beam(beam_size, min_length=3, padding_token=0, bos_token=1, eos_token=2, min_n_best=3, cuda='cpu')

Generic beam class. It keeps information about beam_size hypothesis.

__init__(beam_size, min_length=3, padding_token=0, bos_token=1, eos_token=2, min_n_best=3, cuda='cpu')

Instantiate Beam object.

  • beam_size – number of hypothesis in the beam
  • min_length – minimum length of the predicted sequence
  • padding_token – Set to 0 as usual in ParlAI
  • bos_token – Set to 1 as usual in ParlAI
  • eos_token – Set to 2 as usual in ParlAI
  • min_n_best – Beam will not be done unless this amount of finished hypothesis (with EOS) is done
  • cuda – What device to use for computations

Get the outputput at the current step.


Get the backtrack at the current step.


Advance the beam one step.


Return whether beam search is complete.


Get single best hypothesis.

Returns:hypothesis sequence and the final score

Extract hypothesis ending with EOS at timestep with hyp_id.

  • timestep – timestep with range up to len(self.outputs)-1
  • hyp_id – id with range up to beam_size-1

hypothesis sequence

static get_pretty_hypothesis(list_of_hypotails)

Return prettier version of the hypotheses.


Return finished hypotheses in rescored order.

Parameters:n_best – how many n best hypothesis to return
Returns:list with hypothesis

Check if self.finished is empty and add hyptail in that case.

This will be suboptimal hypothesis since the model did not get any EOS

get_beam_dot(dictionary=None, n_best=None)

Create pydot graph representation of the beam.

  • outputs – self.outputs from the beam
  • dictionary – tok 2 word dict to save words in the tree nodes

pydot graph