Use BFloat16 Mixed Precision for PyTorch Lightning Training#
Brain Floating Point Format (BFloat16) is a custom 16-bit floating point format designed for machine learning. BFloat16 is comprised of 1 sign bit, 8 exponent bits, and 7 mantissa bits. With the same number of exponent bits, BFloat16 has the same dynamic range as FP32, but requires only half the memory usage.
BFloat16 Mixed Precison combines BFloat16 and FP32 during training, which could lead to increased performance and reduced memory usage. Compared to FP16 mixed precison, BFloat16 mixed precision has better numerical stability.
bigdl.nano.pytorch.Trainer API extends PyTorch Lightning Trainer with multiple integrated optimizations. You could instantiate a BigDL-Nano
precision='bf16' to use BFloat16 mixed precision for training.
Before starting your PyTorch Lightning application, it is highly recommended to run
source bigdl-nano-initto set several environment variables based on your current hardware. Empirically, these variables will bring big performance increase for most PyTorch Lightning applications on training workloads.
model = MyLightningModule() train_loader, val_loader = create_dataloaders()
The definition of
create_dataloaders can be found in the runnable example.
To use BFloat16 mixed precision for your PyTorch Lightning application, you could simply import BigDL-Nano Trainer, and set
precision to be
from bigdl.nano.pytorch import Trainer trainer = Trainer(max_epochs=5, precision='bf16')
BFloat16 mixed precision in PyTorch Lightning applications requires
Using BFloat16 mixed precision with
torch<1.12may result in extremely slow training.
You can also set
precision='bf16' at the meantime to enable IPEX (Intel® Extension for PyTorch*) optimizer fusion for BFloat16 mixed precision training to gain more acceleration:
from bigdl.nano.pytorch import Trainer trainer = Trainer(max_epochs=5, use_ipex=True, precision='bf16')
Trainer(..., use_ipex=True, precision='bf16')intends to disable PyTorch native Auto Mixed Precision (AMP) and enable AMP from IPEX.
You could then do BFloat16 mixed precision training and evaluation as normal:
trainer.fit(model, train_dataloaders=train_loader) trainer.validate(model, dataloaders=val_loader)