# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import paddle from . import core, unique_name MAX_INTEGER = 2**31 - 1 MIN_INTEGER = -(2**31) def replace_ellipsis(var, item): from .framework import Variable # Use slice(None) to replace Ellipsis. # For var, var.shape = [3,4,5,6] # # var[..., 1:2] -> var[:, :, :, 1:2] # var[0, ...] -> var[0] # var[0, ..., 1:2] -> var[0, :, :, 1:2] item = list(item) # Remove Variable to skip bug when counting Ellipsis item_remove_var = [ ele for ele in item if not isinstance(ele, (Variable, paddle.pir.Value, np.ndarray)) and ele is not None ] ell_count = item_remove_var.count(Ellipsis) if ell_count == 0: return item elif ell_count > 1: raise IndexError("An index can only have a single ellipsis ('...')") ell_idx = item.index(Ellipsis) if ell_idx == len(item) - 1: return item[:-1] else: item[ell_idx : ell_idx + 1] = [slice(None)] * ( len(var.shape) - len(item) + item.count(None) + 1 ) return item def replace_ndarray_and_range(item): new_item = [] for slice_item in item: if isinstance(slice_item, np.ndarray): new_item.append(paddle.assign(slice_item)) elif isinstance(slice_item, range): new_item.append(list(slice_item)) else: new_item.append(slice_item) return new_item def replace_none(item): new_item = [] none_axes = [] for i, slice_item in enumerate(item): if slice_item is None: none_axes.append(i) else: new_item.append(slice_item) return new_item, none_axes def is_scalar_tensor(ele): from .framework import Variable if isinstance(ele, Variable): if len(ele.shape) == 0 and ele.dtype != paddle.bool: return True elif isinstance(ele, paddle.pir.Value): if len(ele.shape) == 0 and ele.dtype != paddle.base.libpaddle.BOOL: return True return False def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags): from .framework import Variable if paddle.utils._contain_var(attr): inputs[tensor_attr_name] = paddle.utils._convert_to_tensor_list( attr, dtype="int64" ) for i, dim in enumerate(attr): if isinstance(dim, (Variable, paddle.pir.Value)): attrs[attr_name].append(-1) infer_flags[i] = -1 else: attrs[attr_name].append(dim) else: attrs[attr_name] = attr def get_value_for_bool_tensor(var, item): if len(item.shape) > len(var.shape): raise IndexError( "The dims of bool index doesn't match indexed array, " "the dims of bool index except to be equal or less " f"than {len(var.shape)}, but received {len(item.shape)}." ) i = 0 item_shape = item.shape while i < len(item.shape): dim_len = item_shape[i] if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]: raise IndexError( "The dimension of bool index doesn't match indexed array along " f"dimension {i}, the target dimension is {var.shape[i]}, but received {dim_len}." ) i += 1 if len(item.shape) == len(var.shape): return paddle.masked_select(var, item) bool_2_idx = paddle.nonzero(item) return paddle.gather_nd(var, bool_2_idx) def _setitem_for_tensor_array(var, item, value): """branches for tensor array setitem operation. A item can be a: (1) int/Variable, which is a simple number/variable such as [1], [-2] (2) Slice, which is represented by bounds such as [2:-1] (3) Tuple, which includes the above two cases such as [2:-1, 1] If item is case (1), we perform paddle.tensor.array_write, in other cases, we raise a NotImplementedError. """ from .framework import Variable assert not paddle.in_dynamic_mode(), ( "setitem for tensor_array must be called in static graph mode." ) if isinstance(item, (Variable, paddle.pir.Value, int)): from paddle.jit.dy2static.convert_operators import to_static_variable from paddle.tensor import array_write item = paddle.cast(to_static_variable(item), dtype='int64') value = to_static_variable(value) return array_write(x=value, i=item, array=var) else: raise NotImplementedError( f"Only support __setitem__ by Int/Variable in tensor_array, but gets {type(item)}" ) def deal_advanced_index( ori_tensor, indices, is_for_setitem, values, out_is_view=True ): """ Transpose origin Tensor and advanced indices to the front. Returns: transed_tensor (Tensor): transposed tensor, corresponding with advanced indices transed_index (List): advanced indices transposed to the front trans_back_dim (List): order of axes to transpose back to original order. Only used in __setitem__. pos_of_new_dim (int): axis of new dim in the result. Only used in __getitem__. rank_of_new_dim (int): rank of new dim in the result. Only used in __getitem__. transed_value_tensor (Tensor): value tensor transposed to the front. Only used in __setitem__. """ transed_dim = [] transed_index = [] # These flags indicates whether the result get by gather_nd requires a second transpose. # Only used in __getitem__. pos_of_new_dim = MAX_INTEGER rank_of_new_dim = 1 for i, indice in enumerate(indices): if indice is not None: if i == 0: # case 1: advanced indices at axis 0, the new dim will be at first. pos_of_new_dim = 0 if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1: # case 2: there are not adjacent advanced indices, the new dim will be at first. pos_of_new_dim = 0 else: pos_of_new_dim = min(pos_of_new_dim, i) rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim) transed_dim.append(i) transed_index.append(indice[1]) for i in range(ori_tensor.ndim): if indices[i] is None: transed_dim.append(i) trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else [] transed_value_tensor = None if transed_dim == list(range(ori_tensor.ndim)): transed_tensor = ori_tensor if is_for_setitem: transed_value_tensor = values else: out_is_view = True transed_tensor = ori_tensor.transpose(transed_dim) if is_for_setitem: if values.ndim > 1 and pos_of_new_dim != 0: # If the value tensor is not a scalar / 1-D Tensor, and the src tensor was # transposed at 1st dim, the value tensor should be transposed too. transed_value_tensor = values.transpose(transed_dim) else: transed_value_tensor = values return ( transed_tensor, transed_index, trans_back_dim, pos_of_new_dim, rank_of_new_dim, transed_value_tensor, out_is_view, ) def slice_is_same_to_original(start, end, step): if start is None and end is None and step is None: return True # If there is Variable, we cannot determine whether it is the same to original. if isinstance(start, (paddle.base.Variable, paddle.pir.Value)): return False if isinstance(end, (paddle.base.Variable, paddle.pir.Value)): return False if isinstance(step, (paddle.base.Variable, paddle.pir.Value)): return False return start == 0 and end == MAX_INTEGER and step == 1 def is_tensor_array_type(value): from .framework import in_pir_mode if in_pir_mode(): return value.is_dense_tensor_array_type() else: return ( hasattr(value, "desc") and value.desc.type() == core.VarDesc.VarType.DENSE_TENSOR_ARRAY ) def parse_index(x, indices): is_tensor_array = is_tensor_array_type(x) advanced_index = ( [] if is_tensor_array else [None] * 2 * len(x.shape) ) # content is (dim, index) # for set_value / slice / strided_slice OP decrease_axes = [] axes = [] starts = [] ends = [] steps = [] use_strided_slice = False has_advanced_index = False if not isinstance(indices, tuple): indices = (indices,) indices = replace_ndarray_and_range(indices) indices = replace_ellipsis(x, indices) indices, none_axes = replace_none(indices) estimated_dim = 0 dim = 0 for i, slice_item in enumerate(indices): start, end, step = None, None, None if type(slice_item) is int: if ( not is_tensor_array and x.shape[dim] is not None and x.shape[dim] >= 0 and slice_item >= x.shape[dim] ): # For python, if users write a, b = var, the __getitem__ # method will iterate through 0, 1, 2 ... until __getitem__ # throws an IndexError, then stop. The var[0], var[1] will # be given to a, b respectively. If more values are given, # the unpack size would cause error. # We raises IndexError here to support grammar like `a, b = var` raise IndexError( f"slice_item {slice_item} at dim {dim} should be >= 0 and < x.shape[{dim}]: {x.shape[dim]}" ) # not calculate result to reduce call times for slice OP. decrease_axes.append(dim) start = slice_item step = 1 end = slice_item + 1 if slice_item != -1 else MAX_INTEGER dim += 1 elif is_scalar_tensor(slice_item): # not calculate result to reduce call times for slice OP. decrease_axes.append(dim) start = slice_item step = 1 end = slice_item + 1 dim += 1 elif isinstance(slice_item, bool): # single bool is advanced-indexing none_axes.append(dim) advanced_index[estimated_dim] = ( estimated_dim, paddle.to_tensor([slice_item]), ) has_advanced_index = True estimated_dim += 1 elif isinstance(slice_item, slice): start = slice_item.start end = slice_item.stop step = slice_item.step if start is None and end is None and step is None: estimated_dim += 1 dim += 1 continue step = 1 if step is None else step if start is None: start = 0 if step > 0 else MAX_INTEGER if end is None: end = MAX_INTEGER if step > 0 else MIN_INTEGER if not ( is_tensor_array or isinstance(end, (paddle.base.Variable, paddle.pir.Value)) or isinstance(step, (paddle.base.Variable, paddle.pir.Value)) ): if x.shape[dim] != -1 and end >= x.shape[dim]: end = MAX_INTEGER if step > 0 else x.shape[dim] estimated_dim += 1 dim += 1 elif isinstance(slice_item, (list, tuple)): advanced_index[estimated_dim] = ( estimated_dim, paddle.to_tensor(slice_item), ) if ( advanced_index[estimated_dim][1].dtype == paddle.bool and len(slice_item) != x.shape[dim] ): raise IndexError( f"The shape of boolean index {len(slice_item)} did not match indexed tensor {x.shape[dim]} along axis {dim}" ) has_advanced_index = True estimated_dim += 1 dim += 1 elif isinstance(slice_item, paddle.base.Variable): # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. if ( slice_item.dtype == paddle.bool or slice_item.dtype == paddle.base.libpaddle.BOOL ): if slice_item.ndim == 0: # 0-D bool Tensor, same as single PY-bool. none_axes.append(dim) elif slice_item.shape[0] != x.shape[dim]: raise IndexError( f"The shape of boolean index {slice_item.shape[0]} did not match indexed tensor {x.shape[dim]} along axis {dim}" ) advanced_index[estimated_dim] = (estimated_dim, slice_item) has_advanced_index = True estimated_dim += 1 dim += 1 elif isinstance(slice_item, paddle.pir.Value): # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. if slice_item.dtype == paddle.pir.core.DataType.BOOL: if slice_item.ndim == 0: # 0-D bool Tensor, same as single PY-bool. none_axes.append(dim) elif slice_item.shape[0] != x.shape[dim]: raise IndexError( f"The shape of boolean index {slice_item.shape[0]} did not match indexed tensor {x.shape[dim]} along axis {dim}" ) advanced_index[estimated_dim] = (estimated_dim, slice_item) has_advanced_index = True estimated_dim += 1 dim += 1 else: raise IndexError( f"Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {slice_item}." ) if not slice_is_same_to_original(start, end, step): starts.append(start) ends.append(end) steps.append(step) axes.append(dim - 1) use_strided_slice = ( True if ( isinstance(step, (paddle.base.Variable, paddle.pir.Value)) or step != 1 ) else use_strided_slice ) return ( starts, ends, steps, axes, none_axes, decrease_axes, advanced_index, has_advanced_index, use_strided_slice, ) def _setitem_static(x, indices, values): """ In dynamic mode, this function will modify the value at input tensor, returning same Tensor as input. But it will return a new Tensor with assigned value in static mode. Args: x(Tensor): Tensor to be set value. indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. values(Tensor|Number|Ndarray): values to be assigned to the x. """ from . import in_dynamic_or_pir_mode from .framework import Variable, in_pir_mode is_tensor_array = is_tensor_array_type(x) if is_tensor_array: return _setitem_for_tensor_array(x, indices, values) # step1: parsing the index and recording them ( starts, ends, steps, axes, none_axes, decrease_axes, advanced_index, has_advanced_index, use_strided_slice, ) = parse_index(x, indices) inputs = {'Input': x} attrs = { 'axes': axes, 'starts': starts, 'ends': ends, 'steps': steps, 'decrease_axes': decrease_axes, 'none_axes': none_axes, } value_tensor = None StartsTensorList = None EndsTensorList = None StepsTensorList = None shape = None if paddle.utils._contain_var(starts): StartsTensorList = paddle.utils._convert_to_tensor_list(starts) inputs['StartsTensorList'] = StartsTensorList del attrs['starts'] if paddle.utils._contain_var(ends): EndsTensorList = paddle.utils._convert_to_tensor_list(ends) inputs['EndsTensorList'] = EndsTensorList del attrs['ends'] if paddle.utils._contain_var(steps): StepsTensorList = paddle.utils._convert_to_tensor_list(steps) inputs['StepsTensorList'] = StepsTensorList del attrs['steps'] if not has_advanced_index: # step2. Parse values dtype = x.dtype attrs['dtype'] = dtype from .data_feeder import convert_dtype if isinstance(values, (bool, int, float, complex)): values = np.array([values]).astype(convert_dtype(dtype)) if isinstance(values, np.ndarray): shape = list(values.shape) values = values.ravel().tolist() attrs["values"] = values attrs["shape"] = shape elif isinstance(values, (Variable, paddle.pir.Value)): values = values.astype(dtype) inputs["ValueTensor"] = values value_tensor = values else: raise TypeError( "Only support to assign an integer, float, numpy.ndarray or " f"paddle.Tensor to a paddle.Tensor, but received {type(values)}" ) # step3.1: Only basic indexing, use OP set_value to set value. if in_dynamic_or_pir_mode(): if in_pir_mode(): if isinstance(starts, (list, tuple)): if paddle.utils._contain_var(starts): starts = paddle.utils.get_int_tensor_list(starts) if isinstance(ends, (list, tuple)): if paddle.utils._contain_var(ends): ends = paddle.utils.get_int_tensor_list(ends) if isinstance(steps, (list, tuple)): if paddle.utils._contain_var(steps): steps = paddle.utils.get_int_tensor_list(steps) if value_tensor is None: output = paddle._C_ops.set_value_( x, starts, ends, steps, axes, decrease_axes, none_axes, shape, values, ) else: output = paddle._C_ops.set_value_with_tensor_( x, value_tensor, starts, ends, steps, axes, decrease_axes, none_axes, ) if in_pir_mode(): # map var to the new output, for dy2static from paddle.jit.dy2static.parameter_recorder import ( _global_inplace_map, ) _global_inplace_map.add( paddle.static.default_main_program(), x, output ) return output else: helper = paddle.base.layer_helper.LayerHelper( 'set_value', **locals() ) if helper.main_program.current_block_idx != 0: # not in global block, we should create a global variable. output = helper._create_global_variable_for_type_inference( dtype=x.dtype ) else: output = helper.create_variable_for_type_inference( dtype=x.dtype ) cur_block = paddle.static.default_main_program().current_block() cur_block.append_op( type="set_value", inputs=inputs, outputs={'Out': output}, attrs=attrs, inplace_map={"Input": "Out"}, ) # map var to the new output paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add( cur_block.program, x.desc.id(), output ) return output else: # step3.2: Case for there are advanced indexing. # 1. get __getitem__ result of basic indexing; # 2. transpose original tensor so that the axis with advanced indexing will come to the first; # 3. assign values to the sliced result by index_put OP; # 4. transpose back and assign the result to original tensor by set_value OP. if not isinstance(values, (Variable, paddle.pir.Value)): values = paddle.assign(values).astype(x.dtype) sub_tensor, is_view = get_tensor_with_basic_indexing( x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice, ) ( transed_sub_tensor, adjusted_advanced_index, transback_dim, _, _, values, is_view, ) = deal_advanced_index( sub_tensor, advanced_index, True, values, is_view ) if values.dtype != transed_sub_tensor.dtype: values = values.astype(transed_sub_tensor.dtype) if paddle.in_dynamic_mode(): if ( len(adjusted_advanced_index) == 1 and adjusted_advanced_index[0].dtype in (paddle.bool, paddle.base.libpaddle.BOOL) and len( adjusted_advanced_index[0].shape == len(transed_sub_tensor.shape) ) ): if values.shape != transed_sub_tensor.shape: values = values.expand(transed_sub_tensor.shape) transed_sub_tensor = paddle._C_ops.where_( paddle.logical_not(adjusted_advanced_index[0]), transed_sub_tensor, values, ) if not is_view: return x else: # NOTE(zoooo0820): directly return result instead of another set_value, after backward bug fixed. transed_sub_tensor = transed_sub_tensor.index_put_( adjusted_advanced_index, values ) if not is_view: return x else: transed_sub_tensor = transed_sub_tensor.index_put( adjusted_advanced_index, values ) transback_sub_tensor = transed_sub_tensor.transpose(transback_dim) inputs["ValueTensor"] = transback_sub_tensor if in_dynamic_or_pir_mode(): if in_pir_mode(): if isinstance(starts, (list, tuple)): if paddle.utils._contain_var(starts): starts = paddle.utils.get_int_tensor_list(starts) if isinstance(ends, (list, tuple)): if paddle.utils._contain_var(ends): ends = paddle.utils.get_int_tensor_list(ends) if isinstance(steps, (list, tuple)): if paddle.utils._contain_var(steps): ends = paddle.utils.get_int_tensor_list(steps) output = paddle._C_ops.set_value_with_tensor_( x, transback_sub_tensor, starts, ends, steps, axes, decrease_axes, none_axes, ) from paddle.jit.dy2static.parameter_recorder import ( _global_inplace_map, ) _global_inplace_map.add( paddle.static.default_main_program(), x, output ) else: helper = paddle.base.layer_helper.LayerHelper( 'set_value', **locals() ) if helper.main_program.current_block_idx != 0: # not in global block, we should create a global variable. output = helper._create_global_variable_for_type_inference( dtype=x.dtype ) else: output = helper.create_variable_for_type_inference( dtype=x.dtype ) cur_block = paddle.static.default_main_program().current_block() cur_block.append_op( type="set_value", inputs=inputs, outputs={'Out': output}, attrs=attrs, inplace_map={"Input": "Out"}, ) # map var to the new output paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add( cur_block.program, x.desc.id(), output ) return output def get_tensor_with_basic_indexing( x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice ): from .dygraph.base import in_to_static_mode out_is_view = False if in_to_static_mode() and hasattr(x, "is_view_var"): x.is_view_var = True if len(axes) == 0: out = x else: out_is_view = True op_type = "strided_slice" if use_strided_slice else "slice" inputs = {'Input': [x]} attrs = { 'axes': axes, 'starts': [], 'ends': [], 'decrease_axis': decrease_axes, } if use_strided_slice: attrs['strides'] = [] infer_flags = [1] * len(axes) deal_attrs( attrs, starts, "starts", "StartsTensorList", inputs, infer_flags ) deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags) deal_attrs( attrs, steps, "strides", "StridesTensorList", inputs, infer_flags ) attrs['infer_flags'] = infer_flags from . import in_dynamic_or_pir_mode, in_pir_mode if in_dynamic_or_pir_mode(): if "StartsTensorList" in inputs.keys(): st = inputs['StartsTensorList'] else: st = attrs['starts'] if "EndsTensorList" in inputs.keys(): end = inputs['EndsTensorList'] else: end = attrs['ends'] if "StridesTensorList" in inputs.keys(): stride = inputs['StridesTensorList'] else: stride = attrs['strides'] if use_strided_slice: # TODO(zoooo0820): support strided_slice_array until PIR API is ready if in_pir_mode(): if isinstance(st, (list, tuple)): if paddle.utils._contain_var(st): st = paddle.utils.get_int_tensor_list(st) if isinstance(end, (list, tuple)): if paddle.utils._contain_var(end): end = paddle.utils.get_int_tensor_list(end) if isinstance(stride, (list, tuple)): if paddle.utils._contain_var(stride): stride = paddle.utils.get_int_tensor_list(stride) out = paddle._C_ops.strided_slice(x, axes, st, end, stride) if len(decrease_axes) > 0: out = paddle._C_ops.squeeze(out, decrease_axes) else: if in_pir_mode(): if isinstance(st, (list, tuple)): if paddle.utils._contain_var(st): st = paddle.utils.get_int_tensor_list(st) if isinstance(end, (list, tuple)): if paddle.utils._contain_var(end): end = paddle.utils.get_int_tensor_list(end) if x.is_dense_tensor_array_type(): if len(decrease_axes) > 0: return ( paddle._pir_ops.slice_array_dense(x, st), False, ) else: return ( paddle._pir_ops.slice_array(x, st, end), False, ) out = paddle._C_ops.slice( x, axes, st, end, attrs['infer_flags'], attrs['decrease_axis'], ) else: target_block = paddle.static.default_main_program().current_block() slice_out_var = target_block.create_var( name=unique_name.generate_with_ignorable_key( x.name + "_" + op_type ), dtype=x.dtype, ) target_block.append_op( type=op_type, inputs=inputs, outputs={'Out': [slice_out_var]}, attrs=attrs, ) out = slice_out_var if len(none_axes) > 0: out_is_view = True # Deal with cases that decrease_axes is not empty # For example: # # x.shape: (2,3,4) # out = x[0, 0:2, None] # out.shape : (2, 1, 4) for idx, axis in enumerate(none_axes): l = len([i for i in decrease_axes if i < axis]) new_axis = axis - l none_axes[idx] = new_axis out = paddle.unsqueeze(out, axis=none_axes) if in_to_static_mode() and hasattr(out, "is_view_var"): out.is_view_var = True return out, out_is_view def _getitem_static(x, indices): """ Args: x(Tensor): Tensor to be indexing. indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. """ # step1: parsing the index and recording them ( starts, ends, steps, axes, none_axes, decrease_axes, advanced_index, has_advanced_index, use_strided_slice, ) = parse_index(x, indices) # step2: Dealing with basic indexing out, _ = get_tensor_with_basic_indexing( x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice, ) # step3: Dealing with advanced indexing if has_advanced_index: ( transed_tensor, adjusted_advanced_index, _, pos_of_new_dim, rank_of_new_dim, _, _, ) = deal_advanced_index(out, advanced_index, False, None) # TODO(zooooo0820): Replacing gather_nd to another advanced OP for handling of mixed indexes more efficiently if len(adjusted_advanced_index) == 1 and adjusted_advanced_index[ 0 ].dtype in (paddle.bool, paddle.base.libpaddle.BOOL): # Note: now slice not support 0-size Tensor, so only one bool tensor can return empty 0-size. out = get_value_for_bool_tensor( transed_tensor, adjusted_advanced_index[0] ) else: adjusted_advanced_index = parse_bool_and_broadcast_indices( adjusted_advanced_index ) if len(adjusted_advanced_index) > 1: advanced_index_tensor = paddle.stack( adjusted_advanced_index, axis=-1 ) else: # fast path for single bool tensor, since stack is much slower than unsuqeeze advanced_index_tensor = adjusted_advanced_index[0].unsqueeze(-1) out = paddle.gather_nd(transed_tensor, advanced_index_tensor) if pos_of_new_dim != 0: perm = ( list(range(rank_of_new_dim, pos_of_new_dim + rank_of_new_dim)) + list(range(0, rank_of_new_dim)) + list(range(pos_of_new_dim + rank_of_new_dim, out.ndim)) ) out = out.transpose(perm) return out def parse_bool_and_broadcast_indices(indices): # deal with multiple Tensors and translating bool tensor to int tensor. # In static mode, bool-tensor cannot be broadcasted since its corresponding int tensor's shape cannot be inferred. for i, indice in enumerate(indices): if ( indice.dtype == paddle.bool or indice.dtype == paddle.base.libpaddle.BOOL ): indices[i] = paddle.nonzero(indice)[:, 0] if len(indices) > 1: indices = paddle.broadcast_tensors(indices) return indices