Save and Load Models to Disk in Keras, Python: A Complete Guide

Photo of a futuristic hard disk drive, pulsating with electric blue energy, with 3D graphic of a brain floating above it.

In machine/deep learning, it’s essential to save your trained models so you can reuse, share, or deploy them without wasting time and computational resources to retrain them from scratch. In Keras, one of the most popular deep-learning libraries, this process is both straightforward and versatile. This article will guide you through the process of saving and loading Keras models.

In this guide, we’ll delve deep into the intricacies of saving and loading Keras models. Here’s a glimpse of what you’ll learn:

  • Saving model weights and architecture
  • Loading model weights and architecture
  • Save and load checkpoints (that contain also other information crucial during training)

If you only want the quick and easy answer in 10 seconds, here it is:

Saving a model in Keras:

model = ... #Get the Keras model you wish to save (of course you need to define the model first)
model.save("path_to_model/model.keras")

Loading the same model:

model = keras.models.load_model("path_to_model/model.keras")

One notable advantage of saving models is the ability to pause and resume training, circumventing long training durations. Moreover, saving your model facilitates sharing, enabling peers to reproduce your results. It’s a common practice among machine learning experts to share not only the code responsible for the model but also its trained parameters or weights.

A typical Keras model encapsulates:

  1. Architecture: This defines the layers present in the model and their interconnections.
  2. Weights: These are the learned parameters that allow the model to make predictions.
  3. Optimizer: Chosen during the model compilation phase, it determines how the model updates based on the data it sees and its loss function.
  4. Losses and Metrics: Also selected during compilation, they gauge the model’s performance.

Keras facilitates streamlined saving procedures, each suitable for different kinds of needs and situations. Join us as we explore the steps and intricacies involved in this process.

Table of contents:

  1. Saving a Whole Model
  2. Loading a Whole Model
  3. Save and Loading Checkpoints During Training
  4. Saving and Loading Only Weights
  5. Export Model for Inference
  6. Saving and Loading Custom Objects

1. Saving a Whole Model

Keras offers a straightforward method to save your entire model, encompassing the architecture, weights, optimizer, and even the loss and metric information. This all-inclusive approach ensures that you can later load the model and resume training or make predictions without any additional setup.

Below, we present a simple code example that covers defining a model, training it, and subsequently saving it using the model.save() function:

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

# Define a simple model
model = keras.Sequential([
    layers.Input(shape=(784,)),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Generate dummy data for training
x_train = np.random.random((1000, 784))
y_train = np.random.randint(10, size=(1000,))

# Train the model
model.fit(x_train, y_train, epochs=10)

# Save the model using model.save()
model.save("path_to_model/my_model.keras")

Equivalently, instead of the model.save() method, you can also use the keras.models.save_model() function:

keras.models.save_model(model, "save_model_method.keras")

By default, Keras utilizes the recommended “Keras v3” format, which saves the model with a .keras file extension. This format provides a comprehensive package of the model’s essential components.

However, Keras provides flexibility in the saving format. You can choose between:

  1. Keras v3 format (.keras): As discussed, this is the default and recommended format.
model.save("my_model_v3.keras", save_format='keras')
  1. TensorFlow SavedModel format: Another format native to TensorFlow, which offers compatibility with other TensorFlow tools.
model.save("my_model_SavedModel", save_format='tf')
  1. Keras H5 format: This is an older format, mainly used for backward compatibility with older Keras versions.
model.save("my_model.h5", save_format='h5')

If you run the example code above after adding these different saving options (and removing model.save("path_to_model/my_model.keras")), you should have the following files:

In conclusion, saving a whole model in Keras is straightforward, and the availability of multiple formats ensures compatibility and flexibility based on your project’s needs.

2. Loading a Whole Model

After saving a Keras model, the process to reload it is just as straightforward, ensuring a seamless continuation of tasks like predictions or further training. This is achieved using the load_model function.

Here, we will detail how to load the models saved in the previous examples:

from tensorflow.keras.models import load_model

# Load model from the Keras v3 format
loaded_model_v3 = load_model("my_model.keras")

# Load model from the TensorFlow SavedModel format
loaded_model_SavedModel = load_model("my_model_SavedModel")

# Load model from the Keras H5 format
loaded_model_h5 = load_model("my_model.h5")

To verify that the loaded models retain the functionality of the original, you can use them for predictions. For instance:

# Generate some new dummy data for prediction
x_new = np.random.random((5, 784))

# Predict using the original model
original_predictions = model.predict(x_new)

# Predict using the loaded Keras v3 format model
loaded_v3_predictions = loaded_model_v3.predict(x_new)

# Ensure predictions match
if np.allclose(original_predictions, loaded_v3_predictions):
    print("Predictions match!")
else:
    print("Predictions do not match!")

# Repeat similar checks for the other loaded models
loaded_SavedModel_predictions = loaded_model_SavedModel.predict(x_new)
loaded_h5_predictions = loaded_model_h5.predict(x_new)

if np.allclose(original_predictions, loaded_SavedModel_predictions):
    print("Predictions match!")
else:
    print("Predictions do not match!")

if np.allclose(original_predictions, loaded_h5_predictions):
    print("Predictions match!")
else:
    print("Predictions do not match!")

In this example, the np.allclose() function is used to check if the predictions of the loaded models are close enough to the predictions of the original model, ensuring that they work identically. If you add this code to the end of the code from the previous section, you should get an output similar to this:

1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 29ms/step
Predictions match!
1/1 [==============================] - 0s 33ms/step
1/1 [==============================] - 0s 41ms/step
Predictions match!
Predictions match!

To conclude, loading a saved Keras model is a breeze, and the loaded model maintains its original functionality, allowing for an uninterrupted workflow in your deep learning projects.

3. Save and Loading Checkpoints During Training

Training deep learning models can be a time-consuming process, especially with large datasets or complex architectures. The ability to save and load checkpoints during training not only provides a safety net against unexpected interruptions but also allows for strategies such as early stopping or fine-tuning.

In Keras, the tf.keras.callbacks.ModelCheckpoint callback is designed precisely for this purpose.

Using ModelCheckpoint

Here’s a simple demonstration of how to utilize the ModelCheckpoint callback during training:

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint

# Define a simple model
model = keras.Sequential([
    layers.Input(shape=(784,)),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Generate dummy data for training
x_train = np.random.random((1000, 784))
y_train = np.random.randint(10, size=(1000,))

# Create ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(filepath="model_checkpoint.h5", 
                                      save_best_only=True,
                                      save_weights_only=False,
                                      monitor='val_loss',
                                      mode='min',
                                      verbose=1)

# Train the model with the callback
model.fit(x_train, y_train, validation_split=0.2, epochs=10, callbacks=[checkpoint_callback])

In this example, save_best_only=True ensures only the model with the best validation loss gets saved. If you only wish to save the model weights and not the entire model, you can set save_weights_only=True.

Note: At the time of writing, my version of Keras had a problem with saving checkpoints in .keras format. For this reason, I’m using the .h5 format in the examples.

Advanced Checkpointing Options

  1. Unique Names for Checkpoints:
    To generate unique names for each checkpoint (useful if you save checkpoints every epoch), you can use formatting options in the filepath:
checkpoint_callback = ModelCheckpoint(filepath="model_checkpoint_{epoch:02d}.h5")

This will save models with filenames like model_checkpoint_01.h5, model_checkpoint_02.h5, and so on.

  1. Adjusting Checkpointing Frequency:
    By default, the model checkpoint is saved after every epoch. However, you can adjust this frequency using the save_freq parameter. For instance, to save a checkpoint every other epoch in this example:
checkpoint_callback = ModelCheckpoint(filepath="model_checkpoint_{epoch:02d}.h5", save_freq=50)

Note that save_freq is set in terms of the number of batches. In this example, we have 1000 datapoints, out of which 800 are in the training set. Additionally, the default batch_size is 32, and so there are 25 batches in one epoch. Therefore we can use save_freq=50 to save a checkpoint every other epoch.

Loading from Checkpoints

If training gets interrupted, or if you want to load a model from a specific checkpoint, it’s as straightforward as using the load_model function:

loaded_from_checkpoint = keras.models.load_model("model_checkpoint.h5")

In conclusion, checkpoints are pivotal in providing flexibility and security during the training process. They enable resuming from a specific point, experimenting with different model configurations without retraining from scratch, or recovering from interruptions, ensuring that the extensive computational resources utilized during training are not wasted.

4. Saving and Loading Only Weights

Saving and Loading Only Weights

Weights form the crux of a deep learning model. They are the learned parameters that allow the model to make accurate predictions based on input data. In many scenarios, there is an interest in saving and loading only the weights of a model, rather than the entire model structure or other associated information.

Why Save and Load Only Weights?

  1. Flexibility: Saving only the weights allows users to modify the architecture or switch to a different model structure while reusing the same weights.
  2. Reduced Storage: Weights-only files are typically smaller than full model saves, which can be beneficial when storage space is a consideration.
  3. Research & Experimentation: Often, when fine-tuning or transferring learned knowledge from one model to another (transfer learning), only the weights are required.
  4. Shared Architecture: If the model architecture is standard and well-known (e.g., VGG, ResNet), sharing just the weights can be more efficient.

Methods to Save and Load Weights

  1. get_weights and set_weights:
  • get_weights(): This method retrieves the weights of a model as a list of NumPy arrays. It doesn’t save them to disk but allows for in-memory operations or custom saving routines.
  • set_weights(): It sets the model’s weights with the provided list of NumPy arrays. The list structure should match that of get_weights().
# Get weights of the model
weights = model.get_weights()

# Later, or in another script, you can set those weights
new_model = build_your_model()  # This should have an identical architecture to the original model
new_model.set_weights(weights)
  1. save_weights and load_weights:
  • save_weights(): Saves the model’s weights to a specified file. The format can be TensorFlow (default) or H5.
  • load_weights(): Loads the model’s weights from a specified file. The architecture of the model should be the same as the one from which the weights were saved.
# Save the weights of the model to a file
model.save_weights("model_weights.tf")

# Later, or in another script, you can load those weights
same_architecture_model = build_your_model()  # This should have an identical architecture to the original model
same_architecture_model.load_weights("model_weights.tf")

Similarities and Differences:

  • Similarities:
  • Both pairs (get_weights/set_weights and save_weights/load_weights) deal with model weights exclusively, ignoring the architecture, optimizer, and other settings.
  • In both cases, when restoring weights, the architecture of the target model should match the architecture of the source model.
  • Differences:
  • get_weights and set_weights work in memory, returning and accepting Python lists, respectively. They don’t interact directly with the file system.
  • save_weights and load_weights, on the other hand, directly read from and write to the disk, storing weights in a persistent format.

In conclusion, Keras offers flexible ways to handle model weights, whether you’re looking to perform in-memory operations, save them for future use, or transition between different architectures. Knowing when and how to use these methods can greatly streamline your machine learning workflow.

5. Export Model for Inference

In the realm of machine learning and deep learning, once a model is trained, it’s often desired to deploy it for production or real-time inference. Keras provides intuitive tools to streamline this deployment process, ensuring that the model is lightweight, efficient, and standalone.

Understanding Exporting for Inference

When exporting for inference, the objective is to encapsulate only the forward pass (or the inference functionality) of the model, removing any additional overhead related to training. This results in a more compact and performance-optimized model artifact suitable for various deployment scenarios, including TF-Serving. What makes this exported artifact stand out is its independence from the original code, meaning even custom layers or unique architectures are baked into the artifact, eliminating the need for the original codebase during deployment.

Key APIs for Exporting

  • model.export(): This method allows you to convert your Keras model into a lightweight SavedModel artifact exclusively tailored for inference. Once exported, the artifact can be easily served without any dependence on the original code, making deployment hassle-free.
  • artifact.serve(): After exporting, this method lets you perform inference on the model artifact, effectively running the forward pass and getting predictions.

For those who require more advanced customization, there’s also:

  • keras.export.ExportArchive: This offers a deeper level of customization for the serving endpoints, giving users more control over the deployment process. In case you wish to have more information about this customization, I suggest to check out this link. Under the hood, model.export() utilizes ExportArchive for the export process.

A Simple Walkthrough on Exporting for Inference

Consider a basic Keras Functional model:

import numpy as np
import tensorflow as tf
from tensorflow import keras

# Define a simple Functional model
inputs = keras.Input(shape=(16,))
x = keras.layers.Dense(8, activation="relu")(inputs)
x = keras.layers.BatchNormalization()(x)
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)

# Generate some random input data
input_data = np.random.random((8, 16))
output_data = model(input_data)  # **NOTE**: Ensure your model is built before exporting!

# Use `model.export()` to save the model for inference
model.export("exported_model_path")

# Load the SavedModel artifact using TensorFlow's utilities
reloaded_artifact = tf.saved_model.load("exported_model_path")

# Get predictions using the `.serve()` method on the loaded artifact
new_output_data = reloaded_artifact.serve(input_data)

# Ensure predictions match
if np.allclose(output_data, new_output_data):
    print("Predictions match!")
else:
    print("Predictions do not match!")

6. Saving and Loading Custom Objects

As machine learning and deep learning evolve, practitioners often find themselves implementing custom layers, models, or other objects to meet specific requirements. When it comes to saving and loading models with such custom objects in Keras, special care is needed to ensure these custom objects are recognized and loaded correctly.

Implementing Custom Objects with Save & Load Support

For Keras to be able to save and then later load a custom object (e.g., a custom layer or a custom loss function), the object must define two key methods:

  1. get_config(): This method should return a dictionary containing the configuration of the custom object. The configuration ensures that Keras can recreate the object accurately when loading it later.
  2. from_config(config): This is a class method that should return a new instance of the custom object, initialized with the configuration provided.
class CustomLayer(tf.keras.layers.Layer):
    def __init__(self, param, **kwargs):
        super(CustomLayer, self).__init__(**kwargs)
        self.param = param

    def call(self, inputs):
        # Define the forward pass here
        return outputs

    def get_config(self):
        config = super().get_config()
        # Update the config with the custom layer's parameters
        config.update({'param': self.param})
        return config

    @classmethod
    def from_config(cls, config):
        # Note that you can also use `keras.saving.deserialize_keras_object` here
        config["param"] = keras.layers.deserialize(config["param"])
        return cls(**config)

Loading Models with Custom Objects

When you save a model containing custom objects, you’ll need to provide these objects when loading the model. There are three main ways to ensure Keras recognizes your custom objects:

  1. Using the @tf.keras.utils.register_keras_serializable Decorator (Recommended):
    This decorator can be used to automatically register a custom object, making it recognizable by Keras during the load process without any additional arguments.
   @tf.keras.utils.register_keras_serializable
   class CustomLayer(tf.keras.layers.Layer):
       ...

Then, when loading:

   loaded_model = tf.keras.models.load_model('path_to_saved_model')
  1. Directly Passing the Custom Object:
    If not registered using the decorator, the custom object can be passed directly to the load_model() function via the custom_objects argument.
   loaded_model = tf.keras.models.load_model('path_to_saved_model', 
                                             custom_objects={'CustomLayer': CustomLayer})
  1. Using tf.keras.utils.custom_object_scope:
    If you have multiple custom objects or prefer a scoped approach, you can use the custom_object_scope.
   with tf.keras.utils.custom_object_scope({'CustomLayer': CustomLayer}):
       loaded_model = tf.keras.models.load_model('path_to_saved_model')

In summary, when working with custom objects in Keras, it’s crucial to ensure they’re serializable and can be accurately recreated during the loading process. Whether you’re defining a novel layer or a unique activation function, Keras offers multiple ways to seamlessly integrate, save, and load your custom implementations.

Scroll to Top