2022-04-08 14:39:01 +08:00
|
|
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2024-06-19 11:29:05 +08:00
|
|
|
from __future__ import annotations
|
2022-04-08 14:39:01 +08:00
|
|
|
|
|
|
|
|
import re
|
2024-06-19 11:29:05 +08:00
|
|
|
from typing import TYPE_CHECKING
|
2022-11-28 11:52:40 +08:00
|
|
|
|
|
|
|
|
from paddle import _C_ops, _legacy_C_ops
|
|
|
|
|
|
2023-09-07 17:26:19 +08:00
|
|
|
from ..base.data_feeder import check_variable_and_dtype
|
remove fluid.initializer.UniformInitializer, ConstantInitializer, NormalInitializer, TruncatedNormalInitializer, XavierInitializer, BilinearInitializer, MSRAInitializer, NumpyArrayInitializer and calculate_gain.. (#49498)
* move UniformInitializer and ConstantInitializer
* more modify
* circular import resolved
* another circular import resolved?
* more circular import 2
* circular import 3
* change import paddle in metric.py
* BuildStrategy import from fluid
* modify the framework import path in common.py
* change rnn.py import, from static to original framework
* change import static in the nn folder
* default_main_program should import from common_ops_import
* add import paddle in param_attr.py
* use core not paddle module for using VarDesc
* another old uniform
* mistake that use Uniform instead of UniformInitializer
* modify UniformInitializer doc
* move fluid.NormalInitializer to nn.initializer.NormalInitializer
* remove import of Normal in fluid.layers.nn.py
* remove more import of old Normal
* remove more import of old Normal
* sample code modify and tests modify import
* is_listen_failed passing arg should be log file
* problem solved
* a mistake solved
* comments resoleved and remove paddle.fluid.initializer.TruncatedNormalInitializer
* remove paddle.fluid.initializer.XavierInitializer and paddle.fluid.initializer.MSRAInitializer
* remove paddle.fluid.initializer.BilinearInitializer NumpyArrayInitializer and set_global_initializer
* change fluid to static
* change static to fluid to avoid circular import in distributed_strategy.py
* fix example code and test_initializer
* ValueType
* sample code fix
* change set_global_initializer back to fluid
* put paddle.static.BuildStrategy.ReduceStrategy into the fuction to avoid circular import
* remove calculate_gain, delete BilinearInitializer and revert set_global_initializer
* change the time of using UniformInitializer, ConstantInitializer, NormalInitializer, TruncatedNormalInitializer, XavierInitializer, MSRAInitializer, NumpyArrayInitializer as few as possible
* fix argument incampatible
* fix more arg incompatible
* fix test_prelu_op_xpu.py Constant
* fix inaccurate doc
* more doc fix: default value
2023-02-01 21:38:27 +08:00
|
|
|
from ..common_ops_import import Variable
|
2022-10-23 20:01:27 +08:00
|
|
|
from ..framework import (
|
2022-11-28 11:52:40 +08:00
|
|
|
LayerHelper,
|
2022-10-23 20:01:27 +08:00
|
|
|
OpProtoHolder,
|
|
|
|
|
convert_np_dtype_to_dtype_,
|
|
|
|
|
core,
|
2023-05-22 20:56:38 +08:00
|
|
|
in_dynamic_mode,
|
2023-10-10 10:26:14 +08:00
|
|
|
in_dynamic_or_pir_mode,
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2022-04-08 14:39:01 +08:00
|
|
|
|
2024-06-19 11:29:05 +08:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from paddle import Tensor
|
|
|
|
|
|
2022-04-08 14:39:01 +08:00
|
|
|
__all__ = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_(name):
|
|
|
|
|
"""
|
|
|
|
|
Formatting.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: The name/alias
|
|
|
|
|
|
|
|
|
|
This function takes in a name and converts it to a standard format of
|
|
|
|
|
group1_group2. Where as per the regular expression, group1 can have
|
|
|
|
|
alphabets and numbers and group2 has capital alphabets.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
|
|
|
|
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
|
|
|
|
|
|
|
|
|
|
2024-06-19 11:29:05 +08:00
|
|
|
def generate_layer_fn(op_type: str):
|
2022-04-08 14:39:01 +08:00
|
|
|
"""Register the Python layer for an Operator.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
op_type: The name of the operator to be created.
|
|
|
|
|
|
|
|
|
|
This function takes in the operator type (sigmoid, mean , average etc) and
|
|
|
|
|
creates the operator functionality.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
|
2022-10-23 20:01:27 +08:00
|
|
|
not_intermediate_outputs = [
|
|
|
|
|
output for output in op_proto.outputs if not output.intermediate
|
|
|
|
|
]
|
|
|
|
|
intermediate_outputs = [
|
|
|
|
|
output for output in op_proto.outputs if output.intermediate
|
|
|
|
|
]
|
2022-04-08 14:39:01 +08:00
|
|
|
|
|
|
|
|
if len(not_intermediate_outputs) != 1:
|
2022-10-23 20:01:27 +08:00
|
|
|
raise ValueError(
|
2024-07-01 14:24:08 +08:00
|
|
|
"Only one non intermediate output operator can be"
|
|
|
|
|
f"automatically generated. {op_type}"
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2022-04-08 14:39:01 +08:00
|
|
|
|
|
|
|
|
if not_intermediate_outputs[0].duplicable:
|
|
|
|
|
raise ValueError(
|
2022-10-23 20:01:27 +08:00
|
|
|
"Only non duplicable op can be automatically generated."
|
|
|
|
|
)
|
2022-04-08 14:39:01 +08:00
|
|
|
|
|
|
|
|
for output in intermediate_outputs:
|
|
|
|
|
if output.duplicable:
|
2022-10-23 20:01:27 +08:00
|
|
|
raise ValueError(
|
2024-07-01 14:24:08 +08:00
|
|
|
"The op can be automatically generated only when "
|
|
|
|
|
"all intermediate ops are not duplicable."
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2022-04-08 14:39:01 +08:00
|
|
|
|
|
|
|
|
o_name = not_intermediate_outputs[0].name
|
|
|
|
|
intermediate_output_names = [output.name for output in intermediate_outputs]
|
|
|
|
|
|
|
|
|
|
def infer_and_check_dtype(op_proto, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
This function performs the sanity check for dtype and
|
|
|
|
|
instance type.
|
|
|
|
|
"""
|
|
|
|
|
dtype = None
|
|
|
|
|
for ipt in op_proto.inputs:
|
|
|
|
|
name = _convert_(ipt.name)
|
|
|
|
|
val = kwargs.pop(name, [])
|
|
|
|
|
if not isinstance(val, list) and not isinstance(val, tuple):
|
|
|
|
|
val = [val]
|
|
|
|
|
if len(val) == 0:
|
|
|
|
|
if len(args) == 0:
|
|
|
|
|
continue
|
|
|
|
|
val = [args[0]]
|
|
|
|
|
args = args[1:]
|
|
|
|
|
|
|
|
|
|
for each in val:
|
|
|
|
|
if not isinstance(each, Variable):
|
2023-03-31 10:11:56 +08:00
|
|
|
raise ValueError(f"input of {op_type} must be variable")
|
2022-04-08 14:39:01 +08:00
|
|
|
|
|
|
|
|
if dtype is None:
|
|
|
|
|
dtype = each.dtype
|
|
|
|
|
elif dtype != each.dtype:
|
|
|
|
|
raise ValueError(
|
2024-04-01 10:20:33 +08:00
|
|
|
f"operator {op_type} must input same dtype. {dtype} vs {each.dtype}"
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2022-04-08 14:39:01 +08:00
|
|
|
|
|
|
|
|
if dtype is None:
|
|
|
|
|
arg_dtype = kwargs.get("dtype")
|
|
|
|
|
if arg_dtype:
|
|
|
|
|
if not isinstance(arg_dtype, core.VarDesc.VarType):
|
|
|
|
|
dtype = convert_np_dtype_to_dtype_(arg_dtype)
|
|
|
|
|
else:
|
|
|
|
|
dtype = arg_dtype
|
|
|
|
|
else:
|
|
|
|
|
dtype = core.VarDesc.VarType.FP32
|
|
|
|
|
return dtype
|
|
|
|
|
|
2024-06-19 11:29:05 +08:00
|
|
|
def func(*args, **kwargs) -> Tensor:
|
2022-04-08 14:39:01 +08:00
|
|
|
helper = LayerHelper(op_type, **kwargs)
|
|
|
|
|
|
|
|
|
|
dtype = infer_and_check_dtype(op_proto, *args, **kwargs)
|
|
|
|
|
|
2023-03-23 10:16:17 +08:00
|
|
|
inputs = {}
|
2022-04-08 14:39:01 +08:00
|
|
|
for ipt in op_proto.inputs:
|
|
|
|
|
name = _convert_(ipt.name)
|
|
|
|
|
val = kwargs.pop(name, [])
|
|
|
|
|
if not isinstance(val, list) and not isinstance(val, tuple):
|
|
|
|
|
val = [val]
|
|
|
|
|
if len(val) == 0 and len(args) != 0:
|
|
|
|
|
val = args[0]
|
|
|
|
|
args = args[1:]
|
|
|
|
|
inputs[ipt.name] = val
|
|
|
|
|
|
2023-03-23 10:16:17 +08:00
|
|
|
outputs = {}
|
2022-04-08 14:39:01 +08:00
|
|
|
out = kwargs.pop(_convert_(o_name), [])
|
|
|
|
|
if out:
|
2023-03-28 14:40:48 +08:00
|
|
|
out_var = out[0] if isinstance(out, (list, tuple)) else out
|
2022-04-08 14:39:01 +08:00
|
|
|
else:
|
|
|
|
|
out_var = helper.create_variable_for_type_inference(dtype=dtype)
|
|
|
|
|
outputs[o_name] = [out_var]
|
|
|
|
|
for name in intermediate_output_names:
|
|
|
|
|
outputs[name] = [
|
|
|
|
|
helper.create_variable_for_type_inference(dtype=dtype)
|
|
|
|
|
]
|
2022-10-23 20:01:27 +08:00
|
|
|
helper.append_op(
|
|
|
|
|
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs
|
|
|
|
|
)
|
2022-04-08 14:39:01 +08:00
|
|
|
return helper.append_activation(out_var)
|
|
|
|
|
|
|
|
|
|
func.__name__ = op_type
|
|
|
|
|
return func
|
|
|
|
|
|
|
|
|
|
|
2024-06-19 11:29:05 +08:00
|
|
|
def generate_activation_fn(op_type: str):
|
2022-04-08 14:39:01 +08:00
|
|
|
"""Register the Python layer for an Operator without Attribute.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
op_type: The name of the operator to be created.
|
|
|
|
|
|
|
|
|
|
This function takes in the operator type (sigmoid, exp , tanh etc) and
|
|
|
|
|
creates the operator functionality.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
2024-06-19 11:29:05 +08:00
|
|
|
def func(x, name: str | None = None) -> Tensor:
|
2023-10-10 10:26:14 +08:00
|
|
|
if in_dynamic_or_pir_mode():
|
2022-12-27 09:06:13 +08:00
|
|
|
if hasattr(_C_ops, op_type):
|
|
|
|
|
op = getattr(_C_ops, op_type)
|
|
|
|
|
return op(x)
|
|
|
|
|
else:
|
|
|
|
|
# TODO(dev): Because some ops' yaml has not been migrated.
|
|
|
|
|
# Replace it with _C_ops while all yaml work is done.
|
|
|
|
|
op = getattr(_legacy_C_ops, op_type)
|
|
|
|
|
return op(x)
|
2022-04-08 14:39:01 +08:00
|
|
|
else:
|
2022-12-27 09:06:13 +08:00
|
|
|
if op_type not in ["abs", "exp", "square"]:
|
|
|
|
|
check_variable_and_dtype(
|
|
|
|
|
x, 'x', ['float16', 'float32', 'float64'], op_type
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# abs exp square ops support dtype(int32, int64, float16, float32, float64)
|
|
|
|
|
check_variable_and_dtype(
|
|
|
|
|
x,
|
|
|
|
|
'x',
|
|
|
|
|
[
|
|
|
|
|
'int32',
|
|
|
|
|
'int64',
|
|
|
|
|
'float16',
|
|
|
|
|
'float32',
|
|
|
|
|
'float64',
|
|
|
|
|
'complex64',
|
|
|
|
|
'complex128',
|
2023-04-08 19:27:20 -07:00
|
|
|
'uint16',
|
2022-12-27 09:06:13 +08:00
|
|
|
],
|
|
|
|
|
op_type,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
helper = LayerHelper(op_type, **locals())
|
|
|
|
|
|
|
|
|
|
output = helper.create_variable_for_type_inference(dtype=x.dtype)
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type=op_type, inputs={"X": x}, outputs={"Out": output}
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2022-12-27 09:06:13 +08:00
|
|
|
return output
|
2022-04-08 14:39:01 +08:00
|
|
|
|
|
|
|
|
return func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_inplace_fn(inplace_op_type):
|
|
|
|
|
"""Register the Python layer for an Inplace Operator without Attribute.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
inplace_op_type: The name of the inplace operator to be created.
|
|
|
|
|
|
|
|
|
|
This function takes in the inplace operator type (exp_ , ceil_ etc) and
|
|
|
|
|
creates the operator functionality.
|
|
|
|
|
"""
|
|
|
|
|
origin_op_type = inplace_op_type[:-1]
|
|
|
|
|
|
|
|
|
|
def func(x, name=None):
|
2023-05-22 20:56:38 +08:00
|
|
|
if in_dynamic_mode():
|
2022-12-27 09:06:13 +08:00
|
|
|
if hasattr(_C_ops, inplace_op_type):
|
|
|
|
|
op = getattr(_C_ops, inplace_op_type)
|
|
|
|
|
return op(x)
|
|
|
|
|
else:
|
|
|
|
|
op = getattr(_legacy_C_ops, inplace_op_type)
|
|
|
|
|
return op(x)
|
2022-04-08 14:39:01 +08:00
|
|
|
|
|
|
|
|
func.__name__ = inplace_op_type
|
2024-04-01 10:20:33 +08:00
|
|
|
func.__doc__ = f"""
|
|
|
|
|
Inplace version of ``{origin_op_type}`` API, the output Tensor will be inplaced with input ``x``.
|
|
|
|
|
Please refer to :ref:`api_paddle_{origin_op_type}`.
|
|
|
|
|
"""
|
2022-04-08 14:39:01 +08:00
|
|
|
return func
|