# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import copy from contextlib import nullcontext from typing import TYPE_CHECKING, Any, TypeVar import paddle from paddle.base.data_feeder import convert_dtype from paddle.base.framework import convert_np_dtype_to_dtype_ from paddle.base.unique_name import ( UniqueNameGenerator, guard as UniqueNameGuard, ) from paddle.distributed.auto_parallel.placement_type import ( get_shard_spec, to_placements, ) from paddle.distributed.auto_parallel.static.dist_input_spec import ( DistributedInputSpec, ) from paddle.distributed.auto_parallel.static.utils import ( convert_to_dims_mapping, ) from paddle.jit.dy2static.utils import ( ALREADY_D2S, extract_tensor_dynamic_dims, graph_tracing_guard, ) from paddle.pir import is_fake_value from paddle.static import InputSpec from paddle.utils import flatten, is_sequence from .symbolic_shape.symbolic_value import SymbolicInt from .utils import ( Cache, Singleton, get_min_non_specialized_number, map_if_extend, meta_str, ) from .utils.exceptions import BreakGraphError, NullMetaBreak if TYPE_CHECKING: import numpy.typing as npt DynamicSymbolT = TypeVar("DynamicSymbolT") SOT_INFER_META_INNER_VAR = "___SOT_INFER_META_INNER_VAR" class DistInfo: def __init__(self, mesh=None, dims_mapping=None, local_shape=None): self.mesh = mesh self.dims_mapping = dims_mapping self.local_shape = local_shape @staticmethod def from_tensor(tensor: paddle.Tensor) -> DistInfo: assert isinstance(tensor, paddle.Tensor) and tensor.is_dist(), ( f"Expect a Tensor, but got a {type(tensor)}." ) mesh = tensor.process_mesh sharding_specs = get_shard_spec( mesh, tensor.placements, len(tensor.shape) ) dims_mapping = convert_to_dims_mapping(sharding_specs, mesh) local_shape = tensor._local_value().shape return DistInfo(mesh, dims_mapping, local_shape) @staticmethod def from_value(value: paddle.pir.Value) -> DistInfo: assert isinstance(value, paddle.pir.Value) and value.is_dist(), ( f"Expect a Value, but got a {type(value)}." ) return DistInfo( value.dist_attr().process_mesh, value.dist_attr().dims_mapping, value._local_shape, ) def __deepcopy__(self, memo): return DistInfo( mesh=copy.deepcopy(self.mesh), dims_mapping=copy.deepcopy(self.dims_mapping), local_shape=copy.deepcopy(self.local_shape), ) def __repr__(self) -> str: return f"DistInfo(mesh={self.mesh}, dims_mapping={self.dims_mapping}, local_shape={self.local_shape})" class MetaInfoOrNull: def __init__(self, meta: MetaInfo | None): self.meta = meta @staticmethod def null(): return MetaInfoOrNull(None) def is_null(self): return self.meta is None def unwrap_or_breakgraph(self): if self.meta is None: raise BreakGraphError(NullMetaBreak()) return self.meta def unwrap_unsafe(self): assert self.meta is not None, "MetaInfo is None" return self.meta def with_dynamic_axes( self, name: str, dynamic_axes: list[int] ) -> MetaInfoOrNull: if self.meta is None: return MetaInfoOrNull.null() return self.meta.with_dynamic_axes(name, dynamic_axes).wrap() def to_input_spec(self): if self.meta is None: return None return self.meta.to_input_spec() def guard_str(self): if self.meta is None: return "(Null)" return self.meta.guard_str() def __deepcopy__(self, memo): if self.meta is None: return MetaInfoOrNull(None) return MetaInfoOrNull(copy.deepcopy(self.meta)) @staticmethod def mix_axes(axes1: list[int], axes2: list[int]) -> list[int]: return sorted(set(axes1 + axes2)) @staticmethod def from_tensor( tensor: paddle.Tensor, *, dynamic_axes: list[int] | None = None ) -> MetaInfoOrNull: if not tensor._is_dense_tensor_hold_allocation(): return MetaInfoOrNull.null() assert isinstance(tensor, paddle.Tensor), ( "Expect a Tensor, but got a Value." ) assert -1 not in tensor.shape, ( "Tensor shape should not contain -1, maybe you pass a Value to from_tensor" ) user_specified_dynamic_axes = extract_tensor_dynamic_dims(tensor) dynamic_axes = dynamic_axes or [] dynamic_axes = MetaInfoOrNull.mix_axes( dynamic_axes, list(user_specified_dynamic_axes) ) shape = [ SymbolicInt(dim) if i in dynamic_axes else dim for i, dim in enumerate(tensor.shape) ] if tensor.is_dist(): dist_info = DistInfo.from_tensor(tensor) else: dist_info = None return MetaInfo( shape, tensor.dtype, tensor.stop_gradient, tensor.name, tensor.persistable, tensor.type, tensor.place, None, dist_info=dist_info, ).wrap() @staticmethod def from_numpy( nparray: npt.NDArray[Any], *, dynamic_axes: list[int] | None = None ) -> MetaInfoOrNull: dtype = convert_np_dtype_to_dtype_(nparray.dtype) dynamic_axes = dynamic_axes or [] shape = [ SymbolicInt() if i in dynamic_axes else dim for i, dim in enumerate(nparray.shape) ] return MetaInfo( shape, dtype, True, # stop_gradient None, None, # persistable None, None, None, dist_info=None, ).wrap() @staticmethod def from_value(value) -> MetaInfoOrNull: if is_fake_value(value): return MetaInfoOrNull.null() name = SOT_INFER_META_INNER_VAR shape = [SymbolicInt() if dim == -1 else dim for dim in value.shape] for dim in shape: if isinstance(dim, int): assert dim >= 0, ( "Dimensions must be non-negative integers or SymbolicInt. " f"Encountered value {dim} in shape {shape}." ) if isinstance(value, paddle.pir.Value) and value.is_dist(): dist_info = DistInfo.from_value(value) else: dist_info = None return MetaInfo( shape, value.dtype, value.stop_gradient, name, value.persistable, None, # type is not a unified attribute in dygraph and static mode. None, # We can't infer the right place in compile time. None, # there's no spec_name specified when from_value. dist_info=dist_info, ).wrap() def __repr__(self): if self.meta is None: return "MetaInfoOrNull(None)" return f"MetaInfoOrNull({self.meta})" def __eq__(self, other): if self.meta is None: return other.meta is None if other.meta is None: return False return self.meta == other.meta def __hash__(self): if self.meta is None: return hash(None) return hash(self.meta) class MetaInfo: shape: list[int | SymbolicInt] def __init__( self, shape, dtype, stop_gradient, name, persistable, type, place, spec_name=None, dist_info=None, ): assert -1 not in shape, ( "NOTE: Shape should not contain -1, consider convert it to SymbolicInt." ) self.name = name self.persistable = persistable self.type = type self.place = place self.shape = shape self.dtype = dtype self.stop_gradient = stop_gradient self.dist_info = dist_info self.spec_name = spec_name def wrap(self): return MetaInfoOrNull(self) def shape_with_special_symbol( self, dynamic_symbol: DynamicSymbolT = -1 ) -> list[int | DynamicSymbolT]: return [ dynamic_symbol if isinstance(dim, SymbolicInt) else dim for dim in self.shape ] def with_dynamic_axes(self, name: str, dynamic_axes: list[int]) -> MetaInfo: mixed_dynamic_axes = MetaInfoOrNull.mix_axes( self.dynamic_axes, dynamic_axes ) # NOTE(SigureMo): Make sure create a new shape list with dynamic axes. # We will create a new shape list variable lazily in the future. shape = [ ( SymbolicInt(dim) if ( i in mixed_dynamic_axes and not isinstance(dim, SymbolicInt) ) else dim ) for i, dim in enumerate(self.shape) ] return MetaInfo( shape, self.dtype, self.stop_gradient, self.name, self.persistable, self.type, self.place, spec_name=name, dist_info=self.dist_info, ) @property def dynamic_axes(self): return [ i for i, dim in enumerate(self.shape) if isinstance(dim, SymbolicInt) ] def is_inner_var(self): return self.name == SOT_INFER_META_INNER_VAR def is_dynamic_shape(self): """ if SymbolicInt in shape, return True else: return False """ return len(self.dynamic_axes) > 0 def to_input_spec(self) -> DistributedInputSpec | ConstrainedInputSpec: shape = self.shape_with_special_symbol(None) if self.dist_info is not None: placements = to_placements( self.dist_info.dims_mapping, self.dist_info.mesh ) return DistributedInputSpec( shape, dtype=self.dtype, stop_gradient=self.stop_gradient, mesh=self.dist_info.mesh, placements=placements, local_shape=self.dist_info.local_shape, ) else: return ConstrainedInputSpec( self.dynamic_axes, shape, dtype=self.dtype, name=self.spec_name, stop_gradient=self.stop_gradient, ) def guard_str(self): shape = self.shape_with_special_symbol(SymbolicInt()) return f"({shape}, {self.dtype}, {self.stop_gradient})" def __deepcopy__(self, memo): return MetaInfo( list(self.shape), self.dtype, self.stop_gradient, self.name, self.persistable, self.type, self.place, self.spec_name, dist_info=copy.deepcopy(self.dist_info), ) def __repr__(self): return meta_str(self.shape, self.dtype, self.stop_gradient) def __eq__(self, meta): return ( self.shape == meta.shape and self.dtype == meta.dtype and self.stop_gradient == meta.stop_gradient ) def __hash__(self): return hash((tuple(self.shape), self.dtype, self.stop_gradient)) class VariableCreator(metaclass=Singleton): """ We use the static graph Variable to infer the meta information of Tensor. This singleton class is used to create Variable for infer meta. """ def __init__(self): self.var_name_generator = UniqueNameGenerator(SOT_INFER_META_INNER_VAR) self.var_cache = {} self.main_program = paddle.static.Program() self.startup_program = paddle.static.Program() def gen_name(self, meta_or_null: MetaInfoOrNull): if meta_or_null.is_null(): return "null" meta = meta_or_null.unwrap_unsafe() name = f"{meta.dtype}_{meta.stop_gradient}_" name += "_".join(map(str, meta.shape)) return name def create_var(self, meta_or_null: MetaInfoOrNull): if meta_or_null.is_null(): return None meta = meta_or_null.unwrap_unsafe() shape = meta.shape_with_special_symbol(-1) with paddle.static.program_guard( self.main_program, self.startup_program ): var = paddle.static.input.data( name=self.gen_name(meta.wrap()), shape=shape, dtype=convert_dtype(meta.dtype), ) var.stop_gradient = meta.stop_gradient if meta.dist_info is not None: mesh = meta.dist_info.mesh placements = to_placements(meta.dist_info.dims_mapping, mesh) var = paddle._pir_ops.shard_tensor(var, mesh, placements) var.stop_gradient = meta.stop_gradient assert not isinstance(var, paddle.Tensor), ( "Expect a Variable, but got a Tensor." ) return var def get_variable(self, meta: MetaInfoOrNull, without_cache=False): var_feature_name = self.gen_name(meta) if without_cache: return self.create_var(meta) if var_feature_name not in self.var_cache: self.var_cache[var_feature_name] = self.create_var(meta) return self.var_cache[var_feature_name] def infer_meta(self, func, *args, **kwargs): with ( paddle.base.framework._dygraph_guard(None), UniqueNameGuard(self.var_name_generator), ): if func is paddle.distributed.shard_tensor: args, kwargs = ( convert_meta_to_variable(args, without_cache=True), convert_meta_to_variable(kwargs, without_cache=True), ) else: args, kwargs = ( convert_meta_to_variable(args), convert_meta_to_variable(kwargs), ) graph_tracing_context_manager = nullcontext() with paddle.static.program_guard( self.main_program, self.startup_program ): if isinstance(func, str): # TODO(Aurelius84): Is length of args always greater than 0? # Do we need add condition check here? func = getattr(args[0], func) args = args[1:] if hasattr(func, ALREADY_D2S): graph_tracing_context_manager = graph_tracing_guard( self.main_program ) with graph_tracing_context_manager: out = func(*args, **kwargs) return convert_variable_to_meta_info(out) def convert_meta_to_variable(args, without_cache=False): return map_if_extend( args, pred=lambda x: isinstance(x, MetaInfoOrNull), true_fn=lambda x: VariableCreator().get_variable( x, without_cache=without_cache ), false_fn=lambda x: x, ) def convert_meta_to_input_spec(args): return map_if_extend( args, pred=lambda x: isinstance(x, MetaInfoOrNull), true_fn=lambda x: x.to_input_spec(), # TODO(xiongkun): can x be tensor ? false_fn=lambda x: ( paddle.static.InputSpec.from_tensor(x) if isinstance(x, paddle.Tensor) else x ), ) def convert_variable_to_meta_info(args): return map_if_extend( args, pred=lambda x: isinstance(x, paddle.pir.Value), true_fn=lambda x: MetaInfoOrNull.from_value(x), false_fn=lambda x: x, ) def infer_meta(func, *args, **kwargs): fn = SpecialInferMeta().get_infermeta_fn(func) if fn: return fn(*args, **kwargs) return VariableCreator().infer_meta(func, *args, **kwargs) def infer_meta_for_layer(layer, *args, **kwargs): assert isinstance(layer, paddle.nn.Layer), ( f"Expect a Layer, but got {layer}." ) layer = paddle.jit.to_static(layer, full_graph=True) args_, kwargs_ = convert_meta_to_input_spec((args, kwargs)) ( concrete_program, partial_program_layer, ) = layer.forward.get_concrete_program(*args_, **kwargs_) output_values = partial_program_layer._outputs.var_list out = partial_program_layer._restore_out( [ x for x in paddle.utils.flatten( convert_variable_to_meta_info(output_values) ) if isinstance(x, MetaInfoOrNull) ] ) layer.forward.rollback() return out def ast_infer_meta(static_function, *args, **kwargs): args_, kwargs_ = convert_meta_to_input_spec((args, kwargs)) ( concrete_program, partial_program_layer, ) = static_function.get_concrete_program(*args_, **kwargs_) out = partial_program_layer._restore_out( [ x for x in paddle.utils.flatten( convert_variable_to_meta_info(concrete_program.outputs) ) if isinstance(x, MetaInfoOrNull) ] ) return out class SpecialInferMeta(metaclass=Singleton): """ There are some functions that cannot be inferred directly through static graph, and need to be implemented manually. This class is used to implement infer meta for these functions. """ def __init__(self): pass def get_infermeta_fn(self, fn): try: funcname = fn.__name__ return getattr(self, f"infermeta_{funcname}") except: pass return None def infermeta_grad( self, outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False, no_grad_vars=None, ): if not is_sequence(inputs): inputs = [inputs] return inputs class InferMetaCache(Cache, metaclass=Singleton): def __init__(self): super().__init__(copy=True) def key_fn( self, func, *args, **kwargs ): # args & kwargs have transformed to MetaInfo return ( func, tuple(flatten(args)), tuple(kwargs.keys()), tuple(flatten(kwargs)), ) def value_fn(self, func, *args, **kwargs): return infer_meta(func, *args, **kwargs) class LayerInferMetaCache(Cache, metaclass=Singleton): def __init__(self): super().__init__(copy=True) def key_fn(self, layer, *args, **kwargs): params = [ MetaInfoOrNull.from_tensor(x) for x in layer.parameters(include_sublayers=True) ] return ( layer, tuple(params), tuple(flatten(args)), tuple(kwargs.keys()), tuple(flatten(kwargs)), ) def value_fn(self, layer, *args, **kwargs): return infer_meta_for_layer(layer, *args, **kwargs) class ConstrainedInputSpec(InputSpec): def __init__(self, dynamic_axes: list[int], *args, **kwargs): self.ranges: list[ tuple[int, int | None, int | None] ] = [] # (idx of dim, min, max) super().__init__(*args, **kwargs) min_non_specialized_number = get_min_non_specialized_number() for i in dynamic_axes: self.ranges.append((i, min_non_specialized_number, None))