Guide to Pretraining

In this guide, you’ll learn how to pretrain a model without using labels.

First, import Masterful alongside the other necessary libraries:

[19]:
import numpy as np
import tensorflow as tf

import sys
sys.path.append('../../../')

import masterful
import masterful.backend.pretrain as pretrain

For this guide we’ll be using the CIFAR10 dataset, which consists of 60,000 color images of 10 separate, non-overlapping classes of objects. There are 6,000 images of each class of object; 5,000 in a “training” set and 1,000 held out for testing.

[20]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

Create a Data Specification, which captures CIFAR10’s relevant metadata:

[21]:
data_spec = masterful.DataSpec.from_dataset(
        task=masterful.spec.Task.CLASSIFICATION,
        dataset=train_dataset,
        image_range=masterful.spec.ImageRange.ZERO_255,
        num_classes=10,
        sparse=False)

Now create a Pretraining Policy, which establishes pretraining-related hyperparameters:

[22]:
pretrain_policy = masterful.PretrainPolicy(batch_size=512,
                                           epochs=3,
                                           warmup_steps=10)

Finally, create a model to be pretrained – here, the same simple CNN used in this TensorFlow tutorial for labeled training:

[23]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))

Now you’re ready to pretrain! Pretraining with Masterful will return a training report, including fields like - loss: the unsupervised loss at the end of pretraining. - accuracy: the accuracy achieved by performing K-Nearest Neighbors classification on the pretrained backbone’s output features

(Note that the labels in the dataset passed to the labeled_data parameter below are only used for KNN classification, not for pretraining the backbone.)

[24]:
training_report = pretrain.pretrain(
        pretrain_policy=pretrain_policy,
        model=model,
        validation_data=validation_dataset,
        labeled_data=train_dataset,
        unlabeled_data=None,
        synthetic_data=None,
        data_spec=data_spec,
    )
Epoch 1/3
98/98 [==============================] - 155s 2s/step - loss: 1159.5908
Feature extracting: 100%|██████████| 98/98 [00:00<00:00, 124.99it/s]
Test Epoch: Acc@1:32.01%: 100%|██████████| 20/20 [00:00<00:00, 37.86it/s]
kNN Test Accuracy at epoch 0: 32.01000213623047 Max Accuracy so far: 32.01000213623047
Epoch 2/3
98/98 [==============================] - 153s 2s/step - loss: 982.5400
Feature extracting: 100%|██████████| 98/98 [00:00<00:00, 132.23it/s]
Test Epoch: Acc@1:34.45%: 100%|██████████| 20/20 [00:00<00:00, 40.02it/s]
kNN Test Accuracy at epoch 1: 34.45000076293945 Max Accuracy so far: 34.45000076293945
Epoch 3/3
98/98 [==============================] - 153s 2s/step - loss: 923.1833
Feature extracting: 100%|██████████| 98/98 [00:00<00:00, 135.27it/s]
Test Epoch: Acc@1:35.30%: 100%|██████████| 20/20 [00:00<00:00, 37.78it/s]
kNN Test Accuracy at epoch 2: 35.29999923706055 Max Accuracy so far: 35.29999923706055

The pretraining policy above only runs for a few epochs, to save time; you should expect better results with more epochs (and quicker results for a smaller model). However, note that the pretrained model’s output features—without the use of labels—already outperform the guessing of randomly initialized weights:

[25]:
loss = training_report.validation_metrics['loss']
acc = training_report.validation_metrics['accuracy']
print(f'Final pretraining loss: {loss}')
print(f'Final pretraining KNN accuracy: {acc}')
Final pretraining loss: 923.1832885742188
Final pretraining KNN accuracy: 35.29999923706055

Now you’re ready to use Masterful’s unsupervised pretraining API!