Tutorials
Pytorch-Lightning Integration for DeepChem Models ¶
In this tutorial we will go through how to setup a deepchem model inside the pytorch-lightning framework. Lightning is a pytorch framework which simplifies the process of experimenting with pytorch models easier. A few key functionalities offered by pytorch lightning which deepchem users can find useful are:
-
Multi-gpu training functionalities: pytorch-lightning provides easy multi-gpu, multi-node training. It also simplifies the process of launching multi-gpu, multi-node jobs across different cluster infrastructure, e.g. AWS, slurm based clusters.
-
Reducing boilerplate pytorch code: lightning takes care of details like,
optimizer.zero_grad(), model.train(), model.eval()
. Lightning also provides experiment logging functionality, for e.g. irrespective of training on CPU, GPU, multi-nodes the user can use the methodself.log
inside the trainer and it will appropriately log the metrics. -
Features that can speed up training: half-precision training, gradient checkpointing, code profiling.
Setup ¶
- This notebook assumes that you have already installed deepchem, if you have not follow the instructions at the deepchem installation page: https://deepchem.readthedocs.io/en/latest/get_started/installation.html .
- Install pytorch lightning following the instructions on lightning's home page: https://www.pytorchlightning.ai/
!pip install --pre deepchem
!pip install pytorch_lightning
Requirement already satisfied: deepchem in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (2.6.1.dev20220119163852) Requirement already satisfied: numpy>=1.21 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from deepchem) (1.22.0) Requirement already satisfied: scikit-learn in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from deepchem) (1.0.2) Requirement already satisfied: pandas in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from deepchem) (1.4.0) Collecting rdkit-pypi Downloading rdkit_pypi-2021.9.5.1-cp38-cp38-macosx_11_0_arm64.whl (15.9 MB) |████████████████████████████████| 15.9 MB 6.8 MB/s eta 0:00:01 Requirement already satisfied: joblib in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from deepchem) (1.1.0) Requirement already satisfied: scipy in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from deepchem) (1.7.3) Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from scikit-learn->deepchem) (3.0.0) Requirement already satisfied: python-dateutil>=2.8.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pandas->deepchem) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pandas->deepchem) (2021.3) Requirement already satisfied: Pillow in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from rdkit-pypi->deepchem) (8.4.0) Requirement already satisfied: six>=1.5 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from python-dateutil>=2.8.1->pandas->deepchem) (1.16.0) Installing collected packages: rdkit-pypi Successfully installed rdkit-pypi-2021.9.5.1 Requirement already satisfied: pytorch_lightning in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (1.5.8) Requirement already satisfied: typing-extensions in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (4.0.1) Requirement already satisfied: numpy>=1.17.2 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (1.22.0) Requirement already satisfied: torch>=1.7.* in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (1.10.2) Requirement already satisfied: tensorboard>=2.2.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (2.7.0) Requirement already satisfied: tqdm>=4.41.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (4.62.3) Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (2022.1.0) Requirement already satisfied: packaging>=17.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (21.3) Requirement already satisfied: PyYAML>=5.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (6.0) Requirement already satisfied: pyDeprecate==0.3.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (0.3.1) Processing /Users/princychahal/Library/Caches/pip/wheels/8e/70/28/3d6ccd6e315f65f245da085482a2e1c7d14b90b30f239e2cf4/future-0.18.2-py3-none-any.whl Requirement already satisfied: torchmetrics>=0.4.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (0.7.0) Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.6.0) Requirement already satisfied: absl-py>=0.4 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.0.0) Requirement already satisfied: grpcio>=1.24.3 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.43.0) Requirement already satisfied: requests<3,>=2.21.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (2.27.1) Requirement already satisfied: google-auth<3,>=1.6.3 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (2.3.3) Requirement already satisfied: wheel>=0.26 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.37.1) Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.4.6) Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.8.1) Requirement already satisfied: setuptools>=41.0.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (60.5.0) Requirement already satisfied: werkzeug>=0.11.15 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (2.0.2) Requirement already satisfied: markdown>=2.6.8 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (3.3.6) Requirement already satisfied: protobuf>=3.6.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (3.18.1) Requirement already satisfied: aiohttp; extra == "http" in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (3.8.1) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from packaging>=17.0->pytorch_lightning) (3.0.7) Requirement already satisfied: six in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from absl-py>=0.4->tensorboard>=2.2.0->pytorch_lightning) (1.16.0) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch_lightning) (1.26.8) Requirement already satisfied: certifi>=2017.4.17 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch_lightning) (2021.10.8) Requirement already satisfied: charset-normalizer~=2.0.0; python_version >= "3" in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch_lightning) (2.0.10) Requirement already satisfied: idna<4,>=2.5; python_version >= "3" in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch_lightning) (3.3) Requirement already satisfied: cachetools<5.0,>=2.0.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (4.2.4) Requirement already satisfied: pyasn1-modules>=0.2.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (0.2.7) Requirement already satisfied: rsa<5,>=3.1.4; python_version >= "3.6" in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (4.8) Requirement already satisfied: requests-oauthlib>=0.7.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch_lightning) (1.3.0) Requirement already satisfied: importlib-metadata>=4.4; python_version < "3.10" in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch_lightning) (4.10.1) Requirement already satisfied: aiosignal>=1.1.2 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == "http"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.2.0) Requirement already satisfied: frozenlist>=1.1.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == "http"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.2.0) Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == "http"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (4.0.2) Requirement already satisfied: multidict<7.0,>=4.5 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == "http"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (6.0.2) Requirement already satisfied: yarl<2.0,>=1.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == "http"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.7.2) Requirement already satisfied: attrs>=17.3.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == "http"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (21.4.0) Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (0.4.8) Requirement already satisfied: oauthlib>=3.0.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch_lightning) (3.1.1) Requirement already satisfied: zipp>=0.5 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from importlib-metadata>=4.4; python_version < "3.10"->markdown>=2.6.8->tensorboard>=2.2.0->pytorch_lightning) (3.7.0)
Installing collected packages: future Successfully installed future-0.18.2
Import the relevant packages.
import deepchem as dc
from deepchem.models import GCNModel
import pytorch_lightning as pl
import torch
from torch.nn import functional as F
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from torch.optim import Adam
import numpy as np
import torch
Deepchem Example ¶
Below we show an example of a Graph Convolution Network (GCN). Note that this is a simple example which uses a GCNModel to predict the label from an input sequence. We do not showcase the complete functionality of deepchem in this example as we want to restructure the deepchem code and adapt it so that it can be easily plugged into pytorch-lightning. This example was inspired from the
GCNModel
documentation present
here
.
Prepare the dataset : for training our deepchem models we need a dataset that we can use to train the model. Below we prepare a sample dataset for the purposes of this tutorial. Below we also directly use the featurized to encode examples for the dataset.
smiles = ["C1CCC1", "CCC"]
labels = [0., 1.]
featurizer = dc.feat.MolGraphConvFeaturizer()
X = featurizer.featurize(smiles)
dataset = dc.data.NumpyDataset(X=X, y=labels)
Setup the model : now we initialize the Graph Convolutional Network model that we will use in our training.
model = GCNModel(
mode='classification',
n_tasks=1,
batch_size=2,
learning_rate=0.001
)
[16:00:37] /Users/princychahal/Documents/github/dgl/src/runtime/tensordispatch.cc:43: TensorDispatcher: dlopen failed: Using backend: pytorch dlopen(/Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages/dgl-0.8-py3.8-macosx-11.0-arm64.egg/dgl/tensoradapter/pytorch/libtensoradapter_pytorch_1.10.2.dylib, 1): image not found
Train the model : fit the model on our training dataset, also specify the number of epochs to run.
loss = model.fit(dataset, nb_epoch=5)
print(loss)
0.18830760717391967
/Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages/torch/autocast_mode.py:141: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
Pytorch-Lightning + Deepchem example ¶
Now we will look at an example of the GCN model adapt for Pytorch-Lightning. For using Pytorch-Lightning there are two important components:
-
LightningDataModule
: This module defines who the data is prepared and fed into the model so that the model can use it for training. The module defines the train dataloader function which are directly used by the trainer to generate data for theLightningModule
. To learn more about theLightningDataModule
refer to the datamodules documentation . -
LightningModule
: This module defines the training, validation steps for our model. We can use this module to initialize our model based on the hyperparameters. There are a number of boilerplate functions which we use directly to track our experiments, for example we can save all the hyperparameters that we used for training using theself.save_hyperparameters()
method. For more details on how to use this module refer to the lightningmodules documentation .
Setup the torch dataset
: Note that here we need to create a custome
SmilesDataset
so that we can easily interface with the deepchem featurizers. For this interface we need to define a collate method so that we can create batches for the dataset.
# prepare LightningDataModule
class SmilesDataset(torch.utils.data.Dataset):
def __init__(self, smiles, labels):
assert len(smiles) == len(labels)
featurizer = dc.feat.MolGraphConvFeaturizer()
X = featurizer.featurize(smiles)
self._samples = dc.data.NumpyDataset(X=X, y=labels)
def __len__(self):
return len(self._samples)
def __getitem__(self, index):
return (
self._samples.X[index],
self._samples.y[index],
self._samples.w[index],
)
class SmilesDatasetBatch:
def __init__(self, batch):
X = [np.array([b[0] for b in batch])]
y = [np.array([b[1] for b in batch])]
w = [np.array([b[2] for b in batch])]
self.batch_list = [X, y, w]
def collate_smiles_dataset_wrapper(batch):
return SmilesDatasetBatch(batch)
Create the GCN specific lightning module
: in this part we use an object of the
SmilesDataset
created above to create the
SmilesDatasetModule
class SmilesDatasetModule(pl.LightningDataModule):
def __init__(self, train_smiles, train_labels, batch_size):
super().__init__()
self._train_smiles = train_smiles
self._train_labels = train_labels
self._batch_size = batch_size
def setup(self, stage):
self.train_dataset = SmilesDataset(
self._train_smiles,
self._train_labels,
)
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self._batch_size,
collate_fn=collate_smiles_dataset_wrapper,
shuffle=True,
)
Create the lightning module : in this part we create the GCN specific lightning module. This class specifies the logic flow for the training step. We also create the required models, optimizers and losses for the training flow.
# prepare the LightningModule
class GCNModule(pl.LightningModule):
def __init__(self, mode, n_tasks, learning_rate):
super().__init__()
self.save_hyperparameters(
"mode",
"n_tasks",
"learning_rate",
)
self.gcn_model = GCNModel(
mode=self.hparams.mode,
n_tasks=self.hparams.n_tasks,
learning_rate=self.hparams.learning_rate,
)
self.pt_model = self.gcn_model.model
self.loss = self.gcn_model._loss_fn
def configure_optimizers(self):
return self.gcn_model.optimizer._create_pytorch_optimizer(
self.pt_model.parameters(),
)
def training_step(self, batch, batch_idx):
batch = batch.batch_list
inputs, labels, weights = self.gcn_model._prepare_batch(batch)
outputs = self.pt_model(inputs)
if isinstance(outputs, torch.Tensor):
outputs = [outputs]
if self.gcn_model._loss_outputs is not None:
outputs = [outputs[i] for i in self.gcn_model._loss_outputs]
loss_outputs = self.loss(outputs, labels, weights)
self.log(
"train_loss",
loss_outputs,
on_epoch=True,
sync_dist=True,
reduce_fx="mean",
prog_bar=True,
)
return loss_outputs
Create the relevant objects
# create module objects
smiles_datasetmodule = SmilesDatasetModule(
train_smiles=["C1CCC1", "CCC", "C1CCC1", "CCC", "C1CCC1", "CCC", "C1CCC1", "CCC", "C1CCC1", "CCC"],
train_labels=[0., 1., 0., 1., 0., 1., 0., 1., 0., 1.],
batch_size=2,
)
gcnmodule = GCNModule(
mode="classification",
n_tasks=1,
learning_rate=1e-3,
)
Lightning Trainer ¶
Trainer is the wrapper which builds on top of the
LightningDataModule
and
LightningModule
. When constructing the lightning trainer you can also specify the number of epochs, max-steps to run, number of GPUs, number of nodes to be used for trainer. Lightning trainer acts as a wrapper over your distributed training setup and this way you are able to build your models in a way you would build them in a simple way for your local runs.
trainer = pl.Trainer(
max_epochs=5,
)
GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs
Call the fit function to run model training
# train
trainer.fit(
model=gcnmodule,
datamodule=smiles_datasetmodule,
)
| Name | Type | Params ---------------------------------- 0 | pt_model | GCN | 29.4 K ---------------------------------- 29.4 K Trainable params 0 Non-trainable params 29.4 K Total params 0.118 Total estimated model params size (MB) /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:428: UserWarning: The number of training samples (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. rank_zero_warn(
Training: 0it [00:00, ?it/s]