2020-08-20 13:53:18 +08:00
|
|
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
2024-06-13 12:43:10 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2023-08-16 19:53:23 +08:00
|
|
|
import os
|
2024-06-21 00:40:47 +08:00
|
|
|
from typing import TYPE_CHECKING, Any
|
2023-08-16 19:53:23 +08:00
|
|
|
|
2023-12-07 21:41:22 +08:00
|
|
|
import numpy as np
|
2024-06-21 00:40:47 +08:00
|
|
|
import numpy.typing as npt
|
|
|
|
|
from typing_extensions import Self
|
2023-12-07 21:41:22 +08:00
|
|
|
|
2020-08-27 18:20:05 +08:00
|
|
|
import paddle
|
2023-09-07 17:26:19 +08:00
|
|
|
from paddle.base import Variable, core
|
|
|
|
|
from paddle.base.data_feeder import check_type
|
2023-09-20 10:24:31 +08:00
|
|
|
from paddle.base.framework import (
|
|
|
|
|
convert_np_dtype_to_dtype_,
|
|
|
|
|
in_pir_mode,
|
|
|
|
|
static_only,
|
|
|
|
|
)
|
2023-09-07 17:26:19 +08:00
|
|
|
from paddle.base.layer_helper import LayerHelper
|
2023-11-09 16:23:09 +08:00
|
|
|
from paddle.base.libpaddle import DataType
|
2024-02-05 12:19:41 +08:00
|
|
|
from paddle.base.libpaddle.pir import (
|
|
|
|
|
get_current_insertion_point,
|
|
|
|
|
set_insertion_point,
|
|
|
|
|
)
|
2020-08-20 13:53:18 +08:00
|
|
|
|
2023-11-28 14:35:49 +08:00
|
|
|
from ..base.variable_index import _setitem_static
|
2023-07-19 10:31:56 +08:00
|
|
|
|
2024-06-21 00:40:47 +08:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from paddle import Tensor
|
2024-08-16 10:12:45 +08:00
|
|
|
from paddle._typing import (
|
|
|
|
|
DTypeLike,
|
|
|
|
|
ShapeLike,
|
|
|
|
|
Size1,
|
|
|
|
|
TensorIndex,
|
|
|
|
|
TensorLike,
|
|
|
|
|
)
|
2024-06-21 00:40:47 +08:00
|
|
|
|
2021-04-29 19:31:40 +08:00
|
|
|
__all__ = []
|
|
|
|
|
|
2020-08-20 13:53:18 +08:00
|
|
|
|
2023-08-23 10:55:44 +08:00
|
|
|
def evaluate_flag(val) -> bool:
|
|
|
|
|
return str(val).lower() not in ('false', 'off', '0', 'none')
|
|
|
|
|
|
|
|
|
|
|
2020-09-25 21:35:40 +08:00
|
|
|
@static_only
|
2024-06-13 12:43:10 +08:00
|
|
|
def data(
|
|
|
|
|
name: str,
|
|
|
|
|
shape: ShapeLike,
|
|
|
|
|
dtype: DTypeLike | None = None,
|
|
|
|
|
lod_level: int = 0,
|
|
|
|
|
) -> paddle.Tensor:
|
2020-08-22 16:54:26 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
This function creates a variable on the global block. The global variable
|
|
|
|
|
can be accessed by all the following operators in the graph. The variable
|
|
|
|
|
is a placeholder that could be fed with input, such as Executor can feed
|
|
|
|
|
input into the variable. When `dtype` is None, the dtype
|
|
|
|
|
will get from the global dtype by `paddle.get_default_dtype()`.
|
|
|
|
|
|
|
|
|
|
Args:
|
2024-06-13 12:43:10 +08:00
|
|
|
name (str): The name/alias of the variable, see :ref:`api_guide_Name`
|
|
|
|
|
for more details.
|
|
|
|
|
shape (list|tuple): List|Tuple of integers declaring the shape. You can
|
|
|
|
|
set None or -1 at a dimension to indicate the dimension can be of any
|
|
|
|
|
size. For example, it is useful to set changeable batch size as None or -1.
|
|
|
|
|
dtype (np.dtype|str, optional): The type of the data. Supported
|
|
|
|
|
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
|
|
|
|
|
uint8. Default: None. When `dtype` is not set, the dtype will get
|
|
|
|
|
from the global dtype by `paddle.get_default_dtype()`.
|
2024-11-20 11:46:17 +08:00
|
|
|
lod_level (int, optional): The LoD level of the DenseTensor. Usually users
|
2024-06-13 12:43:10 +08:00
|
|
|
don't have to set this value. Default: 0.
|
2020-08-22 16:54:26 +08:00
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Variable: The global variable that gives access to the data.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 01:44:50 +08:00
|
|
|
.. code-block:: pycon
|
2020-08-22 16:54:26 +08:00
|
|
|
|
2024-11-20 11:46:17 +08:00
|
|
|
>>> # doctest: +SKIP("This has diff in xdoctest env")
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> import numpy as np
|
|
|
|
|
>>> import paddle
|
|
|
|
|
>>> paddle.enable_static()
|
|
|
|
|
|
|
|
|
|
# Creates a variable with fixed size [3, 2, 1]
|
|
|
|
|
# User can only feed data of the same shape to x
|
|
|
|
|
# the dtype is not set, so it will set "float32" by
|
|
|
|
|
# paddle.get_default_dtype(). You can use paddle.get_default_dtype() to
|
|
|
|
|
# change the global dtype
|
|
|
|
|
>>> x = paddle.static.data(name='x', shape=[3, 2, 1])
|
|
|
|
|
|
|
|
|
|
# Creates a variable with changeable batch size -1.
|
|
|
|
|
# Users can feed data of any batch size into y,
|
|
|
|
|
# but size of each data sample has to be [2, 1]
|
|
|
|
|
>>> y = paddle.static.data(name='y', shape=[-1, 2, 1], dtype='float32')
|
|
|
|
|
|
|
|
|
|
>>> z = x + y
|
|
|
|
|
|
|
|
|
|
# In this example, we will feed x and y with np-ndarray "1"
|
|
|
|
|
# and fetch z, like implementing "1 + 1 = 2" in PaddlePaddle
|
2025-05-08 15:55:03 +08:00
|
|
|
>>> feed_data = np.ones(shape=[3, 2, 1], dtype=np.float32)
|
2023-08-25 14:17:30 +08:00
|
|
|
|
|
|
|
|
>>> exe = paddle.static.Executor(paddle.framework.CPUPlace())
|
2026-02-18 01:44:50 +08:00
|
|
|
>>> out = exe.run(
|
|
|
|
|
... paddle.static.default_main_program(),
|
|
|
|
|
... feed={
|
|
|
|
|
... 'x': feed_data,
|
|
|
|
|
... 'y': feed_data,
|
|
|
|
|
... },
|
|
|
|
|
... fetch_list=[z.name],
|
|
|
|
|
... )
|
2023-08-25 14:17:30 +08:00
|
|
|
|
|
|
|
|
# np-ndarray of shape=[3, 2, 1], dtype=float32, whose elements are 2
|
|
|
|
|
>>> print(out)
|
|
|
|
|
[array([[[2.],
|
|
|
|
|
[2.]],
|
2024-06-13 12:43:10 +08:00
|
|
|
[[2.],
|
2023-08-25 14:17:30 +08:00
|
|
|
[2.]],
|
2024-06-13 12:43:10 +08:00
|
|
|
[[2.],
|
2023-08-25 14:17:30 +08:00
|
|
|
[2.]]], dtype=float32)]
|
2020-08-22 16:54:26 +08:00
|
|
|
|
|
|
|
|
"""
|
2023-09-19 10:56:47 +08:00
|
|
|
|
|
|
|
|
def _reset_data_op_insertion_point():
|
2023-09-22 11:56:17 +08:00
|
|
|
default_main_program = paddle.pir.core.default_main_program()
|
2023-09-19 10:56:47 +08:00
|
|
|
ops = default_main_program.global_block().ops
|
|
|
|
|
if len(ops) == 0:
|
|
|
|
|
return
|
|
|
|
|
for op in ops:
|
|
|
|
|
if op.name() != 'pd_op.data':
|
2023-09-22 11:56:17 +08:00
|
|
|
paddle.pir.set_insertion_point(op)
|
2023-09-19 10:56:47 +08:00
|
|
|
return
|
|
|
|
|
|
2023-09-11 17:04:14 +08:00
|
|
|
helper = LayerHelper('data', **locals())
|
|
|
|
|
check_type(name, 'name', (bytes, str), 'data')
|
|
|
|
|
check_type(shape, 'shape', (list, tuple), 'data')
|
2020-08-22 16:54:26 +08:00
|
|
|
|
2023-09-11 17:04:14 +08:00
|
|
|
shape = list(shape)
|
|
|
|
|
for i in range(len(shape)):
|
|
|
|
|
if shape[i] is None:
|
|
|
|
|
shape[i] = -1
|
2024-07-08 10:28:38 +08:00
|
|
|
if isinstance(shape[i], int) and shape[i] < 0 and shape[i] != -1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Only -1 can be used in shape to indicate unknown dimension, but received {shape[i]}"
|
|
|
|
|
)
|
2023-09-11 17:04:14 +08:00
|
|
|
|
|
|
|
|
if dtype is None:
|
2023-09-06 14:56:16 +08:00
|
|
|
dtype = paddle.get_default_dtype()
|
2023-09-11 17:04:14 +08:00
|
|
|
|
2026-02-25 16:32:17 +08:00
|
|
|
if core.is_compiled_with_custom_device("iluvatar_gpu") and os.environ.get(
|
|
|
|
|
'FLAG_FORCE_FLOAT32', ''
|
|
|
|
|
).lower() in ['1', 'true', 'on']:
|
|
|
|
|
dtype_str = dtype if isinstance(dtype, str) else str(dtype)
|
|
|
|
|
if dtype_str in ('float64', np.float64, 'f8'):
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
warnings.warn(
|
|
|
|
|
f"Variable '{name}' dtype 'float64' is not supported on iluvatar gpu, "
|
|
|
|
|
"forcibly using 'float32'.",
|
|
|
|
|
UserWarning,
|
|
|
|
|
stacklevel=2,
|
|
|
|
|
)
|
|
|
|
|
dtype = 'float32'
|
|
|
|
|
elif dtype_str in ('complex128', np.complex128, 'c16'):
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
warnings.warn(
|
|
|
|
|
f"Variable '{name}' dtype 'complex128' is not supported on iluvatar gpu, "
|
|
|
|
|
"forcibly using 'complex64'.",
|
|
|
|
|
UserWarning,
|
|
|
|
|
stacklevel=2,
|
|
|
|
|
)
|
|
|
|
|
dtype = 'complex64'
|
|
|
|
|
|
2023-09-20 10:24:31 +08:00
|
|
|
if in_pir_mode():
|
2023-11-09 16:23:09 +08:00
|
|
|
ir_dtype = dtype
|
|
|
|
|
if not isinstance(ir_dtype, DataType):
|
|
|
|
|
ir_dtype = paddle.pir.core.convert_np_dtype_to_dtype_(dtype)
|
2024-02-05 12:19:41 +08:00
|
|
|
prev_insertion_point = get_current_insertion_point()
|
2023-09-19 10:56:47 +08:00
|
|
|
_reset_data_op_insertion_point()
|
2023-09-22 11:56:17 +08:00
|
|
|
out = paddle._pir_ops.data(name, shape, ir_dtype, core.Place())
|
2024-02-05 12:19:41 +08:00
|
|
|
set_insertion_point(prev_insertion_point)
|
2023-09-20 14:03:32 +08:00
|
|
|
return out
|
2023-08-31 14:42:45 +08:00
|
|
|
|
2023-09-11 17:04:14 +08:00
|
|
|
out = helper.create_global_variable(
|
|
|
|
|
name=name,
|
|
|
|
|
shape=shape,
|
|
|
|
|
dtype=dtype,
|
2024-11-13 16:43:16 +08:00
|
|
|
type=core.VarDesc.VarType.DENSE_TENSOR,
|
2023-09-11 17:04:14 +08:00
|
|
|
stop_gradient=True,
|
|
|
|
|
lod_level=lod_level,
|
|
|
|
|
is_data=True,
|
|
|
|
|
need_check_feed=True,
|
|
|
|
|
)
|
|
|
|
|
|
2023-11-06 10:03:00 +08:00
|
|
|
is_pir_mode = os.environ.get("FLAGS_enable_pir_in_executor", None)
|
2023-09-12 18:53:30 +08:00
|
|
|
if evaluate_flag(is_pir_mode):
|
2023-08-31 14:42:45 +08:00
|
|
|
helper = LayerHelper('data', **locals())
|
2023-09-11 17:04:14 +08:00
|
|
|
if not isinstance(dtype, core.VarDesc.VarType):
|
|
|
|
|
dtype = convert_np_dtype_to_dtype_(dtype)
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='data',
|
|
|
|
|
inputs={},
|
|
|
|
|
outputs={'out': out},
|
|
|
|
|
attrs={
|
|
|
|
|
'shape': shape,
|
|
|
|
|
'dtype': dtype,
|
|
|
|
|
'place': 0,
|
|
|
|
|
'name': name,
|
|
|
|
|
},
|
2023-09-06 14:56:16 +08:00
|
|
|
)
|
2023-09-11 17:04:14 +08:00
|
|
|
return out
|
2023-08-16 19:53:23 +08:00
|
|
|
|
2020-08-22 16:54:26 +08:00
|
|
|
|
2022-11-08 11:29:41 +08:00
|
|
|
class InputSpec:
|
2020-08-20 13:53:18 +08:00
|
|
|
"""
|
2020-08-27 18:20:05 +08:00
|
|
|
InputSpec describes the signature information of the model input, such as ``shape`` , ``dtype`` , ``name`` .
|
|
|
|
|
|
|
|
|
|
This interface is often used to specify input tensor information of models in high-level API.
|
|
|
|
|
It's also used to specify the tensor information for each input parameter of the forward function
|
|
|
|
|
decorated by `@paddle.jit.to_static`.
|
2020-08-20 13:53:18 +08:00
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
shape (tuple(integers)|list[integers]): List|Tuple of integers
|
|
|
|
|
declaring the shape. You can set "None" or -1 at a dimension
|
|
|
|
|
to indicate the dimension can be of any size. For example,
|
|
|
|
|
it is useful to set changeable batch size as "None" or -1.
|
2020-08-22 16:54:26 +08:00
|
|
|
dtype (np.dtype|str, optional): The type of the data. Supported
|
2020-08-20 13:53:18 +08:00
|
|
|
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
|
|
|
|
|
uint8. Default: float32.
|
2020-08-27 18:20:05 +08:00
|
|
|
name (str): The name/alias of the variable, see :ref:`api_guide_Name`
|
|
|
|
|
for more details.
|
2023-09-25 18:05:09 +08:00
|
|
|
stop_gradient (bool, optional): A boolean that mentions whether gradient should flow. Default is False, means don't stop calculate gradients.
|
2020-08-20 13:53:18 +08:00
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 01:44:50 +08:00
|
|
|
.. code-block:: pycon
|
2020-08-20 13:53:18 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> import paddle
|
|
|
|
|
>>> from paddle.static import InputSpec
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> input = InputSpec([None, 784], 'float32', 'x')
|
|
|
|
|
>>> label = InputSpec([None, 1], 'int64', 'label')
|
2020-08-20 13:53:18 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> print(input)
|
|
|
|
|
InputSpec(shape=(-1, 784), dtype=paddle.float32, name=x, stop_gradient=False)
|
|
|
|
|
|
|
|
|
|
>>> print(label)
|
|
|
|
|
InputSpec(shape=(-1, 1), dtype=paddle.int64, name=label, stop_gradient=False)
|
2020-08-20 13:53:18 +08:00
|
|
|
"""
|
|
|
|
|
|
2024-06-21 00:40:47 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
shape: ShapeLike,
|
|
|
|
|
dtype: DTypeLike = 'float32',
|
|
|
|
|
name: str | None = None,
|
|
|
|
|
stop_gradient: bool = False,
|
|
|
|
|
) -> None:
|
2020-08-27 18:20:05 +08:00
|
|
|
# replace `None` in shape with -1
|
|
|
|
|
self.shape = self._verify(shape)
|
2024-01-18 11:31:56 +08:00
|
|
|
# convert dtype into united representation
|
2020-08-27 18:20:05 +08:00
|
|
|
if dtype is not None:
|
2023-12-07 21:41:22 +08:00
|
|
|
if isinstance(dtype, (np.dtype, str)):
|
2020-08-27 18:20:05 +08:00
|
|
|
dtype = convert_np_dtype_to_dtype_(dtype)
|
2023-12-07 21:41:22 +08:00
|
|
|
|
2020-08-20 13:53:18 +08:00
|
|
|
self.dtype = dtype
|
|
|
|
|
self.name = name
|
2023-03-02 15:32:26 +08:00
|
|
|
self.stop_gradient = stop_gradient
|
2020-08-20 13:53:18 +08:00
|
|
|
|
|
|
|
|
def _create_feed_layer(self):
|
|
|
|
|
return data(self.name, shape=self.shape, dtype=self.dtype)
|
|
|
|
|
|
2024-06-21 00:40:47 +08:00
|
|
|
def __repr__(self) -> str:
|
2024-04-01 10:20:33 +08:00
|
|
|
return f'{type(self).__name__}(shape={self.shape}, dtype={self.dtype}, name={self.name}, stop_gradient={self.stop_gradient})'
|
2020-08-27 18:20:05 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
2024-06-21 00:40:47 +08:00
|
|
|
def from_tensor(cls, tensor: Tensor, name: str | None = None) -> Self:
|
2020-08-27 18:20:05 +08:00
|
|
|
"""
|
|
|
|
|
Generates a InputSpec based on the description of input tensor.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
tensor(Tensor): the source tensor to generate a InputSpec instance
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A InputSpec instance generated from Tensor.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 01:44:50 +08:00
|
|
|
.. code-block:: pycon
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> import paddle
|
|
|
|
|
>>> from paddle.static import InputSpec
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> paddle.disable_static()
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> x = paddle.ones([2, 2], dtype="float32")
|
|
|
|
|
>>> x_spec = InputSpec.from_tensor(x, name='x')
|
|
|
|
|
>>> print(x_spec)
|
|
|
|
|
InputSpec(shape=(2, 2), dtype=paddle.float32, name=x, stop_gradient=False)
|
2020-08-27 18:20:05 +08:00
|
|
|
|
|
|
|
|
"""
|
2024-05-07 19:51:38 +08:00
|
|
|
if isinstance(tensor, (Variable, core.eager.Tensor, paddle.pir.Value)):
|
2020-08-27 18:20:05 +08:00
|
|
|
return cls(tensor.shape, tensor.dtype, name or tensor.name)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
2024-04-01 10:20:33 +08:00
|
|
|
f"Input `tensor` should be a Tensor, but received {type(tensor).__name__}."
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2020-08-27 18:20:05 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
2024-06-21 00:40:47 +08:00
|
|
|
def from_numpy(
|
|
|
|
|
cls, ndarray: npt.NDArray[Any], name: str | None = None
|
|
|
|
|
) -> Self:
|
2020-08-27 18:20:05 +08:00
|
|
|
"""
|
|
|
|
|
Generates a InputSpec based on the description of input np.ndarray.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
tensor(Tensor): the source numpy ndarray to generate a InputSpec instance
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A InputSpec instance generated from Tensor.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 01:44:50 +08:00
|
|
|
.. code-block:: pycon
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> import numpy as np
|
|
|
|
|
>>> from paddle.static import InputSpec
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2025-05-08 15:55:03 +08:00
|
|
|
>>> x = np.ones([2, 2], np.float32)
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> x_spec = InputSpec.from_numpy(x, name='x')
|
|
|
|
|
>>> print(x_spec)
|
|
|
|
|
InputSpec(shape=(2, 2), dtype=paddle.float32, name=x, stop_gradient=False)
|
2020-08-27 18:20:05 +08:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
return cls(ndarray.shape, ndarray.dtype, name)
|
|
|
|
|
|
2024-06-21 00:40:47 +08:00
|
|
|
def batch(self, batch_size: int | Size1) -> Self:
|
2020-08-27 18:20:05 +08:00
|
|
|
"""
|
|
|
|
|
Inserts `batch_size` in front of the `shape`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
batch_size(int): the inserted integer value of batch size.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The original InputSpec instance by inserting `batch_size` in front of `shape`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 01:44:50 +08:00
|
|
|
.. code-block:: pycon
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> from paddle.static import InputSpec
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> x_spec = InputSpec(shape=[64], dtype='float32', name='x')
|
|
|
|
|
>>> x_spec.batch(4)
|
|
|
|
|
>>> print(x_spec)
|
|
|
|
|
InputSpec(shape=(4, 64), dtype=paddle.float32, name=x, stop_gradient=False)
|
2020-08-27 18:20:05 +08:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(batch_size, (list, tuple)):
|
|
|
|
|
if len(batch_size) != 1:
|
|
|
|
|
raise ValueError(
|
2024-04-01 10:20:33 +08:00
|
|
|
f"Length of batch_size: {batch_size} shall be 1, but received {len(batch_size)}."
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2024-06-21 00:40:47 +08:00
|
|
|
batch_size = batch_size[0]
|
2022-10-19 15:54:41 +08:00
|
|
|
elif not isinstance(batch_size, int):
|
2022-06-05 10:58:58 +08:00
|
|
|
raise TypeError(
|
2024-04-01 10:20:33 +08:00
|
|
|
f"type(batch_size) shall be `int`, but received {type(batch_size).__name__}."
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2024-08-13 14:21:38 +08:00
|
|
|
new_shape = [batch_size, *list(self.shape)]
|
2020-08-27 18:20:05 +08:00
|
|
|
self.shape = tuple(new_shape)
|
|
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
2024-06-21 00:40:47 +08:00
|
|
|
def unbatch(self) -> Self:
|
2020-08-27 18:20:05 +08:00
|
|
|
"""
|
|
|
|
|
Removes the first element of `shape`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The original InputSpec instance by removing the first element of `shape` .
|
|
|
|
|
|
|
|
|
|
Examples:
|
2026-02-18 01:44:50 +08:00
|
|
|
.. code-block:: pycon
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> from paddle.static import InputSpec
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2023-08-25 14:17:30 +08:00
|
|
|
>>> x_spec = InputSpec(shape=[4, 64], dtype='float32', name='x')
|
|
|
|
|
>>> x_spec.unbatch()
|
2026-02-18 01:44:50 +08:00
|
|
|
>>> print(x_spec) # InputSpec(shape=(64,), dtype=paddle.float32, name=x)
|
2023-08-25 14:17:30 +08:00
|
|
|
InputSpec(shape=(64,), dtype=paddle.float32, name=x, stop_gradient=False)
|
2020-08-27 18:20:05 +08:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if len(self.shape) == 0:
|
|
|
|
|
raise ValueError(
|
2022-10-23 20:01:27 +08:00
|
|
|
"Not support to unbatch a InputSpec when len(shape) == 0."
|
|
|
|
|
)
|
2020-08-27 18:20:05 +08:00
|
|
|
|
|
|
|
|
self.shape = self._verify(self.shape[1:])
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def _verify(self, shape):
|
|
|
|
|
"""
|
|
|
|
|
Verifies the input shape and modifies `None` into `-1`.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(shape, (list, tuple)):
|
|
|
|
|
raise TypeError(
|
2024-04-01 10:20:33 +08:00
|
|
|
f"Type of `shape` in InputSpec should be one of (tuple, list), but received {type(shape).__name__}."
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2020-08-27 18:20:05 +08:00
|
|
|
|
|
|
|
|
for i, ele in enumerate(shape):
|
|
|
|
|
if ele is not None:
|
2022-10-19 15:54:41 +08:00
|
|
|
if not isinstance(ele, int):
|
2020-08-27 18:20:05 +08:00
|
|
|
raise ValueError(
|
2024-04-01 10:20:33 +08:00
|
|
|
f"shape[{i}] should be an `int`, but received `{type(ele).__name__}`:{ele}."
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2020-08-27 18:20:05 +08:00
|
|
|
if ele is None or ele < -1:
|
|
|
|
|
shape[i] = -1
|
|
|
|
|
|
|
|
|
|
return tuple(shape)
|
|
|
|
|
|
2024-06-21 00:40:47 +08:00
|
|
|
def __hash__(self) -> int:
|
2020-08-27 18:20:05 +08:00
|
|
|
# Note(Aurelius84): `name` is not considered as a field to compute hashkey.
|
|
|
|
|
# Because it's no need to generate a new program in following cases while using
|
|
|
|
|
# @paddle.jit.to_static.
|
|
|
|
|
#
|
|
|
|
|
# Case 1:
|
|
|
|
|
# foo(x_var)
|
|
|
|
|
# foo(y_var)
|
|
|
|
|
# x_var and y_var hold same shape and dtype, they should share a same program.
|
|
|
|
|
#
|
|
|
|
|
#
|
|
|
|
|
# Case 2:
|
|
|
|
|
# foo(x_var)
|
|
|
|
|
# foo(x_np) # x_np is a numpy.ndarray.
|
|
|
|
|
# x_var and x_np hold same shape and dtype, they should also share a same program.
|
2023-03-02 15:32:26 +08:00
|
|
|
return hash((tuple(self.shape), self.dtype, self.stop_gradient))
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2024-06-21 00:40:47 +08:00
|
|
|
def __eq__(self, other: Self) -> bool:
|
2023-03-02 15:32:26 +08:00
|
|
|
slots = ['shape', 'dtype', 'name', 'stop_gradient']
|
2022-10-23 20:01:27 +08:00
|
|
|
return type(self) is type(other) and all(
|
|
|
|
|
getattr(self, attr) == getattr(other, attr) for attr in slots
|
|
|
|
|
)
|
2020-08-27 18:20:05 +08:00
|
|
|
|
2024-06-21 00:40:47 +08:00
|
|
|
def __ne__(self, other) -> bool:
|
2020-08-27 18:20:05 +08:00
|
|
|
return not self == other
|
2023-07-19 10:31:56 +08:00
|
|
|
|
|
|
|
|
|
2024-08-16 10:12:45 +08:00
|
|
|
def setitem(
|
|
|
|
|
x: Tensor,
|
|
|
|
|
index: TensorIndex,
|
|
|
|
|
value: TensorLike,
|
|
|
|
|
) -> Tensor:
|
2023-07-19 10:31:56 +08:00
|
|
|
"""
|
|
|
|
|
x(Tensor): input Tensor.
|
|
|
|
|
index(Scalar|Tuple|List|Tensor): Where should be set value.
|
|
|
|
|
value(Scalar|Tensor): The value which is going to be set.
|
|
|
|
|
|
|
|
|
|
[How to write index?]
|
|
|
|
|
1. ':' -> slice(),
|
|
|
|
|
(1) a[:]=v -> setitem(a, slice(None,None,None), v)
|
|
|
|
|
(2) a[1::2] -> setitem(a, slice(1,None,2), v)
|
|
|
|
|
|
|
|
|
|
2. if there are multiple indexes for axes, use TUPLE (Not LIST) to pack them.
|
|
|
|
|
(1) a[1, 2]=v -> setitem(a, (1, 2), v)
|
|
|
|
|
(2) a[[1,2],[2,3]]=v -> setitem(a, ([1,2],[2,3]), v)
|
|
|
|
|
(3) a[1,:, 3] = v -> setitem(a, (1, slice(None,None,None),3), v)
|
|
|
|
|
(4) a[1, ..., 2]=v -> setitem(a, (1, ..., 2), v)
|
|
|
|
|
|
2023-12-07 21:41:22 +08:00
|
|
|
3. You can always use TUPLE as index input, even there is only one index.
|
2023-07-19 10:31:56 +08:00
|
|
|
(1) a[Tensor([10,10])]=v -> setitem(a, (Tensor([10,10]),), v)
|
|
|
|
|
(2) a[1] = v -> setitem(a, (1,), v)
|
|
|
|
|
"""
|
2023-11-28 14:35:49 +08:00
|
|
|
return _setitem_static(x, index, value)
|