# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import unittest from datetime import datetime from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ReusedSQLTestCase from pyspark.testing.utils import ( have_plotly, plotly_requirement_message, have_pandas, pandas_requirement_message, ) if have_plotly and have_pandas: import pyspark.sql.plot # noqa: F401 @unittest.skipIf( not have_plotly or not have_pandas, plotly_requirement_message or pandas_requirement_message ) class DataFramePlotPlotlyTestsMixin: @property def sdf(self): data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] columns = ["category", "int_val", "float_val"] return self.spark.createDataFrame(data, columns) @property def sdf2(self): data = [(5.1, 3.5, "0"), (4.9, 3.0, "0"), (7.0, 3.2, "1"), (6.4, 3.2, "1"), (5.9, 3.0, "2")] columns = ["length", "width", "species"] return self.spark.createDataFrame(data, columns) @property def sdf3(self): data = [ (3, 5, 20, datetime(2018, 1, 31)), (2, 5, 42, datetime(2018, 2, 28)), (3, 6, 28, datetime(2018, 3, 31)), (9, 12, 62, datetime(2018, 4, 30)), ] columns = ["sales", "signups", "visits", "date"] return self.spark.createDataFrame(data, columns) @property def sdf4(self): data = [ ("A", 50, 55), ("B", 55, 60), ("C", 60, 65), ("D", 65, 70), ("E", 70, 75), # outliers ("F", 10, 15), ("G", 85, 90), ("H", 5, 150), ] columns = ["student", "math_score", "english_score"] return self.spark.createDataFrame(data, columns) def _check_fig_data(self, fig_data, **kwargs): for key, expected_value in kwargs.items(): if key in ["x", "y", "labels", "values"]: converted_values = [v.item() if hasattr(v, "item") else v for v in fig_data[key]] self.assertEqual(converted_values, expected_value) else: self.assertEqual(fig_data[key], expected_value) def test_line_plot(self): # single column as vertical axis fig = self.sdf.plot(kind="line", x="category", y="int_val") expected_fig_data = { "mode": "lines", "name": "", "orientation": "v", "x": ["A", "B", "C"], "xaxis": "x", "y": [10, 30, 20], "yaxis": "y", "type": "scatter", } self._check_fig_data(fig["data"][0], **expected_fig_data) # multiple columns as vertical axis fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"]) expected_fig_data = { "mode": "lines", "name": "int_val", "orientation": "v", "x": ["A", "B", "C"], "xaxis": "x", "y": [10, 30, 20], "yaxis": "y", "type": "scatter", } self._check_fig_data(fig["data"][0], **expected_fig_data) expected_fig_data = { "mode": "lines", "name": "float_val", "orientation": "v", "x": ["A", "B", "C"], "xaxis": "x", "y": [1.5, 2.5, 3.5], "yaxis": "y", "type": "scatter", } self._check_fig_data(fig["data"][1], **expected_fig_data) def test_bar_plot(self): # single column as vertical axis fig = self.sdf.plot(kind="bar", x="category", y="int_val") expected_fig_data = { "name": "", "orientation": "v", "x": ["A", "B", "C"], "xaxis": "x", "y": [10, 30, 20], "yaxis": "y", "type": "bar", } self._check_fig_data(fig["data"][0], **expected_fig_data) # multiple columns as vertical axis fig = self.sdf.plot.bar(x="category", y=["int_val", "float_val"]) expected_fig_data = { "name": "int_val", "orientation": "v", "x": ["A", "B", "C"], "xaxis": "x", "y": [10, 30, 20], "yaxis": "y", "type": "bar", } self._check_fig_data(fig["data"][0], **expected_fig_data) expected_fig_data = { "name": "float_val", "orientation": "v", "x": ["A", "B", "C"], "xaxis": "x", "y": [1.5, 2.5, 3.5], "yaxis": "y", "type": "bar", } self._check_fig_data(fig["data"][1], **expected_fig_data) def test_barh_plot(self): # single column as vertical axis fig = self.sdf.plot(kind="barh", x="category", y="int_val") expected_fig_data = { "name": "", "orientation": "h", "x": ["A", "B", "C"], "xaxis": "x", "y": [10, 30, 20], "yaxis": "y", "type": "bar", } self._check_fig_data(fig["data"][0], **expected_fig_data) # multiple columns as vertical axis fig = self.sdf.plot.barh(x="category", y=["int_val", "float_val"]) expected_fig_data = { "name": "int_val", "orientation": "h", "x": ["A", "B", "C"], "xaxis": "x", "y": [10, 30, 20], "yaxis": "y", "type": "bar", } self._check_fig_data(fig["data"][0], **expected_fig_data) expected_fig_data = { "name": "float_val", "orientation": "h", "x": ["A", "B", "C"], "xaxis": "x", "y": [1.5, 2.5, 3.5], "yaxis": "y", "type": "bar", } self._check_fig_data(fig["data"][1], **expected_fig_data) # multiple columns as horizontal axis fig = self.sdf.plot.barh(x=["int_val", "float_val"], y="category") expected_fig_data = { "name": "int_val", "orientation": "h", "y": ["A", "B", "C"], "xaxis": "x", "x": [10, 30, 20], "yaxis": "y", "type": "bar", } self._check_fig_data(fig["data"][0], **expected_fig_data) expected_fig_data = { "name": "float_val", "orientation": "h", "y": ["A", "B", "C"], "xaxis": "x", "x": [1.5, 2.5, 3.5], "yaxis": "y", "type": "bar", } self._check_fig_data(fig["data"][1], **expected_fig_data) def test_scatter_plot(self): fig = self.sdf2.plot(kind="scatter", x="length", y="width") expected_fig_data = { "name": "", "orientation": "v", "x": [5.1, 4.9, 7.0, 6.4, 5.9], "xaxis": "x", "y": [3.5, 3.0, 3.2, 3.2, 3.0], "yaxis": "y", "type": "scatter", } self._check_fig_data(fig["data"][0], **expected_fig_data) expected_fig_data = { "name": "", "orientation": "v", "y": [5.1, 4.9, 7.0, 6.4, 5.9], "xaxis": "x", "x": [3.5, 3.0, 3.2, 3.2, 3.0], "yaxis": "y", "type": "scatter", } fig = self.sdf2.plot.scatter(x="width", y="length") self._check_fig_data(fig["data"][0], **expected_fig_data) def test_area_plot(self): # single column as vertical axis fig = self.sdf3.plot(kind="area", x="date", y="sales") expected_x = [ datetime(2018, 1, 31, 0, 0), datetime(2018, 2, 28, 0, 0), datetime(2018, 3, 31, 0, 0), datetime(2018, 4, 30, 0, 0), ] expected_fig_data = { "name": "", "orientation": "v", "x": expected_x, "xaxis": "x", "y": [3, 2, 3, 9], "yaxis": "y", "mode": "lines", "type": "scatter", } self._check_fig_data(fig["data"][0], **expected_fig_data) # multiple columns as vertical axis fig = self.sdf3.plot.area(x="date", y=["sales", "signups", "visits"]) expected_fig_data = { "name": "sales", "orientation": "v", "x": expected_x, "xaxis": "x", "y": [3, 2, 3, 9], "yaxis": "y", "mode": "lines", "type": "scatter", } self._check_fig_data(fig["data"][0], **expected_fig_data) expected_fig_data = { "name": "signups", "orientation": "v", "x": expected_x, "xaxis": "x", "y": [5, 5, 6, 12], "yaxis": "y", "mode": "lines", "type": "scatter", } self._check_fig_data(fig["data"][1], **expected_fig_data) expected_fig_data = { "name": "visits", "orientation": "v", "x": expected_x, "xaxis": "x", "y": [20, 42, 28, 62], "yaxis": "y", "mode": "lines", "type": "scatter", } self._check_fig_data(fig["data"][2], **expected_fig_data) def test_pie_plot(self): # single column as 'y' fig = self.sdf3.plot(kind="pie", x="date", y="sales") expected_x = [ datetime(2018, 1, 31, 0, 0), datetime(2018, 2, 28, 0, 0), datetime(2018, 3, 31, 0, 0), datetime(2018, 4, 30, 0, 0), ] expected_fig_data_sales = { "name": "", "labels": expected_x, "values": [3, 2, 3, 9], "type": "pie", } self._check_fig_data(fig["data"][0], **expected_fig_data_sales) # all numeric columns as 'y' expected_fig_data_signups = { "name": "", "labels": expected_x, "values": [5, 5, 6, 12], "type": "pie", } expected_fig_data_visits = { "name": "", "labels": expected_x, "values": [20, 42, 28, 62], "type": "pie", } fig = self.sdf3.plot(kind="pie", x="date", subplots=True) self._check_fig_data(fig["data"][0], **expected_fig_data_sales) self._check_fig_data(fig["data"][1], **expected_fig_data_signups) self._check_fig_data(fig["data"][2], **expected_fig_data_visits) # not specify subplots with self.assertRaises(PySparkValueError) as pe: self.sdf3.plot(kind="pie", x="date") self.check_error( exception=pe.exception, errorClass="UNSUPPORTED_PIE_PLOT_PARAM", messageParameters={} ) # y is not a numerical column with self.assertRaises(PySparkTypeError) as pe: self.sdf.plot.pie(x="int_val", y="category") self.check_error( exception=pe.exception, errorClass="PLOT_INVALID_TYPE_COLUMN", messageParameters={ "col_name": "category", "valid_types": "NumericType", "col_type": "StringType", }, ) def test_box_plot(self): fig = self.sdf4.plot.box(column="math_score") expected_fig_data1 = { "boxpoints": "suspectedoutliers", "lowerfence": (5,), "mean": (50.0,), "median": (55,), "name": "math_score", "notched": False, "q1": (10,), "q3": (65,), "upperfence": (85,), "x": [0], "type": "box", } self._check_fig_data(fig["data"][0], **expected_fig_data1) fig = self.sdf4.plot(kind="box", column=["math_score", "english_score"]) self._check_fig_data(fig["data"][0], **expected_fig_data1) expected_fig_data2 = { "boxpoints": "suspectedoutliers", "lowerfence": (55,), "mean": (72.5,), "median": (65,), "name": "english_score", "notched": False, "q1": (55,), "q3": (75,), "upperfence": (90,), "x": [1], "y": [[150, 15]], "type": "box", } self._check_fig_data(fig["data"][1], **expected_fig_data2) fig = self.sdf4.plot(kind="box") self._check_fig_data(fig["data"][0], **expected_fig_data1) self._check_fig_data(fig["data"][1], **expected_fig_data2) with self.assertRaises(PySparkValueError) as pe: self.sdf4.plot.box(column="math_score", boxpoints=True) self.check_error( exception=pe.exception, errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", messageParameters={ "backend": "plotly", "param": "boxpoints", "value": "True", "supported_values": ", ".join(["suspectedoutliers", "False"]), }, ) with self.assertRaises(PySparkValueError) as pe: self.sdf4.plot.box(column="math_score", notched=True) self.check_error( exception=pe.exception, errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", messageParameters={ "backend": "plotly", "param": "notched", "value": "True", "supported_values": ", ".join(["False"]), }, ) def test_kde_plot(self): fig = self.sdf4.plot.kde(column="math_score", bw_method=0.3, ind=5) expected_fig_data1 = { "mode": "lines", "name": "math_score", "orientation": "v", "xaxis": "x", "yaxis": "y", "type": "scatter", } self._check_fig_data(fig["data"][0], **expected_fig_data1) fig = self.sdf4.plot.kde(column=["math_score", "english_score"], bw_method=0.3, ind=5) self._check_fig_data(fig["data"][0], **expected_fig_data1) expected_fig_data2 = { "mode": "lines", "name": "english_score", "orientation": "v", "xaxis": "x", "yaxis": "y", "type": "scatter", } self._check_fig_data(fig["data"][1], **expected_fig_data2) self.assertEqual(list(fig["data"][0]["x"]), list(fig["data"][1]["x"])) fig = self.sdf4.plot.kde(bw_method=0.3, ind=5) self._check_fig_data(fig["data"][0], **expected_fig_data1) self._check_fig_data(fig["data"][1], **expected_fig_data2) self.assertEqual(list(fig["data"][0]["x"]), list(fig["data"][1]["x"])) def test_hist_plot(self): fig = self.sdf2.plot.hist(column="length", bins=4) expected_fig_data = { "name": "length", "x": [5.1625000000000005, 5.6875, 6.2125, 6.7375], "y": [2, 1, 1, 1], "text": ("[4.9, 5.425)", "[5.425, 5.95)", "[5.95, 6.475)", "[6.475, 7.0]"), "type": "bar", } self._check_fig_data(fig["data"][0], **expected_fig_data) fig = self.sdf2.plot.hist(column=["length", "width"], bins=4) expected_fig_data1 = { "name": "length", "x": [3.5, 4.5, 5.5, 6.5], "y": [0, 1, 2, 2], "text": ("[3.0, 4.0)", "[4.0, 5.0)", "[5.0, 6.0)", "[6.0, 7.0]"), "type": "bar", } self._check_fig_data(fig["data"][0], **expected_fig_data1) expected_fig_data2 = { "name": "width", "x": [3.5, 4.5, 5.5, 6.5], "y": [5, 0, 0, 0], "text": ("[3.0, 4.0)", "[4.0, 5.0)", "[5.0, 6.0)", "[6.0, 7.0]"), "type": "bar", } self._check_fig_data(fig["data"][1], **expected_fig_data2) fig = self.sdf2.plot.hist(bins=4) self._check_fig_data(fig["data"][0], **expected_fig_data1) self._check_fig_data(fig["data"][1], **expected_fig_data2) def test_process_column_param_errors(self): with self.assertRaises(PySparkTypeError) as pe: self.sdf4.plot.box(column="math_scor") self.check_error( exception=pe.exception, errorClass="PLOT_INVALID_TYPE_COLUMN", messageParameters={ "col_name": "math_scor", "valid_types": "NumericType", "col_type": "None", }, ) with self.assertRaises(PySparkTypeError) as pe: self.sdf4.plot.box(column="student") self.check_error( exception=pe.exception, errorClass="PLOT_INVALID_TYPE_COLUMN", messageParameters={ "col_name": "student", "valid_types": "NumericType", "col_type": "StringType", }, ) class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): pass if __name__ == "__main__": from pyspark.testing import main main()