PyTorch on Aurora
PyTorch is a popular, open-source deep learning framework developed and released by Facebook. The PyTorch home page, has more information about PyTorch, which you can refer to. For troubleshooting on Aurora, please contact [email protected].
Provided Installation
PyTorch is already installed on Aurora with GPU support and available through the frameworks module. To use it from a compute node, please load the following modules:
Then, you canimport
PyTorch in Python as usual (below showing results from the frameworks/2024.2.1_u1
module):
A simple but useful check could be to use PyTorch to get device information on a compute node. You can do this the following way:
Example output:
GPU availability: True
Number of tiles = 12
Current tile = 0
Current device ID = <intel_extension_for_pytorch.xpu.device object at 0x1540a9f25790>
Device properties = _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) Level-Zero', \
type='gpu', driver_version='1.3.30872', total_memory=65536MB, max_compute_units=448, gpu_eu_count=448, \
gpu_subslice_count=56, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, \
has_atomic64=1)
Each Aurora node has 6 GPUs (also called "Devices" or "cards") and each GPU is composed of two tiles (also called "Sub-device"). By default, each tile is mapped to one PyTorch device, giving a total of 12 devices per node in the above output.
import intel_extension_for_pytorch as ipex
Along with importing the torch
module, you need to import the intel_extension_for_pytorch
module in order to detect Intel GPUs as xpu
devices.
Warning
It is highly recommended to import intel_extension_for_pytorch
right after import torch
, prior to importing other packages, (from Intel's "Getting Started" doc).
Using GPU Devices as PyTorch devices
By default, each tile is mapped to one PyTorch device, giving a total of 12 devices per node, as seen above. To map a PyTorch device to one particular GPU Device out of the 6 available on a compute node, these environmental variables should be set
export ZE_FLAT_DEVICE_HIERARCHY=COMPOSITE
export ZE_AFFINITY_MASK=0
# or, equivalently, following the syntax `Device.Sub-device`
export ZE_AFFINITY_MASK=0.0,0.1
Device:0
and Sub-devices: 0, 1
, i.e. the two tiles of the GPU:0. This is
particularly important in setting a performance benchmarking baseline.
Setting the above environmental variables after loading the frameworks modules,
you can check that each PyTorch device is now mapped to one GPU:
Example output
1
_XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) Level-Zero', type='gpu', driver_version='1.3.30872', total_memory=131072MB, max_compute_units=896, gpu_eu_count=896, gpu_subslice_count=112, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
More information and details are available through the Level Zero Specification Documentation - Affinity Mask
Code changes to run PyTorch on Aurora GPUs
Intel Extension for PyTorch (IPEX) is an open-source project that extends PyTorch with optimizations for extra performance boost on Intel CPUs and enables the use of Intel GPUs.
Here we list some common changes that you may need to do to your PyTorch code in order to use Intel GPUs.
Please consult Intel's IPEX Documentation for additional details and useful tutorials.
- Import the
intel_extension_for_pytorch
right after importingtorch
: - All the
API
calls involvingtorch.cuda
, should be replaced withtorch.xpu
. For example: - When moving tensors and model to GPU, replace
"cuda"
with"xpu"
. For example: - Convert model and loss criterion to
xpu
, and then callipex.optimize
for additional performance boost:
Tip
A more portable solution to select the appropriate device is the following:
Example: training a PyTorch model on a single GPU tile
Here is a simple code to train a dummy PyTorch model on CPU:
And here is the code to train the same model on a single GPU tile on Aurora, with new or modified lines highlighted:
Here are the steps to run the above code on Aurora:
- Login to Aurora:
- Request a one-node interactive job for 30 minutes:
- Copy the above Python script into a file called
pytorch_xpu.py
and make it executable withchmod a+x pytorch_xpu.py
. - Load the frameworks module:
- Run the script:
PyTorch Best Practices on Aurora
When running PyTorch applications, we have found the following practices to be generally, if not universally, useful and encourage you to try some of these techniques to boost performance of your own applications.
-
Use Reduced Precision. Reduced Precision is available on Intel Max 1550 and is supported with PyTorch operations. In general, the way to do this is via the PyTorch Automatic Mixed Precision package (AMP), as described in the mixed precision documentation. In PyTorch, users generally need to manage casting and loss scaling manually, though context managers and function decorators can provide easy tools to do this.
-
PyTorch has a
JIT
module as well as backends to support op fusion, similar to TensorFlow'stf.function
tools. See TorchScript for more information. -
torch.compile
will be available through the next framework release. -
In order to run an application with
TF32
precision type, one must set the following environmental parameter:export IPEX_FP32_MATH_MODE=TF32
. This allows calculations usingTF32
as opposed to the defaultFP32
, and done throughintel_extension_for_pytorch
module. -
For convolutional neural networks, using
channels_last
(NHWC) memory format gives better performance. More info here and here
Distributed Training on multiple GPUs
Distributed training with PyTorch on Aurora is facilitated through both Distributed Data Parallel (DDP) and Horovod, with comparable performance. We recommend using native PyTorch DDP to perform Data Parallel training on Aurora.
Distributed Data Parallel (DDP)
DDP training is accelerated using oneAPI Collective Communications Library Bindings for Pytorch (oneccl_bindings_for_pytorch
). The extension supports FP32 and BF16 data types.
More detailed information and examples are available at the Intel oneCCL repo, formerly known as torch-ccl
.
Code changes to train on multiple GPUs using DDP
The key steps in performing distributed training are:
- Load the
oneccl_bindings_for_pytorch
module, which enables efficient distributed deep learning training in PyTorch using Intel's oneCCL library, implementing collectives likeallreduce
,allgather
,alltoall
. - Initialize PyTorch's
DistributedDataParallel
- Use
DistributedSampler
to partition the training data among the ranks - Pin each rank to a GPU
- Wrap the model in DDP to keep it in sync across the ranks
- Rescale the learning rate
- Use
set_epoch
for shuffling data across epochs
Here is the code to train the same dummy PyTorch model on multiple GPUs, where new or modified lines have been highlighted:
Here are the steps to run the above code on Aurora:
- Login to Aurora:
- Request a one-node interactive job for 30 minutes:
- Copy the above Python script into a file called
pytorch_ddp.py
and make it executable withchmod a+x pytorch_ddp.py
. - Load the frameworks module:
- Run the script on 24 tiles, 12 per node:
Settings for training beyond 16 nodes
When training at medium and large scales, we recommend using the module frameworks_optimized
, which provides an optimized setup based on observed performance.
To use this optimized setup, the last two steps of the above instructions should be replaced with the following ones:
- Load the
frameworks_optimized
module: - Run the script on 24 tiles, 12 per node:
Setting the CPU Affinity
The CPU affinity can be set manually through mpiexec. You can do this the following way (after having loaded all needed modules):
export CPU_BIND="verbose,list:2-4:10-12:18-20:26-28:34-36:42-44:54-56:62-64:70-72:78-80:86-88:94-96"
mpiexec ... --cpu-bind=${CPU_BIND}
These bindings should be used along with the following oneCCL and Horovod environment variable settings:
HOROVOD_THREAD_AFFINITY="4,12,20,28,36,44,56,64,72,80,88,96"
CCL_WORKER_AFFINITY="5,13,21,29,37,45,57,65,73,81,89,97"
When running 12 ranks per node with these settings the framework
s use 3 cores,
with Horovod tightly coupled with the framework
s using one of the 3 cores, and
oneCCL using a separate core for better performance, eg. with rank 0 the
framework
s would use cores 2,3,4, Horovod would use core 4, and oneCCL would
use core 5.
Each workload may perform better with different settings. The criteria for choosing the cpu bindings are:
- Binding for GPU and NIC affinity – To bind the ranks to cores on the proper socket or NUMA nodes.
- Binding for cache access – This is the part that will change per application and some experimentation is needed.
Important: This setup is a work in progress, and based on observed
performance. The recommended settings are likely to changed with new framework
releases.
Distributed Training with Multiple CCSs
The Intel PVC GPUs contain 4 Compute Command Streamers (CCSs) on each tile, which can be used to group Execution Units (EUs) into common pools. These pools can then be accessed by separate processes thereby enabling distributed training with multiple MPI processes per tile. This feature on PVC is similar to MPS on NVIDIA GPUs and can be beneficial for increasing computational throughput when training or performing inference with smaller models which do not require the entire memory of a PVC tile. For more information, see the section on using multiple CCSs under the Running Jobs on Aurora page.
For both DDP and Horovod, distributed training with multiple CCSs can be enabled programmatically within the user code by explicitly setting the xpu
device in PyTorch, for example
- PVC GPU allow the use of 1, 2 or 4 CCSs on each tile
and then adding the proper environment variables and mpiexec
settings in the run script.
For example, to run distributed training with 48 MPI processes per node exposing 4 CCSs per tile, set
Alternatively, users can use the following modified GPU affinity script in their mpiexec
command in order to bind multiple MPI processes to each tile by setting ZE_AFFINITY_MASK
gpu_affinity_ccs.sh | |
---|---|
- Note that the script takes the number of CCSs exposed as a command line argument
Checking PVC usage with xpu-smi
Users are invited to check correct placement of the MPI ranks on the different tiles by connecting to the compute node being used and executing
- In this case, GPU_ID refers to the 6 GPU on each node, not an individual tile
and checking the GPU and memory utilization of both tiles.
Multiple CCSs and oneCCL
- When performing distributed training exposing multiple CCSs, the collective communications with the oneCCL backend are delegated to the CPU. This is done in the background by oneCCL, so no change to the users' code is required to move data between host and device, however it may impact the performance of the collectives at scale.
- When using PyTorch DDP, the model must be offloaded to the XPU device after calling the
DDP()
wrapper on the model to avoid hangs.