# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import linecache import os import re import sys import traceback import numpy as np from .origin_info import Location, OriginInfo, global_origin_info_map from .utils import ( RE_PYMODULE, is_api_in_module_helper, ) __all__ = [] ERROR_DATA = "Error data about original source code information and traceback." # A flag to set whether to open the dygraph2static error reporting module SIMPLIFY_ERROR_ENV_NAME = "TRANSLATOR_SIMPLIFY_NEW_ERROR" DEFAULT_SIMPLIFY_NEW_ERROR = 1 # A flag to set whether to display the simplified error stack DISABLE_ERROR_ENV_NAME = "TRANSLATOR_DISABLE_NEW_ERROR" DEFAULT_DISABLE_NEW_ERROR = 0 SOURCE_CODE_RANGE = 5 BLANK_COUNT_BEFORE_FILE_STR = 4 def attach_error_data(error, in_runtime=False): """ Attaches error data about original source code information and traceback to an error. Args: error(Exception): An native error. in_runtime(bool): `error` is raised in runtime if in_runtime is True, otherwise in compile time Returns: An error attached data about original source code information and traceback. """ e_type, e_value, e_traceback = sys.exc_info() tb = traceback.extract_tb(e_traceback)[1:] error_data = ErrorData(e_type, e_value, tb, global_origin_info_map) error_data.in_runtime = in_runtime setattr(error, ERROR_DATA, error_data) return error class TraceBackFrame(OriginInfo): """ Traceback frame information. """ def __init__(self, location, function_name, source_code): self.location = location self.function_name = function_name self.source_code = source_code self.error_line = '' def formatted_message(self): # self.source_code may be empty in some functions. # For example, decorator generated function return ( ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n\t{}'.format( self.location.filepath, self.location.lineno, self.function_name, ( self.source_code.lstrip() if isinstance(self.source_code, str) else self.source_code ), ) ) class TraceBackFrameRange(OriginInfo): """ Traceback frame information. """ def __init__(self, location, function_name): self.location = location self.function_name = function_name self.source_code = [] self.error_line = '' blank_count = [] begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2)) for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE): line = linecache.getline(self.location.filepath, i).rstrip('\n') line_lstrip = line.lstrip() self.source_code.append(line_lstrip) if not line_lstrip: # empty line from source code blank_count.append(-1) else: blank_count.append(len(line) - len(line_lstrip)) if i == self.location.lineno: self.error_line = self.source_code[-1] hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE' self.source_code.append(hint_msg) blank_count.append(blank_count[-1]) # Note(gouzil): Under jupyter, files are read multiple times, # and we can't actively clean the read cache, which can cause subsequent reads to fail. # It is not possible to modify the contents of the file in the meantime, # so there is no need to clear the cache # remove top and bottom empty line in source code while len(self.source_code) > 0 and not self.source_code[0]: self.source_code.pop(0) blank_count.pop(0) while len(self.source_code) > 0 and not self.source_code[-1]: self.source_code.pop(-1) blank_count.pop(-1) min_black_count = min([i for i in blank_count if i >= 0]) for i in range(len(self.source_code)): # if source_code[i] is empty line between two code line, dont add blank if self.source_code[i]: self.source_code[i] = ( ' ' * ( blank_count[i] - min_black_count + BLANK_COUNT_BEFORE_FILE_STR * 2 ) + self.source_code[i] ) def formatted_message(self): msg = ( ' ' * BLANK_COUNT_BEFORE_FILE_STR + f'File "{self.location.filepath}", line {self.location.lineno}, in {self.function_name}\n' ) # add empty line after range code return msg + '\n'.join(self.source_code) class SuggestionDict: def __init__(self): # {(keywords): (suggestions)} self.suggestion_dict = { ('is not initialized.', 'Hint:', 'IsInitialized'): ( "Please ensure all your sublayers are inherited from nn.Layer.", "Please ensure there is no tensor created explicitly depended on external data, " + "we suggest to register it as buffer tensor. " + "See https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/jit/principle_cn.html#buffers for details", ) } def keys(self): return self.suggestion_dict.keys() def __getitem__(self, key): return self.suggestion_dict[key] class Dy2StKeyError(Exception): pass class ErrorData: """ Error data attached to an exception which is raised in un-transformed code. """ def __init__( self, error_type, error_value, origin_traceback, origin_info_map ): self.error_type = error_type self.error_value = error_value self.origin_traceback = origin_traceback self.origin_info_map = origin_info_map self.in_runtime = False self.suggestion_dict = SuggestionDict() def create_exception(self): message = self.create_message() if self.error_type is KeyError: new_exception = Dy2StKeyError(message) else: new_exception = self.error_type(message) setattr(new_exception, ERROR_DATA, self) return new_exception def numpy_api_check(self, format_exception, error_line): if self.error_type is not TypeError: return format_exception tb = self.origin_traceback func_str = None for frame in tb: searched_name = re.search( rf'({RE_PYMODULE})*{frame.name}', error_line, ) if searched_name: func_str = searched_name.group(0) break try: globals = {'np': np} fn = eval(func_str, globals) module_result = is_api_in_module_helper(fn, "numpy") is_numpy_api_err = module_result or ( func_str.startswith("numpy.") or func_str.startswith("np.") ) except Exception: is_numpy_api_err = False if is_numpy_api_err and func_str: return [ f"TypeError: Code '{error_line}' called numpy API {func_str}, please use Paddle API to replace it.", " values will be changed to variables by dy2static, numpy api can not handle variables", ] else: return format_exception def create_message(self): """ Creates a custom error message which includes trace stack with source code information of dygraph from user. """ message_lines = [] # Step1: Adds header message to prompt users that the following is the original information. header_message = "In transformed code:" message_lines.append(header_message) message_lines.append("") error_line = None # Simplify error value to improve readability if error is raised in runtime if self.in_runtime: try: if int( os.getenv( SIMPLIFY_ERROR_ENV_NAME, DEFAULT_SIMPLIFY_NEW_ERROR ) ): self._simplify_error_value() except: pass else: message_lines.append(str(self.error_value)) return '\n'.join(message_lines) # Step2: Optimizes stack information with source code information of dygraph from user. user_code_traceback_index = [] for i, (filepath, lineno, funcname, code) in enumerate( self.origin_traceback ): dygraph_func_info = self.origin_info_map.get( (filepath, lineno), None ) if dygraph_func_info: user_code_traceback_index.append(i) # Add user code traceback for i in user_code_traceback_index: filepath, lineno, funcname, code = self.origin_traceback[i] dygraph_func_info = self.origin_info_map.get( (filepath, lineno), None ) if i == user_code_traceback_index[-1]: traceback_frame = TraceBackFrameRange( dygraph_func_info.location, dygraph_func_info.function_name ) else: traceback_frame = TraceBackFrame( dygraph_func_info.location, dygraph_func_info.function_name, dygraph_func_info.source_code, ) message_lines.append(traceback_frame.formatted_message()) error_line = traceback_frame.error_line message_lines.append("") # Add paddle traceback after user code traceback paddle_traceback_start_index = ( user_code_traceback_index[-1] + 1 if user_code_traceback_index else 0 ) for filepath, lineno, funcname, code in self.origin_traceback[ paddle_traceback_start_index: ]: traceback_frame = TraceBackFrame( Location(filepath, lineno), funcname, code ) message_lines.append(traceback_frame.formatted_message()) message_lines.append("") # Step3: Adds error message like "TypeError: dtype must be int32, but received float32". # NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length # is gather than 1, for example, the error_type is IndentationError. format_exception = traceback.format_exception_only( self.error_type, self.error_value ) if error_line is not None: format_exception = self.numpy_api_check( format_exception, error_line ) error_message = [ " " * BLANK_COUNT_BEFORE_FILE_STR + line for line in format_exception ] message_lines.extend(error_message) return '\n'.join(message_lines) def _create_revise_suggestion(self, bottom_error_message): revise_suggestions = [ '', ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'Revise suggestion: ', ] for keywords in self.suggestion_dict.keys(): contain_keywords = [ True for i in keywords if i in ''.join(bottom_error_message) ] if len(contain_keywords) == len( keywords ): # all keywords should be in bottom_error_message for suggestion in self.suggestion_dict[keywords]: suggestion_msg = ( ' ' * BLANK_COUNT_BEFORE_FILE_STR * 2 + f'{len(revise_suggestions) - 1}. {suggestion}' ) revise_suggestions.append(suggestion_msg) return revise_suggestions if len(revise_suggestions) > 2 else [] def _simplify_error_value(self): """ Simplifies error value to improve readability if error is raised in runtime. NOTE(liym27): The op callstack information about transformed static code has been replaced with original dygraph code. TODO(liym27): 1. Need a more robust way because the code of start_trace may change. 2. Set the switch to determine whether to simplify error_value """ assert self.in_runtime is True error_value_lines = str(self.error_value).split("\n") error_value_lines_strip = [mes.lstrip(" ") for mes in error_value_lines] start_trace = "outputs = static_func(*inputs)" start_idx = error_value_lines_strip.index(start_trace) error_value_lines = error_value_lines[start_idx + 1 :] error_value_lines_strip = error_value_lines_strip[start_idx + 1 :] # use empty line to locate the bottom_error_message empty_line_idx = error_value_lines_strip.index('') bottom_error_message = error_value_lines[empty_line_idx + 1 :] revise_suggestion = self._create_revise_suggestion(bottom_error_message) error_traceback = [] user_code_traceback_index = [] pattern = 'File "(?P.+)", line (?P.+), in (?P.+)' # Distinguish user code and framework code using static_info_map static_info_map = {} for k, v in self.origin_info_map.items(): origin_filepath = v.location.filepath origin_lineno = v.location.lineno static_info_map[(origin_filepath, origin_lineno)] = k for i in range(0, len(error_value_lines_strip), 2): if error_value_lines_strip[i].startswith("File "): re_result = re.search(pattern, error_value_lines_strip[i]) tmp_filepath, lineno_str, function_name = re_result.groups() code = ( error_value_lines_strip[i + 1] if i + 1 < len(error_value_lines_strip) else '' ) if static_info_map.get((tmp_filepath, int(lineno_str))): user_code_traceback_index.append(len(error_traceback)) error_traceback.append( (tmp_filepath, int(lineno_str), function_name, code) ) error_frame = [] # Add user code traceback for i in user_code_traceback_index: filepath, lineno, funcname, code = error_traceback[i] if i == user_code_traceback_index[-1]: traceback_frame = TraceBackFrameRange( Location(filepath, lineno), funcname ) else: traceback_frame = TraceBackFrame( Location(filepath, lineno), funcname, code ) error_frame.append(traceback_frame.formatted_message()) error_frame.append("") # Add paddle traceback after user code traceback paddle_traceback_start_index = ( user_code_traceback_index[-1] + 1 if user_code_traceback_index else 0 ) for filepath, lineno, funcname, code in error_traceback[ paddle_traceback_start_index: ]: traceback_frame = TraceBackFrame( Location(filepath, lineno), funcname, code ) error_frame.append(traceback_frame.formatted_message()) error_frame.append("") error_frame.extend(bottom_error_message) error_frame.extend(revise_suggestion) error_value_str = '\n'.join(error_frame) self.error_value = self.error_type(error_value_str) def raise_new_exception(self): # Raises the origin error if disable dygraph2static error module, if int(os.getenv(DISABLE_ERROR_ENV_NAME, DEFAULT_DISABLE_NEW_ERROR)): raise self.error_value new_exception = self.create_exception() # NOTE(liym27): # Why `raise new_exception from None`? # # In Python 3, by default, an new exception is raised with trace information of the caught exception. # This only raises new_exception and hides unwanted implementation details from tracebacks of the # caught exception. raise new_exception from None