
In machine/deep learning, it’s essential to save your trained models so you can reuse, share, or deploy them without wasting time and computational resources to retrain them from scratch. In Keras, one of the most popular deep-learning libraries, this process is both straightforward and versatile. This article will guide you through the process of saving and loading Keras models.
In this guide, we’ll delve deep into the intricacies of saving and loading Keras models. Here’s a glimpse of what you’ll learn:
- Saving model weights and architecture
- Loading model weights and architecture
- Save and load checkpoints (that contain also other information crucial during training)
If you only want the quick and easy answer in 10 seconds, here it is:
Saving a model in Keras:
model = ... #Get the Keras model you wish to save (of course you need to define the model first) model.save("path_to_model/model.keras")
Loading the same model:
model = keras.models.load_model("path_to_model/model.keras")
One notable advantage of saving models is the ability to pause and resume training, circumventing long training durations. Moreover, saving your model facilitates sharing, enabling peers to reproduce your results. It’s a common practice among machine learning experts to share not only the code responsible for the model but also its trained parameters or weights.
A typical Keras model encapsulates:
- Architecture: This defines the layers present in the model and their interconnections.
- Weights: These are the learned parameters that allow the model to make predictions.
- Optimizer: Chosen during the model compilation phase, it determines how the model updates based on the data it sees and its loss function.
- Losses and Metrics: Also selected during compilation, they gauge the model’s performance.
Keras facilitates streamlined saving procedures, each suitable for different kinds of needs and situations. Join us as we explore the steps and intricacies involved in this process.
Table of contents:
- Saving a Whole Model
- Loading a Whole Model
- Save and Loading Checkpoints During Training
- Saving and Loading Only Weights
- Export Model for Inference
- Saving and Loading Custom Objects
1. Saving a Whole Model
Keras offers a straightforward method to save your entire model, encompassing the architecture, weights, optimizer, and even the loss and metric information. This all-inclusive approach ensures that you can later load the model and resume training or make predictions without any additional setup.
Below, we present a simple code example that covers defining a model, training it, and subsequently saving it using the
function:model.save()
import numpy as np from tensorflow import keras from tensorflow.keras import layers # Define a simple model model = keras.Sequential([ layers.Input(shape=(784,)), layers.Dense(128, activation='relu'), layers.Dense(10, activation='softmax') ]) # Compile the model model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # Generate dummy data for training x_train = np.random.random((1000, 784)) y_train = np.random.randint(10, size=(1000,)) # Train the model model.fit(x_train, y_train, epochs=10) # Save the model using model.save() model.save("path_to_model/my_model.keras")
Equivalently, instead of the model.save()
method, you can also use the
function:keras.models.save_model()
keras.models.save_model(model, "save_model_method.keras")
By default, Keras utilizes the recommended “Keras v3” format, which saves the model with a .keras
file extension. This format provides a comprehensive package of the model’s essential components.
However, Keras provides flexibility in the saving format. You can choose between:
- Keras v3 format (
.keras
): As discussed, this is the default and recommended format.
model.save("my_model_v3.keras", save_format='keras')
- TensorFlow SavedModel format: Another format native to TensorFlow, which offers compatibility with other TensorFlow tools.
model.save("my_model_SavedModel", save_format='tf')
- Keras H5 format: This is an older format, mainly used for backward compatibility with older Keras versions.
model.save("my_model.h5", save_format='h5')
If you run the example code above after adding these different saving options (and removing model.save("path_to_model/my_model.keras")
), you should have the following files:

In conclusion, saving a whole model in Keras is straightforward, and the availability of multiple formats ensures compatibility and flexibility based on your project’s needs.
2. Loading a Whole Model
After saving a Keras model, the process to reload it is just as straightforward, ensuring a seamless continuation of tasks like predictions or further training. This is achieved using the load_model
function.
Here, we will detail how to load the models saved in the previous examples:
from tensorflow.keras.models import load_model # Load model from the Keras v3 format loaded_model_v3 = load_model("my_model.keras") # Load model from the TensorFlow SavedModel format loaded_model_SavedModel = load_model("my_model_SavedModel") # Load model from the Keras H5 format loaded_model_h5 = load_model("my_model.h5")
To verify that the loaded models retain the functionality of the original, you can use them for predictions. For instance:
# Generate some new dummy data for prediction x_new = np.random.random((5, 784)) # Predict using the original model original_predictions = model.predict(x_new) # Predict using the loaded Keras v3 format model loaded_v3_predictions = loaded_model_v3.predict(x_new) # Ensure predictions match if np.allclose(original_predictions, loaded_v3_predictions): print("Predictions match!") else: print("Predictions do not match!") # Repeat similar checks for the other loaded models loaded_SavedModel_predictions = loaded_model_SavedModel.predict(x_new) loaded_h5_predictions = loaded_model_h5.predict(x_new) if np.allclose(original_predictions, loaded_SavedModel_predictions): print("Predictions match!") else: print("Predictions do not match!") if np.allclose(original_predictions, loaded_h5_predictions): print("Predictions match!") else: print("Predictions do not match!")
In this example, the
function is used to check if the predictions of the loaded models are close enough to the predictions of the original model, ensuring that they work identically. If you add this code to the end of the code from the previous section, you should get an output similar to this:np.allclose()
1/1 [==============================] - 0s 45ms/step
1/1 [==============================] - 0s 29ms/step
Predictions match!
1/1 [==============================] - 0s 33ms/step
1/1 [==============================] - 0s 41ms/step
Predictions match!
Predictions match!
To conclude, loading a saved Keras model is a breeze, and the loaded model maintains its original functionality, allowing for an uninterrupted workflow in your deep learning projects.
3. Save and Loading Checkpoints During Training
Training deep learning models can be a time-consuming process, especially with large datasets or complex architectures. The ability to save and load checkpoints during training not only provides a safety net against unexpected interruptions but also allows for strategies such as early stopping or fine-tuning.
In Keras, the
callback is designed precisely for this purpose.tf.keras.callbacks.ModelCheckpoint
Using ModelCheckpoint
Here’s a simple demonstration of how to utilize the
callback during training:ModelCheckpoint
import numpy as np from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.callbacks import ModelCheckpoint # Define a simple model model = keras.Sequential([ layers.Input(shape=(784,)), layers.Dense(128, activation='relu'), layers.Dense(10, activation='softmax') ]) # Compile the model model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # Generate dummy data for training x_train = np.random.random((1000, 784)) y_train = np.random.randint(10, size=(1000,)) # Create ModelCheckpoint callback checkpoint_callback = ModelCheckpoint(filepath="model_checkpoint.h5", save_best_only=True, save_weights_only=False, monitor='val_loss', mode='min', verbose=1) # Train the model with the callback model.fit(x_train, y_train, validation_split=0.2, epochs=10, callbacks=[checkpoint_callback])
In this example,
ensures only the model with the best validation loss gets saved. If you only wish to save the model weights and not the entire model, you can set save_best_only=True
.save_weights_only=True
Note: At the time of writing, my version of Keras had a problem with saving checkpoints in .keras
format. For this reason, I’m using the .h5
format in the examples.
Advanced Checkpointing Options
- Unique Names for Checkpoints:
To generate unique names for each checkpoint (useful if you save checkpoints every epoch), you can use formatting options in the
:filepath
checkpoint_callback = ModelCheckpoint(filepath="model_checkpoint_{epoch:02d}.h5")
This will save models with filenames like model_checkpoint_01.h5
, model_checkpoint_02.h5
, and so on.
- Adjusting Checkpointing Frequency:
By default, the model checkpoint is saved after every epoch. However, you can adjust this frequency using the
parameter. For instance, to save a checkpoint every other epoch in this example:save_freq
checkpoint_callback = ModelCheckpoint(filepath="model_checkpoint_{epoch:02d}.h5", save_freq=50)
Note that
is set in terms of the number of batches. In this example, we have 1000 datapoints, out of which 800 are in the training set. Additionally, the default save_freq
batch_size
is 32, and so there are 25 batches in one epoch. Therefore we can use save_freq=50
to save a checkpoint every other epoch.
Loading from Checkpoints
If training gets interrupted, or if you want to load a model from a specific checkpoint, it’s as straightforward as using the
function:load_model
loaded_from_checkpoint = keras.models.load_model("model_checkpoint.h5")
In conclusion, checkpoints are pivotal in providing flexibility and security during the training process. They enable resuming from a specific point, experimenting with different model configurations without retraining from scratch, or recovering from interruptions, ensuring that the extensive computational resources utilized during training are not wasted.
4. Saving and Loading Only Weights
Saving and Loading Only Weights
Weights form the crux of a deep learning model. They are the learned parameters that allow the model to make accurate predictions based on input data. In many scenarios, there is an interest in saving and loading only the weights of a model, rather than the entire model structure or other associated information.
Why Save and Load Only Weights?
- Flexibility: Saving only the weights allows users to modify the architecture or switch to a different model structure while reusing the same weights.
- Reduced Storage: Weights-only files are typically smaller than full model saves, which can be beneficial when storage space is a consideration.
- Research & Experimentation: Often, when fine-tuning or transferring learned knowledge from one model to another (transfer learning), only the weights are required.
- Shared Architecture: If the model architecture is standard and well-known (e.g., VGG, ResNet), sharing just the weights can be more efficient.
Methods to Save and Load Weights
- get_weights and set_weights:
- get_weights(): This method retrieves the weights of a model as a list of NumPy arrays. It doesn’t save them to disk but allows for in-memory operations or custom saving routines.
- set_weights(): It sets the model’s weights with the provided list of NumPy arrays. The list structure should match that of
get_weights()
.
# Get weights of the model weights = model.get_weights() # Later, or in another script, you can set those weights new_model = build_your_model() # This should have an identical architecture to the original model new_model.set_weights(weights)
- save_weights and load_weights:
- save_weights(): Saves the model’s weights to a specified file. The format can be TensorFlow (default) or H5.
- load_weights(): Loads the model’s weights from a specified file. The architecture of the model should be the same as the one from which the weights were saved.
# Save the weights of the model to a file model.save_weights("model_weights.tf") # Later, or in another script, you can load those weights same_architecture_model = build_your_model() # This should have an identical architecture to the original model same_architecture_model.load_weights("model_weights.tf")
Similarities and Differences:
- Similarities:
- Both pairs (
get_weights/set_weights
andsave_weights/load_weights
) deal with model weights exclusively, ignoring the architecture, optimizer, and other settings. - In both cases, when restoring weights, the architecture of the target model should match the architecture of the source model.
- Differences:
get_weights
andset_weights
work in memory, returning and accepting Python lists, respectively. They don’t interact directly with the file system.save_weights
andload_weights
, on the other hand, directly read from and write to the disk, storing weights in a persistent format.
In conclusion, Keras offers flexible ways to handle model weights, whether you’re looking to perform in-memory operations, save them for future use, or transition between different architectures. Knowing when and how to use these methods can greatly streamline your machine learning workflow.
5. Export Model for Inference
In the realm of machine learning and deep learning, once a model is trained, it’s often desired to deploy it for production or real-time inference. Keras provides intuitive tools to streamline this deployment process, ensuring that the model is lightweight, efficient, and standalone.
Understanding Exporting for Inference
When exporting for inference, the objective is to encapsulate only the forward pass (or the inference functionality) of the model, removing any additional overhead related to training. This results in a more compact and performance-optimized model artifact suitable for various deployment scenarios, including TF-Serving. What makes this exported artifact stand out is its independence from the original code, meaning even custom layers or unique architectures are baked into the artifact, eliminating the need for the original codebase during deployment.
Key APIs for Exporting
model.export()
: This method allows you to convert your Keras model into a lightweight SavedModel artifact exclusively tailored for inference. Once exported, the artifact can be easily served without any dependence on the original code, making deployment hassle-free.artifact.serve()
: After exporting, this method lets you perform inference on the model artifact, effectively running the forward pass and getting predictions.
For those who require more advanced customization, there’s also:
keras.export.ExportArchive
: This offers a deeper level of customization for the serving endpoints, giving users more control over the deployment process. In case you wish to have more information about this customization, I suggest to check out this link. Under the hood,
utilizesmodel.export()
ExportArchive
for the export process.
A Simple Walkthrough on Exporting for Inference
Consider a basic Keras Functional model:
import numpy as np import tensorflow as tf from tensorflow import keras # Define a simple Functional model inputs = keras.Input(shape=(16,)) x = keras.layers.Dense(8, activation="relu")(inputs) x = keras.layers.BatchNormalization()(x) outputs = keras.layers.Dense(1, activation="sigmoid")(x) model = keras.Model(inputs, outputs) # Generate some random input data input_data = np.random.random((8, 16)) output_data = model(input_data) # **NOTE**: Ensure your model is built before exporting! # Use `model.export()` to save the model for inference model.export("exported_model_path") # Load the SavedModel artifact using TensorFlow's utilities reloaded_artifact = tf.saved_model.load("exported_model_path") # Get predictions using the `.serve()` method on the loaded artifact new_output_data = reloaded_artifact.serve(input_data) # Ensure predictions match if np.allclose(output_data, new_output_data): print("Predictions match!") else: print("Predictions do not match!")
6. Saving and Loading Custom Objects
As machine learning and deep learning evolve, practitioners often find themselves implementing custom layers, models, or other objects to meet specific requirements. When it comes to saving and loading models with such custom objects in Keras, special care is needed to ensure these custom objects are recognized and loaded correctly.
Implementing Custom Objects with Save & Load Support
For Keras to be able to save and then later load a custom object (e.g., a custom layer or a custom loss function), the object must define two key methods:
get_config()
: This method should return a dictionary containing the configuration of the custom object. The configuration ensures that Keras can recreate the object accurately when loading it later.from_config(config)
: This is a class method that should return a new instance of the custom object, initialized with the configuration provided.
class CustomLayer(tf.keras.layers.Layer): def __init__(self, param, **kwargs): super(CustomLayer, self).__init__(**kwargs) self.param = param def call(self, inputs): # Define the forward pass here return outputs def get_config(self): config = super().get_config() # Update the config with the custom layer's parameters config.update({'param': self.param}) return config @classmethod def from_config(cls, config): # Note that you can also use `keras.saving.deserialize_keras_object` here config["param"] = keras.layers.deserialize(config["param"]) return cls(**config)
Loading Models with Custom Objects
When you save a model containing custom objects, you’ll need to provide these objects when loading the model. There are three main ways to ensure Keras recognizes your custom objects:
- Using the
@tf.keras.utils.register_keras_serializable Decorator
(Recommended):
This decorator can be used to automatically register a custom object, making it recognizable by Keras during the load process without any additional arguments.
@tf.keras.utils.register_keras_serializable class CustomLayer(tf.keras.layers.Layer): ...
Then, when loading:
loaded_model = tf.keras.models.load_model('path_to_saved_model')
- Directly Passing the Custom Object:
If not registered using the decorator, the custom object can be passed directly to the
function via theload_model()
argument.custom_objects
loaded_model = tf.keras.models.load_model('path_to_saved_model', custom_objects={'CustomLayer': CustomLayer})
- Using tf.keras.utils.custom_object_scope:
If you have multiple custom objects or prefer a scoped approach, you can use the
.custom_object_scope
with tf.keras.utils.custom_object_scope({'CustomLayer': CustomLayer}): loaded_model = tf.keras.models.load_model('path_to_saved_model')
In summary, when working with custom objects in Keras, it’s crucial to ensure they’re serializable and can be accurately recreated during the loading process. Whether you’re defining a novel layer or a unique activation function, Keras offers multiple ways to seamlessly integrate, save, and load your custom implementations.