Guide to Ensembling (Work In Progress)

In this guide, you’ll learn how to create ensemble model. For a more conceptual discussion, see the concepts documents. First, import the necessary libraries.

[1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa

import sys
sys.path.append('../../../')

import masterful

Let’s prepare our dataset. We’ll use a subset of Imagenet called Imagenette, with some minimal preprocessing. In a real production environment, we would follow the dataset performance guide.

[8]:
BATCHSIZE = 128

train, val = tfds.load('imagenette/160px', split=['train', 'validation'], as_supervised=True, shuffle_files=True)

train = train.map(lambda image, label: (float(image) / 127.5 - 1.0, tf.one_hot(label, 10)), num_parallel_calls=tf.data.AUTOTUNE)
val = val.map(lambda image, label: (float(image) / 127.5 - 1.0, tf.one_hot(label, 10)), num_parallel_calls=tf.data.AUTOTUNE)

train = train.map(lambda image, label: (tf.image.resize(image, (160,160)), label), num_parallel_calls=tf.data.AUTOTUNE)
val = val.map(lambda image, label: (tf.image.resize(image, (160,160)), label), num_parallel_calls=tf.data.AUTOTUNE)

train.cache()
val.cache()

train = train.shuffle(1000)

train = train.prefetch(tf.data.AUTOTUNE)
val = val.prefetch(tf.data.AUTOTUNE)

Now, let’s train a model on a simple dataset. We will use a very small model for demonstration purposes.

[9]:
def get_model():
  backbone = tf.keras.applications.EfficientNetB0(include_top=False, weights=None, input_shape=(160,160,3))
  retval = tf.keras.models.Sequential()
  retval.add(backbone)
  retval.add(tf.keras.layers.GlobalAveragePooling2D())
  retval.add(tf.keras.layers.Dense(10, activation='softmax'))

  retval.compile(tfa.optimizers.LAMB(tf.sqrt(float(BATCHSIZE)) / tf.sqrt(2.) / 32000), 'categorical_crossentropy', 'acc')
  return retval

model = get_model()
model.summary()
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
efficientnetb0 (Functional)  (None, 5, 5, 1280)        4049571
_________________________________________________________________
global_average_pooling2d_1 ( (None, 1280)              0
_________________________________________________________________
dense_2 (Dense)              (None, 10)                12810
=================================================================
Total params: 4,062,381
Trainable params: 4,020,358
Non-trainable params: 42,023
_________________________________________________________________
[10]:
model.fit(x=train.batch(BATCHSIZE), validation_data=val.batch(BATCHSIZE), epochs=20)
Epoch 1/20
101/101 [==============================] - 45s 289ms/step - loss: 2.2532 - acc: 0.1597 - val_loss: 2.3500 - val_acc: 0.1000
Epoch 2/20
101/101 [==============================] - 28s 274ms/step - loss: 1.8171 - acc: 0.3709 - val_loss: 3.0967 - val_acc: 0.1000
Epoch 3/20
101/101 [==============================] - 28s 273ms/step - loss: 1.4765 - acc: 0.5042 - val_loss: 4.2789 - val_acc: 0.1000
Epoch 4/20
101/101 [==============================] - 28s 273ms/step - loss: 1.2076 - acc: 0.6028 - val_loss: 4.3226 - val_acc: 0.1000
Epoch 5/20
101/101 [==============================] - 28s 274ms/step - loss: 1.0039 - acc: 0.6675 - val_loss: 3.5492 - val_acc: 0.1000
Epoch 6/20
101/101 [==============================] - 28s 276ms/step - loss: 0.8081 - acc: 0.7363 - val_loss: 2.9440 - val_acc: 0.1720
Epoch 7/20
101/101 [==============================] - 28s 275ms/step - loss: 0.6417 - acc: 0.7908 - val_loss: 1.3759 - val_acc: 0.5700
Epoch 8/20
101/101 [==============================] - 28s 274ms/step - loss: 0.4930 - acc: 0.8363 - val_loss: 1.0968 - val_acc: 0.7080
Epoch 9/20
101/101 [==============================] - 28s 275ms/step - loss: 0.3975 - acc: 0.8708 - val_loss: 1.1089 - val_acc: 0.7120
Epoch 10/20
101/101 [==============================] - 28s 274ms/step - loss: 0.3300 - acc: 0.8945 - val_loss: 1.5058 - val_acc: 0.6700
Epoch 11/20
 61/101 [=================>............] - ETA: 10s - loss: 0.2752 - acc: 0.9121

(TODO: Access masterful.core.ensemble)