# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Worker that receives input from Piped RDD. """ import os import sys import dataclasses import time import inspect import itertools import json from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, Union from pyspark.accumulators import ( SpecialAccumulatorIds, _accumulatorRegistry, _deserialize_accumulator, ) from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.util import PythonEvalType from pyspark.serializers import ( write_int, read_long, read_bool, write_long, read_int, SpecialLengths, CPickleSerializer, BatchedSerializer, ) from pyspark.sql.conversion import ( LocalDataToArrowConversion, ArrowTableToRowsConversion, ArrowBatchTransformer, PandasToArrowConversion, ) from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( ArrowStreamSerializer, ArrowStreamGroupSerializer, ArrowStreamPandasUDFSerializer, ArrowStreamPandasUDTFSerializer, ArrowStreamGroupUDFSerializer, GroupPandasUDFSerializer, CogroupArrowUDFSerializer, CogroupPandasUDFSerializer, ApplyInPandasWithStateSerializer, TransformWithStateInPandasSerializer, TransformWithStateInPandasInitStateSerializer, TransformWithStateInPySparkRowSerializer, TransformWithStateInPySparkRowInitStateSerializer, ArrowStreamAggPandasUDFSerializer, ArrowStreamUDTFSerializer, ArrowStreamArrowUDTFSerializer, ) from pyspark.sql.pandas.types import to_arrow_schema, to_arrow_type from pyspark.sql.types import ( ArrayType, BinaryType, DataType, MapType, Row, StringType, StructField, StructType, _create_row, _parse_datatype_json_string, ) from pyspark.util import ( fail_on_stopiteration, handle_worker_exception, with_faulthandler, start_faulthandler_periodic_traceback, ) from pyspark import _NoValue, shuffle from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError from pyspark.worker_util import ( check_python_version, get_sock_file_to_executor, read_command, pickleSer, send_accumulator_updates, setup_broadcasts, setup_memory_limits, setup_spark_files, utf8_deserializer, Conf, ) from pyspark.logger.worker_io import capture_outputs class RunnerConf(Conf): @property def assign_cols_by_name(self) -> bool: return ( self.get("spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true") == "true" ) @property def use_large_var_types(self) -> bool: return self.get("spark.sql.execution.arrow.useLargeVarTypes", "false") == "true" @property def use_legacy_pandas_udf_conversion(self) -> bool: return ( self.get("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled", "false") == "true" ) @property def use_legacy_pandas_udtf_conversion(self) -> bool: return ( self.get("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "false") == "true" ) @property def binary_as_bytes(self) -> bool: return self.get("spark.sql.execution.pyspark.binaryAsBytes", "true") == "true" @property def safecheck(self) -> bool: return self.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false") == "true" @property def int_to_decimal_coercion_enabled(self) -> bool: return ( self.get("spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled", "false") == "true" ) @property def prefer_int_ext_dtype(self) -> bool: return ( self.get("spark.sql.execution.pythonUDF.pandas.preferIntExtensionDtype", "false") == "true" ) @property def timezone(self) -> Optional[str]: return self.get("spark.sql.session.timeZone", None, lower_str=False) @property def arrow_max_records_per_batch(self) -> int: return int(self.get("spark.sql.execution.arrow.maxRecordsPerBatch", 10000)) @property def arrow_max_bytes_per_batch(self) -> int: return int(self.get("spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1)) @property def arrow_concurrency_level(self) -> int: return int(self.get("spark.sql.execution.pythonUDF.arrow.concurrency.level", -1)) @property def udf_profiler(self) -> Optional[str]: return self.get("spark.sql.pyspark.udf.profiler", None) @property def data_source_profiler(self) -> Optional[str]: return self.get("spark.sql.pyspark.dataSource.profiler", None) class EvalConf(Conf): @property def state_value_schema(self) -> Optional[StructType]: schema = self.get("state_value_schema", None) if schema is None: return None return StructType.fromJson(json.loads(schema)) @property def grouping_key_schema(self) -> Optional[StructType]: schema = self.get("grouping_key_schema", None) if schema is None: return None return StructType.fromJson(json.loads(schema)) @property def state_server_socket_port(self) -> Optional[int | str]: port = self.get("state_server_socket_port", None) try: return int(port) except ValueError: return port def report_times(outfile, boot, init, finish, processing_time_ms): write_int(SpecialLengths.TIMING_DATA, outfile) write_long(int(1000 * boot), outfile) write_long(int(1000 * init), outfile) write_long(int(1000 * finish), outfile) write_long(processing_time_ms, outfile) def chain(f, g): """chain two functions together""" return lambda *a: g(f(*a)) def verify_result(expected_type: type) -> Callable[[Any], Iterator]: """ Create a result verifier that checks both iterability and element types. Returns a function that takes a UDF result, verifies it is iterable, and lazily type-checks each element via map. Parameters ---------- expected_type : type The expected Python/PyArrow type for each element (e.g. pa.RecordBatch, pa.Array). """ package = getattr(inspect.getmodule(expected_type), "__package__", "") label: str = f"{package}.{expected_type.__name__}" def check_element(element: Any) -> Any: if not isinstance(element, expected_type): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": f"iterator of {label}", "actual": f"iterator of {type(element).__name__}", }, ) return element def check(result: Any) -> Iterator: if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": f"iterator of {label}", "actual": type(result).__name__, }, ) return map(check_element, result) return check def verify_result_row_count(result_length: int, expected: int) -> None: """Raise if the result row count doesn't match the expected input row count.""" if result_length != expected: raise PySparkRuntimeError( errorClass="RESULT_ROWS_MISMATCH", messageParameters={ "output_length": str(result_length), "input_length": str(expected), }, ) def verify_scalar_result(result: Any, num_rows: int) -> Any: """ Verify a scalar UDF result is array-like and has the expected number of rows. Parameters ---------- result : Any The UDF result to verify. num_rows : int Expected number of rows (must match input batch size). """ try: result_length = len(result) except TypeError: raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": "array-like object", "actual": type(result).__name__, }, ) if result_length != num_rows: # TODO: change error class to RESULT_ROWS_MISMATCH raise PySparkRuntimeError( errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", messageParameters={ "udf_type": "arrow_udf", "expected": str(num_rows), "actual": str(result_length), }, ) return result def verify_iterator_exhausted(iterator: Iterator, error_class: str) -> None: """Verify that an iterator has been fully consumed.""" try: next(iterator) except StopIteration: pass else: raise PySparkRuntimeError(errorClass=error_class, messageParameters={}) def verify_output_row_limit( iterator: Iterator, max_rows: Union[int, Callable[[], int]], error_class: str, ) -> Iterator: """Yield elements while verifying total rows do not exceed a limit (fail-fast).""" total_rows = 0 for element in iterator: total_rows += len(element) if total_rows > (max_rows() if callable(max_rows) else max_rows): raise PySparkRuntimeError(errorClass=error_class, messageParameters={}) yield element def verify_output_row_count( iterator: Iterator, expected_rows: Union[int, Callable[[], int]], error_class: str, ) -> Iterator: """Yield elements and verify final row count matches expected exactly.""" actual_rows = 0 for element in iterator: actual_rows += len(element) yield element expected = expected_rows() if callable(expected_rows) else expected_rows if actual_rows != expected: raise PySparkRuntimeError( errorClass=error_class, messageParameters={ "output_length": str(actual_rows), "input_length": str(expected), }, ) def wrap_udf(f, args_offsets, kwargs_offsets, return_type): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) if return_type.needConversion(): toInternal = return_type.toInternal return args_kwargs_offsets, lambda *a: toInternal(func(*a)) else: return args_kwargs_offsets, lambda *a: func(*a) def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) def verify_result_type(result): if not hasattr(result, "__len__"): pd_type = "pandas.DataFrame" if isinstance(return_type, StructType) else "pandas.Series" raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": pd_type, "actual": type(result).__name__, }, ) return result def verify_result_length(result, length): if len(result) != length: raise PySparkRuntimeError( errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", messageParameters={ "udf_type": "pandas_udf", "expected": str(length), "actual": str(len(result)), }, ) return result return ( args_kwargs_offsets, lambda *a: ( verify_result_length(verify_result_type(func(*a)), len(a[0])), return_type, ), ) def wrap_pandas_batch_iter_udf(f, return_type, runner_conf): iter_type_label = "pandas.DataFrame" if isinstance(return_type, StructType) else "pandas.Series" def verify_result(result): if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": "iterator of {}".format(iter_type_label), "actual": type(result).__name__, }, ) return result def verify_element(elem): import pandas as pd if not isinstance(elem, pd.DataFrame if isinstance(return_type, StructType) else pd.Series): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": "iterator of {}".format(iter_type_label), "actual": "iterator of {}".format(type(elem).__name__), }, ) verify_pandas_result( elem, return_type, assign_cols_by_name=True, truncate_return_schema=True ) return elem return lambda *iterator: map( lambda res: (res, return_type), map(verify_element, verify_result(f(*iterator))) ) def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_return_schema): import pandas as pd if isinstance(return_type, StructType): if not isinstance(result, pd.DataFrame): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": "pandas.DataFrame", "actual": type(result).__name__, }, ) # check the schema of the result only if it is not empty or has columns if not result.empty or len(result.columns) != 0: # if any column name of the result is a string # the column names of the result have to match the return type # see create_array in pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer field_names = set([field.name for field in return_type.fields]) # only the first len(field_names) result columns are considered # when truncating the return schema result_columns = ( result.columns[: len(field_names)] if truncate_return_schema else result.columns ) column_names = set(result_columns) if ( assign_cols_by_name and any(isinstance(name, str) for name in result.columns) and column_names != field_names ): missing = sorted(list(field_names.difference(column_names))) missing = f" Missing: {', '.join(missing)}." if missing else "" extra = sorted(list(column_names.difference(field_names))) extra = f" Unexpected: {', '.join(extra)}." if extra else "" raise PySparkRuntimeError( errorClass="RESULT_COLUMN_NAMES_MISMATCH", messageParameters={ "missing": missing, "extra": extra, }, ) # otherwise the number of columns of result have to match the return type elif len(result_columns) != len(return_type): raise PySparkRuntimeError( errorClass="RESULT_COLUMN_SCHEMA_MISMATCH", messageParameters={ "expected": str(len(return_type)), "actual": str(len(result.columns)), }, ) else: if not isinstance(result, pd.Series): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={"expected": "pandas.Series", "actual": type(result).__name__}, ) def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields } else: expected_cols_and_types = [ (col.name, to_arrow_type(col.dataType, timezone="UTC")) for col in return_type.fields ] def wrapped(left_key_table, left_value_table, right_key_table, right_value_table): if len(argspec.args) == 2: result = f(left_value_table, right_value_table) elif len(argspec.args) == 3: key_table = left_key_table if left_key_table.num_rows > 0 else right_key_table key = tuple(c[0] for c in key_table.columns) result = f(key, left_value_table, right_value_table) verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) return result.to_batches() return lambda kl, vl, kr, vr: ( wrapped(kl, vl, kr, vr), to_arrow_type(return_type, timezone="UTC"), ) def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): def wrapped(left_key_series, left_value_series, right_key_series, right_value_series): import pandas as pd left_df = pd.concat(left_value_series, axis=1) right_df = pd.concat(right_value_series, axis=1) if len(argspec.args) == 2: result = f(left_df, right_df) elif len(argspec.args) == 3: key_series = left_key_series if not left_df.empty else right_key_series key = tuple(s[0] for s in key_series) result = f(key, left_df, right_df) verify_pandas_result( result, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False ) return result return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), return_type)] def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): # the types of the fields have to be identical to return type # an empty table can have no columns; if there are columns, they have to match if result.num_columns != 0 or result.num_rows != 0: # columns are either mapped by name or position if assign_cols_by_name: actual_cols_and_types = { name: dataType for name, dataType in zip(result.schema.names, result.schema.types) } missing = sorted( list(set(expected_cols_and_types.keys()).difference(actual_cols_and_types.keys())) ) extra = sorted( list(set(actual_cols_and_types.keys()).difference(expected_cols_and_types.keys())) ) if missing or extra: missing = f" Missing: {', '.join(missing)}." if missing else "" extra = f" Unexpected: {', '.join(extra)}." if extra else "" raise PySparkRuntimeError( errorClass="RESULT_COLUMN_NAMES_MISMATCH", messageParameters={ "missing": missing, "extra": extra, }, ) column_types = [ (name, expected_cols_and_types[name], actual_cols_and_types[name]) for name in sorted(expected_cols_and_types.keys()) ] else: actual_cols_and_types = [ (name, dataType) for name, dataType in zip(result.schema.names, result.schema.types) ] column_types = [ (expected_name, expected_type, actual_type) for (expected_name, expected_type), (actual_name, actual_type) in zip( expected_cols_and_types, actual_cols_and_types ) ] type_mismatch = [ (name, expected, actual) for name, expected, actual in column_types if actual != expected ] if type_mismatch: raise PySparkRuntimeError( errorClass="RESULT_TYPE_MISMATCH_FOR_ARROW_UDF", messageParameters={ "mismatch": ", ".join( "column '{}' (expected {}, actual {})".format(name, expected, actual) for name, expected, actual in type_mismatch ) }, ) def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types): import pyarrow as pa if not isinstance(table, pa.Table): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": "pyarrow.Table", "actual": type(table).__name__, }, ) verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types) def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types): import pyarrow as pa if not isinstance(batch, pa.RecordBatch): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": "pyarrow.RecordBatch", "actual": type(batch).__name__, }, ) verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types) def wrap_grouped_map_arrow_udf(f, return_type, argspec, runner_conf): import pyarrow as pa if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields } else: expected_cols_and_types = [ (col.name, to_arrow_type(col.dataType, timezone="UTC")) for col in return_type.fields ] def wrapped(key_batch, value_batches): value_table = pa.Table.from_batches(value_batches) if len(argspec.args) == 1: result = f(value_table) elif len(argspec.args) == 2: key = tuple(c[0] for c in key_batch.columns) result = f(key, value_table) verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) yield from result.to_batches() arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) return lambda k, v: (wrapped(k, v), arrow_return_type) def wrap_grouped_map_arrow_iter_udf(f, return_type, argspec, runner_conf): if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields } else: expected_cols_and_types = [ (col.name, to_arrow_type(col.dataType, timezone="UTC")) for col in return_type.fields ] def wrapped(key_batch, value_batches): if len(argspec.args) == 1: result = f(value_batches) elif len(argspec.args) == 2: key = tuple(c[0] for c in key_batch.columns) result = f(key, value_batches) def verify_element(batch): verify_arrow_batch(batch, runner_conf.assign_cols_by_name, expected_cols_and_types) return batch yield from map(verify_element, result) arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) return lambda k, v: (wrapped(k, v), arrow_return_type) def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): def wrapped(key_series, value_series): import pandas as pd value_df = pd.concat(value_series, axis=1) if len(argspec.args) == 1: result = f(value_df) elif len(argspec.args) == 2: # Extract key from pandas Series, preserving numpy types key = tuple(s.iloc[0] for s in key_series) result = f(key, value_df) verify_pandas_result( result, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False ) yield result def flatten_wrapper(k, v): # Return Iterator[[(df, spark_type)]] directly for df in wrapped(k, v): yield [(df, return_type)] return flatten_wrapper def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf): def wrapped(key_series, value_batches): import pandas as pd # value_batches is an Iterator[list[pd.Series]] (one list per batch) # Convert each list of Series into a DataFrame def dataframe_iter(): for value_series in value_batches: yield pd.concat(value_series, axis=1) if len(argspec.args) == 1: result = f(dataframe_iter()) elif len(argspec.args) == 2: # Extract key from pandas Series, preserving numpy types key = tuple(s.iloc[0] for s in key_series) result = f(key, dataframe_iter()) def verify_element(df): verify_pandas_result( df, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False ) return df yield from map(verify_element, result) def flatten_wrapper(k, v): # Return Iterator[[(df, spark_type)]] directly for df in wrapped(k, v): yield [(df, return_type)] return flatten_wrapper def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf): def wrapped(stateful_processor_api_client, mode, key, value_series_gen): result_iter = f(stateful_processor_api_client, mode, key, value_series_gen) # TODO(SPARK-49100): add verification that elements in result_iter are # indeed of type pd.DataFrame and confirm to assigned cols return result_iter return lambda p, m, k, v: [(wrapped(p, m, k, v), return_type)] def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, runner_conf): def wrapped(stateful_processor_api_client, mode, key, value_series_gen): # Split the generator into two using itertools.tee state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2) # Extract just the data DataFrames (first element of each tuple) state_values = (data_df for data_df, _ in state_values_gen if not data_df.empty) # Extract just the init DataFrames (second element of each tuple) init_states = (init_df for _, init_df in init_states_gen if not init_df.empty) result_iter = f(stateful_processor_api_client, mode, key, state_values, init_states) # TODO(SPARK-49100): add verification that elements in result_iter are # indeed of type pd.DataFrame and confirm to assigned cols return result_iter return lambda p, m, k, v: [(wrapped(p, m, k, v), return_type)] def wrap_grouped_transform_with_state_udf(f, return_type, runner_conf): def wrapped(stateful_processor_api_client, mode, key, values): result_iter = f(stateful_processor_api_client, mode, key, values) # TODO(SPARK-XXXXX): add verification that elements in result_iter are # indeed of type Row and confirm to assigned cols return result_iter return lambda p, m, k, v: [(wrapped(p, m, k, v), return_type)] def wrap_grouped_transform_with_state_init_state_udf(f, return_type, runner_conf): def wrapped(stateful_processor_api_client, mode, key, values): if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: values_gen = values[0] init_states_gen = values[1] else: values_gen = iter([]) init_states_gen = iter([]) result_iter = f(stateful_processor_api_client, mode, key, values_gen, init_states_gen) # TODO(SPARK-XXXXX): add verification that elements in result_iter are # indeed of type pd.DataFrame and confirm to assigned cols return result_iter return lambda p, m, k, v: [(wrapped(p, m, k, v), return_type)] def wrap_grouped_map_pandas_udf_with_state(f, return_type, runner_conf): """ Provides a new lambda instance wrapping user function of applyInPandasWithState. The lambda instance receives (key series, iterator of value series, state) and performs some conversion to be adapted with the signature of user function. See the function doc of inner function `wrapped` for more details on what adapter does. See the function doc of `mapper` function for `eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` for more details on the input parameters of lambda function. The lambda instance returns a tuple (iterator, return_type). """ def wrapped(key_series, value_series_gen, state): """ Provide an adapter of the user function performing below: - Extract the first value of all columns in key series and produce as a tuple. - If the state has timed out, call the user function with empty pandas DataFrame. - If not, construct a new generator which converts each element of value series to pandas DataFrame (lazy evaluation), and call the user function with the generator - Verify each element of returned iterator to check the schema of pandas DataFrame. """ import pandas as pd key = tuple(s[0] for s in key_series) if state.hasTimedOut: # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. values = [ pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns), ] else: values = (pd.concat(x, axis=1) for x in value_series_gen) result_iter = f(key, values, state) def verify_element(result): if not isinstance(result, pd.DataFrame): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": "iterator of pandas.DataFrame", "actual": "iterator of {}".format(type(result).__name__), }, ) # the number of columns of result have to match the return type # but it is fine for result to have no columns at all if it is empty if not ( len(result.columns) == len(return_type) or (len(result.columns) == 0 and result.empty) ): raise PySparkRuntimeError( errorClass="RESULT_COLUMN_SCHEMA_MISMATCH", messageParameters={ "expected": str(len(return_type)), "actual": str(len(result.columns)), }, ) return result if isinstance(result_iter, pd.DataFrame): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ "expected": "iterable of pandas.DataFrame", "actual": type(result_iter).__name__, }, ) try: iter(result_iter) except TypeError: raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={"expected": "iterable", "actual": type(result_iter).__name__}, ) result_iter_with_validation = (verify_element(x) for x in result_iter) return ( result_iter_with_validation, state, ) return lambda k, v, s: [(wrapped(k, v, s), return_type)] def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) def wrapped(*series): import pandas as pd result = func(*series) return pd.Series([result]) return ( args_kwargs_offsets, lambda *a: (wrapped(*a), return_type), ) def wrap_grouped_agg_pandas_iter_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) def wrapped(series_iter): import pandas as pd # series_iter: Iterator[pd.Series] (single column) or # Iterator[Tuple[pd.Series, ...]] (multiple columns) # This has already been adapted by the mapper function in read_udfs result = func(series_iter) return pd.Series([result]) return ( args_kwargs_offsets, lambda *a: (wrapped(*a), return_type), ) def wrap_window_agg_pandas_udf( f, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index ): window_bound_types_str = runner_conf.get("window_bound_types") window_bound_type = [t.strip().lower() for t in window_bound_types_str.split(",")][udf_index] if window_bound_type == "bounded": return wrap_bounded_window_agg_pandas_udf( f, args_offsets, kwargs_offsets, return_type, runner_conf ) elif window_bound_type == "unbounded": return wrap_unbounded_window_agg_pandas_udf( f, args_offsets, kwargs_offsets, return_type, runner_conf ) else: raise PySparkRuntimeError( errorClass="INVALID_WINDOW_BOUND_TYPE", messageParameters={ "window_bound_type": window_bound_type, }, ) def wrap_unbounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) # This is similar to grouped_agg_pandas_udf, the only difference # is that window_agg_pandas_udf needs to repeat the return value # to match window length, where grouped_agg_pandas_udf just returns # the scalar value. def wrapped(*series): import pandas as pd result = func(*series) return pd.Series([result]).repeat(len(series[0])) return ( args_kwargs_offsets, lambda *a: (wrapped(*a), return_type), ) def wrap_bounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): # args_offsets should have at least 2 for begin_index, end_index. assert len(args_offsets) >= 2, len(args_offsets) func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets[2:], kwargs_offsets) def wrapped(begin_index, end_index, *series): import pandas as pd result = [] # Index operation is faster on np.ndarray, # So we turn the index series into np array # here for performance begin_array = begin_index.values end_array = end_index.values for i in range(len(begin_array)): # Note: Create a slice from a series for each window is # actually pretty expensive. However, there # is no easy way to reduce cost here. # Note: s.iloc[i : j] is about 30% faster than s[i: j], with # the caveat that the created slices shares the same # memory with s. Therefore, user are not allowed to # change the value of input series inside the window # function. It is rare that user needs to modify the # input series in the window function, and therefore, # it is be a reasonable restriction. # Note: Calling reset_index on the slices will increase the cost # of creating slices by about 100%. Therefore, for performance # reasons we don't do it here. series_slices = [s.iloc[begin_array[i] : end_array[i]] for s in series] result.append(func(*series_slices)) return pd.Series(result) return ( args_offsets[:2] + args_kwargs_offsets, lambda *a: (wrapped(*a), return_type), ) def wrap_kwargs_support(f, args_offsets, kwargs_offsets): if len(kwargs_offsets): keys = list(kwargs_offsets.keys()) len_args_offsets = len(args_offsets) if len_args_offsets > 0: def func(*args): return f(*args[:len_args_offsets], **dict(zip(keys, args[len_args_offsets:]))) else: def func(*args): return f(**dict(zip(keys, args))) return func, args_offsets + [kwargs_offsets[key] for key in keys] else: return f, args_offsets def _is_iter_based(eval_type: int) -> bool: return eval_type in ( PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, ) def wrap_perf_profiler(f, eval_type, result_id): from pyspark.sql.profiler import ProfileResultsParam, WorkerPerfProfiler accumulator = _deserialize_accumulator( SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam ) if _is_iter_based(eval_type): def profiling_func(*args, **kwargs): iterator = iter(f(*args, **kwargs)) while True: try: with WorkerPerfProfiler(accumulator, result_id): item = next(iterator) yield item except StopIteration: break else: def profiling_func(*args, **kwargs): with WorkerPerfProfiler(accumulator, result_id): ret = f(*args, **kwargs) return ret return profiling_func def wrap_memory_profiler(f, eval_type, result_id): from pyspark.sql.profiler import ProfileResultsParam, WorkerMemoryProfiler import pyspark.memory_profiler_ext if not pyspark.memory_profiler_ext.has_memory_profiler: return f accumulator = _deserialize_accumulator( SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam ) if _is_iter_based(eval_type): def profiling_func(*args, **kwargs): g = f(*args, **kwargs) iterator = iter(g) while True: try: with WorkerMemoryProfiler(accumulator, result_id, g.gi_code): item = next(iterator) yield item except StopIteration: break else: def profiling_func(*args, **kwargs): with WorkerMemoryProfiler(accumulator, result_id, f): ret = f(*args, **kwargs) return ret return profiling_func def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): num_arg = read_int(infile) args_offsets = [] kwargs_offsets = {} for _ in range(num_arg): offset = read_int(infile) if read_bool(infile): name = utf8_deserializer.loads(infile) kwargs_offsets[name] = offset else: args_offsets.append(offset) chained_func = None for i in range(read_int(infile)): f, return_type = read_command(pickleSer, infile) if chained_func is None: chained_func = f else: chained_func = chain(chained_func, f) result_id = read_long(infile) # If chained_func is from pyspark.sql.worker, it is to read/write data source. # In this case, we check the data_source_profiler config. if getattr(chained_func, "__module__", "").startswith("pyspark.sql.worker."): profiler = runner_conf.data_source_profiler else: profiler = runner_conf.udf_profiler if profiler == "perf": profiling_func = wrap_perf_profiler(chained_func, eval_type, result_id) elif profiler == "memory": profiling_func = wrap_memory_profiler(chained_func, eval_type, result_id) else: profiling_func = chained_func if eval_type in ( PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, ): func = profiling_func else: # make sure StopIteration's raised in the user code are not ignored # when they are processed in a for loop, raise them as RuntimeError's instead func = fail_on_stopiteration(profiling_func) # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF: return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF: return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF: return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: return func, None, None, None elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_grouped_map_pandas_iter_udf( func, return_type, argspec, runner_conf ) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_grouped_map_arrow_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF: argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_grouped_map_arrow_iter_udf( func, return_type, argspec, runner_conf ) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: return args_offsets, wrap_grouped_transform_with_state_pandas_udf( func, return_type, runner_conf ) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF: return args_offsets, wrap_grouped_transform_with_state_pandas_init_state_udf( func, return_type, runner_conf ) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: return args_offsets, wrap_grouped_transform_with_state_udf(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: return args_offsets, wrap_grouped_transform_with_state_init_state_udf( func, return_type, runner_conf ) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_cogrouped_map_arrow_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return wrap_grouped_agg_pandas_udf( func, args_offsets, kwargs_offsets, return_type, runner_conf ) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, ): return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF: return wrap_grouped_agg_pandas_iter_udf( func, args_offsets, kwargs_offsets, return_type, runner_conf ) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: return wrap_window_agg_pandas_udf( func, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index ) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF: return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return wrap_udf(func, args_offsets, kwargs_offsets, return_type) else: raise ValueError("Unknown eval type: {}".format(eval_type)) # Read and process a serialized user-defined table function (UDTF) from a socket. # It expects the UDTF to be in a specific format and performs various checks to # ensure the UDTF is valid. This function also prepares a mapper function for applying # the UDTF logic to input rows. def read_udtf(pickleSer, infile, eval_type, runner_conf): if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: input_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if runner_conf.use_legacy_pandas_udtf_conversion: # NOTE: if timezone is set here, that implies respectSessionTimeZone is True ser = ArrowStreamPandasUDTFSerializer( timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, input_type=input_type, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) else: ser = ArrowStreamUDTFSerializer() elif eval_type == PythonEvalType.SQL_ARROW_UDTF: # Read the table argument offsets num_table_arg_offsets = read_int(infile) table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)] # Use PyArrow-native serializer for Arrow UDTFs with potential UDT support ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets) else: # Each row is a group so do not batch but send one by one. ser = BatchedSerializer(CPickleSerializer(), 1) # See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand' num_arg = read_int(infile) args_offsets = [] kwargs_offsets = {} for _ in range(num_arg): offset = read_int(infile) if read_bool(infile): name = utf8_deserializer.loads(infile) kwargs_offsets[name] = offset else: args_offsets.append(offset) num_partition_child_indexes = read_int(infile) partition_child_indexes = [read_int(infile) for i in range(num_partition_child_indexes)] has_pickled_analyze_result = read_bool(infile) if has_pickled_analyze_result: pickled_analyze_result = pickleSer._read_with_length(infile) else: pickled_analyze_result = None # Initially we assume that the UDTF __init__ method accepts the pickled AnalyzeResult, # although we may set this to false later if we find otherwise. handler = read_command(pickleSer, infile) if not isinstance(handler, type): raise PySparkRuntimeError( f"Invalid UDTF handler type. Expected a class (type 'type'), but " f"got an instance of {type(handler).__name__}." ) return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if not isinstance(return_type, StructType): raise PySparkRuntimeError( f"The return type of a UDTF must be a struct type, but got {type(return_type)}." ) udtf_name = utf8_deserializer.loads(infile) # Update the handler that creates a new UDTF instance to first try calling the UDTF constructor # with one argument containing the previous AnalyzeResult. If that fails, then try a constructor # with no arguments. In this way each UDTF class instance can decide if it wants to inspect the # AnalyzeResult. udtf_init_args = inspect.getfullargspec(handler) if has_pickled_analyze_result: if len(udtf_init_args.args) > 2: raise PySparkRuntimeError( errorClass="UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD", messageParameters={"name": udtf_name}, ) elif len(udtf_init_args.args) == 2: prev_handler = handler def construct_udtf(): # Here we pass the AnalyzeResult to the UDTF's __init__ method. return prev_handler(dataclasses.replace(pickled_analyze_result)) handler = construct_udtf elif len(udtf_init_args.args) > 1: raise PySparkRuntimeError( errorClass="UDTF_CONSTRUCTOR_INVALID_NO_ANALYZE_METHOD", messageParameters={"name": udtf_name}, ) class UDTFWithPartitions: """ This implements the logic of a UDTF that accepts an input TABLE argument with one or more PARTITION BY expressions. For example, let's assume we have a table like: CREATE TABLE t (c1 INT, c2 INT) USING delta; Then for the following queries: SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); The partition_child_indexes will be: 0, 1. SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); The partition_child_indexes will be: 0, 2 (where we add a projection for "c2 + 4"). """ def __init__(self, create_udtf: Callable, partition_child_indexes: list): """ Creates a new instance of this class to wrap the provided UDTF with another one that checks the values of projected partitioning expressions on consecutive rows to figure out when the partition boundaries change. Parameters ---------- create_udtf: function Function to create a new instance of the UDTF to be invoked. partition_child_indexes: list List of integers identifying zero-based indexes of the columns of the input table that contain projected partitioning expressions. This class will inspect these values for each pair of consecutive input rows. When they change, this indicates the boundary between two partitions, and we will invoke the 'terminate' method on the UDTF class instance and then destroy it and create a new one to implement the desired partitioning semantics. """ self._create_udtf: Callable = create_udtf self._udtf = create_udtf() self._prev_arguments: list = list() self._partition_child_indexes: list = partition_child_indexes self._eval_raised_skip_rest_of_input_table: bool = False def eval(self, *args, **kwargs) -> Iterator: changed_partitions = self._check_partition_boundaries( list(args) + list(kwargs.values()) ) if changed_partitions: if hasattr(self._udtf, "terminate"): result = self._udtf.terminate() if result is not None: for row in result: yield row self._udtf = self._create_udtf() self._eval_raised_skip_rest_of_input_table = False if self._udtf.eval is not None and not self._eval_raised_skip_rest_of_input_table: # Filter the arguments to exclude projected PARTITION BY values added by Catalyst. filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] filtered_kwargs = { key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() } try: result = self._udtf.eval(*filtered_args, **filtered_kwargs) if result is not None: for row in result: yield row except SkipRestOfInputTableException: # If the 'eval' method raised this exception, then we should skip the rest of # the rows in the current partition. Set this field to True here and then for # each subsequent row in the partition, we will skip calling the 'eval' method # until we see a change in the partition boundaries. self._eval_raised_skip_rest_of_input_table = True def terminate(self) -> Iterator: if hasattr(self._udtf, "terminate"): return self._udtf.terminate() return iter(()) def cleanup(self) -> None: if hasattr(self._udtf, "cleanup"): self._udtf.cleanup() def _check_partition_boundaries(self, arguments: list) -> bool: result = False if len(self._prev_arguments) > 0: cur_table_arg = self._get_table_arg(arguments) prev_table_arg = self._get_table_arg(self._prev_arguments) cur_partitions_args = [] prev_partitions_args = [] for i in self._partition_child_indexes: cur_partitions_args.append(cur_table_arg[i]) prev_partitions_args.append(prev_table_arg[i]) result = any(k != v for k, v in zip(cur_partitions_args, prev_partitions_args)) self._prev_arguments = arguments return result def _get_table_arg(self, inputs: list) -> Row: return [x for x in inputs if type(x) is Row][0] def _remove_partition_by_exprs(self, arg: Any) -> Any: if isinstance(arg, Row): new_row_keys = [] new_row_values = [] for i, (key, value) in enumerate(zip(arg.__fields__, arg)): if i not in self._partition_child_indexes: new_row_keys.append(key) new_row_values.append(value) return _create_row(new_row_keys, new_row_values) else: return arg class ArrowUDTFWithPartition: """ Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument with one or more PARTITION BY expressions. Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row objects. This wrapper ensures the UDTF's eval() method is called separately for each unique partition key value combination. How Catalyst handles PARTITION BY and ORDER BY: ------------------------------------------------ When a UDTF is called with PARTITION BY and/or ORDER BY clauses, Catalyst adds operations to the physical plan to ensure correct data organization: Example SQL: SELECT * FROM my_udtf(TABLE(t) PARTITION BY key1, key2 ORDER BY value DESC) Physical Plan generated by Catalyst: 1. Project: Adds partition_by_0 = key1, partition_by_1 = key2 columns 2. Exchange: hashpartitioning(partition_by_0, partition_by_1, 200) - Shuffles data so rows with same partition keys go to same worker 3. Sort: [partition_by_0 ASC, partition_by_1 ASC, value DESC], local=true - First sorts by partition keys to group them together - Then sorts by ORDER BY expressions within each partition - Local sort (not global) within each worker's data 4. Project: Creates struct with all columns including partition_by_* columns 5. ArrowEvalPythonUDTF: Executes this Python UDTF wrapper Key guarantee: After the Sort operation, all rows with the same partition key values are contiguous within each RecordBatch, allowing efficient boundary detection. Example queries: SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1); partition_child_indexes: [2] (refers to partition_by_0 column at index 2) SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); partition_child_indexes: [2, 3] (partition_by_0 and partition_by_1 columns) SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). """ def __init__(self, create_udtf: Callable, partition_child_indexes: list): """ Create a new instance that wraps the provided Arrow UDTF with partitioning logic. Parameters ---------- create_udtf: function Function that creates a new instance of the Arrow UDTF to invoke. partition_child_indexes: list Zero-based indexes of input-table columns that contain projected partitioning expressions. """ self._create_udtf: Callable = create_udtf self._udtf = create_udtf() self._partition_child_indexes: list = partition_child_indexes # Track last partition key from previous batch self._last_partition_key: Optional[Tuple[Any, ...]] = None self._eval_raised_skip_rest_of_input_table: bool = False def eval(self, *args, **kwargs) -> Iterator: """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" import pyarrow as pa # Get the original batch with partition columns original_batch = self._get_table_arg(list(args) + list(kwargs.values())) if not isinstance(original_batch, pa.RecordBatch): # Arrow UDTFs with PARTITION BY must have a TABLE argument that # results in a PyArrow RecordBatch raise PySparkRuntimeError( errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", messageParameters={ "actual_type": ( str(type(original_batch)) if original_batch is not None else "None" ) }, ) # Remove partition columns to get the filtered arguments filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] filtered_kwargs = { key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() } # Get the filtered RecordBatch (without partition columns) filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) # Process the RecordBatch by partitions yield from self._process_arrow_batch_by_partitions( original_batch, filtered_batch, filtered_args, filtered_kwargs ) def _process_arrow_batch_by_partitions( self, original_batch, filtered_batch, filtered_args, filtered_kwargs ) -> Iterator: """Process an Arrow RecordBatch that may contain multiple partition key values. When using PARTITION BY with Arrow UDTFs, a single RecordBatch from Spark may contain rows with different partition key values. For example, with 10 distinct partition keys and 2 workers, each worker might receive a batch containing 5 different partition key values. According to UDTF PARTITION BY semantics, the UDTF's eval() method must be called separately for each unique partition key value, not for the entire batch. This method handles splitting the batch by partition boundaries and calling the UDTF appropriately. The implementation leverages two key properties: 1. Catalyst guarantees rows with the same partition key are contiguous (pre-sorted) 2. Arrow's columnar format allows efficient boundary detection Parameters: ----------- original_batch : pa.RecordBatch The original batch including partition columns, used for detecting boundaries filtered_batch : pa.RecordBatch The batch with partition columns removed, to be passed to the UDTF filtered_args : list Arguments with partition columns filtered out filtered_kwargs : dict Keyword arguments with partition columns filtered out Yields: ------- Iterator of pa.Table objects returned by the UDTF's eval() method """ import pyarrow as pa # This class should only be used when partition_child_indexes is non-empty assert self._partition_child_indexes, ( "ArrowUDTFWithPartition should only be instantiated when " "len(partition_child_indexes) > 0" ) # Detect partition boundaries. boundaries = self._detect_partition_boundaries(original_batch) # Process each contiguous partition for i in range(len(boundaries) - 1): start_idx = boundaries[i] end_idx = boundaries[i + 1] # Get the partition key for this segment partition_key = tuple( original_batch.column(idx)[start_idx].as_py() for idx in self._partition_child_indexes ) # Check if this is a continuation of the previous batch's partition # TODO: This check is only necessary for the first boundary in each batch. # The following boundaries are always for new partitions within the same batch. # This could be optimized by only checking i == 0. is_new_partition = ( self._last_partition_key is not None and partition_key != self._last_partition_key ) if is_new_partition: # Previous partition ended, call terminate if hasattr(self._udtf, "terminate"): terminate_result = self._udtf.terminate() if terminate_result is not None: yield from terminate_result # Create new UDTF instance for new partition self._udtf = self._create_udtf() self._eval_raised_skip_rest_of_input_table = False # Slice the filtered batch for this partition partition_batch = filtered_batch.slice(start_idx, end_idx - start_idx) # Update the last partition key self._last_partition_key = partition_key # Update filtered args to use the partition batch partition_filtered_args = [] for arg in filtered_args: if isinstance(arg, pa.RecordBatch): partition_filtered_args.append(partition_batch) else: partition_filtered_args.append(arg) partition_filtered_kwargs = {} for key, value in filtered_kwargs.items(): if isinstance(value, pa.RecordBatch): partition_filtered_kwargs[key] = partition_batch else: partition_filtered_kwargs[key] = value # Call the UDTF with this partition's data if not self._eval_raised_skip_rest_of_input_table: try: result = self._udtf.eval( *partition_filtered_args, **partition_filtered_kwargs ) if result is not None: yield from result except SkipRestOfInputTableException: # Skip remaining rows in this partition self._eval_raised_skip_rest_of_input_table = True # Don't terminate here - let the next batch or final terminate handle it def terminate(self) -> Iterator: if hasattr(self._udtf, "terminate"): return self._udtf.terminate() return iter(()) def cleanup(self) -> None: if hasattr(self._udtf, "cleanup"): self._udtf.cleanup() def _get_table_arg(self, inputs: list): """Get the table argument (RecordBatch) from the inputs list. For Arrow UDTFs with TABLE arguments, we can guarantee the table argument will be a pa.RecordBatch, not a Row. """ import pyarrow as pa # Find all RecordBatch arguments batches = [arg for arg in inputs if isinstance(arg, pa.RecordBatch)] if len(batches) == 0: # No RecordBatch found - this shouldn't happen for Arrow UDTFs with TABLE arguments return None elif len(batches) == 1: return batches[0] else: # Multiple RecordBatch arguments found - this is unexpected raise RuntimeError( f"Expected exactly one pa.RecordBatch argument for TABLE parameter, " f"but found {len(batches)}. Received types: " f"{[type(arg).__name__ for arg in inputs]}" ) def _detect_partition_boundaries(self, batch) -> list: """ Efficiently detect partition boundaries in a batch with contiguous partitions. Since Catalyst ensures rows with the same partition key are contiguous, we only need to find where partition values change. Returns: List of indices where each partition starts, plus the total row count. For example: [0, 3, 8, 10] means partitions are rows [0:3), [3:8), [8:10) """ boundaries = [0] # First partition starts at index 0 if batch.num_rows <= 1: boundaries.append(batch.num_rows) return boundaries # Get partition column arrays partition_arrays = [batch.column(i) for i in self._partition_child_indexes] # Find boundaries by comparing consecutive rows for row_idx in range(1, batch.num_rows): # Check if any partition column changed from previous row partition_changed = False for col_array in partition_arrays: if col_array[row_idx].as_py() != col_array[row_idx - 1].as_py(): partition_changed = True break if partition_changed: boundaries.append(row_idx) boundaries.append(batch.num_rows) # Last boundary at end return boundaries def _remove_partition_by_exprs(self, arg: Any) -> Any: """ Remove partition columns from the RecordBatch argument. Why this is needed: When a UDTF is called with TABLE(t) PARTITION BY expressions, Catalyst transforms the data: 1. Adds complex partition expressions as new columns (e.g., "c2 + 4" becomes a new column) 2. Repartitions data by partition columns using hash partitioning 3. Sends ALL columns (including partition columns) to the Python worker Partition columns serve two purposes: - Routing: decide which worker processes which partition - Boundary detection: know when one partition ends and another begins However, the user's UDTF should only receive the actual table data, not the partition columns. This method filters out partition columns before passing data to the user's UDTF eval() method. Example: - User writes: SELECT * FROM udtf(TABLE(t) PARTITION BY c1, c2) - Catalyst sends: RecordBatch with [c1, c2, c3, c4], partition_child_indexes=[0, 1] - This method removes columns at indexes 0, 1 if they are pure partition columns - UDTF.eval() receives: RecordBatch with only the non-partition columns """ import pyarrow as pa if isinstance(arg, pa.RecordBatch): # Remove partition columns from the RecordBatch keep_indices = [ i for i in range(len(arg.schema.names)) if i not in self._partition_child_indexes ] if keep_indices: # Select only the columns we want to keep keep_arrays = [arg.column(i) for i in keep_indices] keep_names = [arg.schema.names[i] for i in keep_indices] return pa.RecordBatch.from_arrays(keep_arrays, names=keep_names) else: # If no columns remain, return an empty RecordBatch with the same number of rows return pa.RecordBatch.from_arrays( [], schema=pa.schema([]), num_rows=arg.num_rows ) # For non-RecordBatch arguments (like scalar pa.Arrays), return unchanged return arg # Instantiate the UDTF class. try: if len(partition_child_indexes) > 0: # Determine if this is an Arrow UDTF is_arrow_udtf = eval_type == PythonEvalType.SQL_ARROW_UDTF if is_arrow_udtf: udtf = ArrowUDTFWithPartition(handler, partition_child_indexes) else: udtf = UDTFWithPartitions(handler, partition_child_indexes) else: udtf = handler() except Exception as e: raise PySparkRuntimeError( errorClass="UDTF_EXEC_ERROR", messageParameters={"method_name": "__init__", "error": str(e)}, ) # Validate the UDTF if not hasattr(udtf, "eval"): raise PySparkRuntimeError( "Failed to execute the user defined table function because it has not " "implemented the 'eval' method. Please add the 'eval' method and try " "the query again." ) # Check that the arguments provided to the UDTF call match the expected parameters defined # in the 'eval' method signature. try: inspect.signature(udtf.eval).bind(*args_offsets, **kwargs_offsets) except TypeError as e: raise PySparkRuntimeError( errorClass="UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE", messageParameters={"name": udtf_name, "reason": str(e)}, ) from None def build_null_checker(return_type: StructType) -> Optional[Callable[[Any], None]]: def raise_(result_column_index): raise PySparkRuntimeError( errorClass="UDTF_EXEC_ERROR", messageParameters={ "method_name": "eval' or 'terminate", "error": f"Column {result_column_index} within a returned row had a " + "value of None, either directly or within array/struct/map " + "subfields, but the corresponding column type was declared as " + "non-nullable; please update the UDTF to return a non-None value at " + "this location or otherwise declare the column type as nullable.", }, ) def checker(data_type: DataType, result_column_index: int): if isinstance(data_type, ArrayType): element_checker = checker(data_type.elementType, result_column_index) contains_null = data_type.containsNull if element_checker is None and contains_null: return None def check_array(arr): if isinstance(arr, list): for e in arr: if e is None: if not contains_null: raise_(result_column_index) elif element_checker is not None: element_checker(e) return check_array elif isinstance(data_type, MapType): key_checker = checker(data_type.keyType, result_column_index) value_checker = checker(data_type.valueType, result_column_index) value_contains_null = data_type.valueContainsNull if value_checker is None and value_contains_null: def check_map(map): if isinstance(map, dict): for k, v in map.items(): if k is None: raise_(result_column_index) elif key_checker is not None: key_checker(k) else: def check_map(map): if isinstance(map, dict): for k, v in map.items(): if k is None: raise_(result_column_index) elif key_checker is not None: key_checker(k) if v is None: if not value_contains_null: raise_(result_column_index) elif value_checker is not None: value_checker(v) return check_map elif isinstance(data_type, StructType): field_checkers = [checker(f.dataType, result_column_index) for f in data_type] nullables = [f.nullable for f in data_type] if all(c is None for c in field_checkers) and all(nullables): return None def check_struct(struct): if isinstance(struct, tuple): for value, checker, nullable in zip(struct, field_checkers, nullables): if value is None: if not nullable: raise_(result_column_index) elif checker is not None: checker(value) return check_struct else: return None field_checkers = [ checker(f.dataType, result_column_index=i) for i, f in enumerate(return_type) ] nullables = [f.nullable for f in return_type] if all(c is None for c in field_checkers) and all(nullables): return None def check(row): if isinstance(row, tuple): for i, (value, checker, nullable) in enumerate(zip(row, field_checkers, nullables)): if value is None: if not nullable: raise_(i) elif checker is not None: checker(value) return check check_output_row_against_schema = build_null_checker(return_type) if ( eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF and runner_conf.use_legacy_pandas_udtf_conversion ): def wrap_arrow_udtf(f, return_type): import pandas as pd return_type_size = len(return_type) def verify_result(result): if not isinstance(result, pd.DataFrame): raise PySparkTypeError( errorClass="INVALID_ARROW_UDTF_RETURN_TYPE", messageParameters={ "return_type": type(result).__name__, "value": str(result), "func": f.__name__, }, ) # Validate the output schema when the result dataframe has either output # rows or columns. Note that we avoid using `df.empty` here because the # result dataframe may contain an empty row. For example, when a UDTF is # defined as follows: def eval(self): yield tuple(). if len(result) > 0 or len(result.columns) > 0: if len(result.columns) != return_type_size: raise PySparkRuntimeError( errorClass="UDTF_RETURN_SCHEMA_MISMATCH", messageParameters={ "expected": str(return_type_size), "actual": str(len(result.columns)), "func": f.__name__, }, ) # Verify the type and the schema of the result. verify_pandas_result( result, return_type, assign_cols_by_name=False, truncate_return_schema=False ) return result # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. def func(*a: Any) -> Any: try: return f(*a) except SkipRestOfInputTableException: raise except Exception as e: raise PySparkRuntimeError( errorClass="UDTF_EXEC_ERROR", messageParameters={"method_name": f.__name__, "error": str(e)}, ) def check_return_value(res): # Check whether the result of an arrow UDTF is iterable before # using it to construct a pandas DataFrame. if res is not None: if not isinstance(res, Iterable): raise PySparkRuntimeError( errorClass="UDTF_RETURN_NOT_ITERABLE", messageParameters={ "type": type(res).__name__, "func": f.__name__, }, ) if check_output_row_against_schema is not None: for row in res: if row is not None: check_output_row_against_schema(row) yield row else: yield from res def evaluate(*args: pd.Series, num_rows=1): if len(args) == 0: for _ in range(num_rows): yield ( verify_result(pd.DataFrame(list(check_return_value(func())))), return_type, ) else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. row_tuples = zip(*args) for row in row_tuples: yield ( verify_result(pd.DataFrame(list(check_return_value(func(*row))))), return_type, ) return evaluate eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support( getattr(udtf, "eval"), args_offsets, kwargs_offsets ) eval = wrap_arrow_udtf(eval_func_kwargs_support, return_type) if hasattr(udtf, "terminate"): terminate = wrap_arrow_udtf(getattr(udtf, "terminate"), return_type) else: terminate = None cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None def mapper(_, it): try: for a in it: # The eval function yields an iterator. Each element produced by this # iterator is a tuple in the form of (pandas.DataFrame, arrow_return_type). yield from eval(*[a[o] for o in args_kwargs_offsets], num_rows=len(a[0])) if terminate is not None: yield from terminate() except SkipRestOfInputTableException: if terminate is not None: yield from terminate() finally: if cleanup is not None: cleanup() return mapper, None, ser, ser elif ( eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF and not runner_conf.use_legacy_pandas_udtf_conversion ): def wrap_arrow_udtf(f, return_type): import pyarrow as pa arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) return_type_size = len(return_type) def verify_result(result): if not isinstance(result, pa.Table): raise PySparkTypeError( errorClass="INVALID_ARROW_UDTF_RETURN_TYPE", messageParameters={ "return_type": type(result).__name__, "value": str(result), "func": f.__name__, }, ) # Validate the output schema when the result dataframe has either output # rows or columns. Note that we avoid using `df.empty` here because the # result dataframe may contain an empty row. For example, when a UDTF is # defined as follows: def eval(self): yield tuple(). if result.num_rows > 0 or result.num_columns > 0: if result.num_columns != return_type_size: raise PySparkRuntimeError( errorClass="UDTF_RETURN_SCHEMA_MISMATCH", messageParameters={ "expected": str(return_type_size), "actual": str(result.num_columns), "func": f.__name__, }, ) # Verify the type and the schema of the result. verify_arrow_result( result, assign_cols_by_name=False, expected_cols_and_types=[ (field.name, field.type) for field in arrow_return_type ], ) return result # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. def func(*a: Any) -> Any: try: return f(*a) except SkipRestOfInputTableException: raise except Exception as e: raise PySparkRuntimeError( errorClass="UDTF_EXEC_ERROR", messageParameters={"method_name": f.__name__, "error": str(e)}, ) def check_return_value(res): # Check whether the result of an arrow UDTF is iterable before # using it to construct a pandas DataFrame. if res is not None: if not isinstance(res, Iterable): raise PySparkRuntimeError( errorClass="UDTF_RETURN_NOT_ITERABLE", messageParameters={ "type": type(res).__name__, "func": f.__name__, }, ) for row in res: if not isinstance(row, tuple) and return_type_size == 1: row = (row,) if check_output_row_against_schema is not None: if row is not None: check_output_row_against_schema(row) yield row def convert_to_arrow(data: Iterable): data = list(check_return_value(data)) if len(data) == 0: # Return one empty RecordBatch to match the left side of the lateral join return [ pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type))) ] def raise_conversion_error(original_exception): raise PySparkRuntimeError( errorClass="UDTF_ARROW_DATA_CONVERSION_ERROR", messageParameters={ "data": str(data), "schema": return_type.simpleString(), "arrow_schema": str(arrow_return_type), }, ) from original_exception try: table = LocalDataToArrowConversion.convert( data, return_type, runner_conf.use_large_var_types ) except PySparkValueError as e: if e.getErrorClass() == "AXIS_LENGTH_MISMATCH": raise PySparkRuntimeError( errorClass="UDTF_RETURN_SCHEMA_MISMATCH", messageParameters={ "expected": e.getMessageParameters()["expected_length"], # type: ignore[index] "actual": e.getMessageParameters()["actual_length"], # type: ignore[index] "func": f.__name__, }, ) from e # Fall through to general conversion error raise_conversion_error(e) except Exception as e: raise_conversion_error(e) return verify_result(table).to_batches() def evaluate(*args: list, num_rows=1): if len(args) == 0: for _ in range(num_rows): for batch in convert_to_arrow(func()): yield batch, arrow_return_type else: for row in zip(*args): for batch in convert_to_arrow(func(*row)): yield batch, arrow_return_type return evaluate eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support( getattr(udtf, "eval"), args_offsets, kwargs_offsets ) eval = wrap_arrow_udtf(eval_func_kwargs_support, return_type) if hasattr(udtf, "terminate"): terminate = wrap_arrow_udtf(getattr(udtf, "terminate"), return_type) else: terminate = None cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None def mapper(_, it): try: converters = [ ArrowTableToRowsConversion._create_converter( f.dataType, none_on_identity=True, binary_as_bytes=runner_conf.binary_as_bytes, ) for f in input_type ] for a in it: pylist = [ ( [conv(v) for v in column.to_pylist()] if conv is not None else column.to_pylist() ) for column, conv in zip(a.columns, converters) ] # The eval function yields an iterator. Each element produced by this # iterator is a tuple in the form of (pyarrow.RecordBatch, arrow_return_type). yield from eval(*[pylist[o] for o in args_kwargs_offsets], num_rows=a.num_rows) if terminate is not None: yield from terminate() except SkipRestOfInputTableException: if terminate is not None: yield from terminate() finally: if cleanup is not None: cleanup() return mapper, None, ser, ser elif eval_type == PythonEvalType.SQL_ARROW_UDTF: def wrap_pyarrow_udtf(f, return_type): import pyarrow as pa arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) return_type_size = len(return_type) def verify_result(result): # Validate the output schema when the result has columns if result.num_columns != return_type_size: raise PySparkRuntimeError( errorClass="UDTF_RETURN_SCHEMA_MISMATCH", messageParameters={ "expected": str(return_type_size), "actual": str(result.num_columns), "func": f.__name__, }, ) # We verify the type of the result and do type corerion # in the serializer return result # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. def func(*a: Any) -> Any: try: return f(*a) except SkipRestOfInputTableException: raise except Exception as e: raise PySparkRuntimeError( errorClass="UDTF_EXEC_ERROR", messageParameters={"method_name": f.__name__, "error": str(e)}, ) def check_return_value(res): # Check whether the result of a PyArrow UDTF is iterable before processing if res is not None: if not isinstance(res, Iterable): raise PySparkRuntimeError( errorClass="UDTF_RETURN_NOT_ITERABLE", messageParameters={ "type": type(res).__name__, "func": f.__name__, }, ) return res else: return iter([]) def convert_to_arrow(data: Iterable): data_iter = check_return_value(data) # Handle PyArrow Tables/RecordBatches directly is_empty = True for item in data_iter: is_empty = False if isinstance(item, pa.Table): yield from item.to_batches() elif isinstance(item, pa.RecordBatch): yield item else: # Arrow UDTF should only return Arrow types (RecordBatch/Table) raise PySparkRuntimeError( errorClass="UDTF_ARROW_TYPE_CONVERSION_ERROR", messageParameters={}, ) if is_empty: yield pa.RecordBatch.from_pylist([], schema=pa.schema(list(arrow_return_type))) def evaluate(*args: pa.RecordBatch): # For Arrow UDTFs, unpack the RecordBatches and pass them to the function for batch in convert_to_arrow(func(*args)): yield verify_result(batch), arrow_return_type return evaluate eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support( getattr(udtf, "eval"), args_offsets, kwargs_offsets ) eval = wrap_pyarrow_udtf(eval_func_kwargs_support, return_type) if hasattr(udtf, "terminate"): terminate = wrap_pyarrow_udtf(getattr(udtf, "terminate"), return_type) else: terminate = None cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None def mapper(_, it): try: for a in it: # For PyArrow UDTFs, pass RecordBatches directly (no row conversion needed) yield from eval(*[a[o] for o in args_kwargs_offsets]) if terminate is not None: yield from terminate() except SkipRestOfInputTableException: if terminate is not None: yield from terminate() finally: if cleanup is not None: cleanup() return mapper, None, ser, ser else: def wrap_udtf(f, return_type): assert return_type.needConversion() toInternal = return_type.toInternal return_type_size = len(return_type) def verify_and_convert_result(result): if result is not None: if hasattr(result, "__UDT__"): # UDT object should not be returned directly. raise PySparkRuntimeError( errorClass="UDTF_INVALID_OUTPUT_ROW_TYPE", messageParameters={ "type": type(result).__name__, "func": f.__name__, }, ) if hasattr(result, "__len__") and len(result) != return_type_size: raise PySparkRuntimeError( errorClass="UDTF_RETURN_SCHEMA_MISMATCH", messageParameters={ "expected": str(return_type_size), "actual": str(len(result)), "func": f.__name__, }, ) if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")): raise PySparkRuntimeError( errorClass="UDTF_INVALID_OUTPUT_ROW_TYPE", messageParameters={ "type": type(result).__name__, "func": f.__name__, }, ) if check_output_row_against_schema is not None: check_output_row_against_schema(result) return toInternal(result) # Evaluate the function and return a tuple back to the executor. def evaluate(*a) -> tuple: try: res = f(*a) except SkipRestOfInputTableException: raise except Exception as e: raise PySparkRuntimeError( errorClass="UDTF_EXEC_ERROR", messageParameters={"method_name": f.__name__, "error": str(e)}, ) if res is None: # If the function returns None or does not have an explicit return statement, # an empty tuple is returned to the executor. # This is because directly constructing tuple(None) results in an exception. return tuple() if not isinstance(res, Iterable): raise PySparkRuntimeError( errorClass="UDTF_RETURN_NOT_ITERABLE", messageParameters={ "type": type(res).__name__, "func": f.__name__, }, ) # If the function returns a result, we map it to the internal representation and # returns the results as a tuple. return tuple(map(verify_and_convert_result, res)) return evaluate eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support( getattr(udtf, "eval"), args_offsets, kwargs_offsets ) eval = wrap_udtf(eval_func_kwargs_support, return_type) if hasattr(udtf, "terminate"): terminate = wrap_udtf(getattr(udtf, "terminate"), return_type) else: terminate = None cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None # Return an iterator of iterators. def mapper(_, it): try: for a in it: yield eval(*[a[o] for o in args_kwargs_offsets]) if terminate is not None: yield terminate() except SkipRestOfInputTableException: if terminate is not None: yield terminate() finally: if cleanup is not None: cleanup() return mapper, None, ser, ser def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf): if eval_type in ( PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_ARROW_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF, PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, ): # NOTE: if timezone is set here, that implies respectSessionTimeZone is True if ( eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF ): ser = ArrowStreamGroupUDFSerializer(assign_cols_by_name=runner_conf.assign_cols_by_name) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, ): ser = ArrowStreamGroupSerializer(write_start_stream=True) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF: ser = ArrowStreamGroupSerializer(write_start_stream=True) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, ): ser = ArrowStreamAggPandasUDFSerializer( timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, assign_cols_by_name=runner_conf.assign_cols_by_name, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) elif ( eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF ): ser = GroupPandasUDFSerializer( timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, assign_cols_by_name=runner_conf.assign_cols_by_name, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: ser = CogroupArrowUDFSerializer(assign_cols_by_name=runner_conf.assign_cols_by_name) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupPandasUDFSerializer( timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, assign_cols_by_name=runner_conf.assign_cols_by_name, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, arrow_cast=True, ) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: ser = ApplyInPandasWithStateSerializer( timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, assign_cols_by_name=runner_conf.assign_cols_by_name, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, state_object_schema=eval_conf.state_value_schema, arrow_max_records_per_batch=runner_conf.arrow_max_records_per_batch, prefers_large_var_types=runner_conf.use_large_var_types, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: ser = TransformWithStateInPandasSerializer( timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, assign_cols_by_name=runner_conf.assign_cols_by_name, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, arrow_max_records_per_batch=runner_conf.arrow_max_records_per_batch, arrow_max_bytes_per_batch=runner_conf.arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF: ser = TransformWithStateInPandasInitStateSerializer( timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, assign_cols_by_name=runner_conf.assign_cols_by_name, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, arrow_max_records_per_batch=runner_conf.arrow_max_records_per_batch, arrow_max_bytes_per_batch=runner_conf.arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: ser = TransformWithStateInPySparkRowSerializer( arrow_max_records_per_batch=runner_conf.arrow_max_records_per_batch ) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: ser = TransformWithStateInPySparkRowInitStateSerializer( arrow_max_records_per_batch=runner_conf.arrow_max_records_per_batch ) elif eval_type in ( PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_SCALAR_ARROW_UDF, PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, ): ser = ArrowStreamSerializer(write_start_stream=True) elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF: input_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) ser = ArrowStreamSerializer(write_start_stream=True) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. df_for_struct = ( eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF ) ser = ArrowStreamPandasUDFSerializer( timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, assign_cols_by_name=runner_conf.assign_cols_by_name, df_for_struct=df_for_struct, struct_in_pandas="dict", ndarray_as_list=False, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, arrow_cast=True, input_type=None, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, prefers_large_types=runner_conf.use_large_var_types, ) else: batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100")) ser = BatchedSerializer(CPickleSerializer(), batch_size) # Read all UDFs num_udfs = read_int(infile) udfs = [ read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i) for i in range(num_udfs) ] if eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: import pyarrow as pa assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." udf_func: Callable[[Iterator[pa.RecordBatch]], Iterator[pa.RecordBatch]] = udfs[0][0] def func(split_index: int, batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: """Apply mapInArrow UDF""" # Pre-processing input_batches: Iterator[pa.RecordBatch] = map( ArrowBatchTransformer.flatten_struct, batches ) # invoke the UDF output_batches = udf_func(input_batches) # Post-processing verified: Iterator[pa.RecordBatch] = verify_result(pa.RecordBatch)(output_batches) yield from map(ArrowBatchTransformer.wrap_struct, verified) # profiling is not supported for UDF return func, None, ser, ser if eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF: import pyarrow as pa col_names = ["_%d" % i for i in range(len(udfs))] combined_arrow_schema = to_arrow_schema( StructType([StructField(n, rt) for n, (_, _, _, rt) in zip(col_names, udfs)]), timezone="UTC", prefers_large_types=runner_conf.use_large_var_types, ) def func(split_index: int, batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: """Apply scalar Arrow UDFs""" for input_batch in batches: output_batch = pa.RecordBatch.from_arrays( [ udf_func( *[input_batch.column(o) for o in args_offsets], **{k: input_batch.column(v) for k, v in kwargs_offsets.items()}, ) for udf_func, args_offsets, kwargs_offsets, _ in udfs ], col_names, ) output_batch = ArrowBatchTransformer.enforce_schema( output_batch, combined_arrow_schema ) verify_scalar_result(output_batch, input_batch.num_rows) yield output_batch # profiling is not supported for UDF return func, None, ser, ser if eval_type == PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF: import pyarrow as pa assert num_udfs == 1, "One SCALAR_ARROW_ITER UDF expected here." udf_func, args_offsets, kwargs_offsets, return_type = udfs[0] # Pre-compute target Arrow type for output coercion arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) def func(split_index: int, batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: """Apply scalar Arrow iterator UDF""" num_input_rows = 0 def extract_args(batch: pa.RecordBatch): nonlocal num_input_rows args = tuple(batch.column(o) for o in args_offsets) num_input_rows += batch.num_rows return args[0] if len(args) == 1 else args # Extract args from input batches (streaming) args_iter = map(extract_args, batches) # Call UDF and verify result type (iterator of pa.Array) verified_iter = verify_result(pa.Array)(udf_func(args_iter)) # Process results: enforce schema and assemble into RecordBatch target_schema = pa.schema([pa.field("_0", arrow_return_type)]) def process_results(): for result in verified_iter: batch = pa.RecordBatch.from_arrays([result], ["_0"]) yield ArrowBatchTransformer.enforce_schema(batch, target_schema, safecheck=True) # Apply row limit check (fail-fast) limited = verify_output_row_limit( process_results(), lambda: num_input_rows, error_class="OUTPUT_EXCEEDS_INPUT_ROWS", ) # Apply row count match check (final) matched = verify_output_row_count( limited, lambda: num_input_rows, error_class="RESULT_ROWS_MISMATCH", ) # Yield batches yield from matched # Verify iterator consumed verify_iterator_exhausted( args_iter, error_class="INPUT_NOT_FULLY_CONSUMED", ) # profiling is not supported for UDF return func, None, ser, ser if eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF: import pyarrow as pa # Pre-compute target schema for output coercion col_names = ["_%d" % i for i in range(len(udfs))] return_schema = to_arrow_schema( StructType([StructField(name, rt) for name, (_, _, _, rt) in zip(col_names, udfs)]), timezone="UTC", prefers_large_types=runner_conf.use_large_var_types, ) def func(split_index: int, batches: Iterator[Any]) -> Iterator[pa.RecordBatch]: for group_batches in batches: batch_list = list(group_batches) if not batch_list: continue if hasattr(pa, "concat_batches"): concatenated = pa.concat_batches(batch_list) else: # pyarrow.concat_batches not supported before 19.0.0 # remove this once we drop support for old versions concatenated = pa.RecordBatch.from_struct_array( pa.concat_arrays([b.to_struct_array() for b in batch_list]) ) results = [ udf_func( *[concatenated.column(o) for o in args_offsets], **{k: concatenated.column(v) for k, v in kwargs_offsets.items()}, ) for udf_func, args_offsets, kwargs_offsets, _ in udfs ] result_arrays = [pa.array([r]) for r in results] batch = pa.RecordBatch.from_arrays(result_arrays, col_names) yield ArrowBatchTransformer.enforce_schema(batch, return_schema) # profiling is not supported for UDF return func, None, ser, ser if eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF: import pyarrow as pa assert num_udfs == 1, "One GROUPED_AGG_ARROW_ITER UDF expected here." udf_func, args_offsets, kwargs_offsets, return_type = udfs[0] return_schema = to_arrow_schema( StructType([StructField("_0", return_type)]), timezone="UTC", prefers_large_types=runner_conf.use_large_var_types, ) def extract_args(batch): args = tuple(batch.column(o) for o in args_offsets) return args[0] if len(args) == 1 else args def func(split_index: int, batches: Iterator[Any]) -> Iterator[pa.RecordBatch]: for group_batches in batches: batch_iter = map(extract_args, group_batches) result = udf_func(batch_iter) # Drain remaining batches to maintain stream position for _ in batch_iter: pass batch = pa.RecordBatch.from_arrays([pa.array([result])], ["_0"]) yield ArrowBatchTransformer.enforce_schema(batch, return_schema) # profiling is not supported for UDF return func, None, ser, ser if eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF: import pyarrow as pa window_bound_types_str = runner_conf.get("window_bound_types") window_bound_types = [t.strip().lower() for t in window_bound_types_str.split(",")] col_names = ["_%d" % i for i in range(len(udfs))] return_schema = to_arrow_schema( StructType([StructField(name, rt) for name, (_, _, _, rt) in zip(col_names, udfs)]), timezone="UTC", prefers_large_types=runner_conf.use_large_var_types, ) def func(split_index: int, batches: Iterator[Any]) -> Iterator[pa.RecordBatch]: for group_batches in batches: batch_list = list(group_batches) if not batch_list: continue if hasattr(pa, "concat_batches"): concatenated = pa.concat_batches(batch_list) else: # pyarrow.concat_batches not supported before 19.0.0 # remove this once we drop support for old versions concatenated = pa.RecordBatch.from_struct_array( pa.concat_arrays([b.to_struct_array() for b in batch_list]) ) num_rows = concatenated.num_rows result_arrays = [] for udf_index, (udf_func, args_offsets, kwargs_offsets, _) in enumerate(udfs): bound_type = window_bound_types[udf_index] if bound_type == "unbounded": result = udf_func( *[concatenated.column(o) for o in args_offsets], **{k: concatenated.column(v) for k, v in kwargs_offsets.items()}, ) result_arrays.append(pa.repeat(result, num_rows)) elif bound_type == "bounded": begin_col = concatenated.column(args_offsets[0]) end_col = concatenated.column(args_offsets[1]) results = [] for i in range(num_rows): offset = begin_col[i].as_py() length = end_col[i].as_py() - offset slices = [ concatenated.column(o).slice(offset=offset, length=length) for o in args_offsets[2:] ] kw_slices = { k: concatenated.column(v).slice(offset=offset, length=length) for k, v in kwargs_offsets.items() } results.append(udf_func(*slices, **kw_slices)) result_arrays.append(pa.array(results)) else: raise PySparkRuntimeError( errorClass="INVALID_WINDOW_BOUND_TYPE", messageParameters={"window_bound_type": bound_type}, ) batch = pa.RecordBatch.from_arrays(result_arrays, col_names) yield ArrowBatchTransformer.enforce_schema(batch, return_schema) # profiling is not supported for UDF return func, None, ser, ser if ( eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF and not runner_conf.use_legacy_pandas_udf_conversion ): import pyarrow as pa # --- UDF preparation --- udf_infos = [] for udf_func, udf_args_offsets, udf_kwargs_offsets, udf_return_type in udfs: wrapped_func, args_kwargs_offsets = wrap_kwargs_support( udf_func, udf_args_offsets, udf_kwargs_offsets ) zero_arg = len(args_kwargs_offsets) == 0 udf_infos.append( ( wrapped_func, args_kwargs_offsets or (0,), zero_arg, to_arrow_type( udf_return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types, ), LocalDataToArrowConversion._create_converter( udf_return_type, none_on_identity=True, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ), ) ) col_names = [f"_{i}" for i in range(len(udfs))] # --- Input preparation --- arrow_to_py_converters = [ ArrowTableToRowsConversion._create_converter( f.dataType, none_on_identity=True, binary_as_bytes=runner_conf.binary_as_bytes ) for f in input_type ] @fail_on_stopiteration def _evaluate_batch_udf(udf_func, rows): if runner_conf.arrow_concurrency_level <= 0: return [udf_func(*row) for row in rows] from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor(max_workers=runner_conf.arrow_concurrency_level) as pool: return list(pool.map(lambda row: udf_func(*row), rows)) def func(split_index: int, batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: for input_batch in batches: num_rows = input_batch.num_rows # --- Input: Arrow -> Python columns --- columns = [ [conv(v) for v in col.to_pylist()] if conv is not None else col.to_pylist() for col, conv in zip(input_batch.itercolumns(), arrow_to_py_converters) ] if not columns: columns = [[_NoValue] * num_rows] # --- Process: evaluate each UDF row-by-row --- output_arrays = [] for udf_func, offsets, zero_arg, arrow_return_type, result_conv in udf_infos: rows = ( [() for _ in range(num_rows)] if zero_arg else list(zip(*[columns[o] for o in offsets])) ) results = _evaluate_batch_udf(udf_func, rows) verify_result_row_count(len(results), num_rows) # --- Output: Python -> Arrow --- converted = ( [result_conv(r) for r in results] if result_conv is not None else results ) try: arr = pa.array(converted, type=arrow_return_type) except pa.lib.ArrowInvalid: arr = pa.array(converted).cast( target_type=arrow_return_type, safe=runner_conf.safecheck ) output_arrays.append(arr) yield pa.RecordBatch.from_arrays(output_arrays, col_names) # profiling is not supported for UDF return func, None, ser, ser if ( eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF and runner_conf.use_legacy_pandas_udf_conversion ): import pandas as pd import pyarrow as pa # --- UDF preparation --- udf_infos = [] for udf_func, udf_args_offsets, udf_kwargs_offsets, udf_return_type in udfs: wrapped_func, args_kwargs_offsets = wrap_kwargs_support( udf_func, udf_args_offsets, udf_kwargs_offsets ) zero_arg = len(args_kwargs_offsets) == 0 # Legacy coerces String/Binary for Arrow compatibility coerce = ( str if isinstance(udf_return_type, StringType) else bytes if isinstance(udf_return_type, BinaryType) else None ) udf_infos.append( ( wrapped_func, args_kwargs_offsets or (0,), zero_arg, udf_return_type, coerce, ) ) col_names = [f"_{i}" for i in range(len(udfs))] return_schema = StructType( [StructField(name, info[3]) for name, info in zip(col_names, udf_infos)] ) @fail_on_stopiteration def _evaluate_batch_udf_legacy(udf_func, rows): if runner_conf.arrow_concurrency_level <= 0: return [udf_func(*row) for row in rows] from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor(max_workers=runner_conf.arrow_concurrency_level) as pool: return list(pool.map(lambda row: udf_func(*row), rows)) def func(split_index: int, batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: for input_batch in batches: # --- Input: Arrow -> pandas columns --- pandas_columns = ArrowBatchTransformer.to_pandas( input_batch, timezone=runner_conf.timezone, schema=input_type, struct_in_pandas="row", ndarray_as_list=True, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, df_for_struct=False, ) num_rows = len(pandas_columns[0]) if pandas_columns else input_batch.num_rows if not pandas_columns: pandas_columns = [pd.Series([_NoValue] * num_rows)] # --- Process: evaluate each UDF row-by-row --- result_series = [] for udf_func, offsets, zero_arg, _, coerce in udf_infos: rows = ( [() for _ in range(num_rows)] if zero_arg else list(zip(*[pandas_columns[o].tolist() for o in offsets])) ) results = _evaluate_batch_udf_legacy(udf_func, rows) verify_result_row_count(len(results), num_rows) if coerce: results = [coerce(v) if v is not None else v for v in results] result_series.append(pd.Series(results)) # --- Output: pandas -> Arrow --- yield PandasToArrowConversion.convert( result_series, return_schema, timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, arrow_cast=True, prefers_large_types=runner_conf.use_large_var_types, assign_cols_by_name=runner_conf.assign_cols_by_name, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) # profiling is not supported for UDF return func, None, ser, ser is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF if is_scalar_iter or is_map_pandas_iter: # TODO: Better error message for num_udfs != 1 if is_scalar_iter: assert num_udfs == 1, "One SCALAR_ITER UDF expected here." if is_map_pandas_iter: assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here." arg_offsets, udf = udfs[0] def func(_, iterator): # type: ignore[misc] num_input_rows = 0 def map_batch(batch): nonlocal num_input_rows udf_args = [batch[offset] for offset in arg_offsets] num_input_rows += len(udf_args[0]) if len(udf_args) == 1: return udf_args[0] else: return tuple(udf_args) iterator = map(map_batch, iterator) result_iter = udf(iterator) num_output_rows = 0 for result_batch, result_type in result_iter: num_output_rows += len(result_batch) # This check is for Scalar Iterator UDF to fail fast. # The length of the entire input can only be explicitly known # by consuming the input iterator in user side. Therefore, # it's very unlikely the output length is higher than # input length. if is_scalar_iter and num_output_rows > num_input_rows: raise PySparkRuntimeError( errorClass="OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={} ) yield (result_batch, result_type) if is_scalar_iter: try: next(iterator) except StopIteration: pass else: raise PySparkRuntimeError( errorClass="INPUT_NOT_FULLY_CONSUMED", messageParameters={}, ) if num_output_rows != num_input_rows: raise PySparkRuntimeError( errorClass="RESULT_ROWS_MISMATCH", messageParameters={ "output_length": str(num_output_rows), "input_length": str(num_input_rows), }, ) # profiling is not supported for UDF return func, None, ser, ser def extract_key_value_indexes(grouped_arg_offsets): """ Helper function to extract the key and value indexes from arg_offsets for the grouped and cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code. Parameters ---------- grouped_arg_offsets: list List containing the key and value indexes of columns of the DataFrames to be passed to the udf. It consists of n repeating groups where n is the number of DataFrames. Each group has the following format: group[0]: length of group group[1]: length of key indexes group[2.. group[1] +2]: key attributes group[group[1] +3 group[0]]: value attributes """ parsed = [] idx = 0 while idx < len(grouped_arg_offsets): offsets_len = grouped_arg_offsets[idx] idx += 1 offsets = grouped_arg_offsets[idx : idx + offsets_len] split_index = offsets[0] + 1 offset_keys = offsets[1:split_index] offset_values = offsets[split_index:] parsed.append([offset_keys, offset_values]) idx += offsets_len return parsed if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: import pyarrow as pa # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) key_offsets = parsed_offsets[0][0] value_offsets = parsed_offsets[0][1] def mapper(batch_iter): # Collect all Arrow batches and merge at Arrow level all_batches = list(batch_iter) if all_batches: table = pa.Table.from_batches(all_batches).combine_chunks() else: table = pa.table({}) # Convert to pandas once for the entire group all_series = ArrowBatchTransformer.to_pandas( table, timezone=ser._timezone, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, ) key_series = [all_series[o] for o in key_offsets] value_series = [all_series[o] for o in value_offsets] yield from f(key_series, value_series) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: import pyarrow as pa # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(batch_iter): # Convert first Arrow batch to pandas to extract keys first_series = ArrowBatchTransformer.to_pandas( next(batch_iter), timezone=ser._timezone, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, ) key_series = [first_series[o] for o in parsed_offsets[0][0]] # Lazily convert remaining Arrow batches to pandas Series def value_series_gen(): yield [first_series[o] for o in parsed_offsets[0][1]] for batch in batch_iter: series = ArrowBatchTransformer.to_pandas( batch, timezone=ser._timezone, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, ) yield [series[o] for o in parsed_offsets[0][1]] yield from f(key_series, value_series_gen()) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See TransformWithStateInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) ser.key_offsets = parsed_offsets[0][0] stateful_processor_api_client = StatefulProcessorApiClient( eval_conf.state_server_socket_port, eval_conf.grouping_key_schema ) def mapper(a): mode = a[0] if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: key = a[1] def values_gen(): for x in a[2]: retVal = x[1].iloc[:, parsed_offsets[0][1]] yield retVal # This must be generator comprehension - do not materialize. return f(stateful_processor_api_client, mode, key, values_gen()) else: # mode == PROCESS_TIMER or mode == COMPLETE return f(stateful_processor_api_client, mode, None, iter([])) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See TransformWithStateInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] # parsed offsets: # [ # [groupingKeyOffsets, dedupDataOffsets], # [initStateGroupingOffsets, dedupInitDataOffsets] # ] parsed_offsets = extract_key_value_indexes(arg_offsets) ser.key_offsets = parsed_offsets[0][0] ser.init_key_offsets = parsed_offsets[1][0] stateful_processor_api_client = StatefulProcessorApiClient( eval_conf.state_server_socket_port, eval_conf.grouping_key_schema ) def mapper(a): mode = a[0] if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: key = a[1] def values_gen(): for x in a[2]: retVal = x[1] initVal = x[2] yield retVal, initVal # This must be generator comprehension - do not materialize. return f(stateful_processor_api_client, mode, key, values_gen()) else: # mode == PROCESS_TIMER or mode == COMPLETE return f(stateful_processor_api_client, mode, None, iter([])) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See TransformWithStateInPySparkExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) ser.key_offsets = parsed_offsets[0][0] stateful_processor_api_client = StatefulProcessorApiClient( eval_conf.state_server_socket_port, eval_conf.grouping_key_schema ) def mapper(a): mode = a[0] if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: key = a[1] values = a[2] # This must be generator comprehension - do not materialize. return f(stateful_processor_api_client, mode, key, values) else: # mode == PROCESS_TIMER or mode == COMPLETE return f(stateful_processor_api_client, mode, None, iter([])) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See TransformWithStateInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] # parsed offsets: # [ # [groupingKeyOffsets, dedupDataOffsets], # [initStateGroupingOffsets, dedupInitDataOffsets] # ] parsed_offsets = extract_key_value_indexes(arg_offsets) ser.key_offsets = parsed_offsets[0][0] ser.init_key_offsets = parsed_offsets[1][0] stateful_processor_api_client = StatefulProcessorApiClient( eval_conf.state_server_socket_port, eval_conf.grouping_key_schema ) def mapper(a): mode = a[0] if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: key = a[1] values = a[2] # This must be generator comprehension - do not materialize. return f(stateful_processor_api_client, mode, key, values) else: # mode == PROCESS_TIMER or mode == COMPLETE return f(stateful_processor_api_client, mode, None, iter([])) elif ( eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF ): import pyarrow as pa # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) def batch_from_offset(batch, offsets): return pa.RecordBatch.from_arrays( arrays=[batch.columns[o] for o in offsets], names=[batch.schema.names[o] for o in offsets], ) def mapper(batches): # Flatten struct column into separate columns flattened = map(ArrowBatchTransformer.flatten_struct, batches) # Need to materialize the first batch to get the keys first_batch = next(flattened) keys = batch_from_offset(first_batch, parsed_offsets[0][0]) value_batches = ( batch_from_offset(batch, parsed_offsets[0][1]) for batch in itertools.chain((first_batch,), flattened) ) return f(keys, value_batches) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(a): """ The function receives (iterator of data, state) and performs extraction of key and value from the data, with retaining lazy evaluation. See `load_stream` in `ApplyInPandasWithStateSerializer` for more details on the input and see `wrap_grouped_map_pandas_udf_with_state` for more details on how output will be used. """ from itertools import chain state = a[1] data_gen = (x[0] for x in a[0]) # We know there should be at least one item in the iterator/generator. # Consume the first element to extract keys first_elem = next(data_gen) keys = [first_elem[o] for o in parsed_offsets[0][0]] # This must be generator comprehension - do not materialize. vals = ([x[o] for o in parsed_offsets[0][1]] for x in chain([first_elem], data_gen)) return f(keys, vals, state) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(a): df1_keys = [a[0][o] for o in parsed_offsets[0][0]] df1_vals = [a[0][o] for o in parsed_offsets[0][1]] df2_keys = [a[1][o] for o in parsed_offsets[1][0]] df2_vals = [a[1][o] for o in parsed_offsets[1][1]] return f(df1_keys, df1_vals, df2_keys, df2_vals) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: import pyarrow as pa # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) def batch_from_offset(batch, offsets): return pa.RecordBatch.from_arrays( arrays=[batch.columns[o] for o in offsets], names=[batch.schema.names[o] for o in offsets], ) def table_from_batches(batches, offsets): return pa.Table.from_batches([batch_from_offset(batch, offsets) for batch in batches]) def mapper(a): df1_keys = table_from_batches(a[0], parsed_offsets[0][0]) df1_vals = table_from_batches(a[0], parsed_offsets[0][1]) df2_keys = table_from_batches(a[1], parsed_offsets[1][0]) df2_vals = table_from_batches(a[1], parsed_offsets[1][1]) return f(df1_keys, df1_vals, df2_keys, df2_vals) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF: # We assume there is only one UDF here because grouped agg doesn't # support combining multiple UDFs. assert num_udfs == 1 arg_offsets, f = udfs[0] # Convert to iterator of pandas Series: # - Iterator[pd.Series] for single column # - Iterator[Tuple[pd.Series, ...]] for multiple columns def mapper(batch_iter): # batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple represents one batch # Convert to Iterator[pd.Series] or Iterator[Tuple[pd.Series, ...]] based on arg_offsets if len(arg_offsets) == 1: # Single column: Iterator[Tuple[pd.Series, ...]] -> Iterator[pd.Series] series_iter = (batch_series[arg_offsets[0]] for batch_series in batch_iter) else: # Multiple columns: Iterator[Tuple[pd.Series, ...]] -> # Iterator[Tuple[pd.Series, ...]] series_iter = ( tuple(batch_series[o] for o in arg_offsets) for batch_series in batch_iter ) return f(series_iter) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, ): import pandas as pd # For SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF, # convert iterator of batch tuples to concatenated pandas Series def mapper(batch_iter): # batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple represents one batch # Collect all batches and concatenate into single Series per column batches = list(batch_iter) if not batches: # Empty batches - determine num_columns from all UDFs' arg_offsets all_offsets = [o for arg_offsets, _ in udfs for o in arg_offsets] num_columns = max(all_offsets) + 1 if all_offsets else 0 concatenated = [pd.Series(dtype=object) for _ in range(num_columns)] else: # Use actual number of columns from the first batch num_columns = len(batches[0]) concatenated = [ pd.concat([batch[i] for batch in batches], ignore_index=True) for i in range(num_columns) ] result = tuple(f(*[concatenated[o] for o in arg_offsets]) for arg_offsets, f in udfs) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. if len(result) == 1: return result[0] else: return result else: def mapper(a): result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. if len(result) == 1: return result[0] else: return result def func(_, it): return map(mapper, it) # profiling is not supported for UDF return func, None, ser, ser @with_faulthandler def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests sys.exit(-1) start_faulthandler_periodic_traceback() check_python_version(infile) memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1")) setup_memory_limits(memory_limit_mb) task_context_json = json.loads(utf8_deserializer.loads(infile)) if task_context_json["isBarrier"]: taskContext = BarrierTaskContext.from_json(task_context_json) else: taskContext = TaskContext.from_json(task_context_json) TaskContext._setTaskContext(taskContext) shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 setup_spark_files(infile) setup_broadcasts(infile) _accumulatorRegistry.clear() eval_type = read_int(infile) runner_conf = RunnerConf(infile) eval_conf = EvalConf(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) elif eval_type in ( PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF, PythonEvalType.SQL_ARROW_UDTF, ): func, profiler, deserializer, serializer = read_udtf( pickleSer, infile, eval_type, runner_conf ) else: func, profiler, deserializer, serializer = read_udfs( pickleSer, infile, eval_type, runner_conf, eval_conf ) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) out_iter = func(split_index, iterator) try: serializer.dump_stream(out_iter, outfile) finally: if hasattr(out_iter, "close"): out_iter.close() processing_start_time = time.time() with capture_outputs(): if profiler: profiler.profile(process) else: process() processing_time_ms = int(1000 * (time.time() - processing_start_time)) # Reset task context to None. This is a guard code to avoid residual context when worker # reuse. TaskContext._setTaskContext(None) BarrierTaskContext._setTaskContext(None) except BaseException as e: handle_worker_exception(e, outfile) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time, processing_time_ms) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) send_accumulator_updates(outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1) if __name__ == "__main__": with get_sock_file_to_executor() as sock_file: main(sock_file, sock_file)