# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import sys import json from typing import IO, Iterator, Tuple import dataclasses from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import IllegalArgumentException, PySparkAssertionError from pyspark.errors.exceptions.base import PySparkException from pyspark.serializers import ( read_int, write_int, write_with_length, SpecialLengths, ) from pyspark.sql.datasource import ( DataSource, DataSourceStreamReader, ) from pyspark.sql.streaming.datasource import ( SupportsTriggerAvailableNow, ) from pyspark.sql.datasource_internal import ( _SimpleStreamReaderWrapper, _streamReader, ReadLimitRegistry, ) from pyspark.sql.pandas.serializers import ArrowStreamSerializer from pyspark.sql.types import ( _parse_datatype_json_string, StructType, ) from pyspark.sql.worker.plan_data_source_read import records_to_arrow_batches from pyspark.util import handle_worker_exception from pyspark.worker_util import ( get_sock_file_to_executor, check_python_version, read_command, pickleSer, send_accumulator_updates, setup_memory_limits, setup_spark_files, utf8_deserializer, ) INITIAL_OFFSET_FUNC_ID = 884 LATEST_OFFSET_FUNC_ID = 885 PARTITIONS_FUNC_ID = 886 COMMIT_FUNC_ID = 887 CHECK_SUPPORTED_FEATURES_ID = 888 PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889 LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890 GET_DEFAULT_READ_LIMIT_FUNC_ID = 891 REPORT_LATEST_OFFSET_FUNC_ID = 892 PREFETCHED_RECORDS_NOT_FOUND = 0 NON_EMPTY_PYARROW_RECORD_BATCHES = 1 EMPTY_PYARROW_RECORD_BATCHES = 2 SUPPORTS_ADMISSION_CONTROL = 1 << 0 SUPPORTS_TRIGGER_AVAILABLE_NOW = 1 << 1 READ_LIMIT_REGISTRY = ReadLimitRegistry() def initial_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None: offset = reader.initialOffset() write_with_length(json.dumps(offset).encode("utf-8"), outfile) def latest_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None: offset = reader.latestOffset() # type: ignore[call-arg] write_with_length(json.dumps(offset).encode("utf-8"), outfile) def partitions_func( reader: DataSourceStreamReader, data_source: DataSource, schema: StructType, max_arrow_batch_size: int, infile: IO, outfile: IO, ) -> None: start_offset = json.loads(utf8_deserializer.loads(infile)) end_offset = json.loads(utf8_deserializer.loads(infile)) partitions = reader.partitions(start_offset, end_offset) # Return the serialized partition values. write_int(len(partitions), outfile) for partition in partitions: pickleSer._write_with_length(partition, outfile) if isinstance(reader, _SimpleStreamReaderWrapper): it = reader.getCache(start_offset, end_offset) if it is None: write_int(PREFETCHED_RECORDS_NOT_FOUND, outfile) else: send_batch_func(it, outfile, schema, max_arrow_batch_size, data_source) else: write_int(PREFETCHED_RECORDS_NOT_FOUND, outfile) def commit_func(reader: DataSourceStreamReader, infile: IO, outfile: IO) -> None: end_offset = json.loads(utf8_deserializer.loads(infile)) reader.commit(end_offset) write_int(0, outfile) def send_batch_func( rows: Iterator[Tuple], outfile: IO, schema: StructType, max_arrow_batch_size: int, data_source: DataSource, ) -> None: batches = list(records_to_arrow_batches(rows, max_arrow_batch_size, schema, data_source)) if len(batches) != 0: write_int(NON_EMPTY_PYARROW_RECORD_BATCHES, outfile) write_int(SpecialLengths.START_ARROW_STREAM, outfile) serializer = ArrowStreamSerializer() serializer.dump_stream(batches, outfile) else: write_int(EMPTY_PYARROW_RECORD_BATCHES, outfile) def check_support_func(reader: DataSourceStreamReader, outfile: IO) -> None: support_flags = 0 if isinstance(reader, _SimpleStreamReaderWrapper): # We consider the method of `read` in simple_reader to already have admission control # into it. support_flags |= SUPPORTS_ADMISSION_CONTROL if isinstance(reader.simple_reader, SupportsTriggerAvailableNow): support_flags |= SUPPORTS_TRIGGER_AVAILABLE_NOW else: import inspect sig = inspect.signature(reader.latestOffset) if len(sig.parameters) == 0: # old signature of latestOffset() pass else: # we don't check the number/type of parameters here strictly - we leave the python to # raise error when calling the method if the types do not match. support_flags |= SUPPORTS_ADMISSION_CONTROL if isinstance(reader, SupportsTriggerAvailableNow): support_flags |= SUPPORTS_TRIGGER_AVAILABLE_NOW write_int(support_flags, outfile) def prepare_for_trigger_available_now_func(reader: DataSourceStreamReader, outfile: IO) -> None: if isinstance(reader, _SimpleStreamReaderWrapper): if isinstance(reader.simple_reader, SupportsTriggerAvailableNow): reader.simple_reader.prepareForTriggerAvailableNow() else: raise PySparkException( "prepareForTriggerAvailableNow is not supported by the underlying simple reader." ) else: if isinstance(reader, SupportsTriggerAvailableNow): reader.prepareForTriggerAvailableNow() else: raise PySparkException( "prepareForTriggerAvailableNow is not supported by the stream reader." ) write_int(0, outfile) def latest_offset_admission_control_func( reader: DataSourceStreamReader, infile: IO, outfile: IO ) -> None: start_offset_dict = json.loads(utf8_deserializer.loads(infile)) limit = json.loads(utf8_deserializer.loads(infile)) limit_obj = READ_LIMIT_REGISTRY.get(limit) offset = reader.latestOffset(start_offset_dict, limit_obj) write_with_length(json.dumps(offset).encode("utf-8"), outfile) def get_default_read_limit_func(reader: DataSourceStreamReader, outfile: IO) -> None: limit = reader.getDefaultReadLimit() limit_as_dict = dataclasses.asdict(limit) | { # type: ignore[call-overload] "_type": limit.__class__.__name__ } write_with_length(json.dumps(limit_as_dict).encode("utf-8"), outfile) def report_latest_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None: if isinstance(reader, _SimpleStreamReaderWrapper): # We do not consider providing latest offset on simple stream reader. write_int(0, outfile) else: offset = reader.reportLatestOffset() if offset is None: write_int(0, outfile) else: write_with_length(json.dumps(offset).encode("utf-8"), outfile) def main(infile: IO, outfile: IO) -> None: try: check_python_version(infile) setup_spark_files(infile) memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1")) setup_memory_limits(memory_limit_mb) _accumulatorRegistry.clear() # Receive the data source instance. data_source = read_command(pickleSer, infile) if not isinstance(data_source, DataSource): raise PySparkAssertionError( errorClass="DATA_SOURCE_TYPE_MISMATCH", messageParameters={ "expected": "a Python data source instance of type 'DataSource'", "actual": f"'{type(data_source).__name__}'", }, ) # Receive the data source output schema. schema_json = utf8_deserializer.loads(infile) schema = _parse_datatype_json_string(schema_json) if not isinstance(schema, StructType): raise PySparkAssertionError( errorClass="DATA_SOURCE_TYPE_MISMATCH", messageParameters={ "expected": "an output schema of type 'StructType'", "actual": f"'{type(schema).__name__}'", }, ) max_arrow_batch_size = read_int(infile) assert max_arrow_batch_size > 0, ( "The maximum arrow batch size should be greater than 0, but got " f"'{max_arrow_batch_size}'" ) # Instantiate data source reader. try: reader = _streamReader(data_source, schema) # Initialization succeed. write_int(0, outfile) outfile.flush() # handle method call from socket while True: func_id = read_int(infile) if func_id == INITIAL_OFFSET_FUNC_ID: initial_offset_func(reader, outfile) elif func_id == LATEST_OFFSET_FUNC_ID: latest_offset_func(reader, outfile) elif func_id == PARTITIONS_FUNC_ID: partitions_func( reader, data_source, schema, max_arrow_batch_size, infile, outfile ) elif func_id == COMMIT_FUNC_ID: commit_func(reader, infile, outfile) elif func_id == CHECK_SUPPORTED_FEATURES_ID: check_support_func(reader, outfile) elif func_id == PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID: prepare_for_trigger_available_now_func(reader, outfile) elif func_id == LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID: latest_offset_admission_control_func(reader, infile, outfile) elif func_id == GET_DEFAULT_READ_LIMIT_FUNC_ID: get_default_read_limit_func(reader, outfile) elif func_id == REPORT_LATEST_OFFSET_FUNC_ID: report_latest_offset_func(reader, outfile) else: raise IllegalArgumentException( errorClass="UNSUPPORTED_OPERATION", messageParameters={ "operation": "Function call id not recognized by stream reader" }, ) outfile.flush() finally: reader.stop() except BaseException as e: handle_worker_exception(e, outfile) # ensure that the updates to the socket are flushed outfile.flush() sys.exit(-1) send_accumulator_updates(outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1) if __name__ == "__main__": with get_sock_file_to_executor(timeout=None) as sock_file: main(sock_file, sock_file)