MNIST with DeepChem and TensorGraphΒΆ

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
import deepchem as dc
import tensorflow as tf
from deepchem.models.tensorgraph.layers import Layer, Input, Reshape, Flatten, Conv2D, Label, Feature
from deepchem.models.tensorgraph.layers import Dense, SoftMaxCrossEntropy, ReduceMean, SoftMax
train =, mnist.train.labels)
valid =, mnist.validation.labels)
tg = dc.models.TensorGraph(tensorboard=True, model_dir='/tmp/mnist', use_queue=False)
feature = Feature(shape=(None, 784))

# Images are square 28x28 (batch, height, width, channel)
make_image = Reshape(shape=(-1, 28, 28, 1), in_layers=[feature])

conv2d_1 = Conv2D(num_outputs=32, in_layers=[make_image])

conv2d_2 = Conv2D(num_outputs=64, in_layers=[conv2d_1])

flatten = Flatten(in_layers=[conv2d_2])

dense1 = Dense(out_channels=1024, activation_fn=tf.nn.relu, in_layers=[flatten])

dense2 = Dense(out_channels=10, in_layers=[dense1])

label = Label(shape=(None, 10))

smce = SoftMaxCrossEntropy(in_layers=[label, dense2])
loss = ReduceMean(in_layers=[smce])

output = SoftMax(in_layers=[dense2])
# nb_epoch set to 0 to permit rendering of tutorials online.
# Set nb_epoch=10 for better results, nb_epoch=0)
TIMING: model fitting took 2.621 s
# Note that AUCs will be nonsensical without setting nb_epoch higher!
from sklearn.metrics import roc_curve, auc
import numpy as np

prediction = np.squeeze(tg.predict_on_batch(valid.X))

fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(10):
    fpr[i], tpr[i], thresh = roc_curve(valid.y[:, i], prediction[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
    print("class %s:auc=%s" % (i, roc_auc[i]))
class 0:auc=0.170555039138
class 1:auc=0.634619025945
class 2:auc=0.303693111629
class 3:auc=0.549712392397
class 4:auc=0.523565007169
class 5:auc=0.648496399959
class 6:auc=0.316489936331
class 7:auc=0.708521348315
class 8:auc=0.555897147512
class 9:auc=0.377021042837