Fine-Tuning Pretrained Networks

Perhaps the most practical, “remember this snippet” worthy section of Chollet’s notebook on using pretrained networks is the section where he outlines how to fine-tune your pre-trained ConvNets for your use case.

Generally, he breaks this practice up into 5 simple steps:

1. Add a custom network on top of an already-trained base network

To do this, we’ll import one of the pre-trained network objects from keras, without the dense layer attache

from keras.applications import VGG16
import warnings
Using TensorFlow backend.
base_model = VGG16(weights='imagenet',
                  include_top=False,   ## <- most important step
                  input_shape=(150, 150, 3))
And so we’ve got a big ol’ summary() printout for the VGG16 model

Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 150, 150, 3)       0         
block1_conv1 (Conv2D)        (None, 150, 150, 64)      1792      
block1_conv2 (Conv2D)        (None, 150, 150, 64)      36928     
block1_pool (MaxPooling2D)   (None, 75, 75, 64)        0         
block2_conv1 (Conv2D)        (None, 75, 75, 128)       73856     
block2_conv2 (Conv2D)        (None, 75, 75, 128)       147584    
block2_pool (MaxPooling2D)   (None, 37, 37, 128)       0         
block3_conv1 (Conv2D)        (None, 37, 37, 256)       295168    
block3_conv2 (Conv2D)        (None, 37, 37, 256)       590080    
block3_conv3 (Conv2D)        (None, 37, 37, 256)       590080    
block3_pool (MaxPooling2D)   (None, 18, 18, 256)       0         
block4_conv1 (Conv2D)        (None, 18, 18, 512)       1180160   
block4_conv2 (Conv2D)        (None, 18, 18, 512)       2359808   
block4_conv3 (Conv2D)        (None, 18, 18, 512)       2359808   
block4_pool (MaxPooling2D)   (None, 9, 9, 512)         0         
block5_conv1 (Conv2D)        (None, 9, 9, 512)         2359808   
block5_conv2 (Conv2D)        (None, 9, 9, 512)         2359808   
block5_conv3 (Conv2D)        (None, 9, 9, 512)         2359808   
block5_pool (MaxPooling2D)   (None, 4, 4, 512)         0         
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0

And we’re going to make a new model that inherits from this

from keras.models import Sequential

model = Sequential()
Layer (type)                 Output Shape              Param #   
vgg16 (Model)                (None, 4, 4, 512)         14714688  
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0

You might have noticed looking at the VGG.summary() that this doesn’t do any sort of classification, so we’ll tack on the usual Flatten/Dense steps at the end. So we’ll add that

from keras.layers import Flatten, Dense
model.add(Dense(256, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
Layer (type)                 Output Shape              Param #   
vgg16 (Model)                (None, 4, 4, 512)         14714688  
flatten_1 (Flatten)          (None, 8192)              0         
dense_1 (Dense)              (None, 256)               2097408   
dense_2 (Dense)              (None, 1)                 257       
Total params: 16,812,353
Trainable params: 16,812,353
Non-trainable params: 0

2. Freeze the base network

16 million trainable params is entirely too many for a network that was already trained on a better computer than we’ll have access to– especially at the lower-level features. So we’re going to ignore updating those and focus on the classification element with a simple one liner

base_model.trainable = False
Layer (type)                 Output Shape              Param #   
vgg16 (Model)                (None, 4, 4, 512)         14714688  
flatten_1 (Flatten)          (None, 8192)              0         
dense_1 (Dense)              (None, 256)               2097408   
dense_2 (Dense)              (None, 1)                 257       
Total params: 16,812,353
Trainable params: 2,097,665
Non-trainable params: 14,714,688

Much better

3. Train the part you added

This is the portion where you’ll do your typical model training code– presented as pseudocode here

train_generator = train_datagen.flow_from_directory()
test_generator = test_datagen.flow_from_directory()


history = model.fit_generator(

4. Unfreeze some layers in the base network

Now that you’ve got a decent classification section of your network, it will likely do you some good to start tailoring the later-layers of your ConvNet to your application.

This is the clever chunk of code you probably opened this notebook up for

base_model.trainable = True

set_trainable = False
for layer in base_model.layers:
    if == 'block5_conv1':
        set_trainable = True
    if set_trainable:
        layer.trainable = True
        layer.trainable = False

The reason this works is because base_model.layers serves up each layer as they’re constructed, sequentially, within the network

print([ for layer in base_model.layers])
['input_1', 'block1_conv1', 'block1_conv2', 'block1_pool', 'block2_conv1', 'block2_conv2', 'block2_pool', 'block3_conv1', 'block3_conv2', 'block3_conv3', 'block3_pool', 'block4_conv1', 'block4_conv2', 'block4_conv3', 'block4_pool', 'block5_conv1', 'block5_conv2', 'block5_conv3', 'block5_pool']

Thus, we loop over each layer, setting layer.trainable = False until we happen upon block5_conv1, near the end, and set the rest to True

5. Jointly train both layers in the part you added

The result is a Network that has a about 4x as many trainable params than before, which should yield better accuracy

Layer (type)                 Output Shape              Param #   
vgg16 (Model)                (None, 4, 4, 512)         14714688  
flatten_1 (Flatten)          (None, 8192)              0         
dense_1 (Dense)              (None, 256)               2097408   
dense_2 (Dense)              (None, 1)                 257       
Total params: 16,812,353
Trainable params: 9,177,089
Non-trainable params: 7,635,264