# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Iterator, Optional import pyarrow as pa import pyspark.sql from pyspark.sql.types import StructType, StructField, BinaryType from pyspark.sql.pandas.types import to_arrow_schema def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pa.RecordBatch]: """Return all the partitions as Arrow arrays in an Iterator.""" # We will be using mapInArrow to convert each partition to Arrow RecordBatches. # The return type of the function will be a single binary column containing # the serialized RecordBatch in Arrow IPC format. binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(), nullable=False)]) def batch_to_bytes_iter(batch_iter: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: """ A generator function that converts RecordBatches to serialized Arrow IPC format. Spark sends each partition as an iterator of RecordBatches. In order to return the entire partition as a stream of Arrow RecordBatches, we need to serialize each RecordBatch to Arrow IPC format and yield it as a single binary blob. """ # The size of the batch can be controlled by the Spark config # `spark.sql.execution.arrow.maxRecordsPerBatch`. for arrow_batch in batch_iter: # We create an in-memory byte stream to hold the serialized batch sink = pa.BufferOutputStream() # Write the batch to the stream using Arrow IPC format with pa.ipc.new_stream(sink, arrow_batch.schema) as writer: writer.write_batch(arrow_batch) buf = sink.getvalue() # The second buffer contains the offsets we are manually creating. offset_buf = pa.array([0, len(buf)], type=pa.int32()).buffers()[1] null_bitmap = None # Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return # signature. This serializes the whole batch into a single pyarrow serialized cell. storage_arr = pa.Array.from_buffers( type=pa.binary(), length=1, buffers=[null_bitmap, offset_buf, buf] ) yield pa.RecordBatch.from_arrays([storage_arr], names=["arrow_ipc_bytes"]) # Convert all partitions to Arrow RecordBatches and map to binary blobs. byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema) # A row is actually a batch of data in Arrow IPC format. Fetch the batches one by one. for row in byte_df.toLocalIterator(): with pa.ipc.open_stream(row.arrow_ipc_bytes) as reader: for batch in reader: # Each batch corresponds to a chunk of data in the partition. yield batch class SparkArrowCStreamer: """ A class that implements that __arrow_c_stream__ protocol for Spark partitions. This class is implemented in a way that allows consumers to consume each partition one at a time without materializing all partitions at once on the driver side. """ def __init__(self, df: pyspark.sql.DataFrame): self._df = df self._schema = to_arrow_schema(df.schema) def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: """ Return the Arrow C stream for the dataframe partitions. """ reader: pa.RecordBatchReader = pa.RecordBatchReader.from_batches( self._schema, _get_arrow_array_partition_stream(self._df) ) return reader.__arrow_c_stream__(requested_schema=requested_schema)