View the runnable example on GitHub

Use BFloat16 Mixed Precision for PyTorch Lightning Training#

Brain Floating Point Format (BFloat16) is a custom 16-bit floating point format designed for machine learning. BFloat16 is comprised of 1 sign bit, 8 exponent bits, and 7 mantissa bits. With the same number of exponent bits, BFloat16 has the same dynamic range as FP32, but requires only half the memory usage.

BFloat16 Mixed Precison combines BFloat16 and FP32 during training, which could lead to increased performance and reduced memory usage. Compared to FP16 mixed precison, BFloat16 mixed precision has better numerical stability.

bigdl.nano.pytorch.Trainer API extends PyTorch Lightning Trainer with multiple integrated optimizations. You could instantiate a BigDL-Nano Trainer with precision='bf16' to use BFloat16 mixed precision for training.

📝 Note

Before starting your PyTorch Lightning application, it is highly recommended to run source bigdl-nano-init to set several environment variables based on your current hardware. Empirically, these variables will bring big performance increase for most PyTorch Lightning applications on training workloads.

Let’s take a self-defined LightningModule (based on a ResNet-18 model pretrained on ImageNet dataset) and dataloaders to finetune the model on OxfordIIITPet dataset as an example:

[ ]:
model = MyLightningModule()
train_loader, val_loader = create_dataloaders()

      The definition of MyLightningModule and create_dataloaders can be found in the runnable example.

To use BFloat16 mixed precision for your PyTorch Lightning application, you could simply import BigDL-Nano Trainer, and set precision to be 'bf16':

[ ]:
from bigdl.nano.pytorch import Trainer

trainer = Trainer(max_epochs=5, precision='bf16')

📝 Note

BFloat16 mixed precision in PyTorch Lightning applications requires torch>=1.10.

⚠️ Warning

Using BFloat16 mixed precision with torch<1.12 may result in extremely slow training.

You can also set use_ipex=True and precision='bf16' at the meantime to enable IPEX (Intel® Extension for PyTorch*) optimizer fusion for BFloat16 mixed precision training to gain more acceleration:

[ ]:
from bigdl.nano.pytorch import Trainer

trainer = Trainer(max_epochs=5, use_ipex=True, precision='bf16')

📝 Note

Trainer(..., use_ipex=True, precision='bf16') intends to disable PyTorch native Auto Mixed Precision (AMP) and enable AMP from IPEX.

You could then do BFloat16 mixed precision training and evaluation as normal:

[ ]:
trainer.fit(model, train_dataloaders=train_loader)
trainer.validate(model, dataloaders=val_loader)