# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import contextlib import io import os import tempfile import unittest import logging import json from dataclasses import dataclass from datetime import datetime from decimal import Decimal from typing import Callable, Iterable, List, Union, Iterator, Tuple from pyspark.errors import AnalysisException, PythonException from pyspark.memory_profiler_ext import has_memory_profiler from pyspark.sql.datasource import ( CaseInsensitiveDict, DataSource, DataSourceArrowWriter, DataSourceReader, DataSourceWriter, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, InputPartition, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, StringContains, StringEndsWith, StringStartsWith, WriterCommitMessage, ) from pyspark.sql.functions import spark_partition_id from pyspark.sql.session import SparkSession from pyspark.sql.types import Row, StructField, StructType, IntegerType, DecimalType, VariantVal from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ( SPARK_HOME, ReusedSQLTestCase, ) from pyspark.testing.utils import ( have_pyarrow, pyarrow_requirement_message, ) from pyspark.util import is_remote_only @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) class BasePythonDataSourceTestsMixin: spark: SparkSession def test_basic_data_source_class(self): class MyDataSource(DataSource): ... options = dict(a=1, b=2) ds = MyDataSource(options=options) self.assertEqual(ds.options, options) self.assertEqual(ds.name(), "MyDataSource") with self.assertRaises(NotImplementedError): ds.schema() with self.assertRaises(NotImplementedError): ds.reader(None) with self.assertRaises(NotImplementedError): ds.writer(None, None) def test_basic_data_source_reader_class(self): class MyDataSourceReader(DataSourceReader): def read(self, partition): yield (None,) reader = MyDataSourceReader() self.assertEqual(list(reader.read(None)), [(None,)]) def test_data_source_register(self): class TestReader(DataSourceReader): def read(self, partition): yield ( 0, 1, VariantVal.parseJson('{"c":1}'), {"v": VariantVal.parseJson('{"d":2}')}, [VariantVal.parseJson('{"e":3}')], {"v1": VariantVal.parseJson('{"f":4}'), "v2": VariantVal.parseJson('{"g":5}')}, ) class TestDataSource(DataSource): def schema(self): return ( "a INT, b INT, c VARIANT, d STRUCT, e ARRAY," "f MAP" ) def reader(self, schema): return TestReader() self.spark.dataSource.register(TestDataSource) df = self.spark.read.format("TestDataSource").load() assertDataFrameEqual( df.selectExpr( "a", "b", "to_json(c) c", "to_json(d.v) d", "to_json(e[0]) e", "to_json(f['v2']) f" ), [Row(a=0, b=1, c='{"c":1}', d='{"d":2}', e='{"e":3}', f='{"g":5}')], ) class MyDataSource(TestDataSource): @classmethod def name(cls): return "TestDataSource" def schema(self): return ( "c INT, d INT, e VARIANT, f STRUCT, g ARRAY," "h MAP" ) # Should be able to register the data source with the same name. self.spark.dataSource.register(MyDataSource) df = self.spark.read.format("TestDataSource").load() assertDataFrameEqual( df.selectExpr( "c", "d", "to_json(e) e", "to_json(f.v) f", "to_json(g[0]) g", "to_json(h['v2']) h" ), [Row(c=0, d=1, e='{"c":1}', f='{"d":2}', g='{"e":3}', h='{"g":5}')], ) def register_data_source( self, read_func: Callable, partition_func: Callable = None, output: Union[str, StructType] = "i int, j int", name: str = "test", ): class TestDataSourceReader(DataSourceReader): def __init__(self, schema): self.schema = schema def partitions(self): if partition_func is not None: return partition_func() else: return [InputPartition(None)] def read(self, partition): return read_func(self.schema, partition) class TestDataSource(DataSource): @classmethod def name(cls): return name def schema(self): return output def reader(self, schema) -> "DataSourceReader": return TestDataSourceReader(schema) self.spark.dataSource.register(TestDataSource) def test_data_source_read_output_tuple(self): self.register_data_source(read_func=lambda schema, partition: iter([(0, 1)])) df = self.spark.read.format("test").load() assertDataFrameEqual(df, [Row(0, 1)]) def test_data_source_read_output_list(self): self.register_data_source(read_func=lambda schema, partition: iter([[0, 1]])) df = self.spark.read.format("test").load() assertDataFrameEqual(df, [Row(0, 1)]) def test_data_source_read_output_row(self): self.register_data_source(read_func=lambda schema, partition: iter([Row(0, 1)])) df = self.spark.read.format("test").load() assertDataFrameEqual(df, [Row(0, 1)]) def test_data_source_read_output_named_row(self): self.register_data_source( read_func=lambda schema, partition: iter([Row(j=1, i=0), Row(i=1, j=2)]) ) df = self.spark.read.format("test").load() assertDataFrameEqual(df, [Row(0, 1), Row(1, 2)]) def test_data_source_read_output_named_row_with_wrong_schema(self): self.register_data_source( read_func=lambda schema, partition: iter([Row(i=1, j=2), Row(j=3, k=4)]) ) with self.assertRaisesRegex( PythonException, r"\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\] Return schema mismatch in the result", ): self.spark.read.format("test").load().show() def test_data_source_read_output_none(self): self.register_data_source(read_func=lambda schema, partition: None) df = self.spark.read.format("test").load() with self.assertRaisesRegex(PythonException, "DATA_SOURCE_INVALID_RETURN_TYPE"): assertDataFrameEqual(df, []) def test_data_source_read_output_empty_iter(self): self.register_data_source(read_func=lambda schema, partition: iter([])) df = self.spark.read.format("test").load() assertDataFrameEqual(df, []) def test_data_source_read_cast_output_schema(self): self.register_data_source( read_func=lambda schema, partition: iter([(0, 1)]), output="i long, j string" ) df = self.spark.read.format("test").load() assertDataFrameEqual(df, [Row(i=0, j="1")]) def test_data_source_read_output_with_partition(self): def partition_func(): return [InputPartition(0), InputPartition(1)] def read_func(schema, partition): if partition.value == 0: return iter([]) elif partition.value == 1: yield (0, 1) self.register_data_source(read_func=read_func, partition_func=partition_func) df = self.spark.read.format("test").load() assertDataFrameEqual(df, [Row(0, 1)]) def test_data_source_read_output_with_schema_mismatch(self): self.register_data_source(read_func=lambda schema, partition: iter([(0, 1)])) df = self.spark.read.format("test").schema("i int").load() with self.assertRaisesRegex(PythonException, "DATA_SOURCE_RETURN_SCHEMA_MISMATCH"): df.collect() self.register_data_source( read_func=lambda schema, partition: iter([(0, 1)]), output="i int, j int, k int" ) with self.assertRaisesRegex(PythonException, "DATA_SOURCE_RETURN_SCHEMA_MISMATCH"): df.collect() def test_read_with_invalid_return_row_type(self): self.register_data_source(read_func=lambda schema, partition: iter([1])) df = self.spark.read.format("test").load() with self.assertRaisesRegex(PythonException, "DATA_SOURCE_INVALID_RETURN_TYPE"): df.collect() def test_in_memory_data_source(self): class InMemDataSourceReader(DataSourceReader): DEFAULT_NUM_PARTITIONS: int = 3 def __init__(self, options): self.options = options def partitions(self): if "num_partitions" in self.options: num_partitions = int(self.options["num_partitions"]) else: num_partitions = self.DEFAULT_NUM_PARTITIONS return [InputPartition(i) for i in range(num_partitions)] def read(self, partition): yield partition.value, str(partition.value) class InMemoryDataSource(DataSource): @classmethod def name(cls): return "memory" def schema(self): return "x INT, y STRING" def reader(self, schema) -> "DataSourceReader": return InMemDataSourceReader(self.options) self.spark.dataSource.register(InMemoryDataSource) df = self.spark.read.format("memory").load() self.assertEqual(df.select(spark_partition_id()).distinct().count(), 3) assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1"), Row(x=2, y="2")]) df = self.spark.read.format("memory").option("num_partitions", 2).load() assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")]) self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2) def test_filter_pushdown(self): class TestDataSourceReader(DataSourceReader): def __init__(self): self.has_filter = False def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: assert set(filters) == { IsNotNull(("x",)), IsNotNull(("y",)), EqualTo(("x",), 1), EqualTo(("y",), 2), }, filters self.has_filter = True # pretend we support x = 1 filter but in fact we don't # so we only return y = 2 filter yield filters[filters.index(EqualTo(("y",), 2))] def partitions(self): assert self.has_filter return super().partitions() def read(self, partition): assert self.has_filter yield [1, 1] yield [1, 2] yield [2, 2] class TestDataSource(DataSource): @classmethod def name(cls): return "test" def schema(self): return "x int, y int" def reader(self, schema) -> "DataSourceReader": return TestDataSourceReader() with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): self.spark.dataSource.register(TestDataSource) df = self.spark.read.format("test").load().filter("x = 1 and y = 2") # only the y = 2 filter is applied post scan assertDataFrameEqual(df, [Row(x=1, y=2), Row(x=2, y=2)]) def test_extraneous_filter(self): class TestDataSourceReader(DataSourceReader): def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: yield EqualTo(("x",), 1) def partitions(self): assert False def read(self, partition): assert False class TestDataSource(DataSource): @classmethod def name(cls): return "test" def schema(self): return "x int" def reader(self, schema) -> "DataSourceReader": return TestDataSourceReader() with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): self.spark.dataSource.register(TestDataSource) with self.assertRaisesRegex(Exception, "DATA_SOURCE_EXTRANEOUS_FILTERS"): self.spark.read.format("test").load().filter("x = 1").show() def test_filter_pushdown_error(self): error_str = "dummy error" class TestDataSourceReader(DataSourceReader): def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: raise Exception(error_str) def read(self, partition): yield [1] class TestDataSource(DataSource): def schema(self): return "x int" def reader(self, schema) -> "DataSourceReader": return TestDataSourceReader() with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): self.spark.dataSource.register(TestDataSource) df = self.spark.read.format("TestDataSource").load().filter("x = 1 or x is null") assertDataFrameEqual(df, [Row(x=1)]) # works when not pushing down filters with self.assertRaisesRegex(Exception, error_str): df.filter("x = 1").explain() def test_filter_pushdown_disabled(self): class TestDataSourceReader(DataSourceReader): def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: assert False def read(self, partition): assert False class TestDataSource(DataSource): def reader(self, schema) -> "DataSourceReader": return TestDataSourceReader() with self.sql_conf({"spark.sql.python.filterPushdown.enabled": False}): self.spark.dataSource.register(TestDataSource) df = self.spark.read.format("TestDataSource").schema("x int").load() with self.assertRaisesRegex(Exception, "DATA_SOURCE_PUSHDOWN_DISABLED"): df.show() def _check_filters(self, sql_type, sql_filter, python_filters): """ Parameters ---------- sql_type: str The SQL type of the column x. sql_filter: str A SQL filter using the column x. python_filters: List[Filter] The expected python filters to be pushed down. """ class TestDataSourceReader(DataSourceReader): def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: actual = [f for f in filters if not isinstance(f, IsNotNull)] expected = python_filters assert actual == expected, (actual, expected) return filters def read(self, partition): yield from [] class TestDataSource(DataSource): def schema(self): return f"x {sql_type}" def reader(self, schema) -> "DataSourceReader": return TestDataSourceReader() with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): self.spark.dataSource.register(TestDataSource) df = self.spark.read.format("TestDataSource").load().filter(sql_filter) df.count() def test_unsupported_filter(self): self._check_filters( "struct", "x.a = 1 and x.b = x.c", [EqualTo(("x", "a"), 1)] ) self._check_filters("int", "x = 1 or x > 2", []) self._check_filters("int", "(0 < x and x < 1) or x = 2", []) self._check_filters("int", "x % 5 = 1", []) self._check_filters("array", "x[0] = 1", []) self._check_filters("string", "x like 'a%a%'", []) self._check_filters("string", "x ilike 'a'", []) self._check_filters("string", "x = 'a' collate zh", []) def test_filter_value_type(self): self._check_filters("int", "x = 1", [EqualTo(("x",), 1)]) self._check_filters("int", "x = null", [EqualTo(("x",), None)]) self._check_filters("float", "x = 3 / 2", [EqualTo(("x",), 1.5)]) self._check_filters("string", "x = '1'", [EqualTo(("x",), "1")]) self._check_filters("array", "x = array(1, 2)", [EqualTo(("x",), [1, 2])]) self._check_filters( "struct", "x = named_struct('x', 1)", [EqualTo(("x",), {"x": 1})] ) self._check_filters( "decimal", "x in (1.1, 2.1)", [In(("x",), [Decimal(1.1), Decimal(2.1)])] ) self._check_filters( "timestamp_ntz", "x = timestamp_ntz '2020-01-01 00:00:00'", [EqualTo(("x",), datetime.strptime("2020-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"))], ) self._check_filters( "interval second", "x = interval '2' second", [], # intervals are not supported ) def test_filter_type(self): self._check_filters("boolean", "x", [EqualTo(("x",), True)]) self._check_filters("boolean", "not x", [Not(EqualTo(("x",), True))]) self._check_filters("int", "x is null", [IsNull(("x",))]) self._check_filters("int", "x <> 0", [Not(EqualTo(("x",), 0))]) self._check_filters("int", "x <=> 1", [EqualNullSafe(("x",), 1)]) self._check_filters("int", "1 < x", [GreaterThan(("x",), 1)]) self._check_filters("int", "1 <= x", [GreaterThanOrEqual(("x",), 1)]) self._check_filters("int", "x < 1", [LessThan(("x",), 1)]) self._check_filters("int", "x <= 1", [LessThanOrEqual(("x",), 1)]) self._check_filters("string", "x like 'a%'", [StringStartsWith(("x",), "a")]) self._check_filters("string", "x like '%a'", [StringEndsWith(("x",), "a")]) self._check_filters("string", "x like '%a%'", [StringContains(("x",), "a")]) self._check_filters( "string", "x like 'a%b'", [StringStartsWith(("x",), "a"), StringEndsWith(("x",), "b")] ) self._check_filters("int", "x in (1, 2)", [In(("x",), [1, 2])]) def test_filter_nested_column(self): self._check_filters("struct", "x.y = 1", [EqualTo(("x", "y"), 1)]) def _get_test_json_data_source(self): import json import os from dataclasses import dataclass class TestJsonReader(DataSourceReader): def __init__(self, options): self.options = options def read(self, partition): path = self.options.get("path") if path is None: raise Exception("path is not specified") with open(path, "r") as file: for line in file.readlines(): if line.strip(): data = json.loads(line) yield data.get("name"), data.get("age") @dataclass class TestCommitMessage(WriterCommitMessage): count: int class TestJsonWriter(DataSourceWriter): def __init__(self, options): self.options = options self.path = self.options.get("path") def write(self, iterator): from pyspark import TaskContext context = TaskContext.get() output_path = os.path.join(self.path, f"{context.partitionId()}.json") count = 0 with open(output_path, "w") as file: for row in iterator: count += 1 if "id" in row and row.id > 5: raise Exception("id > 5") file.write(json.dumps(row.asDict()) + "\n") return TestCommitMessage(count=count) def commit(self, messages): total_count = sum(message.count for message in messages) with open(os.path.join(self.path, "_success.txt"), "a") as file: file.write(f"count: {total_count}\n") def abort(self, messages): with open(os.path.join(self.path, "_failed.txt"), "a") as file: file.write("failed") class TestJsonDataSource(DataSource): @classmethod def name(cls): return "my-json" def schema(self): return "name STRING, age INT" def reader(self, schema) -> "DataSourceReader": return TestJsonReader(self.options) def writer(self, schema, overwrite): return TestJsonWriter(self.options) return TestJsonDataSource def test_custom_json_data_source_read(self): data_source = self._get_test_json_data_source() self.spark.dataSource.register(data_source) path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json") assertDataFrameEqual( self.spark.read.format("my-json").load(path1), [Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)], ) assertDataFrameEqual( self.spark.read.format("my-json").load(path2), [Row(name="Jonathan", age=None)], ) def test_custom_json_data_source_write(self): data_source = self._get_test_json_data_source() self.spark.dataSource.register(data_source) input_path = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") df = self.spark.read.json(input_path) with tempfile.TemporaryDirectory(prefix="test_custom_json_data_source_write") as d: df.write.format("my-json").mode("append").save(d) assertDataFrameEqual(self.spark.read.json(d), self.spark.read.json(input_path)) def test_custom_json_data_source_commit(self): data_source = self._get_test_json_data_source() self.spark.dataSource.register(data_source) with tempfile.TemporaryDirectory(prefix="test_custom_json_data_source_commit") as d: self.spark.range(0, 5, 1, 3).write.format("my-json").mode("append").save(d) with open(os.path.join(d, "_success.txt"), "r") as file: text = file.read() assert text == "count: 5\n" def test_custom_json_data_source_abort(self): data_source = self._get_test_json_data_source() self.spark.dataSource.register(data_source) with tempfile.TemporaryDirectory(prefix="test_custom_json_data_source_abort") as d: with self.assertRaises(PythonException): self.spark.range(0, 8, 1, 3).write.format("my-json").mode("append").save(d) with open(os.path.join(d, "_failed.txt"), "r") as file: text = file.read() assert text == "failed" def test_case_insensitive_dict(self): d = CaseInsensitiveDict({"foo": 1, "Bar": 2}) self.assertEqual(d["foo"], d["FOO"]) self.assertEqual(d["bar"], d["BAR"]) self.assertTrue("baR" in d) d["BAR"] = 3 self.assertEqual(d["BAR"], 3) # Test update d.update({"BaZ": 3}) self.assertEqual(d["BAZ"], 3) d.update({"FOO": 4}) self.assertEqual(d["foo"], 4) # Test delete del d["FoO"] self.assertFalse("FOO" in d) # Test copy d2 = d.copy() self.assertEqual(d2["BaR"], 3) self.assertEqual(d2["baz"], 3) def test_arrow_batch_data_source(self): import pyarrow as pa class ArrowBatchDataSource(DataSource): """ A data source for testing Arrow Batch Serialization """ @classmethod def name(cls): return "arrowbatch" def schema(self): return "key int, value string" def reader(self, schema: str): return ArrowBatchDataSourceReader(schema, self.options) class ArrowBatchDataSourceReader(DataSourceReader): def __init__(self, schema, options): self.schema: str = schema self.options = options def read(self, partition): # Create Arrow Record Batch keys = pa.array([1, 2, 3, 4, 5], type=pa.int32()) values = pa.array(["one", "two", "three", "four", "five"], type=pa.string()) schema = pa.schema([("key", pa.int32()), ("value", pa.string())]) record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema) yield record_batch def partitions(self): # hardcoded number of partitions num_part = 1 return [InputPartition(i) for i in range(num_part)] self.spark.dataSource.register(ArrowBatchDataSource) df = self.spark.read.format("arrowbatch").load() expected_data = [ Row(key=1, value="one"), Row(key=2, value="two"), Row(key=3, value="three"), Row(key=4, value="four"), Row(key=5, value="five"), ] assertDataFrameEqual(df, expected_data) with self.assertRaisesRegex( PythonException, "PySparkRuntimeError: \\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\\] Return schema" " mismatch in the result from 'read' method\\. Expected: 1 columns, Found: 2 columns", ): self.spark.read.format("arrowbatch").schema("dummy int").load().show() with self.assertRaisesRegex( PythonException, "PySparkRuntimeError: \\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\\] Return schema mismatch" " in the result from 'read' method\\. Expected: \\['key', 'dummy'\\] columns, Found:" " \\['key', 'value'\\] columns", ): self.spark.read.format("arrowbatch").schema("key int, dummy string").load().show() # SPARK-55583: Validate Arrow schema types, not just column count/names. # The data source declares "key string, value string" but read() yields # a RecordBatch with (int32, string), causing an Arrow schema type mismatch. class MismatchedTypeDataSource(DataSource): @classmethod def name(cls): return "arrowbatch_type_mismatch" def schema(self): return "key string, value string" def reader(self, schema: str): return MismatchedTypeReader() class MismatchedTypeReader(DataSourceReader): def read(self, partition): keys = pa.array([1, 2], type=pa.int32()) values = pa.array(["a", "b"], type=pa.string()) batch = pa.RecordBatch.from_arrays( [keys, values], schema=pa.schema([("key", pa.int32()), ("value", pa.string())]), ) yield batch def partitions(self): return [InputPartition(0)] self.spark.dataSource.register(MismatchedTypeDataSource) with self.assertRaisesRegex( PythonException, "DATA_SOURCE_RETURN_SCHEMA_MISMATCH", ): self.spark.read.format("arrowbatch_type_mismatch").load().show() def test_arrow_batch_sink(self): class TestDataSource(DataSource): @classmethod def name(cls): return "arrow_sink" def writer(self, schema, overwrite): return TestArrowWriter(self.options["path"]) class TestArrowWriter(DataSourceArrowWriter): def __init__(self, path): self.path = path def write(self, iterator): from pyspark import TaskContext context = TaskContext.get() output_path = os.path.join(self.path, f"{context.partitionId()}.json") with open(output_path, "w") as file: for batch in iterator: df = batch.to_pandas() df.to_json(file, orient="records", lines=True) return WriterCommitMessage() self.spark.dataSource.register(TestDataSource) df = self.spark.range(3) with tempfile.TemporaryDirectory(prefix="test_arrow_batch_sink") as d: df.write.format("arrow_sink").mode("append").save(d) df2 = self.spark.read.format("json").load(d) assertDataFrameEqual(df2, df) def test_data_source_type_mismatch(self): class TestDataSource(DataSource): @classmethod def name(cls): return "test" def schema(self): return "id int" def reader(self, schema): return TestReader() def writer(self, schema, overwrite): return TestWriter() class TestReader: def partitions(self): return [] def read(self, partition): yield (0,) class TestWriter: def write(self, iterator): return WriterCommitMessage() self.spark.dataSource.register(TestDataSource) with self.assertRaisesRegex( AnalysisException, r"\[DATA_SOURCE_TYPE_MISMATCH\] Expected an instance of DataSourceReader", ): self.spark.read.format("test").load().show() df = self.spark.range(10) with self.assertRaisesRegex( AnalysisException, r"\[DATA_SOURCE_TYPE_MISMATCH\] Expected an instance of DataSourceWriter", ): df.write.format("test").mode("append").saveAsTable("test_table") def test_decimal_round(self): class SimpleDataSource(DataSource): @classmethod def name(cls) -> str: return "simple_decimal" def schema(self) -> StructType: return StructType( [ StructField("i", IntegerType()), StructField("d", DecimalType(38, 18)), ] ) def reader(self, schema: StructType) -> DataSourceReader: return SimpleDataSourceReader() class SimpleDataSourceReader(DataSourceReader): def read(self, partition: InputPartition) -> Iterator[Tuple]: yield (1, Decimal(1.234)) self.spark.dataSource.register(SimpleDataSource) df = self.spark.read.format("simple_decimal").load() rounded = df.select("d").first().d self.assertEqual(rounded, Decimal("1.233999999999999986")) def test_data_source_segfault(self): import ctypes for enabled, expected in [ (True, "Segmentation fault"), (False, "Consider setting .* for the better Python traceback."), ]: with ( self.subTest(enabled=enabled), self.sql_conf({"spark.sql.execution.pyspark.udf.faulthandler.enabled": enabled}), ): with self.subTest(worker="pyspark.sql.worker.create_data_source"): class TestDataSource(DataSource): @classmethod def name(cls): return "test" def schema(self): return ctypes.string_at(0) self.spark.dataSource.register(TestDataSource) with self.assertRaisesRegex(Exception, expected): self.spark.read.format("test").load().show() with self.subTest(worker="pyspark.sql.worker.plan_data_source_read"): class TestDataSource(DataSource): @classmethod def name(cls): return "test" def schema(self): return "x string" def reader(self, schema): return TestReader() class TestReader(DataSourceReader): def partitions(self): ctypes.string_at(0) return [] def read(self, partition): return [] self.spark.dataSource.register(TestDataSource) with self.assertRaisesRegex(Exception, expected): self.spark.read.format("test").load().show() with self.subTest(worker="pyspark.worker"): class TestDataSource(DataSource): @classmethod def name(cls): return "test" def schema(self): return "x string" def reader(self, schema): return TestReader2() class TestReader2(DataSourceReader): def read(self, partition): ctypes.string_at(0) yield ("x",) self.spark.dataSource.register(TestDataSource) with self.assertRaisesRegex(Exception, expected): self.spark.read.format("test").load().show() with self.subTest(worker="pyspark.sql.worker.write_into_data_source"): class TestDataSource(DataSource): @classmethod def name(cls): return "test" def writer(self, schema, overwrite): return TestWriter() class TestWriter(DataSourceWriter): def write(self, iterator): ctypes.string_at(0) return WriterCommitMessage() self.spark.dataSource.register(TestDataSource) with tempfile.TemporaryDirectory(prefix="test_segfault_") as d: with self.assertRaisesRegex(Exception, expected): self.spark.range(10).write.format("test").mode("append").save(d) with self.subTest(worker="pyspark.sql.worker.commit_data_source_write"): class TestDataSource(DataSource): @classmethod def name(cls): return "test" def writer(self, schema, overwrite): return TestWriter2() class TestWriter2(DataSourceWriter): def write(self, iterator): return WriterCommitMessage() def commit(self, messages): ctypes.string_at(0) self.spark.dataSource.register(TestDataSource) with tempfile.TemporaryDirectory(prefix="test_segfault_") as d: with self.assertRaisesRegex(Exception, expected): self.spark.range(10).write.format("test").mode("append").save(d) @unittest.skipIf(is_remote_only(), "Requires JVM access") def test_data_source_reader_with_logging(self): logger = logging.getLogger("test_data_source_reader") class TestJsonReader(DataSourceReader): def __init__(self, options): logger.warning(f"TestJsonReader.__init__: {list(options)}") self.options = options def partitions(self): logger.warning("TestJsonReader.partitions") return super().partitions() def read(self, partition): logger.warning(f"TestJsonReader.read: {partition}") path = self.options.get("path") if path is None: raise Exception("path is not specified") with open(path, "r") as file: for line in file.readlines(): if line.strip(): data = json.loads(line) yield data.get("name"), data.get("age") class TestJsonDataSource(DataSource): def __init__(self, options): super().__init__(options) logger.warning(f"TestJsonDataSource.__init__: {list(options)}") @classmethod def name(cls): logger.warning("TestJsonDataSource.name") return "my-json" def schema(self): logger.warning("TestJsonDataSource.schema") return "name STRING, age INT" def reader(self, schema) -> "DataSourceReader": logger.warning(f"TestJsonDataSource.reader: {schema.fieldNames()}") return TestJsonReader(self.options) self.spark.dataSource.register(TestJsonDataSource) path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): assertDataFrameEqual( self.spark.read.format("my-json").load(path1), [ Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19), ], ) logs = self.spark.tvf.python_worker_logs() assertDataFrameEqual( logs.select("level", "msg", "context", "logger"), [ Row( level="WARNING", msg=msg, context=context, logger="test_data_source_reader", ) for msg, context in [ ( "TestJsonDataSource.__init__: ['path']", {"class_name": "TestJsonDataSource", "func_name": "__init__"}, ), ( "TestJsonDataSource.name", {"class_name": "TestJsonDataSource", "func_name": "name"}, ), ( "TestJsonDataSource.schema", {"class_name": "TestJsonDataSource", "func_name": "schema"}, ), ( "TestJsonDataSource.reader: ['name', 'age']", {"class_name": "TestJsonDataSource", "func_name": "reader"}, ), ( "TestJsonReader.__init__: ['path']", {"class_name": "TestJsonDataSource", "func_name": "reader"}, ), ( "TestJsonReader.partitions", {"class_name": "TestJsonReader", "func_name": "partitions"}, ), ( "TestJsonReader.read: InputPartition(value=None)", {"class_name": "TestJsonReader", "func_name": "read"}, ), ] ], ) @unittest.skipIf(is_remote_only(), "Requires JVM access") def test_data_source_reader_pushdown_with_logging(self): logger = logging.getLogger("test_data_source_reader_pushdown") class TestJsonReader(DataSourceReader): def __init__(self, options): logger.warning(f"TestJsonReader.__init__: {list(options)}") self.options = options def pushFilters(self, filters): logger.warning(f"TestJsonReader.pushFilters: {filters}") return super().pushFilters(filters) def partitions(self): logger.warning("TestJsonReader.partitions") return super().partitions() def read(self, partition): logger.warning(f"TestJsonReader.read: {partition}") path = self.options.get("path") if path is None: raise Exception("path is not specified") with open(path, "r") as file: for line in file.readlines(): if line.strip(): data = json.loads(line) yield data.get("name"), data.get("age") class TestJsonDataSource(DataSource): def __init__(self, options): super().__init__(options) logger.warning(f"TestJsonDataSource.__init__: {list(options)}") @classmethod def name(cls): logger.warning("TestJsonDataSource.name") return "my-json" def schema(self): logger.warning("TestJsonDataSource.schema") return "name STRING, age INT" def reader(self, schema) -> "DataSourceReader": logger.warning(f"TestJsonDataSource.reader: {schema.fieldNames()}") return TestJsonReader(self.options) self.spark.dataSource.register(TestJsonDataSource) path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") with self.sql_conf( { "spark.sql.python.filterPushdown.enabled": "true", "spark.sql.pyspark.worker.logging.enabled": "true", } ): assertDataFrameEqual( self.spark.read.format("my-json").load(path1).filter("age is not null"), [ Row(name="Andy", age=30), Row(name="Justin", age=19), ], ) logs = self.spark.tvf.python_worker_logs() assertDataFrameEqual( logs.select("level", "msg", "context", "logger"), [ Row( level="WARNING", msg=msg, context=context, logger="test_data_source_reader_pushdown", ) for msg, context in [ ( "TestJsonDataSource.__init__: ['path']", {"class_name": "TestJsonDataSource", "func_name": "__init__"}, ), ( "TestJsonDataSource.name", {"class_name": "TestJsonDataSource", "func_name": "name"}, ), ( "TestJsonDataSource.schema", {"class_name": "TestJsonDataSource", "func_name": "schema"}, ), ( "TestJsonDataSource.reader: ['name', 'age']", {"class_name": "TestJsonDataSource", "func_name": "reader"}, ), ( "TestJsonReader.pushFilters: [IsNotNull(attribute=('age',))]", {"class_name": "TestJsonReader", "func_name": "pushFilters"}, ), ( "TestJsonReader.__init__: ['path']", {"class_name": "TestJsonDataSource", "func_name": "reader"}, ), ( "TestJsonReader.partitions", {"class_name": "TestJsonReader", "func_name": "partitions"}, ), ( "TestJsonReader.read: InputPartition(value=None)", {"class_name": "TestJsonReader", "func_name": "read"}, ), ] ], ) @unittest.skipIf(is_remote_only(), "Requires JVM access") def test_data_source_writer_with_logging(self): logger = logging.getLogger("test_datasource_writer") @dataclass class TestCommitMessage(WriterCommitMessage): count: int class TestJsonWriter(DataSourceWriter): def __init__(self, options): logger.warning(f"TestJsonWriter.__init__: {list(options)}") self.options = options self.path = self.options.get("path") def write(self, iterator): from pyspark import TaskContext if self.options.get("abort", None): logger.warning("TestJsonWriter.write: abort test") raise Exception("abort test") context = TaskContext.get() output_path = os.path.join(self.path, f"{context.partitionId()}.json") count = 0 rows = [] with open(output_path, "w") as file: for row in iterator: count += 1 rows.append(row.asDict()) file.write(json.dumps(row.asDict()) + "\n") logger.warning(f"TestJsonWriter.write: {count}, {rows}") return TestCommitMessage(count=count) def commit(self, messages): total_count = sum(message.count for message in messages) with open(os.path.join(self.path, "_success.txt"), "w") as file: file.write(f"count: {total_count}\n") logger.warning(f"TestJsonWriter.commit: {total_count}") def abort(self, messages): with open(os.path.join(self.path, "_failed.txt"), "w") as file: file.write("failed") logger.warning("TestJsonWriter.abort") class TestJsonDataSource(DataSource): @classmethod def name(cls): logger.warning("TestJsonDataSource.name") return "my-json" def writer(self, schema, overwrite): logger.warning(f"TestJsonDataSource.writer: {schema.fieldNames(), {overwrite}}") return TestJsonWriter(self.options) # Register the data source self.spark.dataSource.register(TestJsonDataSource) with tempfile.TemporaryDirectory(prefix="test_datasource_write_logging") as d: with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): # Create a simple DataFrame and write it using our custom datasource df = self.spark.createDataFrame( [("Charlie", 35), ("Diana", 28)], "name STRING, age INT" ).repartitionByRange(2, "age") df.write.format("my-json").mode("overwrite").save(d) # Verify the write worked by checking the success file with open(os.path.join(d, "_success.txt"), "r") as file: text = file.read() self.assertEqual(text, "count: 2\n") with self.assertRaises(Exception, msg="abort test"): df.write.format("my-json").mode("append").option("abort", "true").save(d) logs = self.spark.tvf.python_worker_logs() # We could get either 1 or 2 "TestJsonWriter.write: abort test" logs because # the operation is time sensitive. When the first partition gets aborted, # the executor will cancel the rest of the tasks. Whether we are able to get # the second log depends on whether the second partition starts before the # cancellation. When we use simple worker, the second log is often missing # because the spawn overhead is large. non_abort_logs = logs.select("level", "msg", "context", "logger").filter( "msg != 'TestJsonWriter.write: abort test'" ) abort_logs = logs.select("level", "msg", "context", "logger").filter( "msg == 'TestJsonWriter.write: abort test'" ) assertDataFrameEqual( non_abort_logs, [ Row( level="WARNING", msg=msg, context=context, logger="test_datasource_writer", ) for msg, context in [ ( "TestJsonDataSource.name", {"class_name": "TestJsonDataSource", "func_name": "name"}, ), ( "TestJsonDataSource.writer: (['name', 'age'], {True})", {"class_name": "TestJsonDataSource", "func_name": "writer"}, ), ( "TestJsonWriter.__init__: ['path']", {"class_name": "TestJsonDataSource", "func_name": "writer"}, ), ( "TestJsonWriter.write: 1, [{'name': 'Diana', 'age': 28}]", {"class_name": "TestJsonWriter", "func_name": "write"}, ), ( "TestJsonWriter.write: 1, [{'name': 'Charlie', 'age': 35}]", {"class_name": "TestJsonWriter", "func_name": "write"}, ), ( "TestJsonWriter.commit: 2", {"class_name": "TestJsonWriter", "func_name": "commit"}, ), ( "TestJsonDataSource.name", {"class_name": "TestJsonDataSource", "func_name": "name"}, ), ( "TestJsonDataSource.writer: (['name', 'age'], {False})", {"class_name": "TestJsonDataSource", "func_name": "writer"}, ), ( "TestJsonWriter.__init__: ['abort', 'path']", {"class_name": "TestJsonDataSource", "func_name": "writer"}, ), ( "TestJsonWriter.abort", {"class_name": "TestJsonWriter", "func_name": "abort"}, ), ] ], ) assertDataFrameEqual( abort_logs.dropDuplicates(["msg"]), [ Row( level="WARNING", msg="TestJsonWriter.write: abort test", context={"class_name": "TestJsonWriter", "func_name": "write"}, logger="test_datasource_writer", ) ], ) def test_data_source_perf_profiler(self): with self.sql_conf({"spark.sql.pyspark.dataSource.profiler": "perf"}): self.test_custom_json_data_source_read() with contextlib.redirect_stdout(io.StringIO()) as stdout_io: self.spark.profile.show(type="perf") self.spark.profile.clear() stdout = stdout_io.getvalue() self.assertIn("Profile of create_data_source", stdout) self.assertIn("Profile of plan_data_source_read", stdout) self.assertIn("ncalls", stdout) self.assertIn("tottime", stdout) # We should also found UDF profile results for data source read self.assertIn("UDF