Guide to Unsupervised Pretraining¶
In this guide, you’ll learn how to train a backbone without using labels. At the end of this guide, you’ll build a supervised k-nearest-neighbors classifier head on top of this backbone. Although the KNN classifier demonstrates that the backbone has learned representations of the data, see the “Training with a backbone” guide for the recommended applications of a backbone (e.g. linear or MLP head for classification, pyramid feature network for detection, unet for segmentation).
Prerequisites¶
Please follow the Masterful installation instructions here in order to run this Quickstart.
Imports¶
Import libraries and register Masterful.
[ ]:
import numpy as np
import tensorflow as tf
import masterful
masterful = masterful.register()
For this guide you’ll be using the CIFAR10 dataset, which consists of 60,000 color images of 10 separate, non-overlapping classes of objects. There are 6,000 images of each class of object; 5,000 in a training set and 1,000 held out for testing.
[3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
training_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
Create a DataParams instance, which captures CIFAR10’s relevant metadata:
[4]:
training_dataset_params = masterful.data.learn_data_params(
dataset=training_dataset,
task=masterful.enums.Task.CLASSIFICATION,
image_range=masterful.enums.ImageRange.ZERO_255,
num_classes=10,
sparse_labels=True,
)
validation_dataset_params = masterful.data.learn_data_params(
dataset=validation_dataset,
task=masterful.enums.Task.CLASSIFICATION,
image_range=masterful.enums.ImageRange.ZERO_255,
num_classes=10,
sparse_labels=True,
)
Now setup the OptimizationParams, which establishes pretraining-related optimization hyperparameters. Also, create the SSL parameters that you will use for training. In this guide, you will use Barlow Twins to learn the self-supervised representation.
[5]:
optimization_params = masterful.optimization.OptimizationParams(
batch_size=512,
epochs=3,
warmup_epochs=1,
)
ssl_params = masterful.ssl.SemiSupervisedParams(
algorithms = ["barlow_twins"],
)
Finally, create a model to be pretrained – here, the same simple CNN used in this TensorFlow tutorial for labeled training:
[6]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model_params = masterful.architecture.learn_architecture_params(
model=model,
task=masterful.enums.Task.CLASSIFICATION,
input_range=masterful.enums.ImageRange.ZERO_255,
num_classes=10,
prediction_logits=True,
backbone_only=True,
)
Now you’re ready to pretrain! Pretraining with Masterful will return a training report, including fields like - loss
: the unsupervised loss at the end of pretraining. - accuracy
: the accuracy achieved by performing K-Nearest Neighbors classification on the pretrained backbone’s output features
(Note that the labels in the dataset passed to the training_dataset
parameter below are only used for KNN classification, not for pretraining the backbone.)
[7]:
training_report = masterful.ssl.learn_representation(
model=model,
model_params=model_params,
optimization_params=optimization_params,
ssl_params=ssl_params,
training_dataset=training_dataset,
training_dataset_params=training_dataset_params,
validation_dataset=validation_dataset,
validation_dataset_params=validation_dataset_params,
)
Epoch 1/3
98/98 [==============================] - 128s 1s/step - loss: 997.4935
Feature extracting: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:02<00:00, 34.59it/s]
Test Epoch: Acc@1:42.27%: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 14.59it/s]
kNN Test Accuracy at epoch 0: 42.27000045776367 Max Accuracy so far: 42.27000045776367
Epoch 2/3
98/98 [==============================] - 126s 1s/step - loss: 796.1797
Feature extracting: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:02<00:00, 35.67it/s]
Test Epoch: Acc@1:44.84%: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 14.96it/s]
kNN Test Accuracy at epoch 1: 44.84000015258789 Max Accuracy so far: 44.84000015258789
Epoch 3/3
98/98 [==============================] - 125s 1s/step - loss: 715.0130
Feature extracting: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:02<00:00, 36.52it/s]
Test Epoch: Acc@1:45.50%: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 15.09it/s]
kNN Test Accuracy at epoch 2: 45.5 Max Accuracy so far: 45.5
KnnEvaluator: Restoring model weights from epoch 3 with accuracy 45.5.
The pretraining policy above only runs for a few epochs, to save time; you should expect better results with more epochs (and quicker results for a smaller model). However, note that the pretrained model’s output features—without the use of labels—already outperform the guessing of randomly initialized weights:
[8]:
loss = training_report.validation_results['loss']
acc = training_report.validation_results['accuracy']
print(f'Final pretraining loss: {loss}')
print(f'Final pretraining KNN accuracy: {acc}')
Final pretraining loss: 715.0130004882812
Final pretraining KNN accuracy: 45.5
Now you’re ready to use Masterful’s unsupervised pretraining API!