{ "cells": [ { "cell_type": "markdown", "source": [ "# Guide to Distillation\n", "\n", "In this guide, you will learn how to distill a large model into a smaller model using Masterful. For a more conceptual discussion, see the concepts documents.\n", "\n", "This guide closely follows the Keras Knowledge Distillation Guide, and its main goal is to show you how to replicate that work using Masterful. The Keras Knowledge Distillation guide can be found [here](https://keras.io/examples/vision/knowledge_distillation/).\n", "\n", "First step, import all of the required dependencies." ], "metadata": {} }, { "cell_type": "code", "execution_count": 1, "source": [ "import tensorflow as tf\n", "import masterful\n", "import masterful.core" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "You are going to use the MNIST dataset for this guide. You should limit yourself to very simple preprocessing, as required by the model you are distilling into." ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "num_classes = 10\n", "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", "\n", "# Normalize data into the range (0,1)\n", "x_train = x_train.astype(\"float32\") / 255.0\n", "x_test = x_test.astype(\"float32\") / 255.0\n", "\n", "# Masterful needs an explicit channels parameter, so for single channel\n", "# data like MNIST we add the channels parameter explicitly.\n", "x_train = tf.reshape(x_train, (-1, 28, 28, 1))\n", "x_test = tf.reshape(x_train, (-1, 28, 28, 1))\n", "\n", "# Masterful performs best with one-hot labels.\n", "y_train = tf.keras.utils.to_categorical(y_train, num_classes)\n", "y_test = tf.keras.utils.to_categorical(y_test, num_classes)\n", "\n", "# Convert to Tensorflow Datasets for fast pipeline processing.\n", "labeled_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", "test_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "This guide follows the same experimental setup as the Keras guide, so setup the teacher and student models respectively. These can also be called the source and target models. The teacher is a simple convolutional neural network, sized for the MNIST data. " ], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "teacher_model = tf.keras.Sequential(\n", " [\n", " tf.keras.Input(shape=(28, 28, 1)),\n", " tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding=\"same\"),\n", " tf.keras.layers.LeakyReLU(alpha=0.2),\n", " tf.keras.layers.MaxPooling2D(\n", " pool_size=(2, 2), strides=(1, 1), padding=\"same\"),\n", " tf.keras.layers.Conv2D(512, (3, 3), strides=(2, 2), padding=\"same\"),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(num_classes),\n", " ],\n", " name=\"teacher\",\n", ")\n", "teacher_model.summary()" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Model: \"teacher\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "conv2d (Conv2D) (None, 14, 14, 256) 2560 \n", "_________________________________________________________________\n", "leaky_re_lu (LeakyReLU) (None, 14, 14, 256) 0 \n", "_________________________________________________________________\n", "max_pooling2d (MaxPooling2D) (None, 14, 14, 256) 0 \n", "_________________________________________________________________\n", "conv2d_1 (Conv2D) (None, 7, 7, 512) 1180160 \n", "_________________________________________________________________\n", "flatten (Flatten) (None, 25088) 0 \n", "_________________________________________________________________\n", "dense (Dense) (None, 10) 250890 \n", "=================================================================\n", "Total params: 1,433,610\n", "Trainable params: 1,433,610\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The student model is an even simpler convolutional neural network, containing fewer parameters than the teacher network." ], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "student_model = tf.keras.Sequential(\n", " [\n", " tf.keras.Input(shape=(28, 28, 1)),\n", " tf.keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding=\"same\"),\n", " tf.keras.layers.LeakyReLU(alpha=0.2),\n", " tf.keras.layers.MaxPooling2D(\n", " pool_size=(2, 2), strides=(1, 1), padding=\"same\"),\n", " tf.keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding=\"same\"),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(num_classes),\n", " ],\n", " name=\"student\",\n", ")\n", "student_model.summary()" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Model: \"student\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "conv2d_2 (Conv2D) (None, 14, 14, 16) 160 \n", "_________________________________________________________________\n", "leaky_re_lu_1 (LeakyReLU) (None, 14, 14, 16) 0 \n", "_________________________________________________________________\n", "max_pooling2d_1 (MaxPooling2 (None, 14, 14, 16) 0 \n", "_________________________________________________________________\n", "conv2d_3 (Conv2D) (None, 7, 7, 32) 4640 \n", "_________________________________________________________________\n", "flatten_1 (Flatten) (None, 1568) 0 \n", "_________________________________________________________________\n", "dense_1 (Dense) (None, 10) 15690 \n", "=================================================================\n", "Total params: 20,490\n", "Trainable params: 20,490\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Train the Teacher \n", "Typically, you would use an already trained teacher model. In this guide, you need to explicitly \n", "train the teacher first before you can perform distillation. The teacher should achieve 97-98% accuracy in just five epochs. " ], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "batch_size = 64\n", "teacher_model.compile(\n", " optimizer=tf.keras.optimizers.Adam(),\n", " loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.keras.metrics.CategoricalAccuracy()],\n", ")\n", "teacher_model.fit(labeled_dataset.batch(batch_size), epochs=5)\n", "teacher_evaluation_metrics = teacher_model.evaluate(\n", " test_dataset.batch(batch_size), return_dict=True)\n", "print(f'Teacher evaluation metrics: {teacher_evaluation_metrics}')" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/5\n", "938/938 [==============================] - 8s 6ms/step - loss: 0.2936 - categorical_accuracy: 0.9122\n", "Epoch 2/5\n", "938/938 [==============================] - 6s 6ms/step - loss: 0.0854 - categorical_accuracy: 0.9738\n", "Epoch 3/5\n", "938/938 [==============================] - 6s 6ms/step - loss: 0.0694 - categorical_accuracy: 0.9788\n", "Epoch 4/5\n", "938/938 [==============================] - 5s 6ms/step - loss: 0.0621 - categorical_accuracy: 0.9822\n", "Epoch 5/5\n", "938/938 [==============================] - 6s 6ms/step - loss: 0.0625 - categorical_accuracy: 0.9812\n", "938/938 [==============================] - 4s 4ms/step - loss: 0.0691 - categorical_accuracy: 0.9796\n", "Teacher evaluation metrics: {'loss': 0.06912115961313248, 'categorical_accuracy': 0.979616641998291}\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Distill to the student\n", "Now that you have a teacher model, you can distill that knowledge into the student model. The first step is to set up the model and data specifications that you will pass to Masterful. This lets Masterful know a little bit more about the model, data, and the task you are trying to perform." ], "metadata": {} }, { "cell_type": "code", "execution_count": 6, "source": [ "# Create a dataset specification from the training dataset.\n", "labeled_data_spec = masterful.spec.DataSpec.from_dataset(\n", " masterful.spec.Task.CLASSIFICATION,\n", " labeled_dataset,\n", " masterful.spec.ImageRange.ZERO_ONE,\n", " num_classes,\n", " sparse=False)\n", "\n", "# Create a model specification from the teacher model.\n", "teacher_model_spec = masterful.spec.ModelSpec.from_model(\n", " masterful.spec.Task.CLASSIFICATION,\n", " teacher_model,\n", " masterful.spec.ImageRange.ZERO_ONE,\n", " num_classes,\n", " from_logits=True)\n", "\n", "# Create a model specification from the student model.\n", "student_model_spec = masterful.spec.ModelSpec.from_model(\n", " masterful.spec.Task.CLASSIFICATION,\n", " student_model,\n", " masterful.spec.ImageRange.ZERO_ONE,\n", " num_classes,\n", " from_logits=True)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "Next step is to create the distillation policy that Masterful will use during training.\n", "\n", "Notice that Masterful automatically infers the optimal training batch size\n", "based on your model, data, and hardware. " ], "metadata": {} }, { "cell_type": "code", "execution_count": 7, "source": [ "# Find the optimial batch size.\n", "batch_size = masterful.core.find_batch_size(teacher_model, teacher_model_spec,\n", " labeled_dataset,\n", " labeled_data_spec)\n", "\n", "# Create the policy we will use for distillation.\n", "distillation_policy = masterful.DistillationPolicy(batch_size=batch_size)" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "find_batch_size: phase 1 of 2 (exponential): Trying batch size 2.\n", "1/1 [==============================] - 0s 221ms/step - loss: 8.4543e-04\n", "2/2 [==============================] - 0s 6ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 4.\n", "1/1 [==============================] - 0s 34ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 4ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 8.\n", "1/1 [==============================] - 0s 29ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 4ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 16.\n", "1/1 [==============================] - 0s 36ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 4ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 32.\n", "1/1 [==============================] - 0s 5ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 4ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 64.\n", "1/1 [==============================] - 0s 7ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 7ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 128.\n", "1/1 [==============================] - 0s 96ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 8ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 256.\n", "1/1 [==============================] - 0s 141ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 14ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 512.\n", "1/1 [==============================] - 0s 212ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 19ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 1024.\n", "1/1 [==============================] - 0s 348ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 31ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 2048.\n", "1/1 [==============================] - 0s 424ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 55ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 4096.\n", "1/1 [==============================] - 1s 784ms/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 115ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): Trying batch size 8192.\n", "1/1 [==============================] - 2s 2s/step - loss: 0.0000e+00\n", "2/2 [==============================] - 0s 219ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 1 of 2 (exponential): searched batch_size 16384 value is beyond updated_max_batch_size: 15000.\n", "find_batch_size: phase 2 of 2 (binary): Trying batch size 12288.\n", "1/1 [==============================] - 2s 2s/step - loss: 0.0000e+00\n", "2/2 [==============================] - 1s 437ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 2 of 2 (binary):Found lower/upper: 12288 / 16384.\n", "find_batch_size: phase 2 of 2 (binary): Trying batch size 14336.\n", "1/1 [==============================] - 3s 3s/step - loss: 0.0000e+00\n", "2/2 [==============================] - 1s 547ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 2 of 2 (binary):Found lower/upper: 14336 / 16384.\n", "find_batch_size: phase 2 of 2 (binary): Trying batch size 15360.\n", "1/1 [==============================] - 3s 3s/step - loss: 0.0000e+00\n", "2/2 [==============================] - 1s 586ms/step - loss: 0.0000e+00\n", "find_batch_size: phase 2 of 2 (binary):Found lower/upper: 15360 / 16384.\n", "find_batch_size: phase 2 of 2 (binary): Found batch size within 10% of ideal: 15360.\n", "find_batch_size: DONE. Returning 4096, delivering nearly the best images per second.The max batch_size that did not OOM was 15360, but that returns a worse images per second (within tolerances).The returned value may have been clipped by the user passed max_batch_size: 2147483647, or the calculated max_batch_size 15000 that would allow at least min_steps_per_epoch 4.\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The final step is to call into Masterful to initiate the distillation process." ], "metadata": {} }, { "cell_type": "code", "execution_count": 8, "source": [ "distillation_report = masterful.core.distill(\n", " distillation_policy, teacher_model, teacher_model_spec, student_model,\n", " student_model_spec, labeled_dataset, None, None, labeled_data_spec)" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/2147483647\n", "14/14 [==============================] - 3s 111ms/step - loss: 1.9759 - val_categorical_accuracy: 0.4433 - val_student_loss: 1.8047\n", "Epoch 2/2147483647\n", "14/14 [==============================] - 2s 83ms/step - loss: 1.6491 - val_categorical_accuracy: 0.7183 - val_student_loss: 0.9635\n", "Epoch 3/2147483647\n", "14/14 [==============================] - 2s 82ms/step - loss: 1.0391 - val_categorical_accuracy: 0.7565 - val_student_loss: 0.7687\n", "Epoch 4/2147483647\n", "14/14 [==============================] - 2s 80ms/step - loss: 0.5478 - val_categorical_accuracy: 0.8293 - val_student_loss: 0.7633\n", "Epoch 5/2147483647\n", "14/14 [==============================] - 2s 86ms/step - loss: 0.3958 - val_categorical_accuracy: 0.8677 - val_student_loss: 0.6726\n", "Epoch 6/2147483647\n", "14/14 [==============================] - 2s 83ms/step - loss: 0.3163 - val_categorical_accuracy: 0.8778 - val_student_loss: 0.5398\n", "Epoch 7/2147483647\n", "14/14 [==============================] - 1s 81ms/step - loss: 0.2684 - val_categorical_accuracy: 0.8988 - val_student_loss: 0.4701\n", "Epoch 8/2147483647\n", "14/14 [==============================] - 2s 82ms/step - loss: 0.2194 - val_categorical_accuracy: 0.9187 - val_student_loss: 0.3656\n", "Epoch 9/2147483647\n", "14/14 [==============================] - 2s 90ms/step - loss: 0.1598 - val_categorical_accuracy: 0.9370 - val_student_loss: 0.2740\n", "Epoch 10/2147483647\n", "14/14 [==============================] - 2s 83ms/step - loss: 0.1168 - val_categorical_accuracy: 0.9472 - val_student_loss: 0.2059\n", "Epoch 11/2147483647\n", "14/14 [==============================] - 1s 81ms/step - loss: 0.0931 - val_categorical_accuracy: 0.9602 - val_student_loss: 0.1733\n", "Epoch 12/2147483647\n", "14/14 [==============================] - 2s 86ms/step - loss: 0.0727 - val_categorical_accuracy: 0.9570 - val_student_loss: 0.1775\n", "Epoch 13/2147483647\n", "14/14 [==============================] - 2s 86ms/step - loss: 0.0639 - val_categorical_accuracy: 0.9597 - val_student_loss: 0.1895\n", "Epoch 14/2147483647\n", "14/14 [==============================] - 2s 81ms/step - loss: 0.0524 - val_categorical_accuracy: 0.9632 - val_student_loss: 0.1491\n", "Epoch 15/2147483647\n", "14/14 [==============================] - 2s 88ms/step - loss: 0.0460 - val_categorical_accuracy: 0.9683 - val_student_loss: 0.1268\n", "Epoch 16/2147483647\n", "14/14 [==============================] - 2s 89ms/step - loss: 0.0418 - val_categorical_accuracy: 0.9655 - val_student_loss: 0.1469\n", "Epoch 17/2147483647\n", "14/14 [==============================] - 2s 81ms/step - loss: 0.0390 - val_categorical_accuracy: 0.9648 - val_student_loss: 0.1367\n", "Epoch 18/2147483647\n", "14/14 [==============================] - 2s 80ms/step - loss: 0.0384 - val_categorical_accuracy: 0.9650 - val_student_loss: 0.1487\n", "Epoch 19/2147483647\n", "14/14 [==============================] - 2s 84ms/step - loss: 0.0339 - val_categorical_accuracy: 0.9672 - val_student_loss: 0.1360\n", "Epoch 20/2147483647\n", "14/14 [==============================] - 1s 82ms/step - loss: 0.0318 - val_categorical_accuracy: 0.9663 - val_student_loss: 0.1320\n", "Epoch 21/2147483647\n", "14/14 [==============================] - 2s 82ms/step - loss: 0.0266 - val_categorical_accuracy: 0.9690 - val_student_loss: 0.1342\n", "Epoch 22/2147483647\n", "14/14 [==============================] - 2s 87ms/step - loss: 0.0243 - val_categorical_accuracy: 0.9683 - val_student_loss: 0.1368\n", "Epoch 23/2147483647\n", "14/14 [==============================] - 2s 82ms/step - loss: 0.0235 - val_categorical_accuracy: 0.9687 - val_student_loss: 0.1356\n", "Epoch 24/2147483647\n", "14/14 [==============================] - 2s 89ms/step - loss: 0.0226 - val_categorical_accuracy: 0.9692 - val_student_loss: 0.1299\n", "Epoch 25/2147483647\n", "14/14 [==============================] - 2s 83ms/step - loss: 0.0230 - val_categorical_accuracy: 0.9683 - val_student_loss: 0.1279\n", "2/2 [==============================] - 0s 7ms/step - categorical_accuracy: 0.9683 - student_loss: 0.1280\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Measure Results\n", "Let's see how well you did. The DistillationReport returned by Masterful contains information about the distillation process. Your student model is acheiving nearly the same accuracy as the teacher model using 20,000 weights instead of 1,400,000 million weights. You can also evaluate the student model directly on your holdout set. " ], "metadata": {} }, { "cell_type": "code", "execution_count": 9, "source": [ "student_evaluation_metrics = student_model.evaluate(\n", " test_dataset.batch(batch_size), return_dict=True)\n", "print(f'Teacher Evaluation metrics: {teacher_evaluation_metrics}')\n", "print(f'Student Evaluation metrics: {student_evaluation_metrics}')" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "15/15 [==============================] - 1s 19ms/step - loss: 0.1083 - categorical_accuracy: 0.9710\n", "Teacher Evaluation metrics: {'loss': 0.06912115961313248, 'categorical_accuracy': 0.979616641998291}\n", "Student Evaluation metrics: {'loss': 0.10426905751228333, 'categorical_accuracy': 0.97198486328125}\n" ] } ], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.6.9", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.6.9 64-bit ('masterful_tf2': venv)" }, "interpreter": { "hash": "9716fef025f64a2b69cf36238ef46ef956cc7030912a8258187d9e0c43537004" } }, "nbformat": 4, "nbformat_minor": 2 }