# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import inspect from typing import IO from pyspark.errors import PySparkAssertionError, PySparkTypeError from pyspark.logger.worker_io import capture_outputs from pyspark.serializers import ( read_bool, read_int, write_int, write_with_length, ) from pyspark.sql.datasource import DataSource, CaseInsensitiveDict from pyspark.sql.types import _parse_datatype_json_string, StructType from pyspark.sql.worker.utils import worker_run from pyspark.worker_util import ( get_sock_file_to_executor, read_command, pickleSer, utf8_deserializer, ) def _main(infile: IO, outfile: IO) -> None: """ Main method for creating a Python data source instance. This process is invoked from the `UserDefinedPythonDataSourceRunner.runInPython` method in JVM. This process is responsible for creating a `DataSource` object and send the information needed back to the JVM. The JVM sends the following information to this process: - a `DataSource` class representing the data source to be created. - a provider name in string. - an optional user-specified schema in json string. - a dictionary of options in string. This process then creates a `DataSource` instance using the above information and sends the pickled instance as well as the schema back to the JVM. """ # Receive the data source class. data_source_cls = read_command(pickleSer, infile) if not (isinstance(data_source_cls, type) and issubclass(data_source_cls, DataSource)): raise PySparkAssertionError( errorClass="DATA_SOURCE_TYPE_MISMATCH", messageParameters={ "expected": "a subclass of DataSource", "actual": f"'{type(data_source_cls).__name__}'", }, ) # Check the name method is a class method. if not inspect.ismethod(data_source_cls.name): raise PySparkTypeError( errorClass="DATA_SOURCE_TYPE_MISMATCH", messageParameters={ "expected": "'name()' method to be a classmethod", "actual": f"'{type(data_source_cls.name).__name__}'", }, ) # Receive the provider name. provider = utf8_deserializer.loads(infile) with capture_outputs(): # Check if the provider name matches the data source's name. name = data_source_cls.name() if provider.lower() != name.lower(): raise PySparkAssertionError( errorClass="DATA_SOURCE_TYPE_MISMATCH", messageParameters={ "expected": f"provider with name {name}", "actual": f"'{provider}'", }, ) # Receive the user-specified schema user_specified_schema = None if read_bool(infile): user_specified_schema = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if not isinstance(user_specified_schema, StructType): raise PySparkAssertionError( errorClass="DATA_SOURCE_TYPE_MISMATCH", messageParameters={ "expected": "the user-defined schema to be a 'StructType'", "actual": f"'{type(data_source_cls).__name__}'", }, ) # Receive the options. options = CaseInsensitiveDict() num_options = read_int(infile) for _ in range(num_options): key = utf8_deserializer.loads(infile) value = utf8_deserializer.loads(infile) options[key] = value # Instantiate a data source. data_source = data_source_cls(options=options) # Get the schema of the data source. # If user_specified_schema is not None, use user_specified_schema. # Otherwise, use the schema of the data source. # Throw exception if the data source does not implement schema(). if user_specified_schema is None: schema = data_source.schema() else: assert isinstance(user_specified_schema, StructType) schema = user_specified_schema assert schema is not None # Return the pickled data source instance. pickleSer._write_with_length(data_source, outfile) # Return the schema of the data source. if isinstance(schema, str): # Here we cannot use _parse_datatype_string to parse the DDL string schema. # as it requires an active Spark session. write_int(1, outfile) write_with_length(schema.encode("utf-8"), outfile) else: write_int(0, outfile) write_with_length(schema.json().encode("utf-8"), outfile) def main(infile: IO, outfile: IO) -> None: worker_run(_main, infile, outfile) if __name__ == "__main__": with get_sock_file_to_executor() as sock_file: main(sock_file, sock_file)