2021-08-11 15:20:25 +08:00
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2024-08-07 21:27:40 +08:00
|
|
|
from __future__ import annotations
|
2021-08-11 15:20:25 +08:00
|
|
|
|
2023-11-10 17:04:51 +08:00
|
|
|
from functools import reduce
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
2022-09-16 17:12:38 +08:00
|
|
|
import paddle
|
2023-08-24 14:00:34 +08:00
|
|
|
from paddle.framework import core
|
2022-11-29 18:50:04 +08:00
|
|
|
|
|
|
|
|
from .process_mesh import ProcessMesh, get_current_process_mesh
|
2023-05-30 14:07:49 +08:00
|
|
|
from .static.dist_context import get_default_distributed_context
|
|
|
|
|
from .static.dist_op import DistributedOperatorHelper
|
|
|
|
|
from .static.dist_tensor import DistributedTensor
|
|
|
|
|
from .static.utils import (
|
2022-11-09 14:45:44 +08:00
|
|
|
__no_shape_var_type__,
|
2022-11-29 18:50:04 +08:00
|
|
|
convert_to_dims_mapping,
|
|
|
|
|
verify_shard_spec,
|
2022-11-09 14:45:44 +08:00
|
|
|
)
|
2021-08-11 15:20:25 +08:00
|
|
|
|
|
|
|
|
|
2022-09-15 20:35:52 +08:00
|
|
|
def shard_tensor(x, process_mesh=None, shard_spec=None):
|
2021-08-11 15:20:25 +08:00
|
|
|
"""
|
2022-09-15 20:35:52 +08:00
|
|
|
Shard a tensor on a process mesh according to the shard specification.
|
2021-08-11 15:20:25 +08:00
|
|
|
|
|
|
|
|
Args:
|
2021-10-29 11:20:04 +08:00
|
|
|
x (Tensor): the tensor to be sharded.
|
2022-09-15 20:35:52 +08:00
|
|
|
process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh
|
|
|
|
|
topology of the used logical processes where the tensor is sharded. If it is None,
|
2022-09-16 17:12:38 +08:00
|
|
|
the found current process mesh will be used. And an error will be raised if the
|
2022-09-15 20:35:52 +08:00
|
|
|
current process mesh cannot be found. Default: None.
|
|
|
|
|
shard_spec (list, optional): a list to describe the sharding mapping between `x` and `process_mesh`,
|
|
|
|
|
which means the dimension `i` of `x` is split across the dimension `shard_spec[i]` of `process_mesh`,
|
[CodeStyle][Typos][D-2,F-4,F-18,F-25,I-43,O-3,O-6,O-8,S-20,T-10,T-22,U-14,W-4,W-8,W-12,W-18] Ignore 1-3 letter words to reduce false positives (#70623)
2025-01-04 19:47:58 +08:00
|
|
|
where `None` means that tensor dimension is not split. For example, given a tensor with
|
2022-09-15 20:35:52 +08:00
|
|
|
the shape [6, 12] and a process mesh with the shape [2, 3] and the dimension names ["x", "y"]:
|
|
|
|
|
If `shard_spec=["x", "y"]`, each shard of the tensor will have a shape [3, 4];
|
|
|
|
|
If `shard_spec=["y", "x"]`, each shard of the tensor will have a shape [2, 6];
|
|
|
|
|
If `shard_spec=["x", None]`, each shard of the tensor will have a shape [3, 12];
|
|
|
|
|
If `shard_spec=[None, "x"]`, each shard of the tensor will have a shape [6, 4];
|
|
|
|
|
If `shard_spec=["y", None]`, each shard of the tensor will have a shape [2, 12];
|
|
|
|
|
If `shard_spec=[None, "y"]`, each shard of the tensor will have a shape [6, 4];
|
|
|
|
|
If `shard_spec=[None, None]`, each shard of the tensor will have a shape [6, 12];
|
|
|
|
|
If the `shard_spec` is None, the tensor will be replicated across all the processes of `process_mesh`.
|
|
|
|
|
In the above example, the `shard_spec=None` is same as 'shard_spec=[None, None]'. Defaults: None.
|
2021-08-11 15:20:25 +08:00
|
|
|
|
|
|
|
|
Returns:
|
2022-09-16 17:12:38 +08:00
|
|
|
Tensor: the tensor `x` annotated with sharding information.
|
2021-08-11 15:20:25 +08:00
|
|
|
|
|
|
|
|
Examples:
|
2026-02-20 18:49:20 +08:00
|
|
|
.. code-block:: pycon
|
2021-08-11 15:20:25 +08:00
|
|
|
|
2023-08-29 10:38:53 +08:00
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> from paddle.distributed.fleet import auto
|
2021-08-11 15:20:25 +08:00
|
|
|
|
2023-08-29 10:38:53 +08:00
|
|
|
>>> mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
|
|
|
|
|
>>> x = paddle.ones([4, 6])
|
|
|
|
|
>>> shard_spec = ["x", "y"]
|
|
|
|
|
>>> auto.shard_tensor(x, mesh, shard_spec)
|
2021-08-11 15:20:25 +08:00
|
|
|
|
|
|
|
|
"""
|
2022-09-15 20:35:52 +08:00
|
|
|
|
|
|
|
|
if process_mesh is not None:
|
2025-08-19 14:06:48 +08:00
|
|
|
assert isinstance(process_mesh, core.ProcessMesh), (
|
|
|
|
|
f"Argument process_mesh {process_mesh} is not an instance of ProcessMesh"
|
|
|
|
|
)
|
2022-09-15 20:35:52 +08:00
|
|
|
else:
|
|
|
|
|
process_mesh = get_current_process_mesh()
|
2025-08-19 14:06:48 +08:00
|
|
|
assert process_mesh is not None, (
|
|
|
|
|
"Specify the process mesh argument or use ProcessMesh context manager first."
|
|
|
|
|
)
|
|
|
|
|
assert isinstance(shard_spec, list), (
|
|
|
|
|
f"Argument shard_spec {shard_spec} is not an instance of list"
|
|
|
|
|
)
|
2023-04-12 18:25:39 +08:00
|
|
|
if isinstance(x, str):
|
|
|
|
|
x = (
|
|
|
|
|
paddle.static.default_main_program()
|
|
|
|
|
.global_block()
|
|
|
|
|
._var_recursive(x)
|
|
|
|
|
)
|
|
|
|
|
dist_tensor = DistributedTensor(x)
|
|
|
|
|
else:
|
|
|
|
|
dist_tensor = DistributedTensor(x)
|
2022-09-15 20:35:52 +08:00
|
|
|
serial_tensor = dist_tensor.serial_tensor
|
|
|
|
|
dist_tensor.dist_attr.process_mesh = process_mesh
|
2022-11-09 14:45:44 +08:00
|
|
|
if serial_tensor.type in __no_shape_var_type__:
|
2022-09-15 20:35:52 +08:00
|
|
|
tensor_shape = []
|
|
|
|
|
else:
|
|
|
|
|
tensor_shape = serial_tensor.shape
|
|
|
|
|
if shard_spec is not None:
|
2025-10-21 10:21:02 +08:00
|
|
|
valid_dims = (
|
|
|
|
|
process_mesh.get_dim_names()
|
|
|
|
|
if hasattr(process_mesh, "get_dim_names")
|
|
|
|
|
else process_mesh.dim_names
|
|
|
|
|
)
|
|
|
|
|
for i, dim in enumerate(shard_spec):
|
|
|
|
|
if dim is not None and (
|
|
|
|
|
not isinstance(dim, str) or dim not in valid_dims
|
|
|
|
|
):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid shard_spec at index {i}: '{dim}' "
|
|
|
|
|
f"is not a valid dimension name in process_mesh {valid_dims}."
|
|
|
|
|
)
|
2025-08-19 14:06:48 +08:00
|
|
|
assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), (
|
|
|
|
|
f"For tensor {serial_tensor.name}, shard_spec {shard_spec} is invalid with tensor_shape {tensor_shape} and process_mesh {process_mesh}."
|
|
|
|
|
)
|
2022-09-15 20:35:52 +08:00
|
|
|
dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping(
|
2022-10-23 20:01:27 +08:00
|
|
|
shard_spec, process_mesh
|
|
|
|
|
)
|
2022-09-15 20:35:52 +08:00
|
|
|
if process_mesh is not None:
|
|
|
|
|
dist_tensor.dist_attr.mark_annotated("process_mesh")
|
|
|
|
|
if shard_spec is not None:
|
|
|
|
|
dist_tensor.dist_attr.mark_annotated("dims_mapping")
|
2021-10-29 11:20:04 +08:00
|
|
|
default_dist_ctx = get_default_distributed_context()
|
|
|
|
|
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
|
2022-09-15 20:35:52 +08:00
|
|
|
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x)
|
2023-04-12 18:25:39 +08:00
|
|
|
default_dist_ctx.add_process_mesh(process_mesh)
|
2021-08-11 15:20:25 +08:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
2023-11-29 19:07:37 +08:00
|
|
|
def shard_op(
|
|
|
|
|
op, process_mesh=None, in_shard_specs=None, out_shard_specs=None, **kwargs
|
|
|
|
|
):
|
2021-08-11 15:20:25 +08:00
|
|
|
"""
|
2022-09-15 20:35:52 +08:00
|
|
|
Shard an operation on a process mesh according to its input and output shard specification.
|
2021-08-11 15:20:25 +08:00
|
|
|
|
|
|
|
|
Args:
|
2022-09-15 20:35:52 +08:00
|
|
|
op (Callable): a callable operator or module to be sharded.
|
|
|
|
|
process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh
|
|
|
|
|
topology of the used logical processes where the op is sharded. All of its inputs and
|
|
|
|
|
outputs are sharded by this process mesh. If it is None, the found current process mesh
|
|
|
|
|
will be used. And an error will be raised if the current process mesh cannot be found.
|
|
|
|
|
Default: None.
|
|
|
|
|
in_shard_specs (list of list, optional): a list of list to describe the sharding specifications
|
2023-02-27 16:30:10 +08:00
|
|
|
for the inputs. Each item of `in_shard_specs` is a `shard_spec` between the corresponding input
|
|
|
|
|
and `process_mesh`. If one item is None, the corresponding input is replicated across all processes
|
|
|
|
|
If it is None, all inputs are replicated across all processes. Note that the length of the
|
2022-09-15 20:35:52 +08:00
|
|
|
`in_shard_specs` should be equal to the actual number of inputs when calling this operation.
|
|
|
|
|
Default: None.
|
|
|
|
|
out_shard_specs (list of list, optional): a list of list to describe the sharding specifications
|
2023-02-27 16:30:10 +08:00
|
|
|
for the outputs. Each item of `out_shard_specs` is a `shard_spec` between the corresponding output
|
|
|
|
|
and `process_mesh`. If one item is None, the corresponding output is replicated across all processes
|
|
|
|
|
If it is None, all outputs are replicated across all processes. Note that the length of the
|
2022-09-15 20:35:52 +08:00
|
|
|
`in_shard_specs` should be equal to the actual number of inputs when calling this operation.
|
|
|
|
|
Default: None. Default: None.
|
2021-08-11 15:20:25 +08:00
|
|
|
|
|
|
|
|
Returns:
|
2022-09-15 20:35:52 +08:00
|
|
|
Outputs of `op`, each of which is annotated with sharding information.
|
2021-08-11 15:20:25 +08:00
|
|
|
|
|
|
|
|
Examples:
|
2026-02-20 18:49:20 +08:00
|
|
|
.. code-block:: pycon
|
2021-08-11 15:20:25 +08:00
|
|
|
|
2023-08-29 10:38:53 +08:00
|
|
|
>>> import paddle
|
|
|
|
|
>>> from paddle.distributed.fleet import auto
|
|
|
|
|
|
|
|
|
|
>>> x = paddle.ones([4, 6])
|
|
|
|
|
>>> y = paddle.zeros([4, 6])
|
|
|
|
|
>>> mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
|
2026-02-20 18:49:20 +08:00
|
|
|
>>> dist_add = auto.shard_op(
|
|
|
|
|
... paddle.add,
|
|
|
|
|
... mesh,
|
|
|
|
|
... in_shard_specs=[["x", "y"], ["y", None]],
|
|
|
|
|
... out_shard_specs=[[None, "x"]],
|
|
|
|
|
... )
|
2023-08-29 10:38:53 +08:00
|
|
|
>>> dist_add(x, y)
|
2021-08-11 15:20:25 +08:00
|
|
|
|
|
|
|
|
"""
|
2022-09-15 20:35:52 +08:00
|
|
|
|
|
|
|
|
if process_mesh is not None:
|
2025-08-19 14:06:48 +08:00
|
|
|
assert isinstance(process_mesh, ProcessMesh), (
|
|
|
|
|
f"Argument process_mesh {process_mesh} is not an instance of ProcessMesh"
|
|
|
|
|
)
|
2022-09-15 20:35:52 +08:00
|
|
|
else:
|
|
|
|
|
process_mesh = get_current_process_mesh()
|
2025-08-19 14:06:48 +08:00
|
|
|
assert process_mesh is not None, (
|
|
|
|
|
"Specify the process mesh argument or use ProcessMesh context manager first."
|
|
|
|
|
)
|
2022-09-15 20:35:52 +08:00
|
|
|
in_dims_mappings = []
|
|
|
|
|
if in_shard_specs is not None:
|
2022-10-23 20:01:27 +08:00
|
|
|
assert all(
|
|
|
|
|
(isinstance(shard_spec, list) or shard_spec is None)
|
|
|
|
|
for shard_spec in in_shard_specs
|
2023-09-22 10:14:38 +08:00
|
|
|
), f"in_shard_spec {in_shard_specs} is not a list of list or None"
|
2022-09-15 20:35:52 +08:00
|
|
|
for shard_spec in in_shard_specs:
|
|
|
|
|
if shard_spec is not None:
|
|
|
|
|
in_dims_mappings.append(
|
2022-10-23 20:01:27 +08:00
|
|
|
convert_to_dims_mapping(shard_spec, process_mesh)
|
|
|
|
|
)
|
2022-09-15 20:35:52 +08:00
|
|
|
else:
|
|
|
|
|
in_dims_mappings.append(None)
|
|
|
|
|
out_dims_mappings = []
|
|
|
|
|
if out_shard_specs is not None:
|
2022-10-23 20:01:27 +08:00
|
|
|
assert all(
|
|
|
|
|
(isinstance(shard_spec, list) or shard_spec is None)
|
|
|
|
|
for shard_spec in out_shard_specs
|
2023-09-22 10:14:38 +08:00
|
|
|
), f"out_shard_spec {out_shard_specs} is not a list of list or None"
|
2022-09-15 20:35:52 +08:00
|
|
|
for shard_spec in out_shard_specs:
|
|
|
|
|
if shard_spec is not None:
|
|
|
|
|
out_dims_mappings.append(
|
2022-10-23 20:01:27 +08:00
|
|
|
convert_to_dims_mapping(shard_spec, process_mesh)
|
|
|
|
|
)
|
2022-09-15 20:35:52 +08:00
|
|
|
else:
|
|
|
|
|
out_dims_mappings.append(None)
|
2022-10-23 20:01:27 +08:00
|
|
|
op = DistributedOperatorHelper(
|
2023-11-29 19:07:37 +08:00
|
|
|
op, process_mesh, in_dims_mappings, out_dims_mappings, kwargs
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2022-09-15 20:35:52 +08:00
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
|
2022-11-18 14:57:43 +08:00
|
|
|
_g_recompute_idx = -1
|
|
|
|
|
|
|
|
|
|
|
2022-09-15 20:35:52 +08:00
|
|
|
def recompute(op):
|
2022-11-18 14:57:43 +08:00
|
|
|
global _g_recompute_idx
|
|
|
|
|
_g_recompute_idx += 1
|
|
|
|
|
|
2022-09-15 20:35:52 +08:00
|
|
|
class RecomputeOperator:
|
|
|
|
|
def __init__(self, op):
|
|
|
|
|
self._op = op
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
2024-12-05 22:39:41 +08:00
|
|
|
block = paddle.static.default_main_program().global_block()
|
|
|
|
|
rc_begin_id = len(block.ops)
|
|
|
|
|
|
2023-12-20 20:26:19 +08:00
|
|
|
with paddle.static.name_scope(
|
|
|
|
|
f'/auto_parallel/rc_{_g_recompute_idx}'
|
|
|
|
|
):
|
2023-12-22 17:32:00 +08:00
|
|
|
if paddle.base.dygraph.base.in_to_static_mode():
|
|
|
|
|
output = (
|
|
|
|
|
paddle.jit.dy2static.convert_call_func.convert_call(
|
|
|
|
|
self._op
|
|
|
|
|
)(*args, **kwargs)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
output = self._op(*args, **kwargs)
|
2022-09-15 20:35:52 +08:00
|
|
|
|
2024-12-05 22:39:41 +08:00
|
|
|
if paddle.framework.in_pir_mode():
|
|
|
|
|
block = paddle.static.default_main_program().global_block()
|
|
|
|
|
rc_end_id = len(block.ops)
|
|
|
|
|
for idx in range(rc_begin_id, rc_end_id):
|
|
|
|
|
rc_op = block.ops[idx]
|
|
|
|
|
rc_op.set_int_attr("fwd_recompute_id", _g_recompute_idx)
|
|
|
|
|
|
2022-09-15 20:35:52 +08:00
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
return RecomputeOperator(op)
|
|
|
|
|
|
|
|
|
|
|
2023-10-31 15:13:37 +08:00
|
|
|
def exclude_ops_in_recompute(run_function):
|
|
|
|
|
"""
|
2024-01-11 14:32:52 +08:00
|
|
|
Exclude some operators in recompute segments.
|
2023-10-31 15:13:37 +08:00
|
|
|
Args:
|
2024-01-11 14:32:52 +08:00
|
|
|
run_function (callable): The callable function to be excluded.
|
2023-10-31 15:13:37 +08:00
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
ExcludeOperator: The callable object.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
class ExcludeOperator:
|
|
|
|
|
def __init__(self, run_function):
|
|
|
|
|
self._run_function = run_function
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
2023-12-20 20:26:19 +08:00
|
|
|
with paddle.static.name_scope('/exclude_rc'):
|
2023-12-22 17:32:00 +08:00
|
|
|
if paddle.base.dygraph.base.in_to_static_mode():
|
|
|
|
|
output = (
|
|
|
|
|
paddle.jit.dy2static.convert_call_func.convert_call(
|
|
|
|
|
self._run_function
|
|
|
|
|
)(*args, **kwargs)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
output = self._run_function(*args, **kwargs)
|
2023-10-31 15:13:37 +08:00
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
return ExcludeOperator(run_function)
|
|
|
|
|
|
|
|
|
|
|
2022-09-27 16:59:43 +08:00
|
|
|
_g_collections = {}
|
|
|
|
|
|
|
|
|
|
|
2022-11-08 11:29:41 +08:00
|
|
|
class CollectionNames:
|
2022-09-27 16:59:43 +08:00
|
|
|
FETCHES = "fetches"
|
2022-10-10 16:00:10 +08:00
|
|
|
LOGGING = "logging"
|
2022-09-27 16:59:43 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_collection(name):
|
|
|
|
|
collection = _g_collections.get(name, None)
|
|
|
|
|
if collection is None:
|
|
|
|
|
collection = []
|
|
|
|
|
_g_collections[name] = collection
|
|
|
|
|
return _g_collections[name]
|
|
|
|
|
|
|
|
|
|
|
2022-10-12 19:29:06 +08:00
|
|
|
def add_to_collection(collection_name, value, name=None):
|
2022-09-27 16:59:43 +08:00
|
|
|
if collection_name not in _g_collections:
|
|
|
|
|
_g_collections[collection_name] = []
|
2022-10-12 19:29:06 +08:00
|
|
|
if name is not None:
|
2022-10-18 10:00:40 +08:00
|
|
|
for _, v in _g_collections[collection_name]:
|
2022-10-23 20:01:27 +08:00
|
|
|
if v == value:
|
|
|
|
|
return
|
2022-10-12 19:29:06 +08:00
|
|
|
_g_collections[collection_name].append((name, value))
|
2022-09-15 20:35:52 +08:00
|
|
|
else:
|
2022-10-18 10:00:40 +08:00
|
|
|
for _, v in _g_collections[collection_name]:
|
2022-10-23 20:01:27 +08:00
|
|
|
if v == value:
|
|
|
|
|
return
|
2022-10-10 16:00:10 +08:00
|
|
|
_g_collections[collection_name].append((None, value))
|
2022-09-15 20:35:52 +08:00
|
|
|
|
|
|
|
|
|
2022-10-10 16:00:10 +08:00
|
|
|
def fetch(tensor, name=None, logging=False):
|
2023-01-16 10:07:38 +08:00
|
|
|
if isinstance(tensor, paddle.static.Variable):
|
2022-12-21 20:10:03 +08:00
|
|
|
tensor = tensor.name
|
|
|
|
|
elif isinstance(tensor, str):
|
|
|
|
|
tensor = tensor
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(
|
2024-04-01 10:20:33 +08:00
|
|
|
f"Only support fetch `Variable` or `str`[`Variable`'s name], but got `{type(tensor)}`"
|
2022-12-21 20:10:03 +08:00
|
|
|
)
|
2022-09-27 16:59:43 +08:00
|
|
|
add_to_collection(CollectionNames.FETCHES, tensor, name)
|
2022-10-10 16:00:10 +08:00
|
|
|
if logging:
|
|
|
|
|
add_to_collection(CollectionNames.LOGGING, tensor, name)
|
2023-11-10 17:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
_g_mesh = None
|
|
|
|
|
|
|
|
|
|
|
2024-12-06 13:15:43 +08:00
|
|
|
def get_mesh() -> paddle.distributed.ProcessMesh:
|
|
|
|
|
"""
|
|
|
|
|
Get the global mesh set by set_mesh.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
mesh (paddle.distributed.ProcessMesh): the global mesh.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-20 18:49:20 +08:00
|
|
|
.. code-block:: pycon
|
2024-12-06 13:15:43 +08:00
|
|
|
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.distributed as dist
|
|
|
|
|
>>> mesh = dist.ProcessMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dim_names=["dp", "mp", "pp"])
|
|
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> dist.auto_parallel.set_mesh(mesh)
|
|
|
|
|
>>> mesh = dist.auto_parallel.get_mesh()
|
|
|
|
|
>>> # This case need to be executed in multi-card environment
|
|
|
|
|
>>> # python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 {test_case}.py
|
|
|
|
|
"""
|
2023-11-10 17:04:51 +08:00
|
|
|
global _g_mesh
|
|
|
|
|
return _g_mesh
|
|
|
|
|
|
|
|
|
|
|
2024-12-06 13:15:43 +08:00
|
|
|
def set_mesh(mesh: paddle.distributed.ProcessMesh) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Set the global mesh.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
mesh (paddle.distributed.ProcessMesh): global mesh to be set.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-20 18:49:20 +08:00
|
|
|
.. code-block:: pycon
|
2024-12-06 13:15:43 +08:00
|
|
|
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> import paddle.distributed as dist
|
|
|
|
|
>>> mesh = dist.ProcessMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dim_names=["dp", "mp", "pp"])
|
|
|
|
|
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
|
|
|
|
|
>>> dist.auto_parallel.set_mesh(mesh)
|
|
|
|
|
>>> # This case need to be executed in multi-card environment
|
|
|
|
|
>>> # python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 {test_case}.py
|
|
|
|
|
"""
|
2024-03-18 12:25:42 +08:00
|
|
|
global _g_mesh
|
|
|
|
|
_g_mesh = mesh
|
|
|
|
|
|
|
|
|
|
|
2024-08-07 21:27:40 +08:00
|
|
|
def create_mesh(mesh_dims: list[tuple[str, int]]):
|
2023-11-10 17:04:51 +08:00
|
|
|
"""
|
|
|
|
|
Create a global process_mesh for auto parallel.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
mesh_dims (list[tuple[str, int]]): A list of tuple, each element is (dim_name, dim_degree).
|
|
|
|
|
"""
|
|
|
|
|
global _g_mesh
|
|
|
|
|
dim_names = [mesh_dim[0] for mesh_dim in mesh_dims]
|
|
|
|
|
mesh_shape = [mesh_dim[1] for mesh_dim in mesh_dims]
|
|
|
|
|
mesh_arr = np.arange(0, reduce(lambda x, y: x * y, mesh_shape, 1)).reshape(
|
|
|
|
|
mesh_shape
|
|
|
|
|
)
|
|
|
|
|
_g_mesh = ProcessMesh(mesh_arr, dim_names)
|
|
|
|
|
return _g_mesh
|