# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import IO from pyspark.errors import PySparkAssertionError from pyspark.logger.worker_io import capture_outputs from pyspark.serializers import ( read_bool, read_int, write_int, ) from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage from pyspark.sql.worker.utils import worker_run from pyspark.worker_util import get_sock_file_to_executor, pickleSer def _main(infile: IO, outfile: IO) -> None: """ Main method for committing or aborting a data source write operation. This process is invoked from the `UserDefinedPythonDataSourceCommitRunner.runInPython` method in the BatchWrite implementation of the PythonTableProvider. It is responsible for invoking either the `commit` or the `abort` method on a data source writer instance, given a list of commit messages. """ # Receive the data source writer instance. writer = pickleSer._read_with_length(infile) assert isinstance(writer, DataSourceWriter) # Receive the commit messages. num_messages = read_int(infile) commit_messages = [] for _ in range(num_messages): message = pickleSer._read_with_length(infile) if message is not None and not isinstance(message, WriterCommitMessage): raise PySparkAssertionError( errorClass="DATA_SOURCE_TYPE_MISMATCH", messageParameters={ "expected": "an instance of WriterCommitMessage", "actual": f"'{type(message).__name__}'", }, ) commit_messages.append(message) # Receive a boolean to indicate whether to invoke `abort`. abort = read_bool(infile) with capture_outputs(): # Commit or abort the Python data source write. # Note the commit messages can be None if there are failed tasks. if abort: writer.abort(commit_messages) else: writer.commit(commit_messages) # Send a status code back to JVM. write_int(0, outfile) def main(infile: IO, outfile: IO) -> None: worker_run(_main, infile, outfile) if __name__ == "__main__": with get_sock_file_to_executor() as sock_file: main(sock_file, sock_file)