# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import inspect from typing import TYPE_CHECKING, Any, List, Optional, Union from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql.plot import ( PySparkPlotAccessor, PySparkBoxPlotBase, PySparkKdePlotBase, PySparkHistogramPlotBase, ) from pyspark.sql.types import NumericType if TYPE_CHECKING: from pyspark.sql import DataFrame from plotly.graph_objs import Figure def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": import plotly if kind == "pie": return plot_pie(data, **kwargs) if kind == "box": return plot_box(data, **kwargs) if kind == "kde" or kind == "density": return plot_kde(data, **kwargs) if kind == "hist": return plot_histogram(data, **kwargs) if kind not in PySparkPlotAccessor.plot_data_map: raise PySparkValueError( errorClass="UNSUPPORTED_PLOT_KIND", messageParameters={ "plot_type": kind, "supported_plot_types": ", ".join( sorted( list(PySparkPlotAccessor.plot_data_map.keys()) + ["pie", "box", "kde", "density", "hist"] ) ), }, ) return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure": from plotly import express pdf = PySparkPlotAccessor.plot_data_map["pie"](data) x = kwargs.pop("x", None) y = kwargs.pop("y", None) subplots = kwargs.pop("subplots", False) if y is None and not subplots: raise PySparkValueError(errorClass="UNSUPPORTED_PIE_PLOT_PARAM", messageParameters={}) numeric_ys = process_column_param(y, data) if subplots: # One pie chart per numeric column from plotly.subplots import make_subplots fig = make_subplots( rows=1, cols=len(numeric_ys), # To accommodate domain-based trace - pie chart specs=[[{"type": "domain"}] * len(numeric_ys)], ) for i, y_col in enumerate(numeric_ys): subplot_fig = express.pie(pdf, values=y_col, names=x, **kwargs) fig.add_trace( subplot_fig.data[0], row=1, col=i + 1 ) # A single pie chart has only one trace else: fig = express.pie(pdf, values=numeric_ys[0], names=x, **kwargs) return fig def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure": import plotly.graph_objs as go # 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like # plotly doesn't expose the reach of the whiskers to the beyond the first and # third quartiles (?). Looks they use default 1.5. whis = kwargs.pop("whis", 1.5) # 'precision' is pyspark specific to control precision for approx_percentile precision = kwargs.pop("precision", 0.01) colnames = process_column_param(kwargs.pop("column", None), data) # Plotly options boxpoints = kwargs.pop("boxpoints", "suspectedoutliers") notched = kwargs.pop("notched", False) if boxpoints not in ["suspectedoutliers", False]: raise PySparkValueError( errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", messageParameters={ "backend": "plotly", "param": "boxpoints", "value": str(boxpoints), "supported_values": ", ".join(["suspectedoutliers", "False"]), }, ) if notched: raise PySparkValueError( errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", messageParameters={ "backend": "plotly", "param": "notched", "value": str(notched), "supported_values": ", ".join(["False"]), }, ) fig = go.Figure() results = PySparkBoxPlotBase.compute_box( data, colnames, whis, precision, boxpoints is not None, ) assert len(results) == len(colnames) # type: ignore for i, colname in enumerate(colnames): result = results[i] # type: ignore fig.add_trace( go.Box( x=[i], name=colname, q1=[result["q1"]], median=[result["med"]], q3=[result["q3"]], mean=[result["mean"]], lowerfence=[result["lower_whisker"]], upperfence=[result["upper_whisker"]], y=[result["fliers"]] if result["fliers"] else None, boxpoints=boxpoints, notched=notched, **kwargs, ) ) fig["layout"]["yaxis"]["title"] = "value" return fig def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure": from pyspark.testing.utils import have_numpy from pyspark.sql.pandas.utils import require_minimum_pandas_version require_minimum_pandas_version() import pandas as pd from plotly import express if "color" not in kwargs: kwargs["color"] = "names" bw_method = kwargs.pop("bw_method", None) colnames = process_column_param(kwargs.pop("column", None), data) ind = PySparkKdePlotBase.get_ind(data.select(*colnames), kwargs.pop("ind", None)) if have_numpy: import numpy as np if isinstance(ind, np.ndarray): ind = [float(i) for i in ind] kde_cols = [ PySparkKdePlotBase.compute_kde_col( input_col=data[col_name], ind=ind, bw_method=bw_method, ).alias(f"kde_{i}") for i, col_name in enumerate(colnames) ] kde_results = data.select(*kde_cols).first() pdf = pd.concat( [ pd.DataFrame( # type: ignore { "Density": kde_result, "names": col_name, "index": ind, } ) for col_name, kde_result in zip(colnames, list(kde_results)) # type: ignore[arg-type] ] ) fig = express.line(pdf, x="index", y="Density", **kwargs) fig["layout"]["xaxis"]["title"] = None return fig def plot_histogram(data: "DataFrame", **kwargs: Any) -> "Figure": import plotly.graph_objs as go bins = kwargs.get("bins", 10) colnames = process_column_param(kwargs.pop("column", None), data) numeric_data = data.select(*colnames) bins = PySparkHistogramPlotBase.get_bins(numeric_data, bins) assert len(bins) > 2, "the number of buckets must be higher than 2." output_series = PySparkHistogramPlotBase.compute_hist(numeric_data, bins) prev = float("%.9f" % bins[0]) # to make it prettier, truncate. text_bins = [] for b in bins[1:]: norm_b = float("%.9f" % b) text_bins.append("[%s, %s)" % (prev, norm_b)) prev = norm_b text_bins[-1] = text_bins[-1][:-1] + "]" # replace ) to ] for the last bucket. bins = [(bins[i] + bins[i + 1]) / 2 for i in range(0, len(bins) - 1)] output_series = list(output_series) bars = [] for series in output_series: bars.append( go.Bar( x=bins, y=series, name=series.name, text=text_bins, hovertemplate=("variable=" + str(series.name) + "
value=%{text}
count=%{y}"), ) ) layout_keys = inspect.signature(go.Layout).parameters.keys() layout_kwargs = {k: v for k, v in kwargs.items() if k in layout_keys} fig = go.Figure(data=bars, layout=go.Layout(**layout_kwargs)) fig["layout"]["barmode"] = "stack" fig["layout"]["xaxis"]["title"] = "value" fig["layout"]["yaxis"]["title"] = "count" return fig def process_column_param(column: Optional[Union[str, List[str]]], data: "DataFrame") -> List[str]: """ Processes the provided column parameter for a DataFrame. - If `column` is None, returns a list of numeric columns from the DataFrame. - If `column` is a string, converts it to a list first. - If `column` is a list, it checks if all specified columns exist in the DataFrame and are of NumericType. - Raises a PySparkTypeError if any column in the list is not present in the DataFrame or is not of NumericType. """ fields_by_name = {f.name: f for f in data.schema.fields} if column is None: return [name for name, f in fields_by_name.items() if isinstance(f.dataType, NumericType)] if isinstance(column, str): column = [column] for col in column: field = fields_by_name.get(col) if not field or not isinstance(field.dataType, NumericType): raise PySparkTypeError( errorClass="PLOT_INVALID_TYPE_COLUMN", messageParameters={ "col_name": col, "valid_types": NumericType.__name__, "col_type": field.dataType.__class__.__name__ if field else "None", }, ) return column