|
| 1 | +Distributed Training in PyG |
| 2 | +=========================== |
| 3 | + |
| 4 | +.. note:: |
| 5 | + We are thrilled to announce the first **in-house distributed training solution** for :pyg:`PyG` via :class:`torch_geometric.distributed`, available from version 2.5 onwards. |
| 6 | + Developers and researchers can now take full advantage of distributed training on large-scale datasets which cannot be fully loaded in memory of one machine at the same time. |
| 7 | + This implementation doesn't require any additional packages to be installed on top of the default :pyg:`PyG` stack. |
| 8 | + |
| 9 | +In real life applications, graphs often consists of billions of nodes that cannott fit into a single system memory. |
| 10 | +This is when distributed training of Graph Neural Networks comes in handy. |
| 11 | +By allocating a number of partitions of the large graph into a cluster of CPUs, one can deploy synchronized model training on the whole dataset at once by making use of :pytorch:`PyTorch's` `Distributed Data Parallel (DDP) <https://pytorch.org/docs/stable/notes/ddp.html>`_ capabilities. |
| 12 | +This architecture seamlessly distributes training of Graph Neural Networks across multiple nodes via `Remote Procedure Calls (RPCs) <https://pytorch.org/docs/stable/rpc.html>`_ for efficient sampling and retrieval of non-local features with traditional DDP for model training. |
| 13 | + |
| 14 | +Key Advantages |
| 15 | +-------------- |
| 16 | + |
| 17 | +#. **Balanced graph partitioning** via METIS ensures minimal communication overhead when sampling subgraphs across compute nodes. |
| 18 | +#. Utilizing **DDP for model training in conjunction with RPC for remote sampling and feature fetching routines** (with TCP/IP protocol and `gloo <https://github.com/facebookincubator/gloo>`_ communication backend) allows for data parallelism with distinct data partitions at each node. |
| 19 | +#. The implementation via custom :class:`~torch_geometric.data.GraphStore` and :class:`~torch_geometric.data.FeatureStore` APIs provides a flexible and tailored interface for distributing large graph structure information and feature storage. |
| 20 | +#. Distributed neighbor sampling is capable of sampling in both local and remote partitions through RPC communication channels. |
| 21 | + All advanced functionality of single-node sampling are also applicable for distributed training, *e.g.*, heterogeneous sampling, link-level sampling, temporal sampling, *etc*.. |
| 22 | +#. Distributed data loaders offer a high-level abstraction for managing sampler processes, ensuring simplicity and seamless integration with standard :pyg:`PyG` data loaders.. |
| 23 | +#. Incorporating the Python `asyncio <https://docs.python.org/3/library/asyncio.html>`_ library for asynchronous processing on top of :pytorch:`PyTorch`-based RPCs further enhances the system's responsiveness and overall performance. |
| 24 | + |
| 25 | +Architecture Components |
| 26 | +----------------------- |
| 27 | + |
| 28 | +.. note:: |
| 29 | + The purpose of this tutorial is to guide you through the most important steps of deploying distributed training applications in :pyg:`PyG`. |
| 30 | + For code examples, please refer to `examples/distributed/pyg <https://github.com/pyg-team/pytorch_geometric/tree/master/examples/distributed/pyg>`_. |
| 31 | + |
| 32 | +Overall, :class:`torch_geometric.distributed` is divided into the following components: |
| 33 | + |
| 34 | +* :class:`~torch_geometric.distributed.Partitoner` partitions the graph into multiple parts, such that each node only needs to load its local data in memory. |
| 35 | +* :class:`~torch_geometric.distributed.LocalGraphStore` and :class:`~torch_geometric.distributed.LocalFeatureStore` store the graph topology and features per partition, respectively. |
| 36 | + In addition, they maintain a mapping between local and global IDs for efficient assignment of nodes and feature lookup. |
| 37 | +* :class:`~torch_geometric.distributed.DistNeighborSampler` implements the distributed sampling algorithm, which includes local+remote sampling and the final merge between local/remote sampling results based on :pytorch:`PyTorch's` RPC mechanisms. |
| 38 | +* :class:`~torch_geometric.distributed.DistNeighborLoader` manages the distributed neighbor sampling and feature fetching processes via multiple RPC workers. |
| 39 | + Finally, it takes care to form sampled nodes, edges, and their features into the classic :pyg:`PyG` data format. |
| 40 | + |
| 41 | +.. figure:: ../_figures/dist_proc.png |
| 42 | + :align: center |
| 43 | + :width: 100% |
| 44 | + |
| 45 | + Schematic breakdown of the main components of :class:`torch_geometric.distributed`. |
| 46 | + |
| 47 | +Graph Partitioning |
| 48 | +~~~~~~~~~~~~~~~~~~ |
| 49 | + |
| 50 | +The first step for distributed training is to split the graph into multiple smaller portions, which can then be loaded locally into nodes of the cluster. |
| 51 | +Partitioning is built on top of :pyg:`null` :obj:`pyg-lib`'s `implementation <https://pyg-lib.readthedocs.io/en/latest/modules/partition.html#pyg_lib.partition.metis>`_ of the METIS algorithm, suitable to perform graph partitioning efficiently, even on large-scale graphs. |
| 52 | +Note that METIS requires undirected, homogeneous graphs as input. |
| 53 | +:class:`~torch_geometric.distributed.Partitoner` performs necessary processing steps to partition heterogeneous data objects with correct distribution and indexing. |
| 54 | + |
| 55 | +By default, METIS tries to balance the number of nodes of each type in each partition while minimizing the number of edges between partitions. |
| 56 | +This ensures that the resulting partitions provide maximal local access of neighbors, enabling samplers to perform local computations without the need for communication between different compute nodes. |
| 57 | +Through this partitioning approach, every edge receives a distinct assignment, while "halo nodes" (1-hop neighbors that fall into a different partition) are replicated. |
| 58 | +Halo nodes ensure that neighbor sampling for a single node in a single layer stays purely local. |
| 59 | + |
| 60 | +In our distributed training example, we prepared the `partition_graph.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/distributed/pyg/partition_graph.py>`_ script to demonstrate how to apply partitioning on a selected subset of both homogeneous and heterogeneous graphs. |
| 61 | +The :class:`~torch_geometric.distributed.Partitioner` can also preserve node features, edge features, and any temporal attributes at the level of nodes and edges. |
| 62 | +Later on, each node in the cluster then owns a single partition of this graph. |
| 63 | + |
| 64 | +.. warning:: |
| 65 | + Partitioning via METIS is non-deterministic and as such may differ between iterations. |
| 66 | + However, all compute nodes should access the same partition data. |
| 67 | + Therefore, generate the partitions on one node and copy the data to all members of the cluster, or place the folder into a shared location. |
| 68 | + |
| 69 | +The resulting structure of partitioning for a two-part split on the homogeneous :obj:`ogbn-products` is shown below: |
| 70 | + |
| 71 | +.. code-block:: |
| 72 | +
|
| 73 | + partitions |
| 74 | + └─ obgn-products |
| 75 | + ├─ ogbn-products-partitions |
| 76 | + │ ├─ part_0 |
| 77 | + │ ├─ part_1 |
| 78 | + │ ├─ META.json |
| 79 | + │ ├─ node_map.pt |
| 80 | + │ └─ edge_map.pt |
| 81 | + ├─ ogbn-products-label |
| 82 | + │ └─ label.pt |
| 83 | + ├─ ogbn-products-test-partitions |
| 84 | + │ ├─ partition0.pt |
| 85 | + │ └─ partition1.pt |
| 86 | + └─ ogbn-products-train-partitions |
| 87 | + ├─ partition0.pt |
| 88 | + └─ partition1.pt |
| 89 | +
|
| 90 | +Distributed Data Storage |
| 91 | +~~~~~~~~~~~~~~~~~~~~~~~~ |
| 92 | + |
| 93 | +To maintain distributed data partitions, we utilize instantiations of :pyg:`PyG's` :class:`~torch_geometric.data.GraphStore` and :class:`~torch_geometric.data.FeatureStore` remote interfaces. |
| 94 | +Together with an integrated API for sending and receiving RPC requests, they provide a powerful tool for inter-connected distributed data storage. |
| 95 | +Both stores can be filled with data in a number of ways, *e.g.*, from :class:`~torch_geometric.data.Data` and :class:`~torch_geometric.data.HeteroData` objects or initialized directly from generated partition files. |
| 96 | + |
| 97 | +:class:`~torch_geometric.distributed.LocalGraphStore` is a class designed to act as a **container for graph topology information**. |
| 98 | +It holds the edge indices that define relationships between nodes in a graph. |
| 99 | +It offers methods that provide mapping information for nodes and edges to individual partitions and support both homogeneous and heterogeneous data formats. |
| 100 | + |
| 101 | +**Key Features:** |
| 102 | + |
| 103 | +* It only stores information about local graph connections and its halo nodes within a partition. |
| 104 | +* Remote connectivity: The affiliation information of individual nodes and edges to partitions (both local and global) can be retrieved through node and edge "partition books", *i.e.* mappings of partition IDs to global node/edge IDs. |
| 105 | +* It maintains global identifiers for nodes and edges, allowing for consistent mapping across partitions. |
| 106 | + |
| 107 | +:class:`~torch_geometric.distributed.LocalFeatureStore` is a class that serves as both a **node-level and edge-level feature storage**. |
| 108 | +It provides efficient :obj:`put` and :obj:`get` routines for attribute retrieval for both local and remote node/edge IDs. |
| 109 | +The :class:`~torch_geometric.distributed.LocalFeatureStore` is responsible for retrieving and updating features across different partitions and machines during the training process. |
| 110 | + |
| 111 | +**Key Features:** |
| 112 | + |
| 113 | +* It provides functionalities for storing, retrieving, and distributing node and edge features. |
| 114 | + Within the managed partition of a machine, node and edge features are stored locally. |
| 115 | +* Remote feature lookup: It implements mechanisms for looking up features in both local and remote nodes during distributed training processes through RPC requests. |
| 116 | + The class is designed to work seamlessly in distributed training scenarios, allowing for efficient feature handling across partitions. |
| 117 | +* It maintains global identifiers for nodes and edges, allowing for consistent mapping across partitions. |
| 118 | + |
| 119 | +Below is an example of how :class:`~torch_geometric.distributed.LocalFeatureStore` is used internally to retrieve both local+remote features: |
| 120 | + |
| 121 | +.. code-block:: python |
| 122 | +
|
| 123 | + import torch |
| 124 | + from torch_geometric.distributed import LocalFeatureStore |
| 125 | + from torch_geometric.distributed.event_loop import to_asyncio_future |
| 126 | +
|
| 127 | + feature_store = LocalFeatureStore(...) |
| 128 | +
|
| 129 | + async def get_node_features(): |
| 130 | + # Create a `LocalFeatureStore` instance: |
| 131 | +
|
| 132 | + # Retrieve node features for specific node IDs: |
| 133 | + node_id = torch.tensor([1]) |
| 134 | + future = feature_store.lookup_features(node_id) |
| 135 | +
|
| 136 | + return await to_asyncio_future(future) |
| 137 | +
|
| 138 | +Distributed Neighbor Sampling |
| 139 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 140 | + |
| 141 | +:class:`~torch_geometric.distributed.DistNeighborSampler` is a class designed for efficient distributed training of Graph Neural Networks. |
| 142 | +It addresses the challenges of sampling neighbors in a distributed environment, whereby graph data is partitioned across multiple machines or devices. |
| 143 | +The sampler ensures that GNNs can effectively learn from large-scale graphs, maintaining scalability and performance. |
| 144 | + |
| 145 | +**Asynchronous Neighbor Sampling and Feature Collection:** |
| 146 | + |
| 147 | +Distributed neighbor sampling is implemented using asynchronous :class:`torch.distributed.rpc` calls. |
| 148 | +It allows machines to independently sample neighbors without strict synchronization. |
| 149 | +Each machine autonomously selects neighbors from its local graph partition, without waiting for others to complete their sampling processes. |
| 150 | +This approach enhances parallelism, as machines can progress asynchronously, and leads to faster training. |
| 151 | +In addition to asynchronous sampling, distributed neighbor sampling also provides asynchronous feature collection. |
| 152 | + |
| 153 | +**Customizable Sampling Strategies:** |
| 154 | + |
| 155 | +Users can customize neighbor sampling strategies based on their specific requirements. |
| 156 | +The :class:`~torch_geometric.distributed.DistNeighborSampler` class provides full flexibility in defining sampling techniques, such as: |
| 157 | + |
| 158 | +* Node sampling vs. edge sampling |
| 159 | +* Homogeneous vs. heterogeneous sampling |
| 160 | +* Temporal sampling vs. static sampling |
| 161 | + |
| 162 | +**Distributed Neighbor Sampling Workflow:** |
| 163 | + |
| 164 | +A batch of seed nodes follows three main steps before it is made available for the model's :meth:`forward` pass by the data loader: |
| 165 | + |
| 166 | +#. **Distributed node sampling:** While the underlying priciples of neighbor sampling holds for the distributed case as well, the implementation slightly diverges from single-machine sampling. |
| 167 | + In distributed training, seed nodes can belong to different partitions, leading to simultaneous sampling on multiple machines for a single batch. |
| 168 | + Consequently, synchronization of sampling results across machines is necessary to obtain seed nodes for the subsequent layer, requiring modifications to the basic algorithm. |
| 169 | + For nodes within a local partition, the sampling occurs on the local machine. |
| 170 | + Conversely, for nodes associated with a remote partition, the neighbor sampling is conducted on the machine responsible for storing the respective partition. |
| 171 | + Sampling then happens layer-wise, where sampled nodes act as seed nodes in follow-up layers. |
| 172 | +#. **Distributed feature lookup:** Each partition stores an array of features of nodes and edges that are within that partition. |
| 173 | + Consequently, if the output of a sampler on a specific machine includes sampled nodes or edges which do not pertain in its partition, the machine initiates an RPC request to a remote server which these nodes (or edges) belong to. |
| 174 | +#. **Data conversion:** Based on the sampler output and the acquired node (or edge) features, a :pyg:`PyG` :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object is created. |
| 175 | + This object forms a batch used in subsequent computational operations of the model. |
| 176 | + |
| 177 | +Distributed Data Loading |
| 178 | +~~~~~~~~~~~~~~~~~~~~~~~~ |
| 179 | + |
| 180 | +Distributed data loaders such as :class:`~torch_geometric.distributed.DistNeighborLoader` and :class:`~torch_geometric.distributed.DistLinkNeighborLoader` provide a simple API for the sampling engine described above because they entirely wrap initialization and cleanup of sampler processes internally. |
| 181 | +Notably, the distributed data loaders inherit from the standard :pyg:`PyG` single-node :class:`~torch_geometric.loader.NodeLoader` and :class:`~torch_geometric.loader.LinkLoader` loaders, making their application inside training scripts nearly identically. |
| 182 | + |
| 183 | +Batch generation is slightly different from the single-node case in that the step of (local+remote) feature fetching happens within the sampler, rather than encapsulated into two separate steps (sampling->feature fetching). |
| 184 | +This allows limiting the amount of RPCs. |
| 185 | +Due to the asynchronous processing between all sampler sub-processes, the samplers then return their output to a :class:`torch.multiprocessing.Queue`. |
| 186 | + |
| 187 | +Setting up Communication using DDP & RPC |
| 188 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 189 | + |
| 190 | +In this distributed training implementation two :class:`torch.distributed` communication technologies are used: |
| 191 | + |
| 192 | +* :class:`torch.distributed.rpc` for remote sampling calls and distributed feature retrieval |
| 193 | +* :class:`torch.distributed.ddp` for data parallel model training |
| 194 | + |
| 195 | +Our solution opts for :class:`torch.distributed.rpc` over alternatives such as gRPC because :pytorch:`PyTorch` RPC inherently comprehends tensor-type data. |
| 196 | +Unlike other RPC methods, which require the serialization or digitization of JSON or other user data into tensor types, using this method helps avoid additional serialization and digitization overhead. |
| 197 | + |
| 198 | +The DDP group is initialzied in a standard way in the main training script: |
| 199 | + |
| 200 | +.. code-block:: python |
| 201 | +
|
| 202 | + torch.distributed.init_process_group( |
| 203 | + backend='gloo', |
| 204 | + rank=current_ctx.rank, |
| 205 | + world_size=current_ctx.world_size, |
| 206 | + init_method=f'tcp://{master_addr}:{ddp_port}', |
| 207 | + ) |
| 208 | +
|
| 209 | +.. note:: |
| 210 | + For CPU-based sampling we recommended the `gloo <https://github.com/facebookincubator/gloo>`_ communication backend. |
| 211 | + |
| 212 | +RPC group initialization is more complicated because it happens in each sampler subprocess, which is achieved via the :meth:`~torch.utils.data.DataLoader.worker_init_fn` of the data loader, which is called by :pytorch:`PyTorch` directly at the initialization step of worker processes. |
| 213 | +This function first defines a distributed context for each worker and assigns it a group and rank, subsequently initializes its own distributed neighbor sampler, and finally registers a new member in the RPC group. |
| 214 | +This RPC connection remains open as long as the subprocess exists. |
| 215 | +Additionally, we opted for the `atexit <https://docs.python.org/3/library/atexit.html>`_ module to register additional cleanup behaviors that are triggered when the process is terminated. |
| 216 | + |
| 217 | +Results and Performance |
| 218 | +----------------------- |
| 219 | + |
| 220 | +We collected the benchmarking results on :pytorch:`PyTorch` 2.1 using the system configuration at the bottom of this blog. |
| 221 | +The below table shows the scaling performance on the :obj:`ogbn-products` dataset of a :class:`~torch_geometric.nn.models.GraphSAGE` model under different partition configurations (1/2/4/8/16). |
| 222 | + |
| 223 | +.. list-table:: |
| 224 | + :widths: 15 15 15 15 |
| 225 | + :header-rows: 1 |
| 226 | + |
| 227 | + * - #Partitions |
| 228 | + - :obj:`batch_size=1024` |
| 229 | + - :obj:`batch_size=4096` |
| 230 | + - :obj:`batch_size=8192` |
| 231 | + * - 1 |
| 232 | + - 98s |
| 233 | + - 47s |
| 234 | + - 38s |
| 235 | + * - 2 |
| 236 | + - 45s |
| 237 | + - 30s |
| 238 | + - 24s |
| 239 | + * - 4 |
| 240 | + - 38s |
| 241 | + - 21s |
| 242 | + - 16s |
| 243 | + * - 8 |
| 244 | + - 29s |
| 245 | + - 14s |
| 246 | + - 10s |
| 247 | + * - 16 |
| 248 | + - 22s |
| 249 | + - 13s |
| 250 | + - 9s |
| 251 | + |
| 252 | +* **Hardware:** 2x Intel(R) Xeon(R) Platinum 8360Y CPU @ 2.40GHz, 36 cores, HT On, Turbo On, NUMA 2, Integrated Accelerators Available [used]: DLB 0 [0], DSA 0 [0], IAA 0 [0], QAT 0 [0], Total Memory 256GB (16x16GB DDR4 3200 MT/s [3200 MT/s]), BIOS SE5C620.86B.01.01.0003.2104260124, microcode 0xd000389, 2x Ethernet Controller X710 for 10GbE SFP+, 1x MT28908 Family [ConnectX-6], 1x 894.3G INTEL SSDSC2KG96, Rocky Linux 8.8 (Green Obsidian), 4.18.0-477.21.1.el8_8.x86_64 |
| 253 | +* **Software:** :python:`Python` 3.9, :pytorch:`PyTorch` 2.1, :pyg:`PyG` 2.5, :pyg:`null` :obj:`pyg-lib` 0.4.0 |
0 commit comments