deepchem.metalearning package

Submodules

deepchem.metalearning.maml module

Model-Agnostic Meta-Learning (MAML) algorithm for low data learning.

class deepchem.metalearning.maml.MAML(learner, learning_rate=0.001, optimization_steps=1, meta_batch_size=10, optimizer=<deepchem.models.tensorgraph.optimizers.Adam object>, model_dir=None)[source]

Bases: object

Implements the Model-Agnostic Meta-Learning algorithm for low data learning.

The algorithm is described in Finn et al., “Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks” (https://arxiv.org/abs/1703.03400). It is used for training models that can perform a variety of tasks, depending on what data they are trained on. It assumes you have training data for many tasks, but only a small amount for each one. It performs “meta-learning” by looping over tasks and trying to minimize the loss on each one after one or a few steps of gradient descent. That is, it does not try to create a model that can directly solve the tasks, but rather tries to create a model that is very easy to train.

To use this class, create a subclass of MetaLearner that encapsulates the model and data for your learning problem. Pass it to a MAML object and call fit(). You can then use train_on_current_task() to fine tune the model for a particular task.

fit(steps, max_checkpoints_to_keep=5, checkpoint_interval=600, restore=False)[source]

Perform meta-learning to train the model.

Parameters:
  • steps (int) – the number of steps of meta-learning to perform
  • max_checkpoints_to_keep (int) – the maximum number of checkpoint files to keep. When this number is reached, older files are deleted.
  • checkpoint_interval (float) – the time interval at which to save checkpoints, measured in seconds
  • restore (bool) – if True, restore the model from the most recent checkpoint and continue training from there. If False, retrain the model from scratch.
restore()[source]

Reload the model parameters from the most recent checkpoint file.

train_on_current_task(optimization_steps=1, restore=True)[source]

Perform a few steps of gradient descent to fine tune the model on the current task.

Parameters:
  • optimization_steps (int) – the number of steps of gradient descent to perform
  • restore (bool) – if True, restore the model from the most recent checkpoint before optimizing
class deepchem.metalearning.maml.MetaLearner[source]

Bases: object

Model and data to which the MAML algorithm can be applied.

To use MAML, create a subclass of this defining the learning problem to solve. It consists of a model that can be trained to perform many different tasks, and data for training it on a large (possibly infinite) set of different tasks.

get_batch()[source]

Get a batch of data for training.

This should return the data in the form of a Tensorflow feed dict, that is, a dict mapping tensors to values. This will usually be called twice for each task, and should return a different batch on each call.

loss

Get the model’s loss function, represented as a Layer or Tensor.

select_task()[source]

Select a new task to train on.

If there is a fixed set of training tasks, this will typically cycle through them. If there are infinitely many training tasks, this can simply select a new one each time it is called.

variables

Get the list of Tensorflow variables to train.

The default implementation returns all trainable variables in the graph. This is usually what you want, but subclasses can customize it if needed.

Module contents