View the runnable example on GitHub

Save and Load Optimized JIT Model#

This example illustrates how to save and load a model accelerated by JIT. In this example, we use a ResNet18 model pretrained. Then, by calling InferenceOptimizer.trace(model, accelerator="jit"...), we can obtain a model accelarated by JIT method. By calling, path) , we could save the model to a folder. By calling InferenceOptimizer.load(path), we could load the model from a folder.

To inference using Bigdl-nano InferenceOptimizer, the following packages need to be installed first. We recommend you to use Miniconda to prepare the environment and install the following packages in a conda environment.

You can create a conda environment by executing:

# "nano" is conda environment name, you can use any name you like.
conda create -n nano python=3.7 setuptools=58.0.4
conda activate nano

📝 Note

During your installation, there may be some warnings or errors about version, just ignore them.

[ ]:
# Necessary packages for inference accelaration
!pip install --pre --upgrade bigdl-nano[pytorch]

First, prepare model. We need load the pretrained ResNet18 model.

[ ]:
import torch
from torchvision.models import resnet18

model_ft = resnet18(pretrained=True)

Accelerate Inference Using JIT

[ ]:
from bigdl.nano.pytorch import InferenceOptimizer
jit_model = InferenceOptimizer.trace(model_ft,
                                     input_sample=torch.rand(1, 3, 224, 224))

Save Optimized JIT Model The saved model files will be saved at “./optimized_model_jit” directory There are 2 files in optimized_model_jit, users only need to take “ckpt.pth” file for further usage:

  • nano_model_meta.yml: meta information of the saved model checkpoint

  • ckpt.pth: JIT model checkpoint for general use, describes model structure

[ ]:, "./optimized_model_jit")

Load the Optimized Model

[ ]:
loaded_model = InferenceOptimizer.load("./optimized_model_jit")

Inference with the Loaded Model

[ ]:
with InferenceOptimizer.get_context(loaded_model):
    x = torch.rand(2, 3, 224, 224)
    y_hat = loaded_model(x)
    predictions = y_hat.argmax(dim=1)