Save and Load Models to Disk in PyTorch, Python: A Complete Guide

Photo of a futuristic hard disk drive with a transparent casing, revealing its inner workings, accompanied by holographic data projections.

Learn the essentials of saving and loading models in PyTorch with our complete guide. Dive into the details of handling the state_dict and understand its pivotal role in managing model parameters, vital for tasks like transfer learning and model sharing. Grasp practical insights into saving classifiers for inference, ensuring a smooth transition from training to deployment. Further, the article unfolds the systematic approach to resume training through checkpointing, safeguarding against potential disruptions. Additionally, explore advanced concepts like saving multiple models and managing models across GPUs and CPUs, enhancing your toolkit for robust and flexible model management in various application scenarios.

If you are only interested in the quick and easy 10-second introduction to saving and loading PyTorch models, here it is:

For saving a model’s `state_dict`, use torch.save():

# Define file path
file_path = 'path_to_directory/model_name.pth'

#Define, instantiate, and train a model here:
class MyModel(nn.Module):
    ...

model = MyModel()

# ... [training procedure here] ...


# Save model parameters
torch.save(model.state_dict(), file_path)

For loading a model, use torch.load() and model.load_state_dict():

# Define file path
file_path = 'path_to_directory/model_name.pth'

# Instantiate the model (it needs to be defined of course)
model = MyModel()

model.load_state_dict(torch.load(file_path))

model.eval()  # For inference mode

This allows you to save and load the model’s `state_dict`, which you need for inference with the saved model. If you are interested in further details, like resuming training, keep reading!

Table of contents:

  1. What is a state_dict?
  2. How to Save a Classifier to Disk? (Inference)
  3. How to Load a Classifier from Disk?
  4. Save and Load PyTorch Model from a Checkpoint (Resume Training)
  5. Saving Many Models in One File
  6. Save and Load Entire PyTorch Model
  7. Saving and Loading Across GPU and CPU

1. What is a `state_dict`?

In PyTorch, a `state_dict` (or “state dictionary”) is a Python dictionary object that maps each layer in a torch.nn.Module model to its trainable parameters (weights and biases). It is a lightweight representation of the model’s parameters and is an integral element for saving and loading model weights in PyTorch.

A `state_dict` contains the parameter and persistent buffer information for each layer, making it vital for transferring trained model parameters during checkpointing, for inference, or for sharing models among developers. Importantly, only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s `state_dict`. For models comprised of sub-models, the `state_dict` will also contain the parameters of those sub-models, nested within the primary model’s dictionary.

Additionally, optimizer objects (torch.optim) have a ‘state_dict’ as well. This state_dict contains the hyperparameters used by the optimizer as well as information about the optimizer’s state. It is important to save this state_dict as well, if you wish to continue training later (or if you wish to create a checkpoint in case the training process fails for whatever reason). 

Example:

import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 1)

    def forward(self, x):
        return self.fc(x)

# Instantiate the model
model = SimpleModel()

# Instantiate optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Print the state_dict of the model
print("Model's state_dict:")
for parameter_tensor in model.state_dict():
    print(parameter_tensor, "\t", model.state_dict()[parameter_tensor].size())

# Print the state_dict of the optimizer
print("\nOptimizer's state_dict:")
for parameter_tensor in optimizer.state_dict():
    print(parameter_tensor, "\t", optimizer.state_dict()[parameter_tensor])

Running this code should yield the following output:

Model's state_dict:
fc.weight        torch.Size([1, 3])
fc.bias          torch.Size([1])

Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1]}]

As we saw, both the model and the optimizer have a `state_dict` that can be stored, allowing users to save the entire training state. This can be useful for resuming training at a later point, for model evaluation, or for inference. When saving a model in PyTorch, using `torch.save()` to serialize the model’s `state_dict` is a common and recommended approach. This way, you have the flexibility to load the saved parameters into models with identical architectures, facilitating tasks like fine-tuning or transfer learning. Let’s take a look at that in the next section:

2. How to Save a Classifier to Disk? (Inference)

Saving a model for inference in PyTorch involves storing the model parameters, contained within the `state_dict`, to disk. To do so, you utilize `torch.save()` to serialize the `state_dict` and ensure that it can be later loaded into the same model architecture.

Here is a step-by-step guide and sample code to illustrate how to save a PyTorch model’s state_dict to disk:

Step 1: Train the Model

Train your model using your training data and ensure it has achieved satisfactory performance based on your evaluation metrics.

Step 2: Save the Model’s `state_dict`

Utilize `torch.save()` to serialize the model’s `state_dict` and specify the file path where the model parameters will be saved.

# Define file path
file_path = 'path_to_directory/model_name.pth'

# Save model parameters
torch.save(model.state_dict(), file_path)

Example Code:

import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 1)

    def forward(self, x):
        return self.fc(x)

# Instantiate and train the model
model = SimpleModel()

# ... [training procedure here] ...

# Save the model parameters
file_path = 'simple_model.pth'
torch.save(model.state_dict(), file_path)

In the above example, the model’s `state_dict` is saved to a file named `simple_model.pth` in the same directory, where the python file containing the above code resides. This file contains the learned parameters and can be used to initialize an identically structured model in the future, ensuring that the learned information can be utilized for inference on new data. Note that this procedure ensures the model is saved only for inference, without retaining training-specific details like optimizer state.

Running the above example code creates a file called “simple_model.pth”. This file contains the `state_dict` of the model (weights and biases), although with no training they are not actually useful for any task.

3. How to Load a Classifier from Disk (Inference)?

Loading a pre-trained model from disk involves initializing the model architecture and then populating it with the parameters stored in the `state_dict`. For inference purposes, this entails using the stored model weights to make predictions, without further modification to the parameter values. 

Below is a guide and accompanying sample code that demonstrates how to load a saved PyTorch model from disk, specifically for conducting inference:

Step 1: Define the Model Architecture

Recreate the exact same model architecture used when the model was saved.

import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 1)

    def forward(self, x):
        return self.fc(x)

Step 2: Initialize the Model

Instantiate the model:

# Instantiate the model
model = SimpleModel()

Step 3: Load the Model Parameters

Utilize `torch.load()` to deserialize the stored `state_dict` and use `model.load_state_dict()` to populate the model with the loaded parameters. Then ensure the model is set to evaluation mode using `model.eval()`.

Note: The load_state_dict() function takes a dictionary object and not a path to the saved object. This is why you must first deserialize the state_dict that is saved on disk with torch.load().

# Define file path
file_path = 'simple_model.pth'

# Load the stored model parameters
model.load_state_dict(torch.load(file_path))
model.eval()

Full example code, that also tests that loading the `state_dict` had an effect:

import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 1)

    def forward(self, x):
        return self.fc(x)

# Instantiate the model and switch to evaluation mode
model = SimpleModel()
model.eval()

# Generate a random data point
random_input = torch.randn(1, 3)

# Get the output of the model before loading the parameters
before_loading_output = model(random_input)

# Define file path
file_path = 'simple_model.pth'

# Load the stored model parameters
model.load_state_dict(torch.load(file_path))

# Get the output of the model after loading the parameters
after_loading_output = model(random_input)

# Compare the outputs (the outputs should be different before and after loading)
if torch.allclose(before_loading_output, after_loading_output, atol=1e-7):
    print("The model output is the same before and after loading.")
else:
    print("The model output changed after loading!")

You should get the following output:

The model output changed after loading!

After performing these steps, the model is now populated with the learned parameters and is set in evaluation mode, ready to make predictions on new data for inference tasks. When conducting inference, it’s crucial to also apply the same preprocessing to the input data as was applied during training to ensure consistency and reliability in the model’s predictions.

In summary, loading a model for inference in PyTorch involves defining and initializing the model, and then loading the previously saved parameters. This provides the ability to reuse trained models, facilitate model sharing, and deploy models into production environments.

4. Save and Load PyTorch Model from a Checkpoint (Resume Training)

Checkpointing in PyTorch involves saving the `state_dict` of both the model and the optimizer, in addition to other training metadata like the epoch number, to ensure that training can be resumed accurately if interrupted. This is crucial for scenarios involving long training times, accidental disruptions, or planned training across multiple sessions. 

Below is a comprehensive guide along with example code on how to effectively create a checkpoint during training, and subsequently load from it.

Saving a Checkpoint

Step 1: Define Model and Optimizer

Ensure your model and optimizer are defined and instantiated.

import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 1)
    
    def forward(self, x):
        return self.fc(x)

# Instantiate model and optimizer
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)

Step 2: Save the Checkpoint

Create a checkpoint by saving the model’s `state_dict`, the optimizer’s `state_dict`, and any additional training metadata as a serialized object.

# Define file path
checkpoint_path = 'path_to_directory/checkpoint.pth'

# Define checkpoint
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    # add more model training metadata if needed
}

# Save checkpoint
torch.save(checkpoint, checkpoint_path)

Full example code (with dummy data and a training loop):

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 1)
    
    def forward(self, x):
        return self.fc(x)

# Generate some dummy data for training
x_train = torch.randn(100, 3)
y_train = torch.randn(100, 1)
dataset = TensorDataset(x_train, y_train)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Instantiate model, loss function, and optimizer
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0
    for x_batch, y_batch in data_loader:
        optimizer.zero_grad()
        y_pred = model(x_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}')
    
    # Define file path
    checkpoint_path = f'checkpoint_epoch_{epoch+1}.pth'
    
    # Define checkpoint
    checkpoint = {
        'epoch': epoch+1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss/len(data_loader),
    }

    # Save checkpoint
    torch.save(checkpoint, checkpoint_path)

This code will train the SimpleModel on the dummy data, print out the average loss for each epoch, and save a checkpoint for the model at the end of each epoch. Running this code should produce the following output (although the loss values will be different for each run):

Epoch 1/5, Loss: 1.3930703178048134
Epoch 2/5, Loss: 1.632726639509201
Epoch 3/5, Loss: 1.6239219456911087
Epoch 4/5, Loss: 1.300402745604515
Epoch 5/5, Loss: 1.1997566521167755

And the checkpoints should now be saved in the directory, where you ran the code:

Loading from a Checkpoint

Step 1: Load the Checkpoint

Utilize `torch.load()` to deserialize the checkpoint and subsequently load the `state_dict` for both the model and the optimizer. Also, retrieve the training metadata, such as epoch number, to resume training accurately.

# Load the checkpoint
checkpoint = torch.load(checkpoint_path)

# Apply the state_dict to model and optimizer
model = SimpleModel()  # Initialize model; Ensure it's the same architecture
model.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)  # Initialize optimizer; Ensure it's the same optimizer type
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Retrieve the training epoch
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()  # For inference mode
# OR
model.train()  # For training mode (resuming training)

Step 2: Resume Training

Ensure the model is placed in training mode using `model.train()` and resume the training process from the saved epoch.

# Ensure model is in training mode
model.train()

# Continue training from the saved epoch
for e in range(epoch, num_epochs):
    # Training code here...

By employing this checkpointing mechanism, training can be reliably resumed following disruptions or planned pauses. This ensures the fidelity of the training process and allows for flexibility in training deep learning models over extended periods, safeguarding against data loss from various disruptions and facilitating planned, distributed training efforts.

5. Saving Many Models in One File

In certain scenarios, such as during ensemble learning or multitask learning, you might need to manage multiple models. PyTorch enables users to efficiently save multiple models or additional training-related variables in a single file by following exactly the same approach as when saving a general checkpoint. This means saving each model’s (and possibly optimizer’s) state_dict in a single dictionary.

Storing Multiple Models:

To save multiple models or several `state_dict`s in one file, use a dictionary to bundle them and pass this to `torch.save()`.

Example:

# Save
torch.save({
    'modelA_state_dict': modelA.state_dict(),
    'modelB_state_dict': modelB.state_dict(),
    # ... any other models or states
}, 'path_to_directory/multiple_models.pth')

# Load
checkpoint = torch.load('path_to_directory/multiple_models.pth')
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

Saving Models Alongside Other Parameters:

Beyond models, you might want to save other training parameters or variables. Combine the model’s `state_dict` and any additional variables into a Python dictionary and save it.

Example:

# Save
torch.save({
    'modelA_state_dict': modelA.state_dict(),
    'modelB_state_dict': modelB.state_dict(),
    'optimizerA_state_dict': optimizerA.state_dict(),
    'optimizerB_state_dict': optimizerB.state_dict(),
    'epoch': epoch,
    'loss': loss,
    # ... any other states
}, 'path_to_directory/models_and_params.pth')

# Load
checkpoint = torch.load('path_to_directory/models_and_params.pth')
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

Full Example Code:

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 1)
    def forward(self, x):
        return self.fc(x)

# Initialize models and optimizers
modelA = SimpleModel()
optimizerA = optim.SGD(modelA.parameters(), lr=0.001)

modelB = SimpleModel()
optimizerB = optim.SGD(modelB.parameters(), lr=0.001)

# Assume some training happens...

# Saving multiple models and additional training variables
torch.save({
    'modelA_state_dict': modelA.state_dict(),
    'modelB_state_dict': modelB.state_dict(),
    'optimizerA_state_dict': optimizerA.state_dict(),
    'optimizerB_state_dict': optimizerB.state_dict(),
    'epoch': 50,  # Dummy value for reference
    'loss': 0.03, # dummy value for reference
}, 'path_to_directory/models_and_params.pth')

# Assume we now want to load the models and parameters, perhaps in a separate file...

# Initialize new models and optimizers
newModelA = SimpleModel()
newOptimizerA = optim.SGD(newModelA.parameters(), lr=0.001)

newModelB = SimpleModel()
newOptimizerB = optim.SGD(newModelB.parameters(), lr=0.001)

# Loading
checkpoint = torch.load('path_to_directory/models_and_params.pth')
newModelA.load_state_dict(checkpoint['modelA_state_dict'])
newModelB.load_state_dict(checkpoint['modelB_state_dict'])
newOptimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
newOptimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

newModelA.eval()
newModelB.eval()
# - or -
newModelA.train()
newModelB.train()

In the example above, `modelA` and `modelB` (and their respective optimizer states) are saved alongside other training parameters like the epoch number and loss value. Utilizing a single dictionary to encapsulate multiple states for saving ensures organized and efficient storage, while also allowing you to load everything you need from just one file. This approach streamlines both the saving and loading process, particularly when managing multiple models or training states.

6. Save and Load Entire PyTorch Model

PyTorch allows you to save the whole model using `torch.save()` by passing in the model object directly. This saves the entire module, preserving the architecture and the parameter tensors together.

Saving the Entire Model

# Save
torch.save(model, 'path_to_directory/entire_model.pth')

Loading the Entire Model:

To load the entire model, you can again use `torch.load()` and subsequently, ensure to call `model.eval()` to set dropout and batch normalization layers to evaluation mode.

# Load
model = SimpleModel()
model = torch.load('path_to_directory/entire_model.pth')
model.eval()

Note on Portability and Flexibility:

While saving the entire model ensures portability, it relies on Python’s `pickle` module, which is sensitive to the classes and the directory structure that is in use when the model is saved. Therefore, your code can break in various ways:

  • Different versions of PyTorch or other dependencies may break your code
  • Refactors, or use in other projects may break your code

Saving the entire model is suitable for quick and simple model saving. However, this approach is not generally recommended. It’s recommended to save and share the model’s `state_dict` for maximum flexibility and minimum dependency issues.

7. Saving and Loading Across GPU and CPU

In machine learning workflows, flexibility in utilizing both GPUs and CPUs for model training and inference is crucial. PyTorch facilitates ease in saving and loading models across different compute devices, ensuring your workflows are robust and adaptable. Below are the typical scenarios and how to handle them in PyTorch:

1. Save on GPU, Load on CPU

If you’ve trained a model on GPU and want to load it on a CPU for inference or further training, you can instruct PyTorch to map the stored GPU tensors to CPU tensors using the `map_location` argument.

# Save model on GPU
torch.save(model.state_dict(), 'model_gpu.pth')

# Load on CPU
device = torch.devide("cpu")
model = SimpleModel()
model.load_state_dict(torch.load('model_gpu.pth', map_location=device))

2. Save on GPU, Load on GPU

If both the saving and loading are done on a GPU, you can directly load the `state_dict` as is, but ensure to move the model to the GPU first before loading the `state_dict`.

# Save model on GPU
torch.save(model.state_dict(), 'model_gpu.pth')

# Load on GPU
device = torch.device("cuda")
model = SimpleModel().to(device)  # Ensure the model is on GPU
model.load_state_dict(torch.load('model_gpu.pth'))
model.eval()  # Set the model to evaluation mode if needed

3. Save on CPU, Load on GPU

If you’ve saved the model on a CPU and want to load it on a GPU for further training or inference, load the model and then use `model.to()` to move it to the GPU.

# Save model on CPU
torch.save(model.state_dict(), 'model_cpu.pth')

# Load on GPU
device = torch.device("cuda")
model = SimpleModel()
model.load_state_dict(torch.load('model_cpu.pth', map_location="cuda:0"))  # Ensure map_location points to GPU
model.to(device)  # Move the model to GPU
model.eval()  # Set the model to evaluation mode if needed

The `map_location` argument is crucial in ensuring the `state_dict` tensors are loaded onto the correct device. It helps avoid issues related to mismatched tensor device types, ensuring a smooth transition of models between CPUs and GPUs, and providing versatility regardless of the hardware available. This allows for broad sharing of pre-trained models and facilitates varying workflows in different environments.

Scroll to Top