2020-05-08 16:52:06 +08:00
|
|
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
2022-12-29 16:56:09 +08:00
|
|
|
from copy import deepcopy
|
|
|
|
|
|
2020-05-08 16:52:06 +08:00
|
|
|
import numpy as np
|
2020-05-28 17:38:33 +08:00
|
|
|
|
2021-06-08 19:34:13 +08:00
|
|
|
import paddle
|
2022-12-13 17:57:10 +08:00
|
|
|
from paddle import _legacy_C_ops
|
2023-03-20 13:56:39 +08:00
|
|
|
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
|
2023-09-07 17:26:19 +08:00
|
|
|
from paddle.base import backward, core, framework, program_guard
|
|
|
|
|
from paddle.base.compiler import BuildStrategy
|
|
|
|
|
from paddle.base.data_feeder import check_type, convert_dtype
|
|
|
|
|
from paddle.base.dygraph.base import switch_to_static_graph
|
2024-09-18 11:01:07 +08:00
|
|
|
from paddle.base.framework import get_flags
|
2023-04-04 18:41:35 +08:00
|
|
|
from paddle.optimizer.lr import LRScheduler
|
2022-12-13 17:57:10 +08:00
|
|
|
|
|
|
|
|
from . import logging_utils
|
2023-04-11 11:16:53 +08:00
|
|
|
from .utils import (
|
|
|
|
|
RETURN_NO_VALUE_MAGIC_NUM,
|
|
|
|
|
backend_guard,
|
2023-04-27 17:17:04 +08:00
|
|
|
construct_grad_names,
|
2023-04-11 11:16:53 +08:00
|
|
|
)
|
2020-05-08 16:52:06 +08:00
|
|
|
|
2022-12-08 11:14:52 +08:00
|
|
|
__all__ = []
|
|
|
|
|
|
2020-05-28 17:38:33 +08:00
|
|
|
|
2022-11-08 11:29:41 +08:00
|
|
|
class NestSequence:
|
2020-05-28 17:38:33 +08:00
|
|
|
"""
|
|
|
|
|
A wrapper class that easily to flatten and restore the nest structure of
|
|
|
|
|
given sequence.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, raw_input, need_check=False):
|
|
|
|
|
self.__raw_input = raw_input
|
2021-06-30 14:49:37 +08:00
|
|
|
self.__input_list = self.tolist()
|
2020-05-28 17:38:33 +08:00
|
|
|
self.__var_ids = self._get_var_ids()
|
|
|
|
|
self._check_non_variable(need_check)
|
|
|
|
|
|
|
|
|
|
def tolist(self):
|
|
|
|
|
"""
|
|
|
|
|
Flattens the nested sequences into single list.
|
|
|
|
|
"""
|
2023-03-09 20:30:52 +08:00
|
|
|
return paddle.utils.flatten(self.__raw_input)
|
2020-05-28 17:38:33 +08:00
|
|
|
|
|
|
|
|
def restore(self, value_list):
|
|
|
|
|
"""
|
|
|
|
|
Restores the nested sequence from value list.
|
|
|
|
|
"""
|
2021-06-30 14:49:37 +08:00
|
|
|
assert len(self.__input_list) == len(value_list)
|
2023-03-09 20:30:52 +08:00
|
|
|
return paddle.utils.pack_sequence_as(self.__raw_input, value_list)
|
2020-05-28 17:38:33 +08:00
|
|
|
|
|
|
|
|
def _get_var_ids(self):
|
|
|
|
|
var_ids = []
|
2021-06-30 14:49:37 +08:00
|
|
|
for idx, var in enumerate(self.__input_list):
|
2023-03-30 10:11:14 +08:00
|
|
|
if isinstance(var, (framework.Variable, core.eager.Tensor)):
|
2020-05-28 17:38:33 +08:00
|
|
|
var_ids.append(idx)
|
|
|
|
|
|
|
|
|
|
return var_ids
|
|
|
|
|
|
|
|
|
|
def _check_non_variable(self, need_check):
|
|
|
|
|
"""
|
|
|
|
|
Raises warning if output of traced function contains non-tensor type values.
|
|
|
|
|
"""
|
|
|
|
|
if need_check:
|
|
|
|
|
warning_types = set()
|
2021-06-30 14:49:37 +08:00
|
|
|
for var in self.__input_list:
|
2023-03-30 10:11:14 +08:00
|
|
|
if not isinstance(var, (framework.Variable, core.eager.Tensor)):
|
2020-05-28 17:38:33 +08:00
|
|
|
warning_types.add(type(var))
|
|
|
|
|
if warning_types:
|
2020-09-17 23:01:26 +08:00
|
|
|
logging_utils.warn(
|
2024-04-01 10:20:33 +08:00
|
|
|
f"Output of traced function contains non-tensor type values: {list(warning_types)}. "
|
2020-05-28 17:38:33 +08:00
|
|
|
"Currently, We don't support to update them while training and will return "
|
2024-04-01 10:20:33 +08:00
|
|
|
"what we first saw. Please try to return them as tensor."
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2020-05-28 17:38:33 +08:00
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def var_ids(self):
|
|
|
|
|
return self.__var_ids
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, item):
|
2021-06-30 14:49:37 +08:00
|
|
|
return self.__input_list[item]
|
2020-05-28 17:38:33 +08:00
|
|
|
|
2020-05-08 16:52:06 +08:00
|
|
|
|
2022-11-08 11:29:41 +08:00
|
|
|
class LazyInitialized:
|
2020-10-16 15:57:05 +08:00
|
|
|
"""
|
|
|
|
|
Descriptor to implement lazy initialization of property.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, function):
|
|
|
|
|
self.function = function
|
|
|
|
|
|
|
|
|
|
def __get__(self, instance, cls):
|
|
|
|
|
val = self.function(instance)
|
|
|
|
|
setattr(instance, self.function.__name__, val)
|
|
|
|
|
return val
|
|
|
|
|
|
|
|
|
|
|
2023-01-11 14:54:35 +08:00
|
|
|
class ProgramInfo:
|
|
|
|
|
"""
|
2024-01-29 16:00:57 +08:00
|
|
|
A helper class to record Program information
|
2023-01-11 14:54:35 +08:00
|
|
|
"""
|
|
|
|
|
|
2023-02-17 10:27:41 +08:00
|
|
|
def __init__(self):
|
2023-01-11 14:54:35 +08:00
|
|
|
self.op_size = {
|
|
|
|
|
'fp32': -1,
|
|
|
|
|
'amp': -1,
|
|
|
|
|
'fp16': -1,
|
|
|
|
|
}
|
2023-02-17 10:27:41 +08:00
|
|
|
self.programs = {}
|
|
|
|
|
self.mode = "infer"
|
|
|
|
|
|
|
|
|
|
def __call__(self, key, prog_creator):
|
|
|
|
|
"""
|
2024-01-29 16:00:57 +08:00
|
|
|
Record infer program and op size.
|
2023-02-17 10:27:41 +08:00
|
|
|
"""
|
|
|
|
|
assert key in ['fp32', 'amp', 'fp16']
|
|
|
|
|
if key not in self.programs:
|
|
|
|
|
infer_prog = prog_creator(is_infer_mode=True)
|
|
|
|
|
self.programs[key] = infer_prog
|
|
|
|
|
self.op_size[key] = infer_prog.desc.block(0).op_size()
|
|
|
|
|
|
|
|
|
|
return self.programs[key], self.op_size[key]
|
2023-01-11 14:54:35 +08:00
|
|
|
|
|
|
|
|
|
2023-02-28 16:01:17 +08:00
|
|
|
class PartialProgramLayerHook:
|
2024-08-08 10:10:45 +08:00
|
|
|
def before_append_backward(self, forward_program): ...
|
2023-02-28 16:01:17 +08:00
|
|
|
|
2024-08-08 10:10:45 +08:00
|
|
|
def after_append_backward(self, whole_program, backward_start_idx): ...
|
2023-02-28 16:01:17 +08:00
|
|
|
|
2024-08-08 10:10:45 +08:00
|
|
|
def after_infer(self, infer_program): ...
|
2023-02-28 16:01:17 +08:00
|
|
|
|
|
|
|
|
|
2021-06-30 14:49:37 +08:00
|
|
|
class PartialProgramLayer:
|
2020-05-08 16:52:06 +08:00
|
|
|
"""
|
2022-12-19 19:12:43 +08:00
|
|
|
PartialProgramLayer wraps all the ops from layers decorated by `@to_static`
|
2020-05-08 16:52:06 +08:00
|
|
|
and execute them as a static subgraph.
|
|
|
|
|
|
|
|
|
|
.. note::
|
2020-05-09 19:09:45 +08:00
|
|
|
**1. This is a very low level API. Users should not use this API
|
|
|
|
|
directly. Please use `partial_program_from(concrete_program)`
|
|
|
|
|
to create it.
|
2024-11-12 09:52:48 +08:00
|
|
|
**2. DenseTensorArray is not currently supported in the output.
|
2020-05-08 16:52:06 +08:00
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
main_program(Program): The main program that contains ops need to be executed.
|
2022-12-19 19:12:43 +08:00
|
|
|
inputs(list[Variable]): The input list of the decorated function by `@to_static`.
|
|
|
|
|
outputs(list[Variable]): The output list of the decorated function by `@to_static`.
|
2023-03-30 10:11:14 +08:00
|
|
|
parameters(list[Tensor]|None): All trainable parameters included in the program. Default None.
|
2020-05-08 16:52:06 +08:00
|
|
|
|
|
|
|
|
Returns:
|
2022-12-30 11:02:06 +08:00
|
|
|
Layer: A Layer object that run all ops internally in static graph mode.
|
2020-05-08 16:52:06 +08:00
|
|
|
"""
|
|
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
def __init__(
|
2023-12-11 12:50:11 +08:00
|
|
|
self, main_program, inputs, outputs, parameters=None, **kwargs
|
2022-10-23 20:01:27 +08:00
|
|
|
):
|
2022-11-03 14:33:00 +08:00
|
|
|
super().__init__()
|
2020-05-28 17:38:33 +08:00
|
|
|
self._inputs = NestSequence(inputs)
|
|
|
|
|
self._outputs = NestSequence(outputs, need_check=True)
|
2020-05-09 19:09:45 +08:00
|
|
|
self._params = parameters if parameters is not None else []
|
2020-06-24 16:29:53 +08:00
|
|
|
|
2021-07-28 20:35:37 +08:00
|
|
|
self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
|
|
|
|
|
assert isinstance(self._build_strategy, BuildStrategy)
|
|
|
|
|
|
2020-10-16 15:57:05 +08:00
|
|
|
self._origin_main_program = self._verify_program(main_program)
|
2023-09-07 17:26:19 +08:00
|
|
|
with paddle.base.framework._dygraph_guard(paddle.base.dygraph.Tracer()):
|
2023-09-04 14:01:31 +08:00
|
|
|
self._cuda_graph_vec = self._create_cuda_graph_vec()
|
2020-05-08 16:52:06 +08:00
|
|
|
# Set default mode to train
|
2020-08-04 19:06:24 +08:00
|
|
|
self.training = True
|
2023-02-17 10:27:41 +08:00
|
|
|
self._infer_info = ProgramInfo()
|
2023-02-24 11:14:11 +00:00
|
|
|
self._forward_end_index_map = {}
|
2020-05-08 16:52:06 +08:00
|
|
|
|
2023-04-14 11:20:43 +08:00
|
|
|
amp_dtype, custom_white_list, custom_black_list = None, None, None
|
2021-11-24 11:22:22 +08:00
|
|
|
tracer = framework._dygraph_tracer()
|
|
|
|
|
if tracer:
|
|
|
|
|
custom_white_list, custom_black_list = tracer._get_amp_op_list()
|
2023-04-14 11:20:43 +08:00
|
|
|
amp_dtype = tracer._amp_dtype
|
|
|
|
|
if amp_dtype is not None and amp_dtype in ['float16', 'bfloat16']:
|
|
|
|
|
# For AMP training
|
|
|
|
|
self._amp_list = (
|
|
|
|
|
paddle.static.amp.fp16_lists.AutoMixedPrecisionLists(
|
|
|
|
|
custom_white_list=custom_white_list,
|
|
|
|
|
custom_black_list=custom_black_list,
|
|
|
|
|
dtype=amp_dtype,
|
|
|
|
|
)
|
|
|
|
|
)
|
2021-08-05 16:05:54 +08:00
|
|
|
|
2022-09-09 20:27:42 +08:00
|
|
|
# program_id -> list(scope)
|
2023-10-26 15:23:05 +08:00
|
|
|
self._pir_scope_cache = {}
|
|
|
|
|
self._legacy_scope_cache = {}
|
2023-02-28 16:01:17 +08:00
|
|
|
self._hooker = None
|
2023-04-11 11:16:53 +08:00
|
|
|
self._backend = kwargs.get('backend', None)
|
2023-04-27 17:17:04 +08:00
|
|
|
self._grad_var_names = {}
|
2022-09-09 20:27:42 +08:00
|
|
|
|
2023-11-22 17:28:37 +08:00
|
|
|
self._in_var_names = []
|
|
|
|
|
for var in self._inputs:
|
|
|
|
|
if isinstance(var, framework.Variable):
|
|
|
|
|
self._in_var_names.append(var.desc.name())
|
|
|
|
|
self._out_var_descs = [
|
|
|
|
|
self._outputs[var_id].desc for var_id in self._outputs.var_ids
|
|
|
|
|
]
|
|
|
|
|
|
2023-02-17 10:27:41 +08:00
|
|
|
def __call__(self, inputs):
|
|
|
|
|
"""
|
|
|
|
|
Execute static graph by Interpreter and Return dynamic Tensors.
|
|
|
|
|
"""
|
2023-11-22 17:28:37 +08:00
|
|
|
in_vars, in_var_names = self._prepare_inputs(inputs)
|
|
|
|
|
out_vars = self._prepare_outputs()
|
|
|
|
|
self._cast_fp16_if_pure_fp16(in_vars)
|
2023-12-28 14:35:16 +08:00
|
|
|
attrs = self._prepare_attributes()
|
2023-11-22 17:28:37 +08:00
|
|
|
attrs.extend(["x_names", in_var_names])
|
|
|
|
|
|
|
|
|
|
self._sync_lr_value_with_scheduler()
|
|
|
|
|
|
|
|
|
|
_legacy_C_ops.run_program(
|
|
|
|
|
self._valid_vars(in_vars),
|
|
|
|
|
self._valid_vars(self._params),
|
|
|
|
|
self._valid_vars(out_vars),
|
|
|
|
|
self._create_scope_vec(
|
|
|
|
|
program_id=self.program_id, use_scope_cache=True
|
|
|
|
|
),
|
|
|
|
|
self._cuda_graph_vec,
|
2024-04-01 10:20:33 +08:00
|
|
|
*attrs,
|
2023-11-22 17:28:37 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
restored_nest_out = self._restore_out(out_vars)
|
|
|
|
|
restored_nest_out = self._remove_no_value(restored_nest_out)
|
|
|
|
|
|
|
|
|
|
return restored_nest_out
|
|
|
|
|
|
|
|
|
|
def sot_call(self, inputs):
|
2023-12-06 10:51:42 +08:00
|
|
|
"""
|
2023-12-11 12:50:11 +08:00
|
|
|
In sot, inputs and outputs of partial program only contain tensors, so we can skip some step to speed up
|
2023-12-06 10:51:42 +08:00
|
|
|
"""
|
|
|
|
|
out_vars = self._prepare_outputs()
|
2023-12-15 15:57:47 +08:00
|
|
|
self._cast_fp16_if_pure_fp16(inputs)
|
2023-12-28 14:35:16 +08:00
|
|
|
attrs = self._prepare_attributes()
|
2023-12-15 15:57:47 +08:00
|
|
|
attrs.extend(["x_names", self._in_var_names])
|
2023-12-06 10:51:42 +08:00
|
|
|
|
|
|
|
|
self._sync_lr_value_with_scheduler()
|
|
|
|
|
|
|
|
|
|
_legacy_C_ops.run_program(
|
2023-12-15 15:57:47 +08:00
|
|
|
self._valid_vars(inputs),
|
2023-12-06 10:51:42 +08:00
|
|
|
self._valid_vars(self._params),
|
|
|
|
|
self._valid_vars(out_vars),
|
|
|
|
|
self._create_scope_vec(
|
|
|
|
|
program_id=self.program_id, use_scope_cache=True
|
|
|
|
|
),
|
|
|
|
|
self._cuda_graph_vec,
|
2024-04-01 10:20:33 +08:00
|
|
|
*attrs,
|
2023-12-06 10:51:42 +08:00
|
|
|
)
|
|
|
|
|
|
2023-12-15 15:57:47 +08:00
|
|
|
return out_vars
|
2023-12-06 10:51:42 +08:00
|
|
|
|
2023-04-04 18:41:35 +08:00
|
|
|
def _sync_lr_value_with_scheduler(self):
|
|
|
|
|
"""Update lr_var value with calculated by lr_scheduler."""
|
|
|
|
|
main_program = self._origin_main_program
|
|
|
|
|
if hasattr(main_program, 'lr_scheduler') and hasattr(
|
|
|
|
|
main_program, 'lr_var'
|
|
|
|
|
):
|
|
|
|
|
lr_scheduler = main_program.lr_scheduler
|
|
|
|
|
lr_var = main_program.lr_var
|
|
|
|
|
|
|
|
|
|
assert isinstance(lr_scheduler, LRScheduler), "must be LRScheduler"
|
|
|
|
|
lr_scheduler = self._origin_main_program.lr_scheduler
|
|
|
|
|
lr_value = lr_scheduler()
|
|
|
|
|
data = np.array(lr_value).astype(convert_dtype(lr_var.dtype))
|
|
|
|
|
lr_var.set_value(data)
|
|
|
|
|
|
2023-02-28 16:01:17 +08:00
|
|
|
def set_hooker(self, hooker):
|
|
|
|
|
self._hooker = hooker
|
|
|
|
|
|
2022-09-09 20:27:42 +08:00
|
|
|
def _get_scope(self, program_id=None, use_scope_cache=False):
|
2023-12-28 14:35:16 +08:00
|
|
|
if self._in_pir_pt_mode or self._enable_pir_in_executor:
|
2023-11-02 16:50:47 +08:00
|
|
|
_scope_cache = self._pir_scope_cache
|
2022-09-09 20:27:42 +08:00
|
|
|
else:
|
2023-11-02 16:50:47 +08:00
|
|
|
_scope_cache = self._legacy_scope_cache
|
|
|
|
|
if not use_scope_cache:
|
2022-09-09 20:27:42 +08:00
|
|
|
return core.Scope()
|
2023-11-02 16:50:47 +08:00
|
|
|
if program_id not in _scope_cache:
|
|
|
|
|
_scope_cache[program_id] = []
|
|
|
|
|
cached_scopes = _scope_cache[program_id]
|
|
|
|
|
for scope in cached_scopes:
|
|
|
|
|
if scope._can_reused:
|
|
|
|
|
return scope
|
|
|
|
|
scope = core.Scope()
|
|
|
|
|
cached_scopes.append(scope)
|
|
|
|
|
return scope
|
2022-09-09 20:27:42 +08:00
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
# whole
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _create_program(self, is_infer_mode=False):
|
|
|
|
|
if is_infer_mode:
|
2023-02-28 16:01:17 +08:00
|
|
|
infer_program = self._origin_main_program.clone(
|
|
|
|
|
for_test=is_infer_mode
|
|
|
|
|
)
|
|
|
|
|
if self._hooker:
|
2023-03-05 09:04:55 +00:00
|
|
|
infer_program = self._hooker.after_infer(infer_program)
|
2023-02-28 16:01:17 +08:00
|
|
|
return infer_program
|
2022-08-29 21:56:17 +08:00
|
|
|
else:
|
|
|
|
|
train_program = self._append_backward_desc(
|
2022-10-23 20:01:27 +08:00
|
|
|
self._origin_main_program
|
|
|
|
|
)
|
2022-08-29 21:56:17 +08:00
|
|
|
# Note: Only set grad type once after initializing train program. So we put it here.
|
|
|
|
|
self._set_grad_type(self._params, train_program)
|
|
|
|
|
return train_program
|
2020-10-16 15:57:05 +08:00
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _create_amp_program(self, is_infer_mode=False):
|
|
|
|
|
amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
|
|
|
|
|
with program_guard(amp_program):
|
2023-04-24 09:39:40 +08:00
|
|
|
paddle.static.amp.fp16_utils.cast_model_to_fp16(
|
|
|
|
|
amp_program, self._amp_list, use_fp16_guard=False, level='O1'
|
2023-01-12 19:37:36 +08:00
|
|
|
)
|
2022-08-29 21:56:17 +08:00
|
|
|
if is_infer_mode:
|
2023-04-09 21:45:12 +08:00
|
|
|
if self._hooker:
|
|
|
|
|
amp_program = self._hooker.after_infer(amp_program)
|
2022-08-29 21:56:17 +08:00
|
|
|
return amp_program
|
|
|
|
|
else:
|
|
|
|
|
train_amp_program = self._append_backward_desc(amp_program)
|
|
|
|
|
self._set_grad_type(self._params, train_amp_program)
|
|
|
|
|
return train_amp_program
|
2020-10-16 15:57:05 +08:00
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _create_pure_fp16_program(self, is_infer_mode=False):
|
|
|
|
|
pure_fp16_program = self._origin_main_program.clone(
|
2022-10-23 20:01:27 +08:00
|
|
|
for_test=is_infer_mode
|
|
|
|
|
)
|
2022-08-29 21:56:17 +08:00
|
|
|
with program_guard(pure_fp16_program):
|
2023-01-12 19:37:36 +08:00
|
|
|
paddle.static.amp.fp16_utils.cast_model_to_fp16(
|
2022-10-23 20:01:27 +08:00
|
|
|
pure_fp16_program, self._amp_list, use_fp16_guard=False
|
|
|
|
|
)
|
2023-02-24 16:28:33 +08:00
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
if is_infer_mode:
|
2023-03-05 09:04:55 +00:00
|
|
|
if self._hooker:
|
|
|
|
|
pure_fp16_program = self._hooker.after_infer(pure_fp16_program)
|
2022-08-29 21:56:17 +08:00
|
|
|
return pure_fp16_program
|
|
|
|
|
else:
|
|
|
|
|
train_pure_fp16_program = self._append_backward_desc(
|
2022-10-23 20:01:27 +08:00
|
|
|
pure_fp16_program
|
|
|
|
|
)
|
2022-08-29 21:56:17 +08:00
|
|
|
self._set_grad_type(self._params, train_pure_fp16_program)
|
|
|
|
|
return train_pure_fp16_program
|
2020-10-16 15:57:05 +08:00
|
|
|
|
2021-08-05 16:05:54 +08:00
|
|
|
@switch_to_static_graph
|
2022-08-29 21:56:17 +08:00
|
|
|
def _create_forward_backward_train_program(self):
|
2023-01-03 20:14:53 +08:00
|
|
|
whole_program = self._train_program
|
2023-03-02 21:13:03 +08:00
|
|
|
forward_end_op_index = self.get_forward_end_op_idx(whole_program)
|
2023-01-11 14:54:35 +08:00
|
|
|
assert forward_end_op_index >= 0
|
2023-02-17 10:27:41 +08:00
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
return self._get_forward_backward_program_form(
|
|
|
|
|
whole_program, forward_end_op_index
|
|
|
|
|
)
|
2021-08-05 16:05:54 +08:00
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _create_forward_backward_train_amp_program(self):
|
2023-01-03 20:14:53 +08:00
|
|
|
whole_program = self._train_amp_program
|
2023-03-15 12:00:13 +08:00
|
|
|
forward_end_op_index = self.get_forward_end_op_idx(whole_program)
|
2023-01-11 14:54:35 +08:00
|
|
|
assert forward_end_op_index >= 0
|
2023-02-17 10:27:41 +08:00
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
return self._get_forward_backward_program_form(
|
|
|
|
|
whole_program, forward_end_op_index
|
|
|
|
|
)
|
2022-08-29 21:56:17 +08:00
|
|
|
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _create_forward_backward_train_pure_fp16_program(self):
|
2023-01-03 20:14:53 +08:00
|
|
|
whole_program = self._train_pure_fp16_program
|
2023-03-15 12:00:13 +08:00
|
|
|
forward_end_op_index = self.get_forward_end_op_idx(whole_program)
|
2023-01-11 14:54:35 +08:00
|
|
|
assert forward_end_op_index >= 0
|
2023-02-17 10:27:41 +08:00
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
return self._get_forward_backward_program_form(
|
|
|
|
|
whole_program, forward_end_op_index
|
|
|
|
|
)
|
2021-08-05 16:05:54 +08:00
|
|
|
|
|
|
|
|
@LazyInitialized
|
2022-08-29 21:56:17 +08:00
|
|
|
def _train_program(self):
|
|
|
|
|
return self._create_program()
|
2021-08-05 16:05:54 +08:00
|
|
|
|
2021-11-24 11:22:22 +08:00
|
|
|
@LazyInitialized
|
2022-08-29 21:56:17 +08:00
|
|
|
def _infer_program(self):
|
2023-02-17 10:27:41 +08:00
|
|
|
program, op_size = self._infer_info('fp32', self._create_program)
|
|
|
|
|
return self._build_infer_program(program, op_size)
|
2021-11-24 11:22:22 +08:00
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
@LazyInitialized
|
|
|
|
|
def _train_amp_program(self):
|
|
|
|
|
return self._create_amp_program()
|
|
|
|
|
|
|
|
|
|
@LazyInitialized
|
|
|
|
|
def _infer_amp_program(self):
|
2023-02-17 10:27:41 +08:00
|
|
|
program, op_size = self._infer_info('amp', self._create_amp_program)
|
|
|
|
|
return self._build_infer_program(program, op_size)
|
2021-11-24 11:22:22 +08:00
|
|
|
|
|
|
|
|
@LazyInitialized
|
|
|
|
|
def _train_pure_fp16_program(self):
|
2022-08-29 21:56:17 +08:00
|
|
|
return self._create_pure_fp16_program()
|
2021-11-24 11:22:22 +08:00
|
|
|
|
2021-07-27 10:51:27 +08:00
|
|
|
@LazyInitialized
|
2022-08-29 21:56:17 +08:00
|
|
|
def _infer_pure_fp16_program(self):
|
2023-02-17 10:27:41 +08:00
|
|
|
program, op_size = self._infer_info(
|
|
|
|
|
'fp16', self._create_pure_fp16_program
|
2023-01-11 14:54:35 +08:00
|
|
|
)
|
2023-02-17 10:27:41 +08:00
|
|
|
return self._build_infer_program(program, op_size)
|
2021-07-27 10:51:27 +08:00
|
|
|
|
2022-07-21 20:14:23 +08:00
|
|
|
@LazyInitialized
|
2022-08-29 21:56:17 +08:00
|
|
|
def _train_forward_backward_program(self):
|
|
|
|
|
program = self._create_forward_backward_train_program()
|
|
|
|
|
return program
|
2022-07-21 20:14:23 +08:00
|
|
|
|
|
|
|
|
@LazyInitialized
|
2022-08-29 21:56:17 +08:00
|
|
|
def _train_amp_forward_backward_program(self):
|
|
|
|
|
program = self._create_forward_backward_train_amp_program()
|
|
|
|
|
return program
|
|
|
|
|
|
2023-01-09 10:14:25 +08:00
|
|
|
@LazyInitialized
|
|
|
|
|
def _empty_backward_program_for_eval(self):
|
|
|
|
|
return paddle.static.Program()
|
|
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
@LazyInitialized
|
|
|
|
|
def _train_pure_fp16_forward_backward_program(self):
|
|
|
|
|
program = self._create_forward_backward_train_pure_fp16_program()
|
|
|
|
|
return program
|
|
|
|
|
|
2021-07-27 10:51:27 +08:00
|
|
|
@LazyInitialized
|
|
|
|
|
def _train_program_id(self):
|
2023-03-09 20:30:52 +08:00
|
|
|
program_id = paddle.utils._hash_with_id(self._train_program, self)
|
2021-07-28 20:35:37 +08:00
|
|
|
return program_id
|
2021-07-27 10:51:27 +08:00
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
@LazyInitialized
|
|
|
|
|
def _infer_program_id(self):
|
2023-03-09 20:30:52 +08:00
|
|
|
return paddle.utils._hash_with_id(self._infer_program, self)
|
2022-08-29 21:56:17 +08:00
|
|
|
|
2021-08-05 16:05:54 +08:00
|
|
|
@LazyInitialized
|
|
|
|
|
def _train_amp_program_id(self):
|
2023-03-09 20:30:52 +08:00
|
|
|
program_id = paddle.utils._hash_with_id(self._train_amp_program, self)
|
2021-08-05 16:05:54 +08:00
|
|
|
return program_id
|
|
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
@LazyInitialized
|
|
|
|
|
def _infer_amp_program_id(self):
|
2023-03-09 20:30:52 +08:00
|
|
|
return paddle.utils._hash_with_id(self._infer_amp_program, self)
|
2022-08-29 21:56:17 +08:00
|
|
|
|
2021-11-24 11:22:22 +08:00
|
|
|
@LazyInitialized
|
|
|
|
|
def _train_pure_fp16_program_id(self):
|
2023-03-09 20:30:52 +08:00
|
|
|
program_id = paddle.utils._hash_with_id(
|
|
|
|
|
self._train_pure_fp16_program, self
|
|
|
|
|
)
|
2021-11-24 11:22:22 +08:00
|
|
|
return program_id
|
|
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
@LazyInitialized
|
|
|
|
|
def _infer_pure_fp16_program_id(self):
|
2023-03-09 20:30:52 +08:00
|
|
|
return paddle.utils._hash_with_id(self._infer_pure_fp16_program, self)
|
2022-08-29 21:56:17 +08:00
|
|
|
|
2023-03-02 21:13:03 +08:00
|
|
|
def get_forward_end_op_idx(self, program):
|
2023-03-05 09:04:55 +00:00
|
|
|
return self._forward_end_index_map[
|
|
|
|
|
paddle.utils._hash_with_id(program, self)
|
|
|
|
|
]
|
2023-03-02 21:13:03 +08:00
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
@property
|
2023-01-11 14:54:35 +08:00
|
|
|
def program(self):
|
|
|
|
|
"""
|
|
|
|
|
Return current train or eval program.
|
|
|
|
|
"""
|
|
|
|
|
if self.training:
|
|
|
|
|
return self.train_program
|
|
|
|
|
else:
|
|
|
|
|
return self.infer_program
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def program_id(self):
|
|
|
|
|
"""
|
|
|
|
|
Return current train or eval program hash id.
|
|
|
|
|
"""
|
2022-08-29 21:56:17 +08:00
|
|
|
if self.training:
|
|
|
|
|
if _in_amp_guard():
|
|
|
|
|
return self._train_amp_program_id
|
|
|
|
|
elif _in_pure_fp16_guard():
|
|
|
|
|
return self._train_pure_fp16_program_id
|
|
|
|
|
else:
|
|
|
|
|
return self._train_program_id
|
|
|
|
|
else:
|
|
|
|
|
if _in_amp_guard():
|
|
|
|
|
return self._infer_amp_program_id
|
|
|
|
|
elif _in_pure_fp16_guard():
|
|
|
|
|
return self._infer_pure_fp16_program_id
|
|
|
|
|
else:
|
|
|
|
|
return self._infer_program_id
|
|
|
|
|
|
2023-01-11 14:54:35 +08:00
|
|
|
@property
|
|
|
|
|
def train_program(self):
|
|
|
|
|
if _in_amp_guard():
|
|
|
|
|
return self._train_amp_program
|
|
|
|
|
elif _in_pure_fp16_guard():
|
|
|
|
|
return self._train_pure_fp16_program
|
|
|
|
|
else:
|
|
|
|
|
return self._train_program
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def infer_program(self):
|
|
|
|
|
if _in_amp_guard():
|
2023-11-29 09:35:45 +08:00
|
|
|
infer_program = self._infer_amp_program
|
2023-01-11 14:54:35 +08:00
|
|
|
elif _in_pure_fp16_guard():
|
2023-11-29 09:35:45 +08:00
|
|
|
infer_program = self._infer_pure_fp16_program
|
2023-01-11 14:54:35 +08:00
|
|
|
else:
|
2023-11-29 09:35:45 +08:00
|
|
|
infer_program = self._infer_program
|
|
|
|
|
return infer_program
|
2023-01-11 14:54:35 +08:00
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def forward_program(self):
|
2025-01-02 13:11:06 +08:00
|
|
|
forward_program = None
|
2023-01-11 14:54:35 +08:00
|
|
|
if self.training:
|
|
|
|
|
if _in_amp_guard():
|
|
|
|
|
progs = self._train_amp_forward_backward_program
|
|
|
|
|
elif _in_pure_fp16_guard():
|
|
|
|
|
progs = self._train_pure_fp16_forward_backward_program
|
|
|
|
|
else:
|
|
|
|
|
progs = self._train_forward_backward_program
|
|
|
|
|
return progs[0]
|
|
|
|
|
else:
|
2023-11-29 09:35:45 +08:00
|
|
|
forward_program = self.infer_program
|
|
|
|
|
return forward_program
|
2023-01-11 14:54:35 +08:00
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def backward_program(self):
|
|
|
|
|
if self.training:
|
|
|
|
|
if _in_amp_guard():
|
|
|
|
|
progs = self._train_amp_forward_backward_program
|
|
|
|
|
elif _in_pure_fp16_guard():
|
|
|
|
|
progs = self._train_pure_fp16_forward_backward_program
|
|
|
|
|
else:
|
|
|
|
|
progs = self._train_forward_backward_program
|
|
|
|
|
return progs[1]
|
|
|
|
|
else:
|
|
|
|
|
"""
|
|
|
|
|
Can't just return paddle.static.Program(), because self.backward_program is a property,
|
2024-01-29 16:00:57 +08:00
|
|
|
whenever we call this method, a tmp Program() object is created and is gc immediately
|
2023-01-11 14:54:35 +08:00
|
|
|
after executed the following line in PartialProgramLayer.__call__.
|
|
|
|
|
|
|
|
|
|
>>> self.backward_program.desc.block(0),
|
|
|
|
|
|
|
|
|
|
When we access RunProgramAPI, it's possible to get an invalid backward_program address.
|
|
|
|
|
"""
|
|
|
|
|
return self._empty_backward_program_for_eval
|
|
|
|
|
|
2020-07-16 10:11:17 +08:00
|
|
|
def _verify_program(self, main_program):
|
|
|
|
|
"""
|
|
|
|
|
Verify that the program parameter is initialized, prune some unused params,
|
|
|
|
|
and remove redundant op callstack.
|
|
|
|
|
"""
|
|
|
|
|
# 1. Check all params from main program can be found in self._params
|
|
|
|
|
self._check_params_all_inited(main_program)
|
|
|
|
|
# 2. Prune the parameters not used anywhere in the program.
|
|
|
|
|
self._prune_unused_params(main_program)
|
|
|
|
|
|
|
|
|
|
return main_program
|
|
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
def prepare_gradient_aggregation(
|
|
|
|
|
self, start_idx, main_program, target_program
|
|
|
|
|
):
|
2022-08-10 10:47:24 +08:00
|
|
|
"""
|
|
|
|
|
Why we need add gradient aggregation operation ?
|
|
|
|
|
In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
|
|
|
|
|
def forward(self, in):
|
|
|
|
|
x = 2 * in # <---- x is a non-leaf node in program.
|
|
|
|
|
y = x + 3
|
|
|
|
|
return x, y
|
2022-09-14 21:56:19 +08:00
|
|
|
|
2022-08-10 10:47:24 +08:00
|
|
|
loss = forward(in)[0].sum()
|
2024-01-29 16:00:57 +08:00
|
|
|
loss.backward() # <----- x@grad will be overwritten by elementwise_add_grad Op
|
2022-08-10 10:47:24 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _need_aggregation(var):
|
|
|
|
|
"""
|
|
|
|
|
if exist a op whose inputs is var, then return True
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(var, framework.Variable) or var.type not in [
|
2024-11-13 16:43:16 +08:00
|
|
|
core.VarDesc.VarType.DENSE_TENSOR,
|
2022-10-23 20:01:27 +08:00
|
|
|
core.VarDesc.VarType.SELECTED_ROWS,
|
2022-08-10 10:47:24 +08:00
|
|
|
]:
|
|
|
|
|
return False
|
|
|
|
|
if var.dtype not in [paddle.float32, paddle.float64]:
|
|
|
|
|
return False
|
|
|
|
|
for op in main_program.block(0).ops:
|
|
|
|
|
for in_arg in op.input_arg_names:
|
|
|
|
|
if in_arg == var.name:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _insert_aggregation_ops_for_var(target_program, var):
|
|
|
|
|
suffix = "@dy2static"
|
|
|
|
|
var_grad_name = var.grad_name
|
|
|
|
|
new_grad_name = var.name + suffix + "@GRAD"
|
2024-01-26 01:39:45 +08:00
|
|
|
found_ops = list(
|
2022-08-10 10:47:24 +08:00
|
|
|
filter(
|
2026-02-09 11:19:20 +08:00
|
|
|
lambda x: (
|
|
|
|
|
x[0] >= start_idx
|
|
|
|
|
and any(
|
|
|
|
|
out_arg == var_grad_name
|
|
|
|
|
for out_arg in x[1].output_arg_names
|
|
|
|
|
)
|
2022-10-23 20:01:27 +08:00
|
|
|
),
|
|
|
|
|
enumerate(target_program.block(0).ops),
|
|
|
|
|
)
|
|
|
|
|
)
|
2022-08-10 10:47:24 +08:00
|
|
|
|
2024-01-26 01:39:45 +08:00
|
|
|
# len(found_ops) may equals zero when stop_gradient works.
|
|
|
|
|
# len(found_ops) may > 1, because we may have fill_constant op.
|
|
|
|
|
if len(found_ops) == 0:
|
2022-08-10 10:47:24 +08:00
|
|
|
return None
|
|
|
|
|
# step1: create a new var named var.name@GRAD
|
2022-10-23 20:01:27 +08:00
|
|
|
target_program.block(0).create_var(
|
|
|
|
|
name=new_grad_name,
|
|
|
|
|
type=var.type,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape,
|
|
|
|
|
)
|
2022-08-10 10:47:24 +08:00
|
|
|
# step2: rename the var.name@GRAD to var.name@GRAD@dy2static
|
2024-01-26 01:39:45 +08:00
|
|
|
for idx, op in found_ops:
|
2022-08-10 10:47:24 +08:00
|
|
|
op._rename_input(var_grad_name, new_grad_name)
|
|
|
|
|
op._rename_output(var_grad_name, new_grad_name)
|
|
|
|
|
# step3: insert sum op to aggregate the gradient.
|
|
|
|
|
# var.name@GRAD = sum(var.name@dy2static@GRAD, var.name@GRAD)
|
|
|
|
|
target_program.block(0)._insert_op(
|
2024-01-26 01:39:45 +08:00
|
|
|
found_ops[-1][0] + 1,
|
2022-08-10 10:47:24 +08:00
|
|
|
type='sum',
|
|
|
|
|
inputs={'X': [var_grad_name, new_grad_name]},
|
2022-10-23 20:01:27 +08:00
|
|
|
outputs={"Out": var_grad_name},
|
|
|
|
|
)
|
2022-08-10 10:47:24 +08:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
to_processed_vars = list(
|
2022-10-23 20:01:27 +08:00
|
|
|
filter(_need_aggregation, self._outputs.tolist())
|
|
|
|
|
)
|
2022-08-10 10:47:24 +08:00
|
|
|
for _var in to_processed_vars:
|
2023-09-08 20:02:32 +08:00
|
|
|
target_program: paddle.static.Program
|
|
|
|
|
target_var = target_program.global_block().var(_var.name)
|
|
|
|
|
_insert_aggregation_ops_for_var(target_program, target_var)
|
2022-08-10 10:47:24 +08:00
|
|
|
|
2020-05-08 16:52:06 +08:00
|
|
|
@switch_to_static_graph
|
2020-08-04 19:06:24 +08:00
|
|
|
def _append_backward_desc(self, main_program):
|
2023-03-23 19:45:10 +08:00
|
|
|
program = main_program.clone(for_test=False)
|
2023-02-28 16:01:17 +08:00
|
|
|
if self._hooker:
|
2023-03-05 09:04:55 +00:00
|
|
|
program = self._hooker.before_append_backward(program)
|
2020-05-08 16:52:06 +08:00
|
|
|
targets = []
|
2020-05-28 17:38:33 +08:00
|
|
|
for out in self._outputs.tolist():
|
2020-05-08 16:52:06 +08:00
|
|
|
if isinstance(out, framework.Variable):
|
|
|
|
|
targets.append(program.global_block().var(out.name))
|
|
|
|
|
|
2023-03-02 21:13:03 +08:00
|
|
|
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
|
2023-01-17 11:11:50 +08:00
|
|
|
if targets:
|
2023-02-24 11:14:11 +00:00
|
|
|
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
|
2023-04-11 11:16:53 +08:00
|
|
|
with backend_guard(self._backend):
|
2023-04-27 17:17:04 +08:00
|
|
|
check_type(
|
|
|
|
|
targets,
|
|
|
|
|
'targets',
|
|
|
|
|
(framework.Variable, list, tuple),
|
|
|
|
|
'paddle.static.gradients',
|
|
|
|
|
)
|
|
|
|
|
grad_info_map = backward.calc_gradient_helper(
|
|
|
|
|
targets=targets, inputs=[]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
x_vars = [
|
|
|
|
|
program.block(0).var(var.name)
|
|
|
|
|
for var in self._inputs
|
|
|
|
|
if isinstance(var, framework.Variable)
|
|
|
|
|
]
|
|
|
|
|
param_vars = [
|
|
|
|
|
program.block(0).var(param.name) for param in self._params
|
|
|
|
|
]
|
|
|
|
|
out_vars = [
|
|
|
|
|
program.block(0).var(var.name)
|
|
|
|
|
for var in self._outputs
|
|
|
|
|
if isinstance(var, framework.Variable)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
self._grad_var_names = construct_grad_names(
|
|
|
|
|
grad_info_map, x_vars, param_vars, out_vars
|
|
|
|
|
)
|
2023-02-24 11:14:11 +00:00
|
|
|
|
2023-02-28 16:01:17 +08:00
|
|
|
if self._hooker:
|
|
|
|
|
program, start_idx = self._hooker.after_append_backward(
|
2023-03-05 09:04:55 +00:00
|
|
|
program, start_idx
|
2023-02-28 16:01:17 +08:00
|
|
|
)
|
2023-03-05 09:04:55 +00:00
|
|
|
self.prepare_gradient_aggregation(
|
|
|
|
|
start_idx + 1, main_program, program
|
|
|
|
|
)
|
2022-08-10 10:47:24 +08:00
|
|
|
|
2023-03-02 21:13:03 +08:00
|
|
|
self._forward_end_index_map[
|
2023-03-05 09:04:55 +00:00
|
|
|
paddle.utils._hash_with_id(program, self)
|
2023-03-02 21:13:03 +08:00
|
|
|
] = start_idx - len(self._outputs.tolist())
|
2020-05-08 16:52:06 +08:00
|
|
|
return program
|
|
|
|
|
|
2020-06-24 16:29:53 +08:00
|
|
|
def _prune_unused_params(self, program):
|
|
|
|
|
"""
|
|
|
|
|
Prune the parameters not used anywhere in the program.
|
2022-12-19 19:12:43 +08:00
|
|
|
The `@to_static` may only decorated a sub function which
|
2020-06-24 16:29:53 +08:00
|
|
|
contains some unused parameters created in `__init__`.
|
|
|
|
|
So prune these parameters to avoid unnecessary operations in
|
|
|
|
|
`run_program_op`.
|
|
|
|
|
"""
|
|
|
|
|
required_params = []
|
|
|
|
|
for param in self._params:
|
2021-06-04 15:19:19 +08:00
|
|
|
found_param = False
|
2020-06-24 16:29:53 +08:00
|
|
|
for block in program.blocks:
|
2021-06-04 15:19:19 +08:00
|
|
|
for op in block.ops:
|
2022-10-23 20:01:27 +08:00
|
|
|
if (
|
|
|
|
|
param.name in op.input_arg_names
|
|
|
|
|
or param.name in op.output_arg_names
|
|
|
|
|
):
|
2021-06-04 15:19:19 +08:00
|
|
|
required_params.append(param)
|
|
|
|
|
found_param = True
|
|
|
|
|
break
|
|
|
|
|
if found_param:
|
2020-06-24 16:29:53 +08:00
|
|
|
break
|
|
|
|
|
|
|
|
|
|
self._params = required_params
|
|
|
|
|
|
2023-02-17 10:27:41 +08:00
|
|
|
def _cast_fp16_if_pure_fp16(self, in_vars):
|
|
|
|
|
if _in_pure_fp16_guard():
|
|
|
|
|
for i, var in enumerate(in_vars):
|
|
|
|
|
name = var.name
|
|
|
|
|
if (
|
|
|
|
|
self.program.global_block().has_var(name)
|
|
|
|
|
and self.program.global_block().var(name).dtype
|
|
|
|
|
== paddle.float16
|
|
|
|
|
):
|
|
|
|
|
in_vars[i] = var.astype('float16')
|
|
|
|
|
in_vars[i].name = name
|
2022-08-29 21:56:17 +08:00
|
|
|
|
2023-12-28 14:35:16 +08:00
|
|
|
@property
|
|
|
|
|
def _in_pir_pt_mode(self):
|
|
|
|
|
pir_dy2st_flag = 'FLAGS_enable_pir_with_pt_in_dy2st'
|
|
|
|
|
in_pir_pt_mode = get_flags(pir_dy2st_flag)[pir_dy2st_flag]
|
|
|
|
|
is_prim_enabled = (
|
|
|
|
|
core._is_fwd_prim_enabled() or core._is_bwd_prim_enabled()
|
|
|
|
|
)
|
2025-03-25 10:36:13 +08:00
|
|
|
in_cinn_backend = self._backend.is_cinn()
|
2023-12-28 14:35:16 +08:00
|
|
|
is_cinn_enabled = self._build_strategy.build_cinn_pass
|
|
|
|
|
if is_prim_enabled or in_cinn_backend or is_cinn_enabled:
|
|
|
|
|
in_pir_pt_mode = False
|
|
|
|
|
return in_pir_pt_mode
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _enable_pir_in_executor(self):
|
|
|
|
|
enable_pir_in_executor_flag = 'FLAGS_enable_pir_in_executor'
|
|
|
|
|
enable_pir_in_executor = get_flags(enable_pir_in_executor_flag)[
|
|
|
|
|
enable_pir_in_executor_flag
|
|
|
|
|
]
|
|
|
|
|
return enable_pir_in_executor
|
|
|
|
|
|
|
|
|
|
def _prepare_attributes(self):
|
2022-06-02 09:39:44 +08:00
|
|
|
attrs = [
|
2023-02-17 10:27:41 +08:00
|
|
|
'forward_global_block',
|
|
|
|
|
self.forward_program.desc.block(0),
|
|
|
|
|
'backward_global_block',
|
|
|
|
|
self.backward_program.desc.block(0),
|
2022-10-23 20:01:27 +08:00
|
|
|
'is_test',
|
|
|
|
|
not self.training,
|
|
|
|
|
'program_id',
|
|
|
|
|
self.program_id,
|
2022-06-02 09:39:44 +08:00
|
|
|
]
|
2023-02-28 16:01:17 +08:00
|
|
|
|
2023-01-03 20:14:53 +08:00
|
|
|
if self.training:
|
|
|
|
|
# NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
|
|
|
|
|
# `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
|
|
|
|
|
# the correct names of the parameter grads from program. And out grads are similar to above.
|
|
|
|
|
attrs.extend(
|
|
|
|
|
(
|
|
|
|
|
'param_grad_names',
|
2023-04-27 17:17:04 +08:00
|
|
|
self._grad_var_names.get('param', []),
|
2023-01-03 20:14:53 +08:00
|
|
|
'out_grad_names',
|
2023-04-27 17:17:04 +08:00
|
|
|
self._grad_var_names.get('out', []),
|
|
|
|
|
'x_grad_names',
|
|
|
|
|
self._grad_var_names.get('x', []),
|
2023-01-03 20:14:53 +08:00
|
|
|
)
|
|
|
|
|
)
|
2023-11-23 18:00:52 +08:00
|
|
|
|
2023-12-28 14:35:16 +08:00
|
|
|
in_pir_pt_mode = self._in_pir_pt_mode
|
2023-11-23 18:00:52 +08:00
|
|
|
attrs.extend(['in_pir_pt_mode', in_pir_pt_mode])
|
|
|
|
|
|
2023-02-17 10:27:41 +08:00
|
|
|
return attrs
|
2021-11-24 11:22:22 +08:00
|
|
|
|
2023-01-09 21:02:37 +08:00
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _build_infer_program(self, infer_program, forward_end_op_index):
|
|
|
|
|
forward_skip_vars = self._parse_skip_gc_vars(infer_program)
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
built_infer_program = add_build_strategy_for(
|
2023-01-09 21:02:37 +08:00
|
|
|
infer_program,
|
|
|
|
|
0,
|
|
|
|
|
forward_end_op_index,
|
|
|
|
|
self._build_strategy,
|
|
|
|
|
forward_skip_vars,
|
|
|
|
|
)
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
self._apply_inplace_pass(built_infer_program, None)
|
|
|
|
|
return built_infer_program
|
2021-07-27 10:51:27 +08:00
|
|
|
|
2022-08-29 21:56:17 +08:00
|
|
|
@switch_to_static_graph
|
2022-10-23 20:01:27 +08:00
|
|
|
def _get_forward_backward_program_form(
|
|
|
|
|
self, whole_program, forward_end_op_index
|
|
|
|
|
):
|
2022-12-29 16:56:09 +08:00
|
|
|
# NOTE(dev): We apply build_strategy for backward firstly to
|
|
|
|
|
# avoid skipping more gc variables.
|
2023-02-02 18:40:39 +08:00
|
|
|
backward_start_op_index = forward_end_op_index + len(
|
2022-10-23 20:01:27 +08:00
|
|
|
self._outputs.var_ids
|
|
|
|
|
)
|
2022-08-29 21:56:17 +08:00
|
|
|
backward_end_op_index = whole_program.desc.block(0).op_size()
|
2024-01-23 14:43:28 +08:00
|
|
|
# For Backward process in CINN, all param@GRAD should be skipped for GC, because
|
2023-01-06 13:16:28 +08:00
|
|
|
# they will be shared in scope and used by optimizer.
|
2023-04-27 17:17:04 +08:00
|
|
|
backward_skip_vars = self._parse_skip_gc_vars(
|
|
|
|
|
whole_program
|
|
|
|
|
) + self._grad_var_names.get('param', [])
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
backward_built_program = add_build_strategy_for(
|
2022-10-23 20:01:27 +08:00
|
|
|
whole_program,
|
|
|
|
|
backward_start_op_index,
|
|
|
|
|
backward_end_op_index,
|
|
|
|
|
self._build_strategy,
|
2022-12-29 16:56:09 +08:00
|
|
|
backward_skip_vars,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
forward_skip_vars = self._parse_skip_gc_vars(
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
whole_program, backward_built_program
|
2022-12-29 16:56:09 +08:00
|
|
|
)
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
forward_built_program = add_build_strategy_for(
|
2022-12-29 16:56:09 +08:00
|
|
|
whole_program,
|
|
|
|
|
0,
|
|
|
|
|
forward_end_op_index,
|
|
|
|
|
self._build_strategy,
|
|
|
|
|
forward_skip_vars,
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2022-12-29 16:56:09 +08:00
|
|
|
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
self._apply_inplace_pass(forward_built_program, backward_built_program)
|
2023-11-29 09:35:45 +08:00
|
|
|
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
return [forward_built_program, backward_built_program]
|
2022-08-29 21:56:17 +08:00
|
|
|
|
|
|
|
|
def _apply_inplace_pass(self, forward_program, backward_program):
|
|
|
|
|
attr_types = {
|
|
|
|
|
"use_cuda": "bool",
|
|
|
|
|
"mem_opt_skip_vars": "list[str]",
|
2022-10-23 20:01:27 +08:00
|
|
|
"for_partial_block": "bool",
|
2022-08-29 21:56:17 +08:00
|
|
|
}
|
|
|
|
|
empty_startup_program = paddle.static.Program()
|
|
|
|
|
use_cuda = True if core.is_compiled_with_cuda() else False
|
|
|
|
|
# skip data var
|
2022-12-29 16:56:09 +08:00
|
|
|
forward_mem_opt_skip_vars = self._parse_skip_gc_vars(
|
|
|
|
|
forward_program, backward_program
|
|
|
|
|
)
|
|
|
|
|
backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program)
|
2023-01-09 21:02:37 +08:00
|
|
|
if forward_program:
|
|
|
|
|
attrs = {
|
|
|
|
|
"use_cuda": use_cuda,
|
|
|
|
|
"mem_opt_skip_vars": forward_mem_opt_skip_vars,
|
|
|
|
|
"for_partial_block": True,
|
|
|
|
|
}
|
2024-09-18 11:01:07 +08:00
|
|
|
# if not (self._in_pir_pt_mode or self._enable_pir_in_executor):
|
|
|
|
|
# _apply_pass(
|
|
|
|
|
# forward_program,
|
|
|
|
|
# empty_startup_program,
|
|
|
|
|
# "buffer_shared_inplace_pass",
|
|
|
|
|
# attrs,
|
|
|
|
|
# attr_types,
|
|
|
|
|
# )
|
2023-01-09 21:02:37 +08:00
|
|
|
if backward_program:
|
|
|
|
|
attrs = {
|
|
|
|
|
"use_cuda": use_cuda,
|
|
|
|
|
"mem_opt_skip_vars": backward_mem_opt_skip_vars,
|
|
|
|
|
"for_partial_block": True,
|
|
|
|
|
}
|
2024-09-18 11:01:07 +08:00
|
|
|
# if not (self._in_pir_pt_mode or self._enable_pir_in_executor):
|
|
|
|
|
# _apply_pass(
|
|
|
|
|
# backward_program,
|
|
|
|
|
# empty_startup_program,
|
|
|
|
|
# "buffer_shared_inplace_pass",
|
|
|
|
|
# attrs,
|
|
|
|
|
# attr_types,
|
|
|
|
|
# )
|
2022-08-29 21:56:17 +08:00
|
|
|
|
2022-12-29 16:56:09 +08:00
|
|
|
@LazyInitialized
|
|
|
|
|
def _inout_var_names(self):
|
|
|
|
|
"""
|
|
|
|
|
Returns Variable Names from self._inputs and self.outputs
|
|
|
|
|
"""
|
|
|
|
|
var_names = []
|
|
|
|
|
for var in self._inputs:
|
2023-09-07 17:26:19 +08:00
|
|
|
if isinstance(var, paddle.base.framework.Variable):
|
2022-12-29 16:56:09 +08:00
|
|
|
var_names.append(var.desc.name())
|
|
|
|
|
for var in self._outputs:
|
2023-09-07 17:26:19 +08:00
|
|
|
if isinstance(var, paddle.base.framework.Variable):
|
2022-12-29 16:56:09 +08:00
|
|
|
var_names.append(var.desc.name())
|
|
|
|
|
return var_names
|
|
|
|
|
|
|
|
|
|
def _parse_skip_gc_vars(self, program, backward_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Parse variables that need to skip GC after execute it.
|
|
|
|
|
If specify backward_program, it will keep the variables used in backward.
|
|
|
|
|
"""
|
|
|
|
|
# skip data var, DO NOT ignore this deepcopy
|
|
|
|
|
skip_vars = deepcopy(self._inout_var_names)
|
|
|
|
|
for var_name, var in program.global_block().vars.items():
|
|
|
|
|
if var.is_data:
|
|
|
|
|
skip_vars.append(var_name)
|
|
|
|
|
|
|
|
|
|
if backward_program:
|
|
|
|
|
for var_name in core.parse_safe_eager_deletion_skip_vars(
|
2023-01-06 13:16:28 +08:00
|
|
|
backward_program.desc, True
|
2022-12-29 16:56:09 +08:00
|
|
|
):
|
|
|
|
|
skip_vars.append(var_name)
|
|
|
|
|
return skip_vars
|
|
|
|
|
|
2023-11-22 17:28:37 +08:00
|
|
|
def _prepare_inputs(self, inputs):
|
2020-05-08 16:52:06 +08:00
|
|
|
"""
|
|
|
|
|
Prepare inputs, outputs, attrs.
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(inputs, (tuple, list))
|
2020-05-28 17:38:33 +08:00
|
|
|
# Flatten inputs with nested structure into single list.
|
2023-03-09 20:30:52 +08:00
|
|
|
flatten_inputs = paddle.utils.flatten(inputs)
|
2023-03-30 10:11:14 +08:00
|
|
|
# Convert variable into Tensor and feed in training data.
|
2020-05-08 16:52:06 +08:00
|
|
|
input_vars = []
|
2023-08-21 17:24:02 +08:00
|
|
|
input_var_names = []
|
2021-06-16 14:27:19 +08:00
|
|
|
expected_place = framework._current_expected_place()
|
2020-05-28 17:38:33 +08:00
|
|
|
for i, value in enumerate(flatten_inputs):
|
2020-05-08 16:52:06 +08:00
|
|
|
if isinstance(value, np.ndarray):
|
Support slim eager (#39874)
* eager, test=develop
* fix bug, test=develop
* eager, test=develop
* merge legacy to fluid
* eager, test=develop
* eager, test=develop
* Refactor TensorAdd func by template and remove gradient_accumulation in eager
* Remove needless target name
* eager, test=develop
* eager, test=develop
* Use overload instead of template
* Remove legacy code
* Remove legacy code
* selectedrows, test=develop
* Remove DataType test
* eager, test=develop
* eager, test=develop
* support gan, test=develop
* Using Tensor directly instead of using EagerTensor
* support gradient_accumulation
* make test_imperative_lod_tensor_to_selected_rows longer
* make test_imperative_lod_tensor_to_selected_rows longer
* refine code
* ptb, test=develop
* Rename all EagerTensor to Tensor
* Rename some EagerTensor to Tensor
* rename EagerTensor to EagerVariable
* eager, test=develop
* eager, test=develop
* eager, test=develop
* eager, test=develop
* add more test
* eager, test=develop
* Support copiable selected rows and merge develop
* save load, eager, test=develop
* save load, eager, test=develop
* refine, test=develop
* remove useless _set_value method
* refine, test=develop
* refine, test=develop
* revert static_runner, test=develop
* EagerTensor to Tensor, test=develop
* refine, test=develop
* refine, test=develop
* clear grad, test=develop
* merge, develop
* merge, develop
* merge, test=develop
* merge, test=develop
* Support quant and part of slice
* support legacy static save
* extend slim tests time
* remove imperative on inference
* remove imperative on inference
* merge develop
* fix typo
* fix typo
* split slice related code into 2 part for imperative and eager
* split slice from inference
* split slice from inference
* fix test_tensor_register_hook
Co-authored-by: Wang Huan <wanghuan29@baidu.com>
Co-authored-by: Weilong Wu <veyron_wu@163.com>
Co-authored-by: wanghuancoder <wanghuancoder@163.com>
2022-03-03 10:13:22 +08:00
|
|
|
var = None
|
2023-03-30 10:11:14 +08:00
|
|
|
var = core.eager.Tensor(
|
|
|
|
|
value=value,
|
|
|
|
|
name=self._inputs[i].desc.name(),
|
|
|
|
|
persistable=False,
|
|
|
|
|
place=expected_place,
|
|
|
|
|
zero_copy=True,
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(value, core.eager.Tensor):
|
2021-06-16 14:27:19 +08:00
|
|
|
# NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
|
|
|
|
|
# into CUDAPlace when it's as input of multi Ops. so we move it in advance
|
|
|
|
|
# to avoid this problem.
|
|
|
|
|
if value.stop_gradient and not value.place._equals(
|
2022-10-23 20:01:27 +08:00
|
|
|
expected_place
|
|
|
|
|
):
|
2021-06-16 14:27:19 +08:00
|
|
|
var = value._copy_to(expected_place, False)
|
|
|
|
|
var.stop_gradient = True
|
2021-06-08 19:34:13 +08:00
|
|
|
else:
|
|
|
|
|
var = value
|
2020-05-08 16:52:06 +08:00
|
|
|
else:
|
|
|
|
|
continue
|
2023-08-21 17:24:02 +08:00
|
|
|
input_var_names.append(self._inputs[i].desc.name())
|
2020-05-08 16:52:06 +08:00
|
|
|
input_vars.append(var)
|
2020-05-28 17:38:33 +08:00
|
|
|
|
2023-11-22 17:28:37 +08:00
|
|
|
return input_vars, input_var_names
|
2022-07-14 21:11:44 +08:00
|
|
|
|
2023-11-22 17:28:37 +08:00
|
|
|
def _prepare_outputs(self):
|
|
|
|
|
return paddle.framework.core.create_empty_tensors_with_var_descs(
|
|
|
|
|
self._out_var_descs
|
|
|
|
|
)
|
2020-05-08 16:52:06 +08:00
|
|
|
|
2022-09-09 20:27:42 +08:00
|
|
|
def _create_scope_vec(self, program_id=None, use_scope_cache=False):
|
2022-10-23 20:01:27 +08:00
|
|
|
inner_scope = self._get_scope(
|
|
|
|
|
program_id=program_id, use_scope_cache=use_scope_cache
|
|
|
|
|
)
|
2023-11-02 16:50:47 +08:00
|
|
|
return [inner_scope]
|
2020-05-08 16:52:06 +08:00
|
|
|
|
2022-06-02 09:39:44 +08:00
|
|
|
def _create_cuda_graph_vec(self):
|
2023-03-30 10:11:14 +08:00
|
|
|
var = core.eager.Tensor(
|
2022-10-23 20:01:27 +08:00
|
|
|
core.VarDesc.VarType.FP32,
|
|
|
|
|
[],
|
|
|
|
|
"cuda_graph",
|
|
|
|
|
core.VarDesc.VarType.RAW,
|
|
|
|
|
True,
|
|
|
|
|
)
|
2022-06-02 09:39:44 +08:00
|
|
|
var.stop_gradient = True
|
|
|
|
|
return var
|
|
|
|
|
|
2023-08-07 15:10:33 +08:00
|
|
|
def _update_stop_gradient(self, out_vars):
|
|
|
|
|
# Update stop_gradient for all outputs
|
|
|
|
|
def set_stop_gradient(var_id, eager_tensor):
|
|
|
|
|
var = self._outputs[var_id]
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
eager_tensor.stop_gradient = var.stop_gradient
|
|
|
|
|
|
|
|
|
|
for idx, var in zip(self._outputs.var_ids, out_vars):
|
|
|
|
|
set_stop_gradient(idx, var)
|
|
|
|
|
|
2020-05-28 17:38:33 +08:00
|
|
|
def _restore_out(self, out_vars):
|
|
|
|
|
"""
|
2023-03-30 10:11:14 +08:00
|
|
|
Restores same nested outputs by only replacing the Variable with Tensor.
|
2020-05-28 17:38:33 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
flatten_outputs = self._outputs.tolist()
|
|
|
|
|
for i, idx in enumerate(self._outputs.var_ids):
|
|
|
|
|
flatten_outputs[idx] = out_vars[i]
|
|
|
|
|
outs = self._outputs.restore(flatten_outputs)
|
2020-06-30 20:31:23 +08:00
|
|
|
if outs is not None and len(outs) == 1:
|
2020-05-28 17:38:33 +08:00
|
|
|
outs = outs[0]
|
|
|
|
|
|
|
|
|
|
return outs
|
|
|
|
|
|
2020-08-04 19:06:24 +08:00
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _clone_for_test(self, main_program):
|
|
|
|
|
return main_program.clone(for_test=True)
|
|
|
|
|
|
2020-06-30 20:31:23 +08:00
|
|
|
def _is_no_value(self, var):
|
2023-03-30 10:11:14 +08:00
|
|
|
if isinstance(var, core.eager.Tensor) and var.shape == [1]:
|
2021-06-30 14:49:37 +08:00
|
|
|
# NOTE: .numpy() will insert MemcpySync operation, it hits performance.
|
|
|
|
|
if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
|
2020-06-30 20:31:23 +08:00
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _remove_no_value(self, out_vars):
|
|
|
|
|
"""
|
|
|
|
|
Removes invalid value for various-length return statement
|
|
|
|
|
"""
|
2023-03-30 10:11:14 +08:00
|
|
|
if isinstance(out_vars, core.eager.Tensor):
|
2020-06-30 20:31:23 +08:00
|
|
|
if self._is_no_value(out_vars):
|
|
|
|
|
return None
|
|
|
|
|
return out_vars
|
|
|
|
|
elif isinstance(out_vars, (tuple, list)):
|
|
|
|
|
if isinstance(out_vars, tuple):
|
2022-10-23 20:01:27 +08:00
|
|
|
res = tuple(
|
|
|
|
|
var for var in out_vars if not self._is_no_value(var)
|
|
|
|
|
)
|
2020-06-30 20:31:23 +08:00
|
|
|
else:
|
|
|
|
|
# isinstance(out_vars, list)
|
|
|
|
|
res = [var for var in out_vars if not self._is_no_value(var)]
|
|
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
has_removed = len(out_vars) > len(res)
|
2020-06-30 20:31:23 +08:00
|
|
|
# len(out_vars) > len(res) means we have removed var. This is
|
|
|
|
|
# preventing out_vars is empty or just one element at the beginning
|
|
|
|
|
if len(res) == 0 and has_removed:
|
|
|
|
|
return None
|
|
|
|
|
elif len(res) == 1 and has_removed:
|
|
|
|
|
return res[0]
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
return out_vars
|
|
|
|
|
|
2020-10-16 15:57:05 +08:00
|
|
|
def _set_grad_type(self, params, train_program):
|
2020-05-08 16:52:06 +08:00
|
|
|
# NOTE: if user set sparse gradient mode, the param's gradient
|
2024-11-20 11:46:17 +08:00
|
|
|
# will be SelectedRows, not DenseTensor. But tracer will just
|
|
|
|
|
# set param grad Tensor by forward Tensor(DenseTensor)
|
2020-05-08 16:52:06 +08:00
|
|
|
# If we don't change grad_var type here, RunProgramOp need
|
2024-11-14 19:46:56 +08:00
|
|
|
# transform SelectedRows to DenseTensor forcibly, it may not
|
2020-05-08 16:52:06 +08:00
|
|
|
# be user wanted result.
|
|
|
|
|
for param in params:
|
|
|
|
|
grad_name = param.name + core.grad_var_suffix()
|
2022-10-17 14:41:54 +08:00
|
|
|
grad_var = train_program.desc.block(0).find_var(grad_name.encode())
|
2020-05-08 16:52:06 +08:00
|
|
|
# NOTE: cannot find var desc maybe no problem, such as in batch_norm
|
|
|
|
|
if grad_var is None:
|
|
|
|
|
continue
|
|
|
|
|
param._set_grad_type(grad_var.type())
|
|
|
|
|
|
2020-05-09 19:09:45 +08:00
|
|
|
def _check_params_all_inited(self, main_program):
|
|
|
|
|
"""
|
|
|
|
|
Check all params from main program are already initialized, see details as follows:
|
2023-03-30 10:11:14 +08:00
|
|
|
1. all parameters in self._params should be type `framework.EagerParamBase` which are created in dygraph.
|
2020-05-09 19:09:45 +08:00
|
|
|
2. all parameters from transformed program can be found in self._params.
|
2023-03-30 10:11:14 +08:00
|
|
|
Because they share same data with EagerParamBase of original dygraph.
|
2020-05-09 19:09:45 +08:00
|
|
|
"""
|
|
|
|
|
if not isinstance(self._params, (list, tuple)):
|
|
|
|
|
raise TypeError(
|
2024-06-03 19:07:26 +08:00
|
|
|
f"Type of self._params in PartialProgramLayer should be list or tuple, but received {type(self._params)}."
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2020-05-09 19:09:45 +08:00
|
|
|
|
2020-06-29 16:05:16 +08:00
|
|
|
param_and_buffer_names_set = set()
|
|
|
|
|
for i, var in enumerate(self._params):
|
2024-01-29 16:00:57 +08:00
|
|
|
# self._params contains parameters and buffers with persistable=True.
|
2023-03-30 10:11:14 +08:00
|
|
|
if not isinstance(var, core.eager.Tensor):
|
2020-05-09 19:09:45 +08:00
|
|
|
raise TypeError(
|
2024-04-01 10:20:33 +08:00
|
|
|
f'Type of self._params[{i}] in PartialProgramLayer should be Parameter or Variable, but received {type(var)}.'
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2020-06-29 16:05:16 +08:00
|
|
|
param_and_buffer_names_set.add(var.name)
|
2020-05-09 19:09:45 +08:00
|
|
|
|
|
|
|
|
for block in main_program.blocks:
|
2022-10-19 15:54:41 +08:00
|
|
|
for name, var in block.vars.items():
|
2020-05-09 19:09:45 +08:00
|
|
|
if isinstance(var, framework.Parameter):
|
2020-06-29 16:05:16 +08:00
|
|
|
if name not in param_and_buffer_names_set:
|
2020-05-09 19:09:45 +08:00
|
|
|
raise ValueError(
|
2021-09-14 12:09:51 +08:00
|
|
|
"\n\tWe don't support to define layer with parameters in the function decorated by `@to_static`."
|
2024-06-03 19:07:26 +08:00
|
|
|
f"\n\tBut we found parameter({name}) was created in the decorated function."
|
2021-09-14 12:09:51 +08:00
|
|
|
"\n"
|
|
|
|
|
"\n\tRevise suggestion: "
|
2024-02-04 14:51:49 +08:00
|
|
|
"\n\t\t1. Please ensure all your sublayers are inherited from nn.Layer."
|
2021-09-14 12:09:51 +08:00
|
|
|
"\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2020-05-09 19:09:45 +08:00
|
|
|
|
2021-06-30 14:49:37 +08:00
|
|
|
def _valid_vars(self, vars):
|
2023-01-17 11:11:50 +08:00
|
|
|
return vars if vars else None
|
2021-06-30 14:49:37 +08:00
|
|
|
|
2020-05-08 16:52:06 +08:00
|
|
|
|
2023-04-04 18:41:35 +08:00
|
|
|
def partial_program_from(concrete_program, from_method=False):
|
2020-05-08 16:52:06 +08:00
|
|
|
inputs = concrete_program.inputs
|
2023-04-04 18:41:35 +08:00
|
|
|
|
|
|
|
|
# NOTE(SigureMo): Remove the first arg `self` from method args.
|
|
|
|
|
if inputs and from_method:
|
2020-05-08 16:52:06 +08:00
|
|
|
inputs = inputs[1:]
|
|
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
return PartialProgramLayer(
|
|
|
|
|
concrete_program.main_program,
|
|
|
|
|
inputs,
|
|
|
|
|
concrete_program.outputs,
|
|
|
|
|
concrete_program.parameters,
|
2024-04-01 10:20:33 +08:00
|
|
|
**concrete_program.kwargs,
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
2022-08-29 21:56:17 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@switch_to_static_graph
|
2022-10-23 20:01:27 +08:00
|
|
|
def add_build_strategy_for(
|
2022-12-29 16:56:09 +08:00
|
|
|
program, start_op_index, end_op_index, build_strategy=None, skip_vars=None
|
2022-10-23 20:01:27 +08:00
|
|
|
):
|
|
|
|
|
if start_op_index < end_op_index:
|
2022-08-29 21:56:17 +08:00
|
|
|
compiled_program = paddle.static.CompiledProgram(
|
|
|
|
|
core.Graph(program.desc, start_op_index, end_op_index),
|
2022-10-23 20:01:27 +08:00
|
|
|
build_strategy=build_strategy,
|
|
|
|
|
)
|
2022-12-29 16:56:09 +08:00
|
|
|
if skip_vars:
|
|
|
|
|
# TODO(Aurelius84): Need to unify name with C++, such as kSkipVarNames.
|
|
|
|
|
compiled_program._graph.set("skip_gc_vars", set(skip_vars))
|
2022-10-23 20:01:27 +08:00
|
|
|
compiled_program._compile(
|
|
|
|
|
core.Scope(), framework._current_expected_place()
|
|
|
|
|
)
|
2022-08-29 21:56:17 +08:00
|
|
|
ir_graph = framework.IrGraph(compiled_program._graph)
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
built_program = ir_graph.to_program()
|
2023-03-28 15:24:32 +08:00
|
|
|
if hasattr(compiled_program._program, 'lr_scheduler'):
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
built_program.lr_scheduler = compiled_program._program.lr_scheduler
|
2022-08-29 21:56:17 +08:00
|
|
|
else:
|
2023-03-02 21:13:03 +08:00
|
|
|
# can't just create a new program, we need copy the vardesc.
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
built_program = paddle.static.Program()
|
2023-03-02 21:13:03 +08:00
|
|
|
for var in program.block(0).vars.values():
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
built_program.block(0)._clone_variable(var, False)
|
2023-10-24 00:37:17 -07:00
|
|
|
|
|
|
|
|
# set back the parent_idx of blocks
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
for origin, current in zip(program.blocks, built_program.blocks):
|
2023-10-24 00:37:17 -07:00
|
|
|
current.desc.set_parent_idx(origin.desc.parent)
|
|
|
|
|
|
[CodeStyle][Typos][B-14,B-[17-19]] Fix typos(`Broardcast`,`Bradcast`,`Boardcast`,`buitin`,`buitlin`,`Buitin`,`builded`,`ba`) (#69966)
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos
* [CodeStyle][Typos][B-14,B-[17-19]] Fix typos(Broardcast,Bradcast,Boardcast,buitin,buitlin,Buitin,builded,ba)
2024-12-06 10:32:32 +08:00
|
|
|
return built_program
|