# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import typing import warnings import paddle from paddle import pir from paddle.autograd import ir_backward from paddle.autograd.backward_utils import ValueDict, ValueSet from paddle.base.core import ( call_decomp_rule, call_decomp_vjp, decomp_ops_contain_unused_output, has_decomp_rule, has_decomp_vjp, ) from paddle.base.framework import ( pir_chunk_id_guard, pir_op_name_guard, pir_op_role_guard, ) from paddle.base.libpaddle.pir import Block, Operation from paddle.base.wrapped_decorator import signature_safe_contextmanager from paddle.decomposition.recompute import DebugPrint, auto_recompute from paddle.framework import core from . import register logger = logging.getLogger(__name__) @signature_safe_contextmanager def prim_guard(): prim_state = core._is_all_prim_enabled() try: if not prim_state: core._set_prim_all_enabled(True) yield finally: if not prim_state: core._set_prim_all_enabled(False) def _build_tensor_tuple(xs): if isinstance(xs, pir.Value): return (xs,) elif isinstance(xs, typing.Sequence): return tuple(xs) return TypeError(f"Type {type(xs)} is not supported.") def _analyse_decomp_results(orig_outs, decomp_outs, op): assert len(orig_outs) == len(decomp_outs) res = [] for idx, value in enumerate(decomp_outs): if isinstance(orig_outs[idx], pir.Value): if ( op.name() in decomp_ops_contain_unused_output.keys() and idx in decomp_ops_contain_unused_output[op.name()] ): assert value[0] is None else: assert len(value) == 1 and isinstance(value[0], pir.Value) res.append(value[0]) else: res.append(value) return res def _prepare_python_api_arguments(op): """ For standard api of operator, its inputs should keep consistent with organization of its inputs and attrs. Args: op (Operator): The target operator. """ combine_op_name = "builtin.combine" inputs = [] for x in op.operands(): input = x.source() if input.initialized(): prev_op = input.get_defining_op() if ( isinstance(prev_op, Operation) and prev_op.name() == combine_op_name ): input = [item.source() for item in prev_op.operands()] inputs.append(input) else: # for optional input, such as scale for layer_norm op, # if it is not set, there will be an empty Value which is not initialized in ops.operands # therefore append None for it. inputs.append(None) # The inputs of Pir op builtin.combine will be restored as list of tensor. if op.name() == combine_op_name: return (inputs,) api_arguments = inputs + [op.attrs()[x] for x in op.get_attr_names()] return tuple(api_arguments) def _check_prim_dynamic(op): combine_op_name = "builtin.combine" inputs = [] for x in op.operands(): input = x.source() if input.initialized(): prev_op = input.get_defining_op() if ( isinstance(prev_op, Operation) and prev_op.name() == combine_op_name ): for item in prev_op.operands(): shape = item.source().shape if -1 in shape: warnings.warn( f"Decomp op does not support dynamic shape -1, but got shape {item.source().shape} in inputs of op {op.name()} " ) return True else: shape = input.shape if -1 in shape: warnings.warn( f"Decomp op does not support dynamic shape -1, but got shape {input.shape} in op {op.name()} " ) return True def _check_op_results( op_name, orig_outs, new_outs, orig_vars=None, dst_vars=None ): """ Check whether the replaced outputs are consistent with origin outputs. Args: op_name (str): The name of operator. orig_outs (tuple): The outputs of original operator. new_outs (tuple): The outputs of replaced operator. orig_vars (dict): Origin variables of original block. dst_vars (list): Corresponding replaced variables of Origin variables. """ assert len(orig_outs) == len(new_outs), ( f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, ' f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}' ) for orig_out, new_out in zip( orig_outs, new_outs, ): if (orig_out is None or new_out is None) and ( op_name not in core.ops_contain_none ): raise ValueError( f"op {op_name} should not contain any None value. original outs={orig_outs} and its composite rule outs={new_outs}" ) if orig_out is None: # to keep same as phi op definition, orig_out may receive None continue elif new_out is not None: if orig_vars is not None and dst_vars is not None: if orig_out in orig_vars: dst_vars[orig_vars[orig_out]] = new_out orig_dtype = orig_out.dtype new_dtype = new_out.dtype orig_shape = orig_out.shape new_shape = new_out.shape assert orig_dtype == new_dtype, ( f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, ' f'but orig_out dtype={orig_dtype} and new_out dtype={new_dtype}' ) assert -1 not in new_shape, ( f'when replace origin op {op_name} with composite rule, composite out shape has -1.' ) assert orig_shape == new_shape, ( f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, ' f'but orig_out shape={orig_shape} and new_out shape={new_shape}' ) assert not (orig_out is None) ^ (new_out is None), ( "orig_out and new_out should match." ) return def decompose( program, src_vars, blacklist=frozenset(), whitelist=frozenset(), start_index=0, end_index=-1, ): """ Search nonbasic ops which have be registered composite rules and replace them with primitive ops. The operators in blacklist will be excluded from program when decomposed into primitives, and only the operators in whitelist will be decomposed. The priority of blacklist is higher than whitelist, it means an operator both in blacklist and whitelist will not be decomposed. The finally set that will be decomposed is: (block.ops & ops have decomposite rule & whitelist) - blacklist Note: All variables must be contained inside the given program. Args: program (Program): The program to be processed. src_vars (list[Value]): In program, once some operator is decomposed, its vars will be replaced by new ones. This argument means some vars will be used later and corresponding vars will be returned for later usage. blacklist (frozenset): The Operators that will be exclude when decomposed into primitives. whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives. start_index (int): The start index of decomposed operator in global block, default 0; end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed. start_index and end_index follow the principle of left closed and right open, that is [start_index, end_index). Returns: dst_vars (list): A list contains all vars which replace origin ones in src_vars. """ blacklist = core.prim_config["forward_blacklist"] | blacklist assert isinstance(start_index, int) assert isinstance(end_index, int) return core.sinking_decomp( program, src_vars, blacklist, whitelist, start_index, end_index ) def _check_combine_inputs(input1, input2): '''check whether the inputs of two builtins.combine ops are the same''' builtin_combine_op1 = input1.get_defining_op() builtin_combine_op2 = input2.get_defining_op() if builtin_combine_op1.num_operands() != builtin_combine_op2.num_operands(): return False else: for i in range(builtin_combine_op1.num_operands()): if not ( builtin_combine_op1.operand_source(i).is_same( builtin_combine_op2.operand_source(i) ) ): return False return True def _check_op( fwd_op: pir.Operation, bwd_op: pir.Operation, ): '''check whether the bwd_op is corresponding to fwd_op''' if fwd_op is None or fwd_op.name() + "_grad" != bwd_op.name(): return False bwd_op_input_names = bwd_op.get_input_names() bwd_inputs = [x.source() for x in bwd_op.operands()] assert len(bwd_op_input_names) == len(bwd_inputs), ( "backward op names do not match backward op inputs" ) fwd_op_related_inputs_outputs = [] for idx, name in enumerate(bwd_op_input_names): if "_grad" not in name: fwd_op_related_inputs_outputs.append(bwd_inputs[idx]) fwd_inputs = [x.source() for x in fwd_op.operands()] fwd_outputs = fwd_op.results() fwd_vec_inputs = [ x.source() for x in fwd_op.operands() if x.source().initialized() and x.source().get_defining_op().name() == "builtin.combine" ] inserted_op_name_list = ["pd_op.full_int_array", "pd_op.full"] for operand in fwd_op_related_inputs_outputs: if ( operand.initialized() and operand.get_defining_op().name() == "builtin.combine" ): # for pir::VectorType in_fwd = False for vec_input in fwd_vec_inputs: if _check_combine_inputs(operand, vec_input): in_fwd = True break if not in_fwd: return False else: # for pir::VectorType if not ( operand in ValueSet(fwd_inputs) or operand in ValueSet(fwd_outputs) or operand.get_defining_op().name() in inserted_op_name_list ): return False return True def _get_fwd_op(bwd_op, grad_var_to_var): bwd_op_input_names = bwd_op.get_input_names() out_grad_name = ["out_grad", "Out_grad", "loss_grad"] for idx, input_name in enumerate(bwd_op_input_names): if input_name in out_grad_name: out_grad = bwd_op.operand(idx).source() if out_grad in grad_var_to_var: out = grad_var_to_var[out_grad] fwd_op = out.get_defining_op() return fwd_op return None def _decomp_fwd_op( block: Block, fwd_op: pir.Operation, grad_var_to_var: dict, prev_op=None ) -> tuple: ''' Decompose the forward op into a list of primitive ops. Args: block (Block): the block to which the forward op belongs. fwd_op (pir.Operation): the forward op to be decomposed. grad_var_to_var (dict): a dict obtained from distributed processing, which maps the backward grad variable to its corresponding forward variable. prev_op (pir.Operation): the previous op of fwd_op in the block. If prev_op is builtin.combine, insertion point when decomposing fwd_op will be set to prev_op. Returns: new_outputs (tuple(Value)): the new outputs after decomposing. has_decomposed: whether the forward op has been successfully decomposed. ''' with pir.core.program_guard(block.program): op_name = fwd_op.name() orig_outs = fwd_op.results() decom_rule = register.get_decomp_rule(op_name) has_sink_decomp_rule = has_decomp_rule(fwd_op) lower = decom_rule or has_sink_decomp_rule if lower: # step1: check dynamic shape, currently not supported if _check_prim_dynamic(fwd_op): return None, False # step2: check insertion point, if prev_op is builtin.combine (such as concat op), insertion point will be set to prev_op if prev_op is not None: pir.set_insertion_point(prev_op) else: pir.set_insertion_point(fwd_op) # step3: decompose op, and get new outputs input_args = _prepare_python_api_arguments(fwd_op) if has_sink_decomp_rule: decomp_outs = call_decomp_rule(fwd_op) new_outs = _analyse_decomp_results( orig_outs, decomp_outs, fwd_op ) else: new_outs = _build_tensor_tuple(decom_rule(*input_args)) _check_op_results(op_name, orig_outs, new_outs) # step4: upgrade grad_var_to_var with new outputs _upgrade_grad_var_to_var( grad_var_to_var, orig_outs=orig_outs, new_outs=new_outs ) # step5: replace original op with new ops, replace original output with new outputs if fwd_op.name() in decomp_ops_contain_unused_output.keys(): for idx in range(len(orig_outs)): if ( idx not in decomp_ops_contain_unused_output[fwd_op.name()] ): orig_outs[idx].replace_all_uses_with(new_outs[idx]) else: if fwd_op.name() in decomp_ops_contain_unused_output.keys(): orig_outs[0].replace_all_uses_with(new_outs[0]) else: fwd_op.replace_all_uses_with(new_outs) block.remove_op(fwd_op) # step6: remove redundant prev_op (builtin.combine) if prev_op is not None: remove_op = True for item in prev_op.results(): if item.has_one_use(): remove_op = False break if remove_op: block.remove_op(prev_op) prev_op = None return new_outs, True else: return tuple(orig_outs), False def _prepare_inputs(fwd_op): new_inputs = [] for input in fwd_op.operands(): if ( input.source().initialized() and input.source().get_defining_op().name() == "builtin.combine" ): # for pir::VectorType builtin_combine_op = input.source().get_defining_op() new_input = [ builtin_combine_op.operand_source(i) for i in range(0, builtin_combine_op.num_operands()) ] new_inputs.append(new_input) else: new_inputs.append([input.source()]) # for DenseTensorType return new_inputs def _prepare_grad_outputs(fwd_op, bwd_op): # check forward outputs and backward inputs fwd_outputs = fwd_op.results() fwd_output_names = fwd_op.get_output_names() assert len(fwd_output_names) == len(fwd_outputs), ( "forward op output names do not match forward op outputs" ) bwd_inputs = [x.source() for x in bwd_op.operands()] bwd_input_names = bwd_op.get_input_names() assert len(bwd_input_names) == len(bwd_inputs), ( "backward op input names do not match backward op inputs" ) # cut gradients from backward op's inputs fwd_inputs = [x.source() for x in fwd_op.operands()] fwd_vec_inputs = [ x.source() for x in fwd_op.operands() if x.source().initialized() and x.source().get_defining_op().name() == "builtin.combine" ] grad_outputs = [] grad_output_names = [] for i, bwd_input in enumerate(bwd_inputs): if ( bwd_input.initialized() and bwd_input.get_defining_op().name() == "builtin.combine" ): # for pir::VectorType in_fwd = False for vec_input in fwd_vec_inputs: if _check_combine_inputs(bwd_input, vec_input): in_fwd = True break if not in_fwd: grad_outputs.append([bwd_input]) grad_output_names.append(bwd_input_names[i]) else: if not ( bwd_input in ValueSet(fwd_inputs) or bwd_input in ValueSet(fwd_outputs) ): # for paddle::dialect::DenseTensorType grad_outputs.append([bwd_input]) grad_output_names.append(bwd_input_names[i]) # add fake grads for forward op's outputs which are not used in backward op # this is necessary for the call_vjp(), which ensures that len(out_grads) must be equal to len(outputs) new_grad_outputs = [] index = 0 for fwd_output_name in fwd_output_names: if (fwd_output_name + "_grad") in grad_output_names: new_grad_outputs.append(grad_outputs[index]) index += 1 else: new_grad_outputs.append([pir.fake_value()]) return new_grad_outputs def _prepare_stop_gradients(fwd_inputs, bwd_outputs): stop_gradients = [] for idx, bwd_output in enumerate(bwd_outputs): if bwd_output.initialized(): stop_gradient = [False] * len(fwd_inputs[idx]) else: stop_gradient = [True] * len(fwd_inputs[idx]) stop_gradients.append(stop_gradient) return stop_gradients def _upgrade_grad_var_to_var( grad_var_to_var, orig_grads=None, new_grads=None, orig_outs=None, new_outs=None, ): assert grad_var_to_var is not None, "grad_var_to_var should not be None" if orig_grads is not None and new_grads is not None: for idx, grad_input in enumerate(orig_grads): if grad_input in grad_var_to_var: grad_var_to_var[new_grads[idx]] = grad_var_to_var.pop( grad_input ) if orig_outs is not None and new_outs is not None: for grad_var, var in grad_var_to_var.items(): for i, orin_var in enumerate(orig_outs): if var.is_same(orin_var): grad_var_to_var[grad_var] = new_outs[i] def _decomp_bwd_with_vjp( block: Block, fwd_op: pir.Operation, bwd_op: pir.Operation, grad_var_to_var: dict, ) -> tuple: ''' Decompose the backward op into a list of primitive ops. If forward op has composite vjp rules (including custom vjp), call call_vjp() to get a list of primitive operators in backward graph, then replace backward op. ''' # step1: prepare arguments for call_vjp() fwd_inputs_ = _prepare_inputs(fwd_op) fwd_outputs_ = [[fwd_output] for fwd_output in fwd_op.results()] grad_outputs_ = _prepare_grad_outputs(fwd_op, bwd_op) stop_gradients_ = _prepare_stop_gradients(fwd_inputs_, bwd_op.results()) # step2: call call_vjp() to get a list of primitive operators which has the same meaning as the backward op bwd_op_idx = block.ops.index(bwd_op) before_num_ops = len(block.ops) new_grad_inputs = core.call_vjp( fwd_op, fwd_inputs_, fwd_outputs_, grad_outputs_, stop_gradients_ ) after_num_ops = len(block.ops) num_appended_ops = after_num_ops - before_num_ops # if forward op has no composite vjp rules, call_vjp() appends the same op as original backward op, skip decomposing, return False if num_appended_ops == 1 and block.ops[-1].name() == bwd_op.name(): block.remove_op(block.ops[-1]) return None, False else: # step3: record new outputs of the decomposed backward op if block.ops[-1].name() == "builtin.split": new_grad_inputs = [[block.ops[-1].operand(0).source()]] res = [] for grad_input in new_grad_inputs: if grad_input[0] is not None and grad_input[0].initialized(): res.append(grad_input[0]) else: res.append(pir.fake_value()) assert len(res) == len(bwd_op.results()), ( "results of original backward op do not match results of decomposed backward op" ) # step4: upgrade grad_var_to_var _upgrade_grad_var_to_var( grad_var_to_var, orig_grads=bwd_op.results(), new_grads=res ) # step5: replace original backward op with new primitive ops insert_idx = bwd_op_idx for i in range(before_num_ops, after_num_ops): block.move_op(block.ops[i], insert_idx) insert_idx += 1 bwd_op.replace_all_uses_with(res) block.remove_op(bwd_op) return tuple(res), True def _decomp_bwd_without_vjp( block: Block, bwd_op: pir.Operation, grad_var_to_var: dict, fwd_inputs: list, fwd_outputs_after_decompose: tuple, ) -> tuple: ''' Decompose the backward op into a list of primitive ops. If forward op has no composite vjp rules, and forward op has been decomposed to a list of primitive operators in forward graph previously, call grad() for the decomposed forward subgraph to get a list of primitive operators in backward graph, then replace backward op. ''' if fwd_outputs_after_decompose is None: raise RuntimeError( "To decompose backward op, please decompose forward op firstly" ) # step1: prepare arguments for grad() bwd_inputs = [x.source() for x in bwd_op.operands()] grad_inputs = bwd_op.results() grad_outputs = tuple( bwd_input for bwd_input in bwd_inputs if not ( bwd_input in ValueSet(fwd_inputs) or bwd_input in ValueSet(fwd_outputs_after_decompose) ) ) fwd_outputs_ = tuple( grad_var_to_var[grad_output] for grad_output in grad_outputs ) fwd_inputs_ = tuple( grad_var_to_var[grad_input] for grad_input in grad_inputs if grad_input.initialized() ) # step2: call grad() to get a list of primitive operators which has the same meaning as the backward op bwd_op_idx = block.ops.index(bwd_op) before_num_ops = len(block.ops) new_grad_inputs = ir_backward.grad(fwd_outputs_, fwd_inputs_, grad_outputs) after_num_ops = len(block.ops) # step3: record new outputs of the decomposed backward op res = [] input_grads_idx = 0 for idx, grad_input in enumerate(grad_inputs): if grad_input.initialized(): res.append(new_grad_inputs[input_grads_idx]) input_grads_idx += 1 else: res.append(pir.fake_value()) # step4: upgrade grad_var_to_var _upgrade_grad_var_to_var( grad_var_to_var, orig_grads=grad_inputs, new_grads=res ) # step5: replace original backward op with new primitive ops insert_idx = bwd_op_idx for i in range(before_num_ops, after_num_ops): block.move_op(block.ops[i], insert_idx) insert_idx += 1 bwd_op.replace_all_uses_with(res) block.remove_op(bwd_op) has_decomposed = True return tuple(res), has_decomposed def _decomp_bwd_op( block: Block, bwd_op: pir.Operation, grad_var_to_var: dict, ): ''' Decompose a backward op in pir program. Get the corresponding forward op according to grad_var_to_var firstly, then (1) try to decompose backward op by calling _decompose_bwd_with_vjp, if forward op has composite vjp rules (including custom vjp), _decompose_bwd_with_vjp will call call_vjp() to get a list of primitive operators in backward graph, then replace backward op successfully and return True; (2) when _decompose_bwd_with_vjp return False, means there is no composite vjp rules, try to decompose forward op firstly by calling _decomp_fwd_op firstly and get corresponding primitive operators in backward graph by calling _decompose_bwd_without_vjp secondly, then replace backward op successfully and return True; (3) if the backward op is still not decomposed by the above two steps, returns False. Args: block (Block): the block to which the backward op belongs. bwd_op (pir.Operation): the backward op to be decomposed. grad_var_to_var (dict): a dict obtained from distributed processing, which maps the backward grad variable to its corresponding forward variable. Return: new_input_grads (tuple(Value)): new results of backward op after decomposing. has_decomposed: whether the backward op has been successfully decomposed. ''' # get the corresponding forward op according to grad_var_to_var # check and ensure: bwd_inputs = out_grads + fwd_inputs[optional] + fwd_outputs[optional] fwd_op = _get_fwd_op(bwd_op, grad_var_to_var) if not _check_op(fwd_op, bwd_op): logger.debug( f'{bwd_op.name()} can not be decomposed due to the mismatch between forward op and backward op' ) return None, False if _check_prim_dynamic(fwd_op) or _check_prim_dynamic(bwd_op): return None, False # try to decompose backward op directly ( new_grads, bwd_has_decomposed, ) = _decomp_bwd_with_vjp( block, fwd_op, bwd_op, grad_var_to_var, ) if not bwd_has_decomposed: # try to decompose the forward op fwd_inputs = [x.source() for x in fwd_op.operands()] ( new_fwd_outputs, fwd_has_decomposed, ) = _decomp_fwd_op( block, fwd_op, grad_var_to_var, ) if fwd_has_decomposed: # try to decompose the backward op ( new_grads, bwd_has_decomposed, ) = _decomp_bwd_without_vjp( block, bwd_op, grad_var_to_var, fwd_inputs, new_fwd_outputs, ) return new_grads, bwd_has_decomposed def _get_all_bwd_ops(pir_program): bwd_ops = [] global_block = pir_program.global_block() for op in global_block.ops: if ( op.name().endswith("_grad") or op.name().endswith("_grad_") ) and op.name() not in bwd_ops: bwd_ops.append(op.name()) return bwd_ops def _set_prim_state(): state = [] prev_fwd_prim_state = core._is_fwd_prim_enabled() prev_bwd_prim_state = core._is_bwd_prim_enabled() state.append(prev_fwd_prim_state) state.append(prev_bwd_prim_state) core._set_prim_forward_enabled(True) core._set_prim_backward_enabled(True) prev_pir_api_flag = paddle.base.framework.get_flags("FLAGS_enable_pir_api")[ "FLAGS_enable_pir_api" ] paddle.framework.set_flags( {"FLAGS_enable_pir_api": True} ) # set in pir mode for operator overloading paddle.base.framework.global_var._use_pir_api_ = True state.append(prev_pir_api_flag) return state def _reset_prim_state(state): assert len(state) == 3, ( "state should contain fwd_prim_state, bwd_prim_state and pir_api_state" ) core._set_prim_forward_enabled(state[0]) core._set_prim_backward_enabled(state[1]) paddle.framework.set_flags({"FLAGS_enable_pir_api": state[2]}) paddle.base.framework.global_var._use_pir_api_ = state[2] def _translate_gradvartovar_to_pir(param_mapping, grad_var_to_var): '''translate grad_var_to_var (mapping VarDesc->VarDesc) to pir_grad_var_to_var (mapping Value->Value)''' pir_grad_var_to_var = ValueDict() for grad_var, var in grad_var_to_var.items(): if grad_var in param_mapping.keys() and var in param_mapping.keys(): if ( len(param_mapping[grad_var]) == 1 and len(param_mapping[var]) == 1 ): new_grad_var = param_mapping[grad_var][0] new_var = param_mapping[var][0] pir_grad_var_to_var[new_grad_var] = new_var else: new_grad_vars = [] new_vars = [] if len(param_mapping[grad_var]) == 1: new_grad_vars.append(param_mapping[grad_var][0]) elif ( len(param_mapping[grad_var]) == 2 and param_mapping[grad_var][1].get_defining_op().name() == "builtin.slice" ): new_grad_vars.append(param_mapping[grad_var][1]) else: for i in range(0, len(param_mapping[grad_var])): new_grad_vars.append(param_mapping[grad_var][i]) if len(param_mapping[var]) == 1: new_vars.append(param_mapping[var][0]) elif ( len(param_mapping[var]) == 2 and param_mapping[var][1].get_defining_op().name() == "builtin.slice" ): new_vars.append(param_mapping[var][1]) else: last_op = param_mapping[var][-1].get_defining_op() if last_op.name().endswith("_"): new_vars.append(param_mapping[var][0]) assert len(new_vars) == 1, "translate pir_grad_var_to_var error" for i in range(0, len(new_grad_vars)): pir_grad_var_to_var[new_grad_vars[i]] = new_vars[0] return pir_grad_var_to_var def _decomp_bwd_program(pir_program, pir_grad_var_to_var): '''Traverse and decompose all backward OPs in program''' with paddle.pir.core.program_guard(pir_program): bwd_ops = _get_all_bwd_ops(pir_program) undecomposed_bwd_ops = [] ops = pir_program.global_block().ops for op in ops: bwd_op_name = op.name() if op.name() in bwd_ops: _, bwd_has_decomposed = _decomp_bwd_op( pir_program.global_block(), op, pir_grad_var_to_var ) if ( not bwd_has_decomposed and bwd_op_name not in undecomposed_bwd_ops ): undecomposed_bwd_ops.append(bwd_op_name) logger.debug( f'Following backward ops can not be decomposed: {undecomposed_bwd_ops}' ) def _decomp_fwd_program(pir_program, pir_grad_var_to_var): '''Traverse and decompose all forward OPs in program''' with paddle.pir.core.program_guard(pir_program): ops = pir_program.global_block().ops bwd_ops = _get_all_bwd_ops(pir_program) # ops including compile-time infermeta, causing mismatched input shape and output shape, which is unsupported when decomposing. black_fwd_ops = ["pd_op.stack", "pd_op.squeeze"] undecomposed_fwd_ops = [] prev_op = None for op in ops: fwd_op_name = op.name() if op.name() not in bwd_ops: if op.name() not in black_fwd_ops: _, fwd_has_decomposed = _decomp_fwd_op( pir_program.global_block(), op, pir_grad_var_to_var, prev_op, ) if ( not fwd_has_decomposed and fwd_op_name not in undecomposed_fwd_ops ): undecomposed_fwd_ops.append(fwd_op_name) else: if fwd_op_name not in undecomposed_fwd_ops: undecomposed_fwd_ops.append(fwd_op_name) prev_op = op if op.name() == "builtin.combine" else None logger.debug( f'Following forward ops can not be decomposed: {undecomposed_fwd_ops}' ) def decompose_dist_program(pir_program): ''' Decompose all non-primitive ops into primitive ops in a pir program. It may contain forward ops and backward ops. ''' # decomp forward composite ops decompose(pir_program, []) # decomp backward ops blacklist = core.prim_config["backward_blacklist"] block = pir_program.global_block() pre_combine_op = None with paddle.pir.core.program_guard(pir_program): ops = pir_program.global_block().ops for op in ops: bwd_op_name = op.name() if bwd_op_name.split(".")[-1] in blacklist: continue skip_decomp = False if has_decomp_vjp(op): if ( not core._enable_prim_dynamic_shape() ) and _check_prim_dynamic(op): skip_decomp = True if not skip_decomp: with ( pir_op_name_guard(op.name()), pir_op_role_guard(op.op_role), pir_chunk_id_guard(op.chunk_id), ): pir.set_insertion_point(op) orig_outs = op.results() is_next_split = False decomp_outs = call_decomp_vjp(op) for i in range(len(orig_outs)): if orig_outs[i].has_one_use(): next_op = orig_outs[i].first_use().owner() if next_op.name() == "builtin.split": is_next_split = True _check_op_results( next_op.name(), next_op.results(), decomp_outs[i], ) next_op.replace_all_uses_with( decomp_outs[i] ) block.remove_op(next_op) if not is_next_split: new_outs = _analyse_decomp_results( orig_outs, decomp_outs, op ) _check_op_results(op.name(), orig_outs, new_outs) op.replace_all_uses_with(new_outs) block.remove_op(op) if op.name() == "builtin.combine": pre_combine_op = op if pre_combine_op is not None: remove_op = True for item in pre_combine_op.results(): if item.has_one_use(): remove_op = False break if remove_op: block.remove_op(pre_combine_op) pre_combine_op = None paddle.pir.set_insertion_point_to_block_end(block) def decompose_pir_program(pir_program, param_mapping, grad_var_to_var): ''' Decompose all PHI ops into prim ops in a pir program. Args: pir_program (Program): the program to be decomposed param_mapping (dict): a map of program variables to pir program values grad_var_to_var (dict): a dict obtained from distributed processing, which maps the backward grad variable to its corresponding forward variable. ''' # set prim flags and pir_api flags state = _set_prim_state() # translate grad_var_to_var to pir pir_grad_var_to_var = _translate_gradvartovar_to_pir( param_mapping, grad_var_to_var ) # decompose _decomp_bwd_program(pir_program, pir_grad_var_to_var) _decomp_fwd_program(pir_program, pir_grad_var_to_var) # reset prim flags and pir_api flags _reset_prim_state(state) return pir_grad_var_to_var def get_inputs_from_data_and_parameter(pir_program): results = [] for op in pir_program.global_block().ops: if op.name() == "pd_op.data": results.append(op.results()[0]) if op.name() == "builtin.parameter": results.append(op.results()[0]) return results def get_outputs_from_fetch_op(pir_program): results = [] for op in pir_program.global_block().ops: if op.name() == "pd_op.fetch": results.append(op.operand(0).source()) return results def get_grad_var_for_list(outputs, pir_grad_var_to_var): results = [] var2grad_var = ValueDict() for k, v in pir_grad_var_to_var.items(): var2grad_var[v] = k for output in outputs: results.append(var2grad_var[output]) return results def get_defining_op_indices(program, output_values): def getIdx(op): for idx, op_iter in enumerate(program.global_block().ops): if op == op_iter: return idx raise RuntimeError("op not found in program") results = [] for output in output_values: results.append(getIdx(output.get_defining_op())) return results def get_forward_op_idxs(program, is_forward_op_func): def getIdx(op): for idx, op_iter in enumerate(program.global_block().ops): if op == op_iter: return idx raise RuntimeError("op not found in program") results = [] for op in program.global_block().ops: if is_forward_op_func(op): results.append(getIdx(op)) return results def auto_recompute_pir_program(pir_program, is_forward_op_func=None): DebugPrint("Start Recompute Pir Program:") DebugPrint("Before Recompute: ", pir_program) # prepare essential inputs for auto_recompute inputs = get_inputs_from_data_and_parameter(pir_program) outputs = get_outputs_from_fetch_op(pir_program) fwd_op_end_idx = -1 if len(outputs): fwd_op_end_idx = max(get_defining_op_indices(pir_program, outputs)) logger = logging.getLogger("auto-recompute") logger.setLevel(logging.INFO) if is_forward_op_func is not None: try: fwd_op_end_idx = max( get_forward_op_idxs(pir_program, is_forward_op_func) ) except: logger.info("No Forward Ops Found!") if fwd_op_end_idx == -1: logger.info("Skip Auto Recompute!") return pir_program backward_op_start_idx = fwd_op_end_idx + 1 program, _ = auto_recompute( pir_program, inputs, outputs, [], fwd_op_end_idx, backward_op_start_idx, ) return program