Guide to Distillation

In this guide, you will learn how to distill a large model into a smaller model using Masterful. For a more conceptual discussion, see the concepts documents.

This guide closely follows the Keras Knowledge Distillation Guide, and its main goal is to show you how to replicate that work using Masterful. The Keras Knowledge Distillation guide can be found here.

First step, import all of the required dependencies.

[1]:
import tensorflow as tf
import masterful
import masterful.core

You are going to use the MNIST dataset for this guide. You should limit yourself to very simple preprocessing, as required by the model you are distilling into.

[2]:
num_classes = 10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize data into the range (0,1)
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Masterful needs an explicit channels parameter, so for single channel
# data like MNIST we add the channels parameter explicitly.
x_train = tf.reshape(x_train, (-1, 28, 28, 1))
x_test = tf.reshape(x_train, (-1, 28, 28, 1))

# Masterful performs best with one-hot labels.
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

# Convert to Tensorflow Datasets for fast pipeline processing.
labeled_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))

This guide follows the same experimental setup as the Keras guide, so setup the teacher and student models respectively. These can also be called the source and target models. The teacher is a simple convolutional neural network, sized for the MNIST data.

[3]:
teacher_model = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.MaxPooling2D(
            pool_size=(2, 2), strides=(1, 1), padding="same"),
        tf.keras.layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(num_classes),
    ],
    name="teacher",
)
teacher_model.summary()
Model: "teacher"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d (Conv2D)              (None, 14, 14, 256)       2560
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 14, 14, 256)       0
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 256)       0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 7, 7, 512)         1180160
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0
_________________________________________________________________
dense (Dense)                (None, 10)                250890
=================================================================
Total params: 1,433,610
Trainable params: 1,433,610
Non-trainable params: 0
_________________________________________________________________

The student model is an even simpler convolutional neural network, containing fewer parameters than the teacher network.

[4]:
student_model = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.MaxPooling2D(
            pool_size=(2, 2), strides=(1, 1), padding="same"),
        tf.keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(num_classes),
    ],
    name="student",
)
student_model.summary()
Model: "student"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_2 (Conv2D)            (None, 14, 14, 16)        160
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 16)        0
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 16)        0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 32)          4640
_________________________________________________________________
flatten_1 (Flatten)          (None, 1568)              0
_________________________________________________________________
dense_1 (Dense)              (None, 10)                15690
=================================================================
Total params: 20,490
Trainable params: 20,490
Non-trainable params: 0
_________________________________________________________________

Train the Teacher

Typically, you would use an already trained teacher model. In this guide, you need to explicitly train the teacher first before you can perform distillation. The teacher should achieve 97-98% accuracy in just five epochs.

[5]:
batch_size = 64
teacher_model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.CategoricalAccuracy()],
)
teacher_model.fit(labeled_dataset.batch(batch_size), epochs=5)
teacher_evaluation_metrics = teacher_model.evaluate(
    test_dataset.batch(batch_size), return_dict=True)
print(f'Teacher evaluation metrics: {teacher_evaluation_metrics}')
Epoch 1/5
938/938 [==============================] - 8s 6ms/step - loss: 0.2936 - categorical_accuracy: 0.9122
Epoch 2/5
938/938 [==============================] - 6s 6ms/step - loss: 0.0854 - categorical_accuracy: 0.9738
Epoch 3/5
938/938 [==============================] - 6s 6ms/step - loss: 0.0694 - categorical_accuracy: 0.9788
Epoch 4/5
938/938 [==============================] - 5s 6ms/step - loss: 0.0621 - categorical_accuracy: 0.9822
Epoch 5/5
938/938 [==============================] - 6s 6ms/step - loss: 0.0625 - categorical_accuracy: 0.9812
938/938 [==============================] - 4s 4ms/step - loss: 0.0691 - categorical_accuracy: 0.9796
Teacher evaluation metrics: {'loss': 0.06912115961313248, 'categorical_accuracy': 0.979616641998291}

Distill to the student

Now that you have a teacher model, you can distill that knowledge into the student model. The first step is to set up the model and data specifications that you will pass to Masterful. This lets Masterful know a little bit more about the model, data, and the task you are trying to perform.

[6]:
# Create a dataset specification from the training dataset.
labeled_data_spec = masterful.spec.DataSpec.from_dataset(
    masterful.spec.Task.CLASSIFICATION,
    labeled_dataset,
    masterful.spec.ImageRange.ZERO_ONE,
    num_classes,
    sparse=False)

# Create a model specification from the teacher model.
teacher_model_spec = masterful.spec.ModelSpec.from_model(
    masterful.spec.Task.CLASSIFICATION,
    teacher_model,
    masterful.spec.ImageRange.ZERO_ONE,
    num_classes,
    from_logits=True)

# Create a model specification from the student model.
student_model_spec = masterful.spec.ModelSpec.from_model(
    masterful.spec.Task.CLASSIFICATION,
    student_model,
    masterful.spec.ImageRange.ZERO_ONE,
      num_classes,
      from_logits=True)

Next step is to create the distillation policy that Masterful will use during training.

Notice that Masterful automatically infers the optimal training batch size based on your model, data, and hardware.

[7]:
# Find the optimial batch size.
batch_size = masterful.core.find_batch_size(teacher_model, teacher_model_spec,
                                            labeled_dataset,
                                            labeled_data_spec)

# Create the policy we will use for distillation.
distillation_policy = masterful.DistillationPolicy(batch_size=batch_size)
find_batch_size: phase 1 of 2 (exponential): Trying batch size 2.
1/1 [==============================] - 0s 221ms/step - loss: 8.4543e-04
2/2 [==============================] - 0s 6ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 4.
1/1 [==============================] - 0s 34ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 4ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 8.
1/1 [==============================] - 0s 29ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 4ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 16.
1/1 [==============================] - 0s 36ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 4ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 32.
1/1 [==============================] - 0s 5ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 4ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 64.
1/1 [==============================] - 0s 7ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 7ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 128.
1/1 [==============================] - 0s 96ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 8ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 256.
1/1 [==============================] - 0s 141ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 14ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 512.
1/1 [==============================] - 0s 212ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 19ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 1024.
1/1 [==============================] - 0s 348ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 31ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 2048.
1/1 [==============================] - 0s 424ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 55ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 4096.
1/1 [==============================] - 1s 784ms/step - loss: 0.0000e+00
2/2 [==============================] - 0s 115ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): Trying batch size 8192.
1/1 [==============================] - 2s 2s/step - loss: 0.0000e+00
2/2 [==============================] - 0s 219ms/step - loss: 0.0000e+00
find_batch_size: phase 1 of 2 (exponential): searched batch_size 16384 value is beyond updated_max_batch_size: 15000.
find_batch_size: phase 2 of 2 (binary): Trying batch size 12288.
1/1 [==============================] - 2s 2s/step - loss: 0.0000e+00
2/2 [==============================] - 1s 437ms/step - loss: 0.0000e+00
find_batch_size: phase 2 of 2 (binary):Found lower/upper: 12288 / 16384.
find_batch_size: phase 2 of 2 (binary): Trying batch size 14336.
1/1 [==============================] - 3s 3s/step - loss: 0.0000e+00
2/2 [==============================] - 1s 547ms/step - loss: 0.0000e+00
find_batch_size: phase 2 of 2 (binary):Found lower/upper: 14336 / 16384.
find_batch_size: phase 2 of 2 (binary): Trying batch size 15360.
1/1 [==============================] - 3s 3s/step - loss: 0.0000e+00
2/2 [==============================] - 1s 586ms/step - loss: 0.0000e+00
find_batch_size: phase 2 of 2 (binary):Found lower/upper: 15360 / 16384.
find_batch_size: phase 2 of 2 (binary): Found batch size within 10% of ideal: 15360.
find_batch_size: DONE. Returning 4096, delivering nearly the best images per second.The max batch_size that did not OOM was 15360, but that returns a worse images per second (within tolerances).The returned value may have been clipped by the user passed max_batch_size: 2147483647, or the calculated max_batch_size 15000 that would allow at least min_steps_per_epoch 4.

The final step is to call into Masterful to initiate the distillation process.

[8]:
distillation_report = masterful.core.distill(
    distillation_policy, teacher_model, teacher_model_spec, student_model,
    student_model_spec, labeled_dataset, None, None, labeled_data_spec)
Epoch 1/2147483647
14/14 [==============================] - 3s 111ms/step - loss: 1.9759 - val_categorical_accuracy: 0.4433 - val_student_loss: 1.8047
Epoch 2/2147483647
14/14 [==============================] - 2s 83ms/step - loss: 1.6491 - val_categorical_accuracy: 0.7183 - val_student_loss: 0.9635
Epoch 3/2147483647
14/14 [==============================] - 2s 82ms/step - loss: 1.0391 - val_categorical_accuracy: 0.7565 - val_student_loss: 0.7687
Epoch 4/2147483647
14/14 [==============================] - 2s 80ms/step - loss: 0.5478 - val_categorical_accuracy: 0.8293 - val_student_loss: 0.7633
Epoch 5/2147483647
14/14 [==============================] - 2s 86ms/step - loss: 0.3958 - val_categorical_accuracy: 0.8677 - val_student_loss: 0.6726
Epoch 6/2147483647
14/14 [==============================] - 2s 83ms/step - loss: 0.3163 - val_categorical_accuracy: 0.8778 - val_student_loss: 0.5398
Epoch 7/2147483647
14/14 [==============================] - 1s 81ms/step - loss: 0.2684 - val_categorical_accuracy: 0.8988 - val_student_loss: 0.4701
Epoch 8/2147483647
14/14 [==============================] - 2s 82ms/step - loss: 0.2194 - val_categorical_accuracy: 0.9187 - val_student_loss: 0.3656
Epoch 9/2147483647
14/14 [==============================] - 2s 90ms/step - loss: 0.1598 - val_categorical_accuracy: 0.9370 - val_student_loss: 0.2740
Epoch 10/2147483647
14/14 [==============================] - 2s 83ms/step - loss: 0.1168 - val_categorical_accuracy: 0.9472 - val_student_loss: 0.2059
Epoch 11/2147483647
14/14 [==============================] - 1s 81ms/step - loss: 0.0931 - val_categorical_accuracy: 0.9602 - val_student_loss: 0.1733
Epoch 12/2147483647
14/14 [==============================] - 2s 86ms/step - loss: 0.0727 - val_categorical_accuracy: 0.9570 - val_student_loss: 0.1775
Epoch 13/2147483647
14/14 [==============================] - 2s 86ms/step - loss: 0.0639 - val_categorical_accuracy: 0.9597 - val_student_loss: 0.1895
Epoch 14/2147483647
14/14 [==============================] - 2s 81ms/step - loss: 0.0524 - val_categorical_accuracy: 0.9632 - val_student_loss: 0.1491
Epoch 15/2147483647
14/14 [==============================] - 2s 88ms/step - loss: 0.0460 - val_categorical_accuracy: 0.9683 - val_student_loss: 0.1268
Epoch 16/2147483647
14/14 [==============================] - 2s 89ms/step - loss: 0.0418 - val_categorical_accuracy: 0.9655 - val_student_loss: 0.1469
Epoch 17/2147483647
14/14 [==============================] - 2s 81ms/step - loss: 0.0390 - val_categorical_accuracy: 0.9648 - val_student_loss: 0.1367
Epoch 18/2147483647
14/14 [==============================] - 2s 80ms/step - loss: 0.0384 - val_categorical_accuracy: 0.9650 - val_student_loss: 0.1487
Epoch 19/2147483647
14/14 [==============================] - 2s 84ms/step - loss: 0.0339 - val_categorical_accuracy: 0.9672 - val_student_loss: 0.1360
Epoch 20/2147483647
14/14 [==============================] - 1s 82ms/step - loss: 0.0318 - val_categorical_accuracy: 0.9663 - val_student_loss: 0.1320
Epoch 21/2147483647
14/14 [==============================] - 2s 82ms/step - loss: 0.0266 - val_categorical_accuracy: 0.9690 - val_student_loss: 0.1342
Epoch 22/2147483647
14/14 [==============================] - 2s 87ms/step - loss: 0.0243 - val_categorical_accuracy: 0.9683 - val_student_loss: 0.1368
Epoch 23/2147483647
14/14 [==============================] - 2s 82ms/step - loss: 0.0235 - val_categorical_accuracy: 0.9687 - val_student_loss: 0.1356
Epoch 24/2147483647
14/14 [==============================] - 2s 89ms/step - loss: 0.0226 - val_categorical_accuracy: 0.9692 - val_student_loss: 0.1299
Epoch 25/2147483647
14/14 [==============================] - 2s 83ms/step - loss: 0.0230 - val_categorical_accuracy: 0.9683 - val_student_loss: 0.1279
2/2 [==============================] - 0s 7ms/step - categorical_accuracy: 0.9683 - student_loss: 0.1280

Measure Results

Let’s see how well you did. The DistillationReport returned by Masterful contains information about the distillation process. Your student model is acheiving nearly the same accuracy as the teacher model using 20,000 weights instead of 1,400,000 million weights. You can also evaluate the student model directly on your holdout set.

[9]:
student_evaluation_metrics = student_model.evaluate(
    test_dataset.batch(batch_size), return_dict=True)
print(f'Teacher Evaluation metrics: {teacher_evaluation_metrics}')
print(f'Student Evaluation metrics: {student_evaluation_metrics}')
15/15 [==============================] - 1s 19ms/step - loss: 0.1083 - categorical_accuracy: 0.9710
Teacher Evaluation metrics: {'loss': 0.06912115961313248, 'categorical_accuracy': 0.979616641998291}
Student Evaluation metrics: {'loss': 0.10426905751228333, 'categorical_accuracy': 0.97198486328125}