2020-08-28 14:46:28 +08:00
|
|
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except jin compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2024-07-27 16:46:07 +08:00
|
|
|
from __future__ import annotations
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2023-03-02 18:57:07 +08:00
|
|
|
import itertools
|
2020-08-28 14:46:28 +08:00
|
|
|
import os
|
2023-04-25 15:34:04 +08:00
|
|
|
import sys
|
2022-11-29 18:50:04 +08:00
|
|
|
import time
|
2020-08-28 14:46:28 +08:00
|
|
|
import warnings
|
2023-04-25 15:34:04 +08:00
|
|
|
from collections import OrderedDict, namedtuple
|
2023-03-02 18:57:07 +08:00
|
|
|
from contextlib import contextmanager
|
2023-10-09 10:20:29 +08:00
|
|
|
from multiprocessing import Manager, Process
|
2024-07-27 16:46:07 +08:00
|
|
|
from typing import (
|
|
|
|
|
TYPE_CHECKING,
|
|
|
|
|
Any,
|
|
|
|
|
)
|
2022-11-29 18:50:04 +08:00
|
|
|
|
2023-03-02 18:57:07 +08:00
|
|
|
import numpy as np
|
|
|
|
|
|
2022-06-02 17:41:29 +08:00
|
|
|
import paddle
|
2023-03-02 18:57:07 +08:00
|
|
|
from paddle import _legacy_C_ops, framework
|
2025-12-20 10:21:50 +08:00
|
|
|
from paddle.base.core import get_all_custom_device_type
|
2022-11-29 18:50:04 +08:00
|
|
|
from paddle.distributed.collective import (
|
|
|
|
|
Group,
|
|
|
|
|
_default_group_name,
|
|
|
|
|
_get_group_map_by_name,
|
|
|
|
|
_new_process_group_impl,
|
|
|
|
|
_set_default_backend,
|
|
|
|
|
_set_default_store,
|
|
|
|
|
_set_group_map,
|
|
|
|
|
_set_group_map_backend,
|
|
|
|
|
_set_group_map_by_name,
|
|
|
|
|
_valid_backend_list,
|
|
|
|
|
)
|
2023-06-19 16:26:20 +08:00
|
|
|
from paddle.distributed.communication.group import (
|
|
|
|
|
_add_new_group,
|
|
|
|
|
_get_global_group,
|
|
|
|
|
is_initialized,
|
|
|
|
|
)
|
2023-10-09 10:20:29 +08:00
|
|
|
from paddle.distributed.fleet.base.private_helper_function import (
|
2022-11-29 18:50:04 +08:00
|
|
|
wait_server_ready,
|
|
|
|
|
)
|
|
|
|
|
from paddle.distributed.fleet.launch_utils import check_backend
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2022-12-08 11:21:21 +08:00
|
|
|
# (TODO: GhostScreaming) It will be removed later.
|
2023-11-15 10:35:51 +08:00
|
|
|
from paddle.framework import (
|
|
|
|
|
_set_expected_place,
|
|
|
|
|
base as imperative_base,
|
|
|
|
|
core,
|
|
|
|
|
in_dynamic_mode,
|
|
|
|
|
)
|
2024-07-27 16:46:07 +08:00
|
|
|
from paddle.nn.layer import Layer
|
2023-03-02 18:57:07 +08:00
|
|
|
from paddle.utils import deprecated
|
2022-12-08 11:21:21 +08:00
|
|
|
|
2023-03-02 18:57:07 +08:00
|
|
|
from . import parallel_helper
|
2023-07-22 18:19:33 +08:00
|
|
|
from .backup_env import getenv_or_backup
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2024-07-27 16:46:07 +08:00
|
|
|
if TYPE_CHECKING:
|
2024-08-08 11:14:59 +08:00
|
|
|
from collections.abc import Generator
|
|
|
|
|
|
2024-07-27 16:46:07 +08:00
|
|
|
from paddle import Tensor
|
2025-07-02 14:17:24 +08:00
|
|
|
from paddle.base.libpaddle import NCCLConfig
|
2024-07-27 16:46:07 +08:00
|
|
|
from paddle.nn.layer.layers import _StateDict
|
2021-06-11 14:44:29 +08:00
|
|
|
__all__ = []
|
2020-08-28 14:46:28 +08:00
|
|
|
|
|
|
|
|
ParallelStrategy = core.ParallelStrategy
|
|
|
|
|
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
def _build_default_parallel_strategy():
|
|
|
|
|
strategy = ParallelStrategy()
|
|
|
|
|
strategy.nranks = paddle.distributed.ParallelEnv().nranks
|
|
|
|
|
strategy.local_rank = paddle.distributed.ParallelEnv().local_rank
|
|
|
|
|
strategy.trainer_endpoints = (
|
|
|
|
|
paddle.distributed.ParallelEnv().trainer_endpoints
|
|
|
|
|
)
|
|
|
|
|
strategy.current_endpoint = (
|
|
|
|
|
paddle.distributed.ParallelEnv().current_endpoint
|
|
|
|
|
)
|
|
|
|
|
return strategy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _coalesce_tensors(var_groups):
|
|
|
|
|
coalesced_grads_and_grad_vars = []
|
|
|
|
|
for group_id, grad_vars in var_groups.items():
|
|
|
|
|
flattened_vars = []
|
|
|
|
|
g_var_shapes = []
|
|
|
|
|
for g_var in grad_vars:
|
|
|
|
|
g_var_shapes.append(g_var.shape)
|
|
|
|
|
flattened_vars.append(
|
2025-08-29 15:55:29 +08:00
|
|
|
paddle.reshape(
|
|
|
|
|
x=g_var, shape=[np.prod(g_var.shape, dtype="int64")]
|
|
|
|
|
)
|
2023-03-02 18:57:07 +08:00
|
|
|
)
|
|
|
|
|
coalesced_grad = paddle.concat(flattened_vars)
|
|
|
|
|
coalesced_grads_and_grad_vars.append(
|
|
|
|
|
[coalesced_grad, grad_vars, g_var_shapes]
|
|
|
|
|
)
|
|
|
|
|
return coalesced_grads_and_grad_vars
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@framework.dygraph_only
|
|
|
|
|
def _reshape_inplace(x, shape):
|
2023-04-18 12:15:18 +08:00
|
|
|
x_shape = framework._create_tensor(dtype=x.dtype)
|
2023-03-02 18:57:07 +08:00
|
|
|
framework._dygraph_tracer().trace_op(
|
|
|
|
|
type="reshape2",
|
|
|
|
|
inputs={'X': x},
|
|
|
|
|
outputs={'Out': x, 'XShape': x_shape},
|
|
|
|
|
attrs={'shape': shape},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@framework.dygraph_only
|
|
|
|
|
def _split_tensors(coalesced_grads_and_grad_vars):
|
2023-05-22 20:56:38 +08:00
|
|
|
if in_dynamic_mode():
|
2023-03-02 18:57:07 +08:00
|
|
|
for (
|
|
|
|
|
coalesced_grad,
|
|
|
|
|
origin_grad_vars,
|
|
|
|
|
grad_shapes,
|
|
|
|
|
) in coalesced_grads_and_grad_vars:
|
2025-08-29 15:55:29 +08:00
|
|
|
grad_var_len = [
|
|
|
|
|
np.prod(g_shape, dtype="int64") for g_shape in grad_shapes
|
|
|
|
|
]
|
2023-03-02 18:57:07 +08:00
|
|
|
attrs = ()
|
|
|
|
|
attrs += ('sections', grad_var_len)
|
|
|
|
|
attrs += ('axis', 0)
|
|
|
|
|
_legacy_C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
|
|
|
|
|
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
|
|
|
|
|
g_var.reshape_(shape=g_shape)
|
|
|
|
|
assert g_var.shape == g_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@imperative_base.no_grad
|
|
|
|
|
@framework.dygraph_only
|
2024-07-27 16:46:07 +08:00
|
|
|
def build_groups(
|
|
|
|
|
vars: list[Tensor], group_size: int
|
|
|
|
|
) -> list[list[Tensor | list[Tensor] | list[int]]]:
|
2023-03-02 18:57:07 +08:00
|
|
|
group_idx = 0
|
|
|
|
|
memory_counter = 0
|
|
|
|
|
var_groups = OrderedDict()
|
|
|
|
|
dtype = vars[0].dtype
|
|
|
|
|
|
|
|
|
|
for var in vars:
|
2024-06-27 11:29:30 +08:00
|
|
|
var_dtype = var.dtype
|
|
|
|
|
if isinstance(var_dtype, core.DataType):
|
2024-07-16 16:29:09 +08:00
|
|
|
var_dtype = paddle.pir.core.datatype_to_vartype[var_dtype]
|
2025-08-29 15:55:29 +08:00
|
|
|
bytes = np.prod(var.shape, dtype="int64") * core.size_of_dtype(
|
|
|
|
|
var_dtype
|
|
|
|
|
)
|
2023-03-02 18:57:07 +08:00
|
|
|
if memory_counter < group_size and dtype == var.dtype:
|
|
|
|
|
memory_counter += bytes
|
|
|
|
|
else:
|
|
|
|
|
memory_counter = bytes
|
|
|
|
|
dtype = var.dtype
|
|
|
|
|
group_idx += 1
|
|
|
|
|
var_groups.setdefault(group_idx, []).append(var)
|
|
|
|
|
return _coalesce_tensors(var_groups)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@imperative_base.no_grad
|
|
|
|
|
@framework.dygraph_only
|
|
|
|
|
def sync_params_buffers(
|
2024-07-27 16:46:07 +08:00
|
|
|
model: Layer,
|
|
|
|
|
comm_group: Group | None = None,
|
|
|
|
|
src_rank: int = 0,
|
|
|
|
|
is_model_parallel: bool = False,
|
|
|
|
|
fuse_params: bool = True,
|
2025-06-06 16:32:28 +08:00
|
|
|
is_moe_sharding_parallel: bool = False,
|
2024-07-27 16:46:07 +08:00
|
|
|
) -> None:
|
2023-03-02 18:57:07 +08:00
|
|
|
model_vars = []
|
|
|
|
|
for _, param in model._obtain_parameters_buffers().items():
|
2023-03-30 10:11:14 +08:00
|
|
|
if not isinstance(param, core.eager.Tensor):
|
2023-03-02 18:57:07 +08:00
|
|
|
raise TypeError(
|
2024-06-30 19:16:03 +08:00
|
|
|
f"The data type of '{param.name}' must be core.eager.Tensor"
|
2023-03-02 18:57:07 +08:00
|
|
|
)
|
|
|
|
|
|
2023-04-28 15:47:38 +08:00
|
|
|
if is_model_parallel:
|
|
|
|
|
if hasattr(param, "is_distributed") and param.is_distributed:
|
2023-03-02 18:57:07 +08:00
|
|
|
continue
|
2023-04-28 15:47:38 +08:00
|
|
|
|
2025-06-06 16:32:28 +08:00
|
|
|
if not is_moe_sharding_parallel:
|
|
|
|
|
# NOTE(shenliang03): Support situations that do not require synchronization parameters,
|
|
|
|
|
# such as moe's expert parameters
|
|
|
|
|
if getattr(param, "no_sync", False):
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
# NOTE(zhangyuqin1998): In moe sharding parallel, we do need to broadcast expert parameters
|
|
|
|
|
# in moe sharding group.
|
|
|
|
|
if getattr(param, "no_sync", False) and not getattr(
|
|
|
|
|
param, "expert", False
|
|
|
|
|
):
|
|
|
|
|
continue
|
2023-04-28 15:47:38 +08:00
|
|
|
|
2023-03-02 18:57:07 +08:00
|
|
|
if param.type == core.VarDesc.VarType.VOCAB:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
model_vars.append(param.detach())
|
|
|
|
|
if len(model_vars) == 0:
|
|
|
|
|
return
|
|
|
|
|
|
2023-08-19 06:37:43 +08:00
|
|
|
if fuse_params:
|
|
|
|
|
# group size is 128M
|
|
|
|
|
coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-08-19 06:37:43 +08:00
|
|
|
for coalesced_var, _, _ in coalesced_vars:
|
|
|
|
|
paddle.distributed.broadcast(
|
|
|
|
|
coalesced_var, src=src_rank, group=comm_group, sync_op=True
|
|
|
|
|
)
|
|
|
|
|
for coalesced_var, origin_vars, var_shapes in coalesced_vars:
|
2025-08-29 15:55:29 +08:00
|
|
|
var_len = [
|
|
|
|
|
np.prod(v_shape, dtype="int64") for v_shape in var_shapes
|
|
|
|
|
]
|
2023-09-07 17:26:19 +08:00
|
|
|
paddle.base.framework._dygraph_tracer().trace_op(
|
2023-08-19 06:37:43 +08:00
|
|
|
type='split',
|
|
|
|
|
inputs={'X': coalesced_var},
|
|
|
|
|
outputs={'Out': origin_vars},
|
|
|
|
|
attrs={'sections': var_len, 'axis': 0},
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
for var in model_vars:
|
2024-07-08 19:21:22 +08:00
|
|
|
# NOTE(shenliang03): Now, we dont support contiguous tensor in dp
|
|
|
|
|
var = var.contiguous()
|
2023-08-19 06:37:43 +08:00
|
|
|
paddle.distributed.broadcast(
|
|
|
|
|
var, src=src_rank, group=comm_group, sync_op=True
|
|
|
|
|
)
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
|
2024-07-27 16:46:07 +08:00
|
|
|
class DataParallel(Layer):
|
2023-03-02 18:57:07 +08:00
|
|
|
"""
|
|
|
|
|
Run the dygraph module with data parallelism.
|
|
|
|
|
|
|
|
|
|
Currently, DataParallel class only supports to run the dynamic graph
|
|
|
|
|
with multi-process.
|
|
|
|
|
|
|
|
|
|
Now supports two ways to start training:
|
|
|
|
|
|
|
|
|
|
1. start by ``paddle.distributed.spawn`` method, for example:
|
|
|
|
|
|
|
|
|
|
``python demo.py`` (spawn need to be called in ``__main__`` method)
|
|
|
|
|
|
|
|
|
|
2. start by ``paddle.distributed.launch`` module, for example:
|
|
|
|
|
|
|
|
|
|
``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
|
|
|
|
|
|
|
|
|
|
And the content of `demo.py` is the code of examples.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
layers(Layer): The module that should be executed by data parallel.
|
|
|
|
|
strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism,
|
|
|
|
|
contains environment configuration related to parallel execution. Default: None.
|
|
|
|
|
comm_buffer_size(int, optional): It limits the memory size(MB) of one buffer
|
|
|
|
|
parameters' gradient which is the input of communication
|
|
|
|
|
calling(e.g NCCLAllReduce). Default: 25.
|
|
|
|
|
last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
|
|
|
|
|
calling. Making the last communication buffer size small is useful to
|
|
|
|
|
improve performance. Default: 1.
|
|
|
|
|
find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
|
|
|
|
|
all tensors in the return value of the wrapped model's
|
|
|
|
|
forward function. For parameters not involved in loss
|
|
|
|
|
calculation, their gradients will be marked as ready in
|
|
|
|
|
advance to prepare reduce. Please note that all forward
|
|
|
|
|
outputs derived from the wrapped model parameters must
|
|
|
|
|
participate in the calculation of loss and subsequent
|
|
|
|
|
gradient calculations. If not, serious error will occur.
|
|
|
|
|
Note that setting the find_unused_parameters to True
|
|
|
|
|
will affect computing performance. Therefore, if all parameters
|
|
|
|
|
are sure to participate in the loss calculation and the
|
|
|
|
|
autograd graph construction, please set it False. Default: False.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Layer: The data paralleled module.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-03-02 18:57:07 +08:00
|
|
|
:name: dp-example
|
|
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.nn as nn
|
|
|
|
|
>>> import paddle.optimizer as opt
|
|
|
|
|
>>> import paddle.distributed as dist
|
|
|
|
|
|
|
|
|
|
>>> class LinearNet(nn.Layer):
|
|
|
|
|
... def __init__(self):
|
|
|
|
|
... super().__init__()
|
|
|
|
|
... self._linear1 = nn.Linear(10, 10)
|
|
|
|
|
... self._linear2 = nn.Linear(10, 1)
|
2026-02-18 16:12:40 +08:00
|
|
|
...
|
2023-10-12 15:39:32 +08:00
|
|
|
... def forward(self, x):
|
|
|
|
|
... return self._linear2(self._linear1(x))
|
|
|
|
|
|
|
|
|
|
>>> def train():
|
|
|
|
|
... # 1. initialize parallel environment
|
|
|
|
|
... dist.init_parallel_env()
|
|
|
|
|
... # 2. create data parallel layer & optimizer
|
|
|
|
|
... layer = LinearNet()
|
|
|
|
|
... dp_layer = paddle.DataParallel(layer)
|
|
|
|
|
... loss_fn = nn.MSELoss()
|
2026-02-18 16:12:40 +08:00
|
|
|
... adam = opt.Adam(learning_rate=0.001, parameters=dp_layer.parameters())
|
2023-10-12 15:39:32 +08:00
|
|
|
... # 3. run layer
|
|
|
|
|
... inputs = paddle.randn([10, 10], 'float32')
|
|
|
|
|
... outputs = dp_layer(inputs)
|
|
|
|
|
... labels = paddle.randn([10, 1], 'float32')
|
|
|
|
|
... loss = loss_fn(outputs, labels)
|
|
|
|
|
... loss.backward()
|
|
|
|
|
... adam.step()
|
|
|
|
|
... adam.clear_grad()
|
|
|
|
|
|
|
|
|
|
>>> if __name__ == '__main__':
|
|
|
|
|
... # 1. start by ``paddle.distributed.spawn`` (default)
|
|
|
|
|
... dist.spawn(train, nprocs=2)
|
|
|
|
|
... # 2. start by ``paddle.distributed.launch``
|
|
|
|
|
... # train()
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
``PyLayer`` is not supported in DataParallel. To solve problems of this kind,
|
|
|
|
|
it's recommended to skip gradient synchronization among multiple cards by 'no_sync',
|
|
|
|
|
and manually implement 'all_reduce' before model optimization. There is an example
|
2024-02-19 16:10:53 +08:00
|
|
|
showing specific implementation processing.
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-03-02 18:57:07 +08:00
|
|
|
:name: dp-pylayer-example
|
|
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> import numpy
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.distributed as dist
|
|
|
|
|
>>> from paddle.autograd import PyLayer
|
|
|
|
|
>>> from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
|
|
|
|
|
|
|
|
|
|
>>> class cus_tanh(PyLayer):
|
|
|
|
|
... @staticmethod
|
|
|
|
|
... def forward(ctx, x):
|
|
|
|
|
... y = paddle.tanh(x)
|
|
|
|
|
... ctx.save_for_backward(y)
|
|
|
|
|
... return y
|
2026-02-18 16:12:40 +08:00
|
|
|
...
|
2023-10-12 15:39:32 +08:00
|
|
|
... @staticmethod
|
|
|
|
|
... def backward(ctx, dy):
|
2026-02-18 16:12:40 +08:00
|
|
|
... (y,) = ctx.saved_tensor()
|
2023-10-12 15:39:32 +08:00
|
|
|
... grad = dy * (1 - paddle.square(y))
|
|
|
|
|
... return grad
|
|
|
|
|
|
|
|
|
|
>>> class SimpleNet(paddle.nn.Layer):
|
|
|
|
|
... def __init__(self):
|
|
|
|
|
... super().__init__()
|
|
|
|
|
... self.linear = paddle.nn.Linear(2, 2)
|
2026-02-18 16:12:40 +08:00
|
|
|
...
|
2023-10-12 15:39:32 +08:00
|
|
|
... def forward(self, inputs):
|
|
|
|
|
... inputs = cus_tanh.apply(inputs)
|
|
|
|
|
... return self.linear(inputs)
|
|
|
|
|
|
|
|
|
|
>>> if __name__ == '__main__':
|
|
|
|
|
... dist.init_parallel_env()
|
|
|
|
|
... model = SimpleNet()
|
|
|
|
|
... model = paddle.DataParallel(model)
|
|
|
|
|
... opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
|
|
|
|
|
... for step in range(10):
|
2025-05-08 15:55:03 +08:00
|
|
|
... x_data = numpy.random.randn(2, 2).astype(numpy.float32)
|
2023-10-12 15:39:32 +08:00
|
|
|
... x = paddle.to_tensor(x_data)
|
|
|
|
|
... x.stop_gradient = False
|
|
|
|
|
... # step 1 : skip gradient synchronization by 'no_sync'
|
|
|
|
|
... with model.no_sync():
|
|
|
|
|
... y_pred = model(x)
|
|
|
|
|
... loss = y_pred.mean()
|
|
|
|
|
... loss.backward()
|
|
|
|
|
... # step 2 : fuse + allreduce manually before optimization
|
|
|
|
|
... fused_allreduce_gradients(list(model.parameters()), None)
|
|
|
|
|
... opt.step()
|
|
|
|
|
... opt.clear_grad()
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
2024-07-27 16:46:07 +08:00
|
|
|
find_unused_parameters: bool
|
|
|
|
|
grad_need_sync: bool
|
|
|
|
|
group: Group | None
|
|
|
|
|
var_dtype: Tensor
|
|
|
|
|
comm_buffer_size: int
|
|
|
|
|
last_comm_buffer_size: int
|
|
|
|
|
|
2023-03-02 18:57:07 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
2024-07-27 16:46:07 +08:00
|
|
|
layers: Layer,
|
|
|
|
|
strategy: ParallelStrategy | None = None,
|
|
|
|
|
comm_buffer_size: int = 25,
|
|
|
|
|
last_comm_buffer_size: float = 1,
|
|
|
|
|
find_unused_parameters: bool = False,
|
|
|
|
|
group: Group | None = None,
|
|
|
|
|
) -> None:
|
2023-03-02 18:57:07 +08:00
|
|
|
super().__init__(layers.full_name() + "_data_parallel")
|
|
|
|
|
|
2025-08-21 02:00:58 +08:00
|
|
|
assert in_dynamic_mode(), (
|
|
|
|
|
"It's not supported to construct DataParallel in static graph mode."
|
|
|
|
|
)
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
self._layers = layers
|
|
|
|
|
self.find_unused_parameters = find_unused_parameters
|
|
|
|
|
self.grad_need_sync = True
|
|
|
|
|
self.group = group
|
2023-03-30 10:11:14 +08:00
|
|
|
self.var_dtype = core.eager.Tensor
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
# NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
|
|
|
|
|
# It just stores some environment variables, which can be constructed by
|
|
|
|
|
# ParallelEnv. Here it is set as an optional argument.
|
|
|
|
|
# This parameter is not removed because of compatibility with 1.x writing.
|
|
|
|
|
if strategy is not None:
|
|
|
|
|
self._strategy = strategy
|
|
|
|
|
else:
|
|
|
|
|
self._strategy = _build_default_parallel_strategy()
|
|
|
|
|
|
|
|
|
|
if self._strategy.nranks > 1:
|
|
|
|
|
# check the environment
|
|
|
|
|
assert parallel_helper.__parallel_ctx__clz__ is not None, (
|
|
|
|
|
"ParallelContext must be initialized before. You should use init_parallel_env() before"
|
|
|
|
|
"constructing the DataParallel."
|
|
|
|
|
)
|
|
|
|
|
|
2023-05-22 20:56:38 +08:00
|
|
|
if in_dynamic_mode():
|
2023-03-02 18:57:07 +08:00
|
|
|
self.group = (
|
|
|
|
|
paddle.distributed.collective._get_default_group()
|
|
|
|
|
if self.group is None
|
|
|
|
|
else self.group
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert isinstance(
|
|
|
|
|
self.group, paddle.distributed.collective.Group
|
|
|
|
|
), "ProcessGroup must be an instance of Group in DataParallel."
|
|
|
|
|
|
2024-07-08 19:21:22 +08:00
|
|
|
[
|
|
|
|
|
warnings.warn(
|
|
|
|
|
f"param [{name}] is not contiguous, please check it and make it contiguous."
|
|
|
|
|
)
|
|
|
|
|
for name, param in self._layers.named_parameters()
|
|
|
|
|
if not param.is_contiguous()
|
|
|
|
|
]
|
2023-03-02 18:57:07 +08:00
|
|
|
# sync buffer and params
|
2023-08-19 06:37:43 +08:00
|
|
|
sync_params_buffers(self._layers, fuse_params=False)
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
|
|
|
|
|
# NOTE(shenliang03): We can set environment variables to control
|
|
|
|
|
# the size of the group, Default: 1MB. The role of this small group is:
|
|
|
|
|
# when the last group allreduce, the overlap cannot work. Making the
|
|
|
|
|
# the last group small is useful to improve performance.
|
|
|
|
|
self.last_comm_buffer_size = int(
|
|
|
|
|
last_comm_buffer_size * 1024 * 1024
|
|
|
|
|
)
|
|
|
|
|
self.init_reducer()
|
|
|
|
|
else:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The program will return to single-card operation. "
|
|
|
|
|
"Please check 1, whether you use spawn or fleetrun "
|
|
|
|
|
"to start the program. 2, Whether it is a multi-card "
|
|
|
|
|
"program. 3, Is the current environment multi-card."
|
|
|
|
|
)
|
|
|
|
|
|
2024-07-27 16:46:07 +08:00
|
|
|
def init_reducer(self) -> None:
|
2023-03-02 18:57:07 +08:00
|
|
|
layers_param = []
|
|
|
|
|
params_set = set()
|
|
|
|
|
for sublayer in self.sublayers():
|
|
|
|
|
for _, param in sublayer.named_parameters(include_sublayers=False):
|
|
|
|
|
if param is None or param in params_set:
|
|
|
|
|
continue
|
|
|
|
|
params_set.add(param)
|
|
|
|
|
if not isinstance(param, self.var_dtype):
|
|
|
|
|
raise TypeError(
|
2023-06-09 10:14:23 +08:00
|
|
|
f"The data type of '{param.name}' must be '{self.var_dtype}'"
|
2023-03-02 18:57:07 +08:00
|
|
|
)
|
|
|
|
|
if param.trainable:
|
|
|
|
|
layers_param.append((sublayer, param))
|
|
|
|
|
|
|
|
|
|
trainable_parameters = list(
|
|
|
|
|
filter(
|
|
|
|
|
lambda x: not getattr(x, "no_sync", False),
|
|
|
|
|
[param for _, param in layers_param],
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert len(trainable_parameters) > 0, (
|
|
|
|
|
"This model does not have any parameters to train, and "
|
|
|
|
|
"does not need to use DataParallel"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# NOTE(shenliang03): Here we can only use the attributes to judge whether
|
|
|
|
|
# parameter is sparse(or SelectedRows). The reason is that the sparse message
|
|
|
|
|
# can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
|
|
|
|
|
# we should add the layer here like "paddle.nn.layer.common.Embedding".
|
|
|
|
|
def check_layer_sparse(sublayer):
|
|
|
|
|
if isinstance(sublayer, paddle.nn.layer.common.Embedding):
|
|
|
|
|
return sublayer._sparse
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
is_sparse_gradient = [
|
2024-05-31 14:40:50 +08:00
|
|
|
check_layer_sparse(sublayer)
|
|
|
|
|
for sublayer, param in layers_param
|
|
|
|
|
if not getattr(param, "no_sync", False)
|
2023-03-02 18:57:07 +08:00
|
|
|
]
|
|
|
|
|
|
2023-05-22 20:56:38 +08:00
|
|
|
if in_dynamic_mode():
|
2023-03-02 18:57:07 +08:00
|
|
|
self.group_indices = core.eager_assign_group_by_size(
|
|
|
|
|
trainable_parameters,
|
|
|
|
|
is_sparse_gradient,
|
|
|
|
|
[self.last_comm_buffer_size, self.comm_buffer_size],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._reducer = core.EagerReducer(
|
|
|
|
|
trainable_parameters,
|
|
|
|
|
list(reversed(self.group_indices)),
|
|
|
|
|
is_sparse_gradient,
|
|
|
|
|
self.group.process_group,
|
|
|
|
|
[self.last_comm_buffer_size, self.comm_buffer_size],
|
|
|
|
|
self.find_unused_parameters,
|
|
|
|
|
)
|
|
|
|
|
|
2023-04-28 15:47:38 +08:00
|
|
|
def _find_tensor(self, obj):
|
2023-03-30 10:11:14 +08:00
|
|
|
var_type = core.eager.Tensor
|
2023-03-02 18:57:07 +08:00
|
|
|
if isinstance(obj, var_type):
|
|
|
|
|
return [obj]
|
|
|
|
|
if isinstance(obj, (list, tuple)):
|
2023-04-28 15:47:38 +08:00
|
|
|
return itertools.chain(*map(self._find_tensor, obj))
|
2023-03-02 18:57:07 +08:00
|
|
|
if isinstance(obj, dict):
|
2023-04-28 15:47:38 +08:00
|
|
|
return itertools.chain(*map(self._find_tensor, obj.values()))
|
2023-03-02 18:57:07 +08:00
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
2024-07-27 16:46:07 +08:00
|
|
|
def no_sync(self) -> Generator[None, None, None]:
|
2023-03-02 18:57:07 +08:00
|
|
|
"""
|
|
|
|
|
A context manager to stop gradient synchronization. Within no_sync(),
|
|
|
|
|
gradients of parameters will only be accumulated on model and not
|
|
|
|
|
synchronized util the first forward-backward out of this context.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.nn as nn
|
|
|
|
|
>>> import paddle.distributed as dist
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> class SimpleNet(nn.Layer):
|
|
|
|
|
... def __init__(self):
|
|
|
|
|
... super().__init__()
|
|
|
|
|
... self._linear = nn.Linear(10, 1)
|
2026-02-18 16:12:40 +08:00
|
|
|
...
|
2023-10-12 15:39:32 +08:00
|
|
|
... def forward(self, x):
|
|
|
|
|
... return self._linear(x)
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> dist.init_parallel_env()
|
|
|
|
|
>>> model = SimpleNet()
|
|
|
|
|
>>> dp_model = paddle.DataParallel(model)
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> inputs_1 = paddle.randn([10, 10], 'float32')
|
|
|
|
|
>>> inputs_2 = paddle.ones([10, 10], 'float32')
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> with dp_model.no_sync():
|
|
|
|
|
... # gradients will not be synchronized
|
|
|
|
|
... dp_model(inputs_1).backward()
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # synchronization happens here
|
|
|
|
|
>>> dp_model(inputs_2).backward()
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
tmp_grad_need_sync = self.grad_need_sync
|
|
|
|
|
self.grad_need_sync = False
|
|
|
|
|
try:
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
|
|
|
|
self.grad_need_sync = tmp_grad_need_sync
|
|
|
|
|
|
2024-07-27 16:46:07 +08:00
|
|
|
def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
|
2023-03-02 18:57:07 +08:00
|
|
|
outputs = self._layers(*inputs, **kwargs)
|
|
|
|
|
if (
|
|
|
|
|
self._strategy.nranks > 1
|
|
|
|
|
and framework._dygraph_tracer()._has_grad
|
|
|
|
|
and self.grad_need_sync
|
|
|
|
|
):
|
2023-04-28 15:47:38 +08:00
|
|
|
self._reducer.prepare_for_backward(list(self._find_tensor(outputs)))
|
2023-03-02 18:57:07 +08:00
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
@deprecated(
|
|
|
|
|
since="2.0.0", reason="This method does not need to be called anymore."
|
|
|
|
|
)
|
|
|
|
|
def scale_loss(self, loss):
|
|
|
|
|
"""
|
|
|
|
|
Deprecated method, now ``scale_loss`` is an empty method,
|
|
|
|
|
keep this method just for compatibility.
|
|
|
|
|
"""
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
@deprecated(
|
|
|
|
|
since="2.0.0", reason="This method does not need to be called anymore."
|
|
|
|
|
)
|
|
|
|
|
def apply_collective_grads(self):
|
|
|
|
|
"""
|
|
|
|
|
Deprecated method, now ``apply_collective_grads`` is an empty method,
|
|
|
|
|
keep this method just for compatibility.
|
|
|
|
|
"""
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def state_dict(
|
|
|
|
|
self,
|
2024-07-27 16:46:07 +08:00
|
|
|
destination: _StateDict | None = None,
|
|
|
|
|
include_sublayers: bool = True,
|
|
|
|
|
structured_name_prefix: str = "",
|
|
|
|
|
) -> _StateDict:
|
2023-03-02 18:57:07 +08:00
|
|
|
'''
|
|
|
|
|
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
|
|
|
|
|
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
|
|
|
|
|
|
2024-02-19 16:10:53 +08:00
|
|
|
Returns:
|
2023-03-02 18:57:07 +08:00
|
|
|
dict: a dict contains all the parameters and persistable buffers.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.distributed as dist
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> dist.init_parallel_env()
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> emb = paddle.nn.Embedding(10, 10)
|
|
|
|
|
>>> emb = paddle.DataParallel(emb)
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> state_dict = emb.state_dict()
|
|
|
|
|
>>> paddle.save(state_dict, "paddle_dy.pdparams")
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
return self._layers.state_dict(
|
|
|
|
|
destination=destination,
|
|
|
|
|
include_sublayers=include_sublayers,
|
|
|
|
|
structured_name_prefix=structured_name_prefix,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@framework.deprecate_stat_dict
|
2024-07-27 16:46:07 +08:00
|
|
|
def set_state_dict(
|
|
|
|
|
self, state_dict: _StateDict, use_structured_name: bool = True
|
|
|
|
|
) -> None:
|
2023-03-02 18:57:07 +08:00
|
|
|
'''
|
|
|
|
|
Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
state_dict(dict) : Dict contains all the parameters and persistable buffers.
|
|
|
|
|
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
|
|
|
|
|
Default: True
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.distributed as dist
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> dist.init_parallel_env()
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> emb = paddle.nn.Embedding(10, 10)
|
|
|
|
|
>>> emb = paddle.DataParallel(emb)
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> state_dict = emb.state_dict()
|
|
|
|
|
>>> paddle.save(state_dict, "paddle_dy.pdparams")
|
2023-03-02 18:57:07 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> para_state_dict = paddle.load("paddle_dy.pdparams")
|
|
|
|
|
>>> emb.set_state_dict(para_state_dict)
|
2023-03-02 18:57:07 +08:00
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
self._layers.set_state_dict(
|
|
|
|
|
state_dict, use_structured_name=use_structured_name
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# [aliases] Compatible with old method names
|
|
|
|
|
set_dict = set_state_dict
|
|
|
|
|
load_dict = set_state_dict
|
|
|
|
|
|
|
|
|
|
|
2022-06-05 10:58:58 +08:00
|
|
|
# NOTE(chenweihang): Maintain a global parallel env to avoid
|
2020-11-24 21:21:38 +08:00
|
|
|
# initializing ParallelEnv every time and improve performance
|
|
|
|
|
_global_parallel_env = None
|
|
|
|
|
|
|
|
|
|
|
2023-02-09 16:07:05 +08:00
|
|
|
class ParallelEnv:
|
|
|
|
|
"""
|
|
|
|
|
.. note::
|
|
|
|
|
This API is not recommended, if you need to get rank and world_size,
|
|
|
|
|
it is recommended to use ``paddle.distributed.get_rank()`` and
|
|
|
|
|
``paddle.distributed.get_world_size()`` .
|
|
|
|
|
|
|
|
|
|
This class is used to obtain the environment variables required for
|
|
|
|
|
the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
|
|
|
|
|
|
|
|
|
|
The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
|
|
|
|
|
or ``paddle.distributed.spawn`` .
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-10-12 15:39:32 +08:00
|
|
|
|
|
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.distributed as dist
|
|
|
|
|
|
|
|
|
|
>>> def train():
|
|
|
|
|
... # 1. initialize parallel environment
|
|
|
|
|
... dist.init_parallel_env()
|
|
|
|
|
... # 2. get current ParallelEnv
|
|
|
|
|
... parallel_env = dist.ParallelEnv()
|
|
|
|
|
... print("rank: ", parallel_env.rank)
|
|
|
|
|
... print("world_size: ", parallel_env.world_size)
|
|
|
|
|
|
|
|
|
|
>>> if __name__ == '__main__':
|
|
|
|
|
... # 1. start by ``paddle.distributed.spawn`` (default)
|
|
|
|
|
... dist.spawn(train, nprocs=2)
|
|
|
|
|
... # 2. start by ``paddle.distributed.launch``
|
|
|
|
|
... train()
|
|
|
|
|
|
|
|
|
|
# Print result in process 1:
|
|
|
|
|
rank: 1
|
|
|
|
|
world_size: 2
|
|
|
|
|
|
|
|
|
|
# Print result in process 2:
|
|
|
|
|
rank: 2
|
|
|
|
|
world_size: 2
|
|
|
|
|
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
|
|
|
|
|
self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
|
2025-12-20 10:21:50 +08:00
|
|
|
custom_device_types = get_all_custom_device_type()
|
|
|
|
|
self._device_type = (
|
|
|
|
|
str(custom_device_types[0]) if custom_device_types else ""
|
|
|
|
|
)
|
2023-10-23 21:07:51 +08:00
|
|
|
self._pg_timeout = int(os.getenv("PADDLE_PG_TIMEOUT", "1800000"))
|
2023-02-09 16:07:05 +08:00
|
|
|
|
|
|
|
|
# imperative only support one gpu or xpu
|
|
|
|
|
if self._device_type != "":
|
2023-09-22 10:14:38 +08:00
|
|
|
FLAGS_selected_custom_devices = (
|
|
|
|
|
f'FLAGS_selected_{self._device_type}s'
|
2023-02-09 16:07:05 +08:00
|
|
|
)
|
|
|
|
|
selected_custom_devices = os.getenv(
|
|
|
|
|
FLAGS_selected_custom_devices, "0"
|
|
|
|
|
).split(",")
|
|
|
|
|
self._device_id = int(selected_custom_devices[0])
|
|
|
|
|
else:
|
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
|
selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
|
|
|
|
|
self._device_id = int(selected_gpus[0])
|
|
|
|
|
elif core.is_compiled_with_xpu():
|
|
|
|
|
selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
|
|
|
|
|
self._device_id = int(selected_xpus[0])
|
|
|
|
|
|
2023-07-22 18:19:33 +08:00
|
|
|
self._trainer_endpoints = getenv_or_backup(
|
2023-02-09 16:07:05 +08:00
|
|
|
"PADDLE_TRAINER_ENDPOINTS", ""
|
|
|
|
|
).split(",")
|
|
|
|
|
self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
|
|
|
|
|
self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
|
2025-08-21 02:00:58 +08:00
|
|
|
assert self._nrings > 0, (
|
|
|
|
|
"nccl_nrings must be an integer greater than 0."
|
|
|
|
|
)
|
|
|
|
|
assert self._nrings < 9, (
|
|
|
|
|
"nccl_nrings should be less than 9, which is enough in most scenarios."
|
|
|
|
|
)
|
2023-02-09 16:07:05 +08:00
|
|
|
|
|
|
|
|
@property
|
2024-07-27 16:46:07 +08:00
|
|
|
def rank(self) -> int:
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
Rank of current trainer.
|
|
|
|
|
|
|
|
|
|
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-02-09 16:07:05 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> # execute this command in terminal: export PADDLE_TRAINER_ID=0
|
|
|
|
|
>>> import paddle.distributed as dist
|
|
|
|
|
|
|
|
|
|
>>> env = dist.ParallelEnv()
|
|
|
|
|
>>> print("The rank is %d" % env.rank)
|
|
|
|
|
The rank is 0
|
2023-02-09 16:07:05 +08:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
return self._rank
|
|
|
|
|
|
|
|
|
|
@property
|
2024-07-27 16:46:07 +08:00
|
|
|
def world_size(self) -> int:
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
The number of trainers (number of processes participating in current job).
|
|
|
|
|
|
|
|
|
|
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-10-12 15:39:32 +08:00
|
|
|
|
|
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
|
|
|
|
|
>>> import paddle.distributed as dist
|
2023-02-09 16:07:05 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> env = dist.ParallelEnv()
|
|
|
|
|
>>> print("The world_size is %d" % env.world_size)
|
|
|
|
|
The world_size is 4
|
2023-02-09 16:07:05 +08:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
return self._world_size
|
|
|
|
|
|
|
|
|
|
@property
|
2024-07-27 16:46:07 +08:00
|
|
|
def device_id(self) -> int:
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
The ID of selected GPU card for parallel training.
|
|
|
|
|
|
|
|
|
|
Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-02-09 16:07:05 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> # execute this command in terminal: export FLAGS_selected_gpus=1
|
|
|
|
|
>>> import paddle.distributed as dist
|
2023-02-09 16:07:05 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> env = dist.ParallelEnv()
|
|
|
|
|
>>> print("The device id are %d" % env.device_id)
|
|
|
|
|
The device id are 1
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
return self._device_id
|
|
|
|
|
|
|
|
|
|
@property
|
2024-07-27 16:46:07 +08:00
|
|
|
def device_type(self) -> str:
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
The type of custom device for parallel training.
|
|
|
|
|
|
2025-12-20 10:21:50 +08:00
|
|
|
Its value is equal to the value of paddle.device.get_all_custom_device_type() . The default value is None.
|
2023-02-09 16:07:05 +08:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
return self._device_type
|
|
|
|
|
|
|
|
|
|
@property
|
2024-07-27 16:46:07 +08:00
|
|
|
def current_endpoint(self) -> str:
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
The endpoint of current trainer, it is in the form of (node IP + port).
|
|
|
|
|
|
|
|
|
|
Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-02-09 16:07:05 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
|
|
|
|
|
>>> import paddle.distributed as dist
|
2023-02-09 16:07:05 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> env = dist.ParallelEnv()
|
|
|
|
|
>>> print("The current endpoint are %s" % env.current_endpoint)
|
|
|
|
|
The current endpoint are 127.0.0.1:6170
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
return self._current_endpoint
|
|
|
|
|
|
|
|
|
|
@property
|
2024-07-27 16:46:07 +08:00
|
|
|
def trainer_endpoints(self) -> list[str]:
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
The endpoints of all trainer nodes in the task,
|
|
|
|
|
which are used to broadcast the NCCL ID when NCCL2 is initialized.
|
|
|
|
|
|
|
|
|
|
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-10-12 15:39:32 +08:00
|
|
|
|
|
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
|
|
|
|
|
>>> import paddle.distributed as dist
|
2023-02-09 16:07:05 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> env = dist.ParallelEnv()
|
|
|
|
|
>>> print("The trainer endpoints are %s" % env.trainer_endpoints)
|
|
|
|
|
The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
|
2023-02-09 16:07:05 +08:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
return self._trainer_endpoints
|
|
|
|
|
|
|
|
|
|
@property
|
2024-07-27 16:46:07 +08:00
|
|
|
def nrings(self) -> int:
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
Nrings of current trainer.
|
|
|
|
|
|
|
|
|
|
Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-02-09 16:07:05 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> # execute this command in terminal: export FLAGS_nccl_nrings=1
|
|
|
|
|
>>> import paddle.distributed as dist
|
2023-02-09 16:07:05 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> env = dist.ParallelEnv()
|
|
|
|
|
>>> print("The nrings is %d" % env.nrings)
|
|
|
|
|
The nrings is 1
|
2023-02-09 16:07:05 +08:00
|
|
|
"""
|
|
|
|
|
return self._nrings
|
|
|
|
|
|
2023-10-23 21:07:51 +08:00
|
|
|
@property
|
2024-07-27 16:46:07 +08:00
|
|
|
def pg_timeout(self) -> int:
|
2023-10-23 21:07:51 +08:00
|
|
|
"""
|
|
|
|
|
timeout of process group.
|
|
|
|
|
|
|
|
|
|
Its value is equal to the value of the environment variable ``PADDLE_PG_TIMEOUT`` . The default value is 30 minutes.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2023-10-23 21:07:51 +08:00
|
|
|
|
2023-11-03 10:29:27 +08:00
|
|
|
>>> # execute this command in terminal: export PADDLE_PG_TIMEOUT=1800000
|
|
|
|
|
>>> import paddle.distributed as dist
|
2023-10-23 21:07:51 +08:00
|
|
|
|
2023-11-03 10:29:27 +08:00
|
|
|
>>> env = dist.ParallelEnv()
|
|
|
|
|
>>> # the pg_timeout of process group 1800000
|
2023-10-23 21:07:51 +08:00
|
|
|
"""
|
|
|
|
|
return self._pg_timeout
|
|
|
|
|
|
2023-02-09 16:07:05 +08:00
|
|
|
# [aliases] Compatible with old method names
|
|
|
|
|
local_rank = rank
|
|
|
|
|
nranks = world_size
|
|
|
|
|
dev_id = device_id
|
|
|
|
|
|
|
|
|
|
|
2020-11-24 21:21:38 +08:00
|
|
|
def _get_global_parallel_env():
|
|
|
|
|
global _global_parallel_env
|
|
|
|
|
if _global_parallel_env is None:
|
|
|
|
|
_global_parallel_env = ParallelEnv()
|
|
|
|
|
return _global_parallel_env
|
|
|
|
|
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2020-12-08 14:51:40 +08:00
|
|
|
def _start_kv_server(port, http_server_d, size):
|
2020-09-29 10:56:28 +08:00
|
|
|
from paddle.distributed.fleet.utils.http_server import KVServer
|
2022-10-23 20:01:27 +08:00
|
|
|
|
2020-12-08 14:51:40 +08:00
|
|
|
http_server = KVServer(int(port), size=size)
|
2020-09-29 10:56:28 +08:00
|
|
|
http_server.start()
|
2020-12-08 14:51:40 +08:00
|
|
|
wait_seconds = 3
|
2020-11-27 14:37:40 +08:00
|
|
|
while http_server_d.get("running", False) or not http_server.should_stop():
|
2020-09-29 10:56:28 +08:00
|
|
|
time.sleep(wait_seconds)
|
|
|
|
|
http_server.stop()
|
|
|
|
|
|
|
|
|
|
|
2021-10-21 14:07:13 +08:00
|
|
|
def _is_cpuonly(backend):
|
|
|
|
|
check_backend(backend)
|
2022-10-23 20:01:27 +08:00
|
|
|
if (
|
2025-03-31 11:56:20 +08:00
|
|
|
backend in ['auto', 'nccl', 'bkcl', 'heter', 'flagcx']
|
2023-04-06 11:35:33 +08:00
|
|
|
and (core.is_compiled_with_cuda() or core.is_compiled_with_xpu())
|
2022-10-23 20:01:27 +08:00
|
|
|
) or backend == 'xccl':
|
2021-09-08 12:46:31 +08:00
|
|
|
# passes 'auto' and can use cuda or xpu, use the default logics. so return False
|
|
|
|
|
return False
|
|
|
|
|
else:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
2021-12-06 09:01:15 +08:00
|
|
|
def _check_var_exists(var_name):
|
2023-07-22 18:19:33 +08:00
|
|
|
var = getenv_or_backup(var_name, None)
|
2021-12-06 09:01:15 +08:00
|
|
|
if var is None:
|
2022-10-23 20:01:27 +08:00
|
|
|
raise ValueError(
|
|
|
|
|
"paddle.distributed initialize error, "
|
2024-06-30 19:16:03 +08:00
|
|
|
f"environment variable {var_name} is needed, but not set."
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2021-12-06 09:01:15 +08:00
|
|
|
|
|
|
|
|
|
2023-04-25 15:34:04 +08:00
|
|
|
def _get_modified_flags():
|
|
|
|
|
ret = []
|
|
|
|
|
FLAGS = namedtuple('FLAGS', ['name', 'current_value', 'default_value'])
|
|
|
|
|
global_flags = core.globals()
|
|
|
|
|
for key in global_flags.keys():
|
|
|
|
|
value = global_flags.get(key)
|
|
|
|
|
default_value = global_flags.get_default(key)
|
|
|
|
|
if not value == default_value:
|
|
|
|
|
ret.append(FLAGS(key, value, default_value))
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _print_modified_flags(modified_flags):
|
|
|
|
|
if len(modified_flags) > 0:
|
|
|
|
|
sys.stderr.write(
|
|
|
|
|
"======================= Modified FLAGS detected =======================\n"
|
|
|
|
|
)
|
|
|
|
|
for flag in modified_flags:
|
|
|
|
|
sys.stderr.write(str(flag))
|
|
|
|
|
sys.stderr.write("\n")
|
|
|
|
|
sys.stderr.write(
|
|
|
|
|
"=======================================================================\n"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-07-02 14:17:24 +08:00
|
|
|
def init_parallel_env(nccl_config: NCCLConfig | None = None) -> Group:
|
2020-08-28 14:46:28 +08:00
|
|
|
"""
|
[Docs]fix math api en docs issue (#47448)
* fix_docx_stanh
* fix einsum api en docs issue
* fix model api en docs issue
* for codestyle
* fix_einsum.py_einsum, test=document_fix
* fix_model.py_Model, test=ducument_fix
* fix_creation.py_meshgrid, test=document_fix
* fix_linalg.py_slogdet, test=document_fix
* fix_loss.py_SoftMarginLoss_CrossEntropyLoss_NLLLoss_BCELoss, test=document_fix
* norm.py_SyncBatchNorm, test=document-fix
* norm.py_SyncBatchNorm, test=document_fix
* norm.py_SyncBatchNorm, test=document_fix
* list18-30, test=document_fix
* refix_list1-15, test=document_fix
* deletefiles, test=document_fix
* fixedapi_pre-commit, test=document_fix
* fix_list31-45, test=document_fix
* list111, test=document_fix
* some_fix, test=document_fix
* some_fix, test=document_fix
* somefix, test=document_fix
* somefix, test=document_fix
* refix, test=document_fix
* refix, test=document_fix
* refix, test=document_fix
* refix, test=document_fix
* rerfix, test=document_fix
Co-authored-by: Ligoml <limengliu@tiaozhan.com>
2022-11-22 17:31:35 +08:00
|
|
|
|
2020-08-31 21:08:02 +08:00
|
|
|
Initialize parallel training environment in dynamic graph mode.
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2022-09-22 18:51:39 +08:00
|
|
|
Note:
|
2020-11-24 21:21:38 +08:00
|
|
|
Now initialize both `NCCL` and `GLOO` contexts for communication.
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2021-09-08 12:46:31 +08:00
|
|
|
Args:
|
|
|
|
|
backend (string): A string represents the backend used by DataParallel,
|
|
|
|
|
should be one of 'gloo'(for cpu), 'nccl'(for cuda), 'bkcl'(for xpu), 'auto'(auto detect).
|
|
|
|
|
The auto detection prefer 'nccl', 'bkcl' than 'gloo'.
|
|
|
|
|
|
2020-08-28 14:46:28 +08:00
|
|
|
Returns:
|
|
|
|
|
None
|
2022-09-14 21:56:19 +08:00
|
|
|
|
2020-08-28 14:46:28 +08:00
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
[Docs]fix math api en docs issue (#47448)
* fix_docx_stanh
* fix einsum api en docs issue
* fix model api en docs issue
* for codestyle
* fix_einsum.py_einsum, test=document_fix
* fix_model.py_Model, test=ducument_fix
* fix_creation.py_meshgrid, test=document_fix
* fix_linalg.py_slogdet, test=document_fix
* fix_loss.py_SoftMarginLoss_CrossEntropyLoss_NLLLoss_BCELoss, test=document_fix
* norm.py_SyncBatchNorm, test=document-fix
* norm.py_SyncBatchNorm, test=document_fix
* norm.py_SyncBatchNorm, test=document_fix
* list18-30, test=document_fix
* refix_list1-15, test=document_fix
* deletefiles, test=document_fix
* fixedapi_pre-commit, test=document_fix
* fix_list31-45, test=document_fix
* list111, test=document_fix
* some_fix, test=document_fix
* some_fix, test=document_fix
* somefix, test=document_fix
* somefix, test=document_fix
* refix, test=document_fix
* refix, test=document_fix
* refix, test=document_fix
* refix, test=document_fix
* rerfix, test=document_fix
Co-authored-by: Ligoml <limengliu@tiaozhan.com>
2022-11-22 17:31:35 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:GPU, env:DISTRIBUTED)
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.nn as nn
|
|
|
|
|
>>> import paddle.optimizer as opt
|
|
|
|
|
>>> import paddle.distributed as dist
|
|
|
|
|
|
|
|
|
|
>>> class LinearNet(nn.Layer):
|
|
|
|
|
... def __init__(self):
|
|
|
|
|
... super().__init__()
|
|
|
|
|
... self._linear1 = nn.Linear(10, 10)
|
|
|
|
|
... self._linear2 = nn.Linear(10, 1)
|
2026-02-18 16:12:40 +08:00
|
|
|
...
|
2023-10-12 15:39:32 +08:00
|
|
|
... def forward(self, x):
|
|
|
|
|
... return self._linear2(self._linear1(x))
|
|
|
|
|
|
|
|
|
|
>>> def train():
|
|
|
|
|
... # 1. initialize parallel environment
|
|
|
|
|
... dist.init_parallel_env()
|
|
|
|
|
... # 2. create data parallel layer & optimizer
|
|
|
|
|
... layer = LinearNet()
|
|
|
|
|
... dp_layer = paddle.DataParallel(layer)
|
|
|
|
|
... loss_fn = nn.MSELoss()
|
2026-02-18 16:12:40 +08:00
|
|
|
... adam = opt.Adam(learning_rate=0.001, parameters=dp_layer.parameters())
|
2023-10-12 15:39:32 +08:00
|
|
|
... # 3. run layer
|
|
|
|
|
... inputs = paddle.randn([10, 10], 'float32')
|
|
|
|
|
... outputs = dp_layer(inputs)
|
|
|
|
|
... labels = paddle.randn([10, 1], 'float32')
|
|
|
|
|
... loss = loss_fn(outputs, labels)
|
|
|
|
|
... loss.backward()
|
|
|
|
|
... adam.step()
|
|
|
|
|
... adam.clear_grad()
|
|
|
|
|
|
|
|
|
|
>>> if __name__ == '__main__':
|
|
|
|
|
... dist.spawn(train)
|
[Docs]fix math api en docs issue (#47448)
* fix_docx_stanh
* fix einsum api en docs issue
* fix model api en docs issue
* for codestyle
* fix_einsum.py_einsum, test=document_fix
* fix_model.py_Model, test=ducument_fix
* fix_creation.py_meshgrid, test=document_fix
* fix_linalg.py_slogdet, test=document_fix
* fix_loss.py_SoftMarginLoss_CrossEntropyLoss_NLLLoss_BCELoss, test=document_fix
* norm.py_SyncBatchNorm, test=document-fix
* norm.py_SyncBatchNorm, test=document_fix
* norm.py_SyncBatchNorm, test=document_fix
* list18-30, test=document_fix
* refix_list1-15, test=document_fix
* deletefiles, test=document_fix
* fixedapi_pre-commit, test=document_fix
* fix_list31-45, test=document_fix
* list111, test=document_fix
* some_fix, test=document_fix
* some_fix, test=document_fix
* somefix, test=document_fix
* somefix, test=document_fix
* refix, test=document_fix
* refix, test=document_fix
* refix, test=document_fix
* refix, test=document_fix
* rerfix, test=document_fix
Co-authored-by: Ligoml <limengliu@tiaozhan.com>
2022-11-22 17:31:35 +08:00
|
|
|
|
2020-08-28 14:46:28 +08:00
|
|
|
"""
|
|
|
|
|
|
2023-04-25 15:34:04 +08:00
|
|
|
modified_flags = _get_modified_flags()
|
|
|
|
|
_print_modified_flags(modified_flags)
|
|
|
|
|
|
2020-11-24 21:21:38 +08:00
|
|
|
# 0. get env & check world size
|
|
|
|
|
global _global_parallel_env
|
|
|
|
|
# when call init_parallel_env, need update `_global_parallel_env`
|
|
|
|
|
_global_parallel_env = ParallelEnv()
|
|
|
|
|
parallel_env = _global_parallel_env
|
|
|
|
|
# if not parallel, `init_parallel_env` do nothing
|
|
|
|
|
if parallel_env.world_size < 2:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"Currently not a parallel execution environment, `paddle.distributed.init_parallel_env` will not do anything."
|
|
|
|
|
)
|
|
|
|
|
return
|
2022-06-05 10:58:58 +08:00
|
|
|
# NOTE(xiongkun): support cpu gloo only, add this environment variable to
|
2024-02-19 16:10:53 +08:00
|
|
|
# enable cpu only gloo parallel training)
|
2021-10-21 14:07:13 +08:00
|
|
|
backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto')
|
2025-09-01 17:29:39 +08:00
|
|
|
# if we want to use flagcx as backend in xpu environment, we need to
|
|
|
|
|
# set backend to bkcl, and process_group_bkcl will internally invoke
|
|
|
|
|
# flagcx to perform communication tasks
|
|
|
|
|
if backend == "flagcx" and core.is_compiled_with_xpu():
|
|
|
|
|
os.environ['PADDLE_DISTRI_BACKEND'] = "bkcl"
|
|
|
|
|
backend = "bkcl"
|
2021-10-21 14:07:13 +08:00
|
|
|
is_cpu_only = _is_cpuonly(backend)
|
2022-06-05 10:58:58 +08:00
|
|
|
# 1. gpu xpu check, must be gpu or xpu,
|
2022-10-23 20:01:27 +08:00
|
|
|
if not (
|
|
|
|
|
is_cpu_only
|
|
|
|
|
or core.is_compiled_with_cuda()
|
|
|
|
|
or core.is_compiled_with_xpu()
|
2022-11-16 11:17:11 +08:00
|
|
|
or backend == "xccl"
|
2022-10-23 20:01:27 +08:00
|
|
|
):
|
2020-08-31 21:08:02 +08:00
|
|
|
raise NotImplementedError(
|
2022-10-23 20:01:27 +08:00
|
|
|
"If you want to use CPU-only version, please use 'gloo' as backend"
|
|
|
|
|
)
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2022-08-03 18:52:16 +08:00
|
|
|
if backend == "xccl":
|
2023-09-22 10:14:38 +08:00
|
|
|
FLAGS_selected_custom_devices = (
|
|
|
|
|
f'FLAGS_selected_{parallel_env.device_type}s'
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2022-08-03 18:52:16 +08:00
|
|
|
_check_var_exists(FLAGS_selected_custom_devices)
|
|
|
|
|
else:
|
|
|
|
|
if not is_cpu_only and core.is_compiled_with_cuda():
|
|
|
|
|
_check_var_exists("FLAGS_selected_gpus")
|
|
|
|
|
backend = "nccl" if backend == "auto" else backend
|
|
|
|
|
elif not is_cpu_only and core.is_compiled_with_xpu():
|
|
|
|
|
_check_var_exists('FLAGS_selected_xpus')
|
|
|
|
|
backend = "bkcl" if backend == "auto" else backend
|
2021-02-03 10:45:47 +08:00
|
|
|
|
2020-08-28 14:46:28 +08:00
|
|
|
_check_var_exists("PADDLE_TRAINER_ID")
|
|
|
|
|
_check_var_exists("PADDLE_CURRENT_ENDPOINT")
|
|
|
|
|
_check_var_exists("PADDLE_TRAINERS_NUM")
|
|
|
|
|
|
2022-04-02 11:12:58 +08:00
|
|
|
# NOTE(chenweihang): [ why config global place here? ]
|
|
|
|
|
# the dygraph mode will be set to default mode,
|
|
|
|
|
# users will not call `dygraph.guard` or `enable_dygraph`
|
|
|
|
|
# directly, if they want to switch default place,
|
|
|
|
|
# they need to call a function to change default place,
|
|
|
|
|
# here just set correctly place to users
|
2022-08-03 18:52:16 +08:00
|
|
|
if backend == "xccl":
|
2022-10-23 20:01:27 +08:00
|
|
|
place = core.CustomPlace(
|
|
|
|
|
parallel_env.device_type, parallel_env.device_id
|
|
|
|
|
)
|
2022-08-03 18:52:16 +08:00
|
|
|
elif is_cpu_only:
|
2022-04-02 11:12:58 +08:00
|
|
|
place = core.CPUPlace()
|
|
|
|
|
elif core.is_compiled_with_cuda():
|
|
|
|
|
place = core.CUDAPlace(parallel_env.device_id)
|
|
|
|
|
elif core.is_compiled_with_xpu():
|
|
|
|
|
place = core.XPUPlace(parallel_env.device_id)
|
|
|
|
|
_set_expected_place(place)
|
|
|
|
|
|
|
|
|
|
group = None
|
2023-01-09 10:10:32 +08:00
|
|
|
|
2023-05-22 20:56:38 +08:00
|
|
|
if backend in _valid_backend_list and in_dynamic_mode():
|
2022-04-06 12:52:52 +08:00
|
|
|
if _default_group_name in _get_group_map_by_name():
|
|
|
|
|
return _get_group_map_by_name()[_default_group_name]
|
|
|
|
|
_set_default_backend(backend)
|
2022-04-02 11:12:58 +08:00
|
|
|
rank = int(os.getenv("PADDLE_TRAINER_ID"))
|
|
|
|
|
world_size = int(os.getenv("PADDLE_TRAINERS_NUM"))
|
|
|
|
|
assert rank >= 0 and world_size > rank and world_size > 1, (
|
|
|
|
|
"rank must be non-negative and world_size must be the "
|
|
|
|
|
"maximum rank plus one. Moreover, at least two processes are "
|
2022-10-23 20:01:27 +08:00
|
|
|
"required to create a process group."
|
|
|
|
|
)
|
2022-04-02 11:12:58 +08:00
|
|
|
master_addr = os.getenv("MASTER_ADDR", None)
|
|
|
|
|
master_port = os.getenv("MASTER_PORT", None)
|
2022-10-23 20:01:27 +08:00
|
|
|
endpoints = (
|
|
|
|
|
":".join([master_addr, master_port])
|
|
|
|
|
if master_addr and master_port
|
|
|
|
|
else None
|
|
|
|
|
)
|
2022-04-27 11:35:22 +08:00
|
|
|
if endpoints is None:
|
2022-04-02 11:12:58 +08:00
|
|
|
endpoints = os.getenv("PADDLE_MASTER", None)
|
|
|
|
|
if endpoints is None:
|
2023-07-22 18:19:33 +08:00
|
|
|
endpoints = getenv_or_backup("PADDLE_TRAINER_ENDPOINTS").split(',')[
|
|
|
|
|
0
|
|
|
|
|
]
|
2022-04-02 11:12:58 +08:00
|
|
|
assert endpoints, (
|
|
|
|
|
"The environment variable 'MASTER_ADDR' and 'MASTER_PORT' "
|
|
|
|
|
"must be specified, for example 'export MASTER_ADDR=127.0.0.1' "
|
|
|
|
|
"and 'export MASTER_ADDR=54612'. Or you can start your training"
|
2022-10-23 20:01:27 +08:00
|
|
|
"with paddle.distributed.run module."
|
|
|
|
|
)
|
2022-04-02 11:12:58 +08:00
|
|
|
master_addr, master_port = endpoints.split(":")
|
|
|
|
|
master_port = int(master_port)
|
|
|
|
|
is_master = rank == 0
|
2022-05-13 19:56:31 +08:00
|
|
|
stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900"))
|
2023-08-30 14:18:27 +08:00
|
|
|
default_store = core.create_or_get_global_tcp_store()
|
2022-04-06 12:52:52 +08:00
|
|
|
_set_default_store(default_store)
|
2024-10-22 15:26:58 +08:00
|
|
|
|
2025-03-31 11:56:20 +08:00
|
|
|
if backend in ["nccl", 'xccl', 'bkcl', 'flagcx']:
|
2024-10-22 15:26:58 +08:00
|
|
|
core.CommContextManager.set_device_id(parallel_env.device_id)
|
|
|
|
|
|
2025-07-02 14:17:24 +08:00
|
|
|
from paddle.distributed.fleet.base.topology import (
|
|
|
|
|
message2nccl_config,
|
|
|
|
|
)
|
|
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
pg = _new_process_group_impl(
|
|
|
|
|
backend,
|
|
|
|
|
default_store,
|
|
|
|
|
rank,
|
|
|
|
|
world_size,
|
|
|
|
|
_default_group_name,
|
|
|
|
|
pg_options=None,
|
2025-07-02 14:17:24 +08:00
|
|
|
nccl_config=message2nccl_config(
|
|
|
|
|
nccl_config,
|
|
|
|
|
"default",
|
|
|
|
|
),
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2022-04-02 11:12:58 +08:00
|
|
|
ranks = list(range(world_size))
|
2022-10-10 11:32:01 +08:00
|
|
|
group = Group(rank, 0, ranks, pg=pg, name=_default_group_name)
|
2022-04-06 12:52:52 +08:00
|
|
|
_set_group_map_by_name(_default_group_name, group)
|
|
|
|
|
_set_group_map(0, group)
|
2022-07-11 20:20:51 +08:00
|
|
|
_set_group_map_backend(group, backend)
|
2022-10-10 11:32:01 +08:00
|
|
|
_add_new_group(group)
|
2022-04-02 11:12:58 +08:00
|
|
|
parallel_helper._set_parallel_ctx(True)
|
|
|
|
|
return group
|
|
|
|
|
|
2023-03-23 10:17:12 +08:00
|
|
|
node_num = {i.split(":")[0] for i in parallel_env.trainer_endpoints}
|
2024-02-19 16:10:53 +08:00
|
|
|
# 3: init gloo context (step 1: httpserver start)
|
2020-12-31 14:33:35 +08:00
|
|
|
init_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0"))
|
2021-12-06 09:01:15 +08:00
|
|
|
if is_cpu_only or init_gloo or backend == "heter":
|
2020-12-31 14:33:35 +08:00
|
|
|
ep_rank_0 = parallel_env.trainer_endpoints[0].split(":")
|
|
|
|
|
manager = Manager()
|
2024-02-19 16:10:53 +08:00
|
|
|
# global dict to store status
|
2020-12-31 14:33:35 +08:00
|
|
|
http_server_d = manager.dict()
|
|
|
|
|
http_server_d["running"] = False
|
|
|
|
|
if parallel_env.rank == 0:
|
|
|
|
|
# The scope for worker used by http server is '_worker'
|
|
|
|
|
size = {'_worker': parallel_env.world_size}
|
2021-12-06 09:01:15 +08:00
|
|
|
if backend == "heter":
|
|
|
|
|
size = {'_worker': len(node_num)}
|
2022-10-23 20:01:27 +08:00
|
|
|
http_server = Process(
|
|
|
|
|
target=_start_kv_server,
|
|
|
|
|
args=(int(ep_rank_0[1]), http_server_d, size),
|
|
|
|
|
)
|
2020-12-31 14:33:35 +08:00
|
|
|
http_server.daemon = True
|
|
|
|
|
http_server_d["running"] = True
|
|
|
|
|
http_server.start()
|
2020-09-29 10:56:28 +08:00
|
|
|
|
|
|
|
|
# 4. init NCCL ParallelStrategy
|
2020-08-28 14:46:28 +08:00
|
|
|
strategy = ParallelStrategy()
|
2020-08-31 21:08:02 +08:00
|
|
|
if parallel_helper._is_parallel_ctx_initialized():
|
|
|
|
|
warnings.warn("The parallel environment has been initialized.")
|
2020-11-24 21:21:38 +08:00
|
|
|
strategy.nranks = parallel_env.world_size
|
|
|
|
|
strategy.local_rank = parallel_env.rank
|
|
|
|
|
strategy.trainer_endpoints = parallel_env.trainer_endpoints
|
|
|
|
|
strategy.current_endpoint = parallel_env.current_endpoint
|
2020-12-22 11:00:07 +08:00
|
|
|
strategy.nrings = parallel_env.nrings
|
2020-09-29 10:56:28 +08:00
|
|
|
|
2023-04-06 11:35:33 +08:00
|
|
|
# init nccl or bkcl or heter context
|
2021-09-08 12:46:31 +08:00
|
|
|
if is_cpu_only:
|
|
|
|
|
parallel_helper._set_parallel_ctx(
|
2022-10-23 20:01:27 +08:00
|
|
|
core.GLOOParallelContext(strategy, place)
|
|
|
|
|
)
|
|
|
|
|
elif backend == "heter":
|
2021-12-06 09:01:15 +08:00
|
|
|
parallel_helper._set_parallel_ctx(
|
2022-10-23 20:01:27 +08:00
|
|
|
core.HeterParallelContext(strategy, parallel_env.device_id)
|
|
|
|
|
)
|
2021-09-08 12:46:31 +08:00
|
|
|
elif core.is_compiled_with_cuda():
|
2021-02-03 10:45:47 +08:00
|
|
|
parallel_helper._set_parallel_ctx(
|
2022-10-23 20:01:27 +08:00
|
|
|
core.NCCLParallelContext(strategy, place)
|
|
|
|
|
)
|
2021-02-03 10:45:47 +08:00
|
|
|
elif core.is_compiled_with_xpu():
|
|
|
|
|
parallel_helper._set_parallel_ctx(
|
2022-10-23 20:01:27 +08:00
|
|
|
core.BKCLParallelContext(strategy, place)
|
|
|
|
|
)
|
2023-04-06 11:35:33 +08:00
|
|
|
|
2021-12-06 09:01:15 +08:00
|
|
|
if backend != "heter":
|
|
|
|
|
other_endpoints = strategy.trainer_endpoints[:]
|
|
|
|
|
other_endpoints.remove(strategy.current_endpoint)
|
|
|
|
|
if not is_cpu_only and strategy.local_rank == 0:
|
|
|
|
|
wait_server_ready(other_endpoints)
|
2021-07-16 20:18:16 +08:00
|
|
|
|
2020-08-31 21:08:02 +08:00
|
|
|
parallel_helper._init_parallel_ctx()
|
2021-12-06 09:01:15 +08:00
|
|
|
|
2020-11-16 11:19:28 +08:00
|
|
|
# 5: init gloo context (step 2: gloo init)
|
2024-02-19 16:10:53 +08:00
|
|
|
# dividing init_gloo into two part because nccl and gloo
|
2020-11-16 11:19:28 +08:00
|
|
|
# are separately looking for free ports which sometimes
|
|
|
|
|
# leads to port-conflict.
|
2021-12-06 09:01:15 +08:00
|
|
|
if (is_cpu_only or backend == "heter") and parallel_env.rank == 0:
|
2022-06-05 10:58:58 +08:00
|
|
|
# compare to init_gloo, we don't need to
|
2021-09-08 12:46:31 +08:00
|
|
|
# init gloo, because we do this in _init_parallel_ctx;
|
|
|
|
|
http_server_d["running"] = False
|
|
|
|
|
http_server.join()
|
2020-12-31 14:33:35 +08:00
|
|
|
|
2021-09-08 12:46:31 +08:00
|
|
|
elif init_gloo:
|
|
|
|
|
wait_server_ready([parallel_env.trainer_endpoints[0]])
|
2020-12-31 14:33:35 +08:00
|
|
|
gloo_strategy = core.GlooParallelStrategy()
|
|
|
|
|
gloo_strategy.rank = parallel_env.rank
|
|
|
|
|
gloo_strategy.rank_num = parallel_env.world_size
|
|
|
|
|
gloo_strategy.ip_address = ep_rank_0[0]
|
|
|
|
|
gloo_strategy.ip_port = int(ep_rank_0[1])
|
|
|
|
|
default_init_timeout_seconds = 3600
|
|
|
|
|
default_run_timeout_seconds = 9999999
|
|
|
|
|
gloo_strategy.init_seconds = default_init_timeout_seconds
|
|
|
|
|
gloo_strategy.run_seconds = default_run_timeout_seconds
|
|
|
|
|
gloo = core.GlooParallelContext(gloo_strategy)
|
|
|
|
|
gloo.init()
|
|
|
|
|
if parallel_env.rank == 0:
|
|
|
|
|
http_server_d["running"] = False
|
|
|
|
|
http_server.join()
|
2022-04-02 11:12:58 +08:00
|
|
|
return group
|
2020-11-16 11:19:28 +08:00
|
|
|
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2024-07-27 16:46:07 +08:00
|
|
|
def get_rank(group: Group | None = None) -> int:
|
2020-08-28 14:46:28 +08:00
|
|
|
"""
|
2022-09-08 11:32:20 +08:00
|
|
|
Returns the rank of current trainer in the given group, ranks are consecutive integers in [0, ``world_size``).
|
|
|
|
|
If none of the group is given, the global group will be used as default.
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2022-09-08 11:32:20 +08:00
|
|
|
Args:
|
|
|
|
|
group (Group, optional): The communication group you want to get rank of current trainer, use global group as default if group is None.
|
2020-08-28 14:46:28 +08:00
|
|
|
|
|
|
|
|
Returns:
|
2022-09-08 11:32:20 +08:00
|
|
|
(int) The rank of current trainer in the given group. Return -1 if the process is not part of the given group.
|
|
|
|
|
|
|
|
|
|
Warning:
|
|
|
|
|
Argument ``group`` only supports in dygraph mode.
|
2020-08-28 14:46:28 +08:00
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> # Execute this script using distributed launch with one card configs.
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.distributed as dist
|
|
|
|
|
|
|
|
|
|
>>> dist.init_parallel_env()
|
|
|
|
|
>>> print("The rank is %d" % dist.get_rank())
|
|
|
|
|
The rank is 0
|
2020-08-28 14:46:28 +08:00
|
|
|
|
|
|
|
|
"""
|
2023-05-22 20:56:38 +08:00
|
|
|
if in_dynamic_mode() and group:
|
2022-09-08 11:32:20 +08:00
|
|
|
return group.rank
|
|
|
|
|
|
|
|
|
|
assert group is None, "Only support group argument in eager mode."
|
2020-11-24 21:21:38 +08:00
|
|
|
return _get_global_parallel_env().rank
|
2020-08-28 14:46:28 +08:00
|
|
|
|
|
|
|
|
|
2024-07-27 16:46:07 +08:00
|
|
|
def get_world_size(group: Group | None = None) -> int:
|
2020-08-28 14:46:28 +08:00
|
|
|
"""
|
2022-09-08 11:32:20 +08:00
|
|
|
Returns the number of trainers (number of processes participating in current job) in the given group.
|
|
|
|
|
If none of the group is given, the global group will be used as default.
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2022-09-08 11:32:20 +08:00
|
|
|
Args:
|
|
|
|
|
group (Group, optional): The communication group you want to check world size, use global group as default if group is None.
|
2020-08-28 14:46:28 +08:00
|
|
|
|
|
|
|
|
Returns:
|
2022-09-08 11:32:20 +08:00
|
|
|
(int) The number of trainers in the given group. Return -1 if the process if not part of the given group.
|
|
|
|
|
|
|
|
|
|
Warning:
|
|
|
|
|
Argument ``group`` only supports in dygraph mode.
|
2020-08-28 14:46:28 +08:00
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 16:12:40 +08:00
|
|
|
.. code-block:: pycon
|
2020-08-28 14:46:28 +08:00
|
|
|
|
2023-10-12 15:39:32 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> # Execute this script using distributed launch with one card configs.
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.distributed as dist
|
|
|
|
|
|
|
|
|
|
>>> dist.init_parallel_env()
|
|
|
|
|
>>> print("The world_size is %d" % dist.get_world_size())
|
|
|
|
|
The world_size is 1
|
2020-08-28 14:46:28 +08:00
|
|
|
|
|
|
|
|
"""
|
2023-06-19 16:26:20 +08:00
|
|
|
if in_dynamic_mode() and (group is None):
|
|
|
|
|
if is_initialized():
|
|
|
|
|
group = _get_global_group()
|
|
|
|
|
|
2023-05-22 20:56:38 +08:00
|
|
|
if in_dynamic_mode() and group:
|
2022-09-08 11:32:20 +08:00
|
|
|
return group.world_size
|
|
|
|
|
|
|
|
|
|
assert group is None, "Only support group argument in eager mode."
|
2020-11-24 21:21:38 +08:00
|
|
|
return _get_global_parallel_env().world_size
|