# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import glob import math import os import shutil import tempfile from contextlib import contextmanager from pyspark.sql import SparkSession from pyspark.sql.types import Row from pyspark.testing.utils import ( ReusedPySparkTestCase, PySparkErrorTestUtils, ) from pyspark.find_spark_home import _find_spark_home SPARK_HOME = _find_spark_home() def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix, return_first=False): # Note that 'sbt_jar_name_prefix' and 'mvn_jar_name_prefix' are used since the prefix can # vary for SBT or Maven specifically. See also SPARK-26856 project_full_path = os.path.join(SPARK_HOME, project_relative_path) # We should ignore the following jars ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") # Search jar in the project dir using the jar name_prefix for both sbt build and maven # build because the artifact jars are in different directories. sbt_build = glob.glob( os.path.join(project_full_path, "target/scala-*/%s*.jar" % sbt_jar_name_prefix) ) maven_build = glob.glob(os.path.join(project_full_path, "target/%s*.jar" % mvn_jar_name_prefix)) jar_paths = sbt_build + maven_build jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)] if not jars: return None elif len(jars) > 1 and not return_first: raise RuntimeError("Found multiple JARs: %s; please remove all but one" % (", ".join(jars))) else: return jars[0] def get_sbt_runtime_classpath(project_relative_path, project_name_map): """ Get the runtime classpath for a project using SBT. This is used as a fallback when classpath.txt doesn't exist (SBT builds). Args: project_relative_path: Relative path from SPARK_HOME to the project directory project_name_map: Dict mapping project paths to SBT project names (e.g., {"connector/kafka-0-10-sql": "sql-kafka-0-10"}) Returns: Comma-separated string of JAR paths, or None if SBT command fails """ import subprocess import re sbt_project = project_name_map.get(project_relative_path) if not sbt_project: return None try: # Run SBT command to get runtime classpath sbt_cmd = os.path.join(SPARK_HOME, "build", "sbt") result = subprocess.run( [sbt_cmd, f"show {sbt_project}/Runtime/dependencyClasspath"], cwd=SPARK_HOME, capture_output=True, text=True, timeout=180, ) if result.returncode != 0: return None # Parse the output to extract JAR paths. # SBT's `show dependencyClasspath` emits lines like: # [info] * Attributed(/path/to/file.jar) # The Attributed(...) format is part of SBT's core attributed path API and has # been stable across all SBT versions used by Spark. If SBT changes this format, # update the regex below to match the new output. jar_paths = [] for line in result.stdout.splitlines(): match = re.search(r"Attributed\(([^)]+\.jar)\)", line) if match: jar_paths.append(match.group(1)) if jar_paths: return ",".join(jar_paths) return None except Exception: return None def read_classpath(project_relative_path, project_name_map=None): """ Read the classpath file for a project and return it as a comma-separated string. The classpath file is typically generated by Maven and contains colon-separated paths. For SBT builds (when classpath.txt doesn't exist), this function will query SBT for the runtime classpath. Args: project_relative_path: Relative path from SPARK_HOME to the project directory project_name_map: Optional dict mapping project paths to SBT project names. Required if classpath.txt doesn't exist and SBT query is needed. Returns: Comma-separated string of JAR paths Raises: RuntimeError: If the classpath cannot be determined Example: >>> classpath = read_classpath("connector/kafka-0-10-sql") >>> # Returns: "/path/to/jar1.jar,/path/to/jar2.jar,..." """ project_full_path = os.path.join(SPARK_HOME, project_relative_path) classpath_file = os.path.join(project_full_path, "target", "classpath.txt") # First try to read classpath.txt (Maven builds) if os.path.exists(classpath_file): with open(classpath_file, "r") as f: classpath = f.read().strip() # Replace colon with comma for spark-submit --jars format return classpath.replace(":", ",") # Fall back to querying SBT for runtime classpath (SBT builds) if project_name_map is not None: sbt_classpath = get_sbt_runtime_classpath(project_relative_path, project_name_map) if sbt_classpath: return sbt_classpath # If both methods fail, raise an error raise RuntimeError( f"Classpath file not found at {classpath_file} and could not query SBT. " f"Please build the project with Maven: " f"'build/mvn -pl {project_relative_path} dependency:build-classpath " f"-Dmdep.outputFile=target/classpath.txt' " f"or with SBT: 'build/sbt Test/package'" ) test_not_compiled_message = "" try: from pyspark.sql.utils import require_test_compiled require_test_compiled() except Exception as e: test_not_compiled_message = str(e) test_compiled = not test_not_compiled_message class SQLTestUtils: """ This util assumes the instance of this to have 'spark' attribute, having a spark session. It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the the implementation of this class has 'spark' attribute. """ @contextmanager def sql_conf(self, pairs): """ A convenient context manager to test some configuration specific logic. This sets `value` to the configuration `key` and then restores it back when it exits. """ assert isinstance(pairs, dict), "pairs should be a dictionary." assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." keys = pairs.keys() new_values = pairs.values() old_values = [self.spark.conf.get(key, None) for key in keys] for key, new_value in zip(keys, new_values): self.spark.conf.set(key, new_value) try: yield finally: for key, old_value in zip(keys, old_values): if old_value is None: self.spark.conf.unset(key) else: self.spark.conf.set(key, old_value) @contextmanager def database(self, *databases): """ A convenient context manager to test with some specific databases. This drops the given databases if it exists and sets current database to "default" when it exits. """ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." try: yield finally: for db in databases: self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db) self.spark.catalog.setCurrentDatabase("default") @contextmanager def table(self, *tables): """ A convenient context manager to test with some specific tables. This drops the given tables if it exists. """ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." try: yield finally: for t in tables: self.spark.sql("DROP TABLE IF EXISTS %s" % t) @contextmanager def view(self, *views): """ A convenient context manager for persistent (catalog) views. On exit, runs ``DROP VIEW IF EXISTS`` for each name. For temporary views, use :meth:`temp_view`. """ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." try: yield finally: for v in views: self.spark.sql("DROP VIEW IF EXISTS %s" % v) @contextmanager def temp_view(self, *views): """ A convenient context manager to test with some specific views. This drops the given views if it exists. """ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." try: yield finally: for v in views: self.spark.catalog.dropTempView(v) @contextmanager def temp_func(self, *functions): """ A convenient context manager to test with some specific temporary functions. This drops the temporary functions if it exists. """ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." try: yield finally: for f in functions: self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS %s" % f) @contextmanager def function(self, *functions): """ A convenient context manager to test with some specific functions. This drops the given functions if it exists. """ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." try: yield finally: for f in functions: self.spark.sql("DROP FUNCTION IF EXISTS %s" % f) @contextmanager def temp_env(self, pairs): assert isinstance(pairs, dict), "pairs should be a dictionary." keys = pairs.keys() new_values = pairs.values() old_values = [os.environ.get(key, None) for key in keys] for key, new_value in zip(keys, new_values): if new_value is None: if key in os.environ: del os.environ[key] else: os.environ[key] = new_value try: yield finally: for key, old_value in zip(keys, old_values): if old_value is None: if key in os.environ: del os.environ[key] else: os.environ[key] = old_value @staticmethod def assert_close(a, b): c = [j[0] for j in b] diff = [abs(v - c[k]) < 1e-6 if math.isfinite(v) else v == c[k] for k, v in enumerate(a)] assert sum(diff) == len(a), f"sum: {sum(diff)}, len: {len(a)}" class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils, PySparkErrorTestUtils): @classmethod def setUpClass(cls): super().setUpClass() cls._legacy_sc = cls.sc cls.spark = SparkSession(cls.sc) cls.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(cls.tempdir.name) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] cls.df = cls.spark.createDataFrame(cls.testData) @classmethod def tearDownClass(cls): try: cls.spark.stop() shutil.rmtree(cls.tempdir.name, ignore_errors=True) finally: super().tearDownClass() def tearDown(self): try: self.spark._jsparkSession.cleanupPythonWorkerLogs() finally: super().tearDown()