6.15. PySpark#

6.15.1. Distributed Data Joining with Shuffle Joins in PySpark#

Hide code cell content
!pip install 'pyspark[sql]'

Shuffle joins in PySpark distribute data across worker nodes, enabling parallel processing and improving performance compared to single-node joins. By dividing data into partitions and joining each partition simultaneously, shuffle joins can handle large datasets efficiently.

Here’s an example of performing a shuffle join in PySpark:

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
employees = spark.createDataFrame(
    [(1, "John", "Sales"), (2, "Jane", "Marketing"), (3, "Bob", "Engineering")],
    ["id", "name", "department"],
)

salaries = spark.createDataFrame([(1, 5000), (2, 6000), (4, 7000)], ["id", "salary"])

# Perform an inner join using the join key "id"
joined_df = employees.join(salaries, "id", "inner")

joined_df.show()
[Stage 7:>                                                          (0 + 8) / 8]
+---+----+----------+------+
| id|name|department|salary|
+---+----+----------+------+
|  1|John|     Sales|  5000|
|  2|Jane| Marketing|  6000|
+---+----+----------+------+
                                                                                

In this example, PySpark performs a shuffle join behind the scenes to combine the two DataFrames. The process involves partitioning the data based on the join key (β€œid”), shuffling the partitions across the worker nodes, performing local joins on each worker node, and finally merging the results.

6.15.2. Spark DataFrame: Avoid Out-of-Memory Errors with Lazy Evaluation#

Hide code cell content
!pip install 'pyspark[sql]'

Retrieving all rows from a large dataset into memory can cause out-of-memory errors. When creating a Spark DataFrame, computations are not executed until the collect() method is invoked. This allows you to reduce the size of the DataFrame through operations such as filtering or aggregating before bringing them into memory.

As a result, you can manage memory usage more efficiently and avoid unnecessary computations.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
df = spark.read.parquet('test_data.parquet')
df.show(5)
                                                                                
+---+----+----+
|cat|val1|val2|
+---+----+----+
|  b|   0|  34|
|  a|  58|  12|
|  c|  24|  72|
|  a|  20|  58|
|  b|  13|  17|
+---+----+----+
only showing top 5 rows
processed_df = df.filter(df["val1"] >= 50).groupBy("cat").agg({"val2": "mean"})
processed_df.collect()
                                                                                
[Row(cat='c', avg(val2)=49.54095055783208),
 Row(cat='b', avg(val2)=49.46593810642427),
 Row(cat='a', avg(val2)=49.52092805080465)]

6.15.3. Pandas-Friendly Big Data Processing with Spark#

!pip install "pyspark[pandas_on_spark]"

Spark enables scaling of your pandas workloads across multiple nodes. However, learning PySpark syntax can be daunting for pandas users.

Pandas API on Spark enables leveraging Spark’s capabilities for big data while retaining a familiar pandas-like syntax.

The following code compares the syntax between PySpark and the Pandas API on Spark.

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

Pandas API on Spark:

import numpy as np
import pyspark.pandas as ps
psdf = ps.DataFrame(
    {
        "A": ["foo", "bar", "foo"],
        "B": ["one", "one", "two"],
        "C": [0.1, 0.3, 0.5],
        "D": [0.2, 0.4, 0.6],
    }
)
psdf.sort_values(by='B')
A B C D
0 foo one 0.1 0.2
1 bar one 0.3 0.4
2 foo two 0.5 0.6
psdf.groupby('A').sum()
C D
A
foo 0.6 0.8
bar 0.3 0.4
psdf.query("C > 0.4")
A B C D
2 foo two 0.5 0.6
psdf[["C", "D"]].abs()
C D
0 0.1 0.2
1 0.3 0.4
2 0.5 0.6

PySpark:

from pyspark.sql.functions import col
from pyspark.sql.functions import abs
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
spark_data = spark.createDataFrame([
    ("foo", "one", 0.1, 0.2),
    ("bar", "one", 0.3, 0.4),
    ("foo", "two", 0.5, 0.6),
], ["A", "B", "C", "D"])
spark_data.sort(col('B')).show()
+---+---+---+---+
|  A|  B|  C|  D|
+---+---+---+---+
|foo|one|0.1|0.2|
|bar|one|0.3|0.4|
|foo|two|0.5|0.6|
+---+---+---+---+
spark_data.groupBy('A').sum().show()
[Stage 25:>                                                         (0 + 8) / 8]
+---+------+------+
|  A|sum(C)|sum(D)|
+---+------+------+
|foo|   0.6|   0.8|
|bar|   0.3|   0.4|
+---+------+------+
                                                                                
spark_data.filter(col('C') > 0.4).show()
+---+---+---+---+
|  A|  B|  C|  D|
+---+---+---+---+
|foo|two|0.5|0.6|
+---+---+---+---+
spark_data.select(abs(spark_data["C"]).alias("C"), abs(spark_data["D"]).alias("D"))
DataFrame[C: double, D: double]

6.15.4. PySpark SQL: Enhancing Reusability with Parameterized Queries#

Hide code cell content
!pip install "pyspark[sql]"

In PySpark, parametrized queries enable the same query structure to be reused with different inputs, without rewriting the SQL.

Additionally, they safeguard against SQL injection attacks by treating input data as parameters rather than as executable code.

from pyspark.sql import SparkSession
import pandas as pd 

spark = SparkSession.builder.getOrCreate()
# Create a Spark DataFrame
item_price_pandas = pd.DataFrame({"item_id": [1, 2, 3, 4], "price": [4, 2, 5, 1]})
item_price = spark.createDataFrame(item_price_pandas)
item_price.show()
                                                                                
+-------+-----+
|item_id|price|
+-------+-----+
|      1|    4|
|      2|    2|
|      3|    5|
|      4|    1|
+-------+-----+
query = """SELECT item_id, price 
FROM {item_price} 
WHERE item_id = {id_val} 
"""

spark.sql(query, id_val=1, item_price=item_price).show()
                                                                                
+-------+-----+
|item_id|price|
+-------+-----+
|      1|    4|
+-------+-----+
spark.sql(query, id_val=2, item_price=item_price).show()
+-------+-----+
|item_id|price|
+-------+-----+
|      2|    2|
+-------+-----+

6.15.5. Working with Arrays Made Easier in Spark 3.5#

Hide code cell content
!pip install "pyspark[sql]"

Spark 3.5 added new array helper functions that simplify the process of working with array data. Below are a few examples showcasing these new array functions.

Hide code cell content
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
Hide code cell source
from pyspark.sql import Row

df = spark.createDataFrame(
    [
        Row(customer="Alex", orders=["πŸ‹", "πŸ‹"]),
        Row(customer="Bob", orders=["🍊"]),
    ]
)

df.show()
                                                                                
+--------+--------+
|customer|  orders|
+--------+--------+
|    Alex|[πŸ‹, πŸ‹]|
|     Bob|    [🍊]|
+--------+--------+
from pyspark.sql.functions import (
    col,
    array_append,
    array_prepend,
    array_contains,
    array_distinct,
)

df.withColumn("orders", array_append(col("orders"), "πŸ‡")).show()
+--------+------------+
|customer|      orders|
+--------+------------+
|    Alex|[πŸ‹, πŸ‹, πŸ‡]|
|     Bob|    [🍊, πŸ‡]|
+--------+------------+
df.withColumn("orders", array_prepend(col("orders"), "πŸ‡")).show()
+--------+------------+
|customer|      orders|
+--------+------------+
|    Alex|[πŸ‡, πŸ‹, πŸ‹]|
|     Bob|    [πŸ‡, 🍊]|
+--------+------------+
df.withColumn("orders", array_distinct(col("orders"))).show()
+--------+------+
|customer|orders|
+--------+------+
|    Alex|  [πŸ‹]|
|     Bob|  [🍊]|
+--------+------+
df.withColumn("has_πŸ‹", array_contains(col("orders"), "πŸ‹")).show()
+--------+--------+------+
|customer|  orders|has_πŸ‹|
+--------+--------+------+
|    Alex|[πŸ‹, πŸ‹]|  true|
|     Bob|    [🍊]| false|
+--------+--------+------+

View other array functions.

6.15.6. Simplify Complex SQL Queries with PySpark UDFs#

Hide code cell content
!pip install "pyspark[sql]"
Hide code cell content
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

SQL queries can often become complex and challenging to comprehend.

df = spark.createDataFrame(
    [(1, "John Doe"), (2, "Jane Smith"), (3, "Bob Johnson")], ["id", "name"]
)

# Register the DataFrame as a temporary table or view
df.createOrReplaceTempView("df")

# Complex SQL query
spark.sql(
    """
    SELECT id, CONCAT(UPPER(SUBSTRING(name, 1, 1)), LOWER(SUBSTRING(name, 2))) AS modified_name
    FROM df
"""
).show()
                                                                                
+---+-------------+
| id|modified_name|
+---+-------------+
|  1|     John doe|
|  2|   Jane smith|
|  3|  Bob johnson|
+---+-------------+

Using PySpark UDFs simplifies complex SQL queries by encapsulating complex operations into a single function call, resulting in cleaner queries. UDFs also allow for the reuse of complex logic across different queries.

In the code example below, we define a UDF called modify_name that converts the name to uppercase.

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType


# Define a UDF to modify the name
@udf(returnType=StringType())
def modify_name(name):
    return name[0].upper() + name[1:].lower()

spark.udf.register('modify_name', modify_name)

# Apply the UDF in the spark.sql query
df.createOrReplaceTempView("df")

spark.sql("""
    SELECT id, modify_name(name) AS modified_name
    FROM df
"""
).show()
24/03/30 14:36:24 WARN SimpleFunctionRegistry: The function modify_name replaced a previously registered function.
                                                                                
+---+-------------+
| id|modified_name|
+---+-------------+
|  1|     John doe|
|  2|   Jane smith|
|  3|  Bob johnson|
+---+-------------+

Learn more about PySPark UDFs.

6.15.7. Leverage Spark UDFs for Reusable Complex Logic in SQL Queries#

Hide code cell content
!pip install "pyspark[sql]"
Hide code cell content
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StringType

# Create SparkSession
spark = SparkSession.builder.getOrCreate()

Duplicated code in SQL queries can lead to inconsistencies if changes are made to one instance of the duplicated code but not to others.

# Sample DataFrame
df = spark.createDataFrame(
    [("Product 1", 10.0, 5), ("Product 2", 15.0, 3), ("Product 3", 8.0, 2)],
    ["name", "price", "quantity"],
)

# Use df within Spark SQL queries
df.createOrReplaceTempView("products")

# Select Statement 1
result1 = spark.sql(
    """
    SELECT name, price, quantity,
        CASE
            WHEN price < 10.0 THEN 'Low'
            WHEN price >= 10.0 AND price < 15.0 THEN 'Medium'
            ELSE 'High'
        END AS category
    FROM products
"""
)

# Select Statement 2
result2 = spark.sql(
    """
    SELECT name,
        CASE
            WHEN price < 10.0 THEN 'Low'
            WHEN price >= 10.0 AND price < 15.0 THEN 'Medium'
            ELSE 'High'
        END AS category
    FROM products
    WHERE quantity > 3
"""
)

# Display the results
result1.show()
result2.show()
+---------+-----+--------+--------+
|     name|price|quantity|category|
+---------+-----+--------+--------+
|Product 1| 10.0|       5|  Medium|
|Product 2| 15.0|       3|    High|
|Product 3|  8.0|       2|     Low|
+---------+-----+--------+--------+

+---------+--------+
|     name|category|
+---------+--------+
|Product 1|  Medium|
+---------+--------+

Spark UDFs (User-Defined Functions) can help address these issues by encapsulating complex logic that is reused across multiple SQL queries.

In the code example above, we define a UDF assign_category_label that assigns category labels based on price. This UDF is then reused in two different SQL statements.

# Define UDF to assign category label based on price
@udf(returnType=StringType())
def assign_category_label(price):
    if price < 10.0:
        return "Low"
    elif price >= 10.0 and price < 15.0:
        return "Medium"
    else:
        return "High"


# Register UDF
spark.udf.register("assign_category_label", assign_category_label)

# Select Statement 1
result1 = spark.sql(
    """
    SELECT name, price, quantity, assign_category_label(price) AS category
    FROM products
"""
)

# Select Statement 2
result2 = spark.sql(
    """
    SELECT name, assign_category_label(price) AS category
    FROM products
    WHERE quantity > 3
"""
)

# Display the results
result1.show()
result2.show()
24/04/15 09:28:11 WARN SimpleFunctionRegistry: The function assign_category_label replaced a previously registered function.
+---------+-----+--------+--------+
|     name|price|quantity|category|
+---------+-----+--------+--------+
|Product 1| 10.0|       5|  Medium|
|Product 2| 15.0|       3|    High|
|Product 3|  8.0|       2|     Low|
+---------+-----+--------+--------+

+---------+--------+
|     name|category|
+---------+--------+
|Product 1|  Medium|
+---------+--------+

6.15.8. Simplify Unit Testing of SQL Queries with PySpark#

Hide code cell content
!pip install ipytest "pyspark[sql]"

Testing your SQL queries helps to ensure that they are correct and functioning as intended.

PySpark enables users to parameterize queries, which simplifies unit testing of SQL queries. In this example, the df and amount variables are parameterized to verify whether the actual_df matches the expected_df.

%%ipytest -qq
import pytest
from pyspark.testing import assertDataFrameEqual


@pytest.fixture
def query():
    return "SELECT * from {df} where price > {amount} AND name LIKE '%Product%';"


def test_query_return_correct_number_of_rows(query):

    spark = SparkSession.builder.getOrCreate()

    # Create a sample DataFrame
    df = spark.createDataFrame(
        [
            ("Product 1", 10.0, 5),
            ("Product 2", 15.0, 3),
            ("Product 3", 8.0, 2),
        ],
        ["name", "price", "quantity"],
    )

    # Execute the query
    actual_df = spark.sql(query, df=df, amount=10)

    # Assert the result
    expected_df = spark.createDataFrame(
        [
            ("Product 2", 15.0, 3),
        ],
        ["name", "price", "quantity"],
    )
    assertDataFrameEqual(actual_df, expected_df)
                                                                                
.                                                                                            [100%]

6.15.9. Update Multiple Columns in Spark 3.3 and Later#

Hide code cell content
!pip install -U "pyspark[sql]"
Hide code cell content
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.getOrCreate()
from pyspark.sql.functions import col, trim

# Create a sample DataFrame
data = [("   John   ", 35), ("Jane", 28)]
columns = ["first_name", "age"]
df = spark.createDataFrame(data, columns)
df.show()
+----------+---+
|first_name|age|
+----------+---+
|   John   | 35|
|      Jane| 28|
+----------+---+

Prior to PySpark 3.3, appending multiple columns to a Spark DataFrame required chaining multiple withColumn calls.

# Before Spark 3.3 
new_df = (df
          .withColumn("first_name", trim(col("first_name")))
          .withColumn("age_after_10_years", col("age") + 10)
         )

new_df.show()
+----------+---+------------------+
|first_name|age|age_after_10_years|
+----------+---+------------------+
|      John| 35|                45|
|      Jane| 28|                38|
+----------+---+------------------+

In PySpark 3.3 and later, you can use the withColumns method in a dictionary style to append multiple columns to a DataFrame. This syntax is more user-friendly for pandas users.

new_df = df.withColumns(
    {
        "first_name": trim(col("first_name")),
        "age_after_10_years": col("age") + 10,
    }
)

new_df.show()
                                                                                
+----------+---+------------------+
|first_name|age|age_after_10_years|
+----------+---+------------------+
|      John| 35|                45|
|      Jane| 28|                38|
+----------+---+------------------+

6.15.10. Vectorized Operations in PySpark: pandas_udf vs Standard UDF#

Hide code cell content
!pip install -U pyspark
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.getOrCreate()
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/23 10:51:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable

Standard UDF functions process data row-by-row, resulting in Python function call overhead.

In contrast, pandas_udf uses Pandas’ vectorized operations to process entire columns in a single operation, significantly improving performance.

# Sample DataFrame
data = [(1.0,), (2.0,), (3.0,), (4.0,)]
df = spark.createDataFrame(data, ["val1"])

df.show()
                                                                                
+----+
|val1|
+----+
| 1.0|
| 2.0|
| 3.0|
| 4.0|
+----+
from pyspark.sql.functions import udf

# Standard UDF
@udf('double')
def plus_one(val):
    return val + 1

# Apply the Standard UDF
df.withColumn('val2', plus_one(df.val1)).show()
                                                                                
+----+----+
|val1|val2|
+----+----+
| 1.0| 2.0|
| 2.0| 3.0|
| 3.0| 4.0|
| 4.0| 5.0|
+----+----+
from pyspark.sql.functions import pandas_udf
import pandas as pd


# Pandas UDF
@pandas_udf("double")
def pandas_plus_one(val: pd.Series) -> pd.Series:
    return val + 1


# Apply the Pandas UDF
df.withColumn("val2", pandas_plus_one(df.val1)).show()
                                                                                
+----+----+
|val1|val2|
+----+----+
| 1.0| 2.0|
| 2.0| 3.0|
| 3.0| 4.0|
| 4.0| 5.0|
+----+----+

Learn more about pandas_udf.

6.15.11. Optimizing PySpark Queries: DataFrame API or SQL?#

Hide code cell content
!pip install "pyspark[sql]"
Hide code cell content
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

PySpark queries with different syntax (DataFrame API or parameterized SQL) can have the same performance, as the physical plan is identical. Here is an example:

from pyspark.sql.functions import col

fruits = spark.createDataFrame(
    [("apple", 4), ("orange", 3), ("banana", 2)], ["item", "price"]
)
fruits.show()
+------+-----+
|  item|price|
+------+-----+
| apple|    4|
|orange|    3|
|banana|    2|
+------+-----+

Use the DataFrame API to filter rows where the price is greater than 3.

fruits.where(col("price") > 3).explain()
== Physical Plan ==
*(1) Filter (isnotnull(price#80L) AND (price#80L > 3))
+- *(1) Scan ExistingRDD[item#79,price#80L]

Use the spark.sql() method to execute an equivalent SQL query.

spark.sql("select * from {df} where price > 3", df=fruits).explain()
== Physical Plan ==
*(1) Filter (isnotnull(price#80L) AND (price#80L > 3))
+- *(1) Scan ExistingRDD[item#79,price#80L]

The physical plan for both queries is the same, indicating identical performance.

Thus, the choice between DataFrame API and spark.sql() depends on the following:

  • Familiarity: Use spark.sql() if your team prefers SQL syntax. Use the DataFrame API if chained method calls are more intuitive for your team.

  • Complexity of Transformations: The DataFrame API is more flexible for complex manipulations, while SQL is more concise for simpler queries.

6.15.12. Enhance Code Modularity and Reusability with Temporary Views in PySpark#

Hide code cell content
!pip install -U 'pyspark[sql]'
Hide code cell content
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.getOrCreate()
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/07/14 09:13:37 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable

In PySpark, temporary views are virtual tables that can be queried using SQL, enabling code reusability and modularity.

To demonstrate this, let’s create a PySpark DataFrame called orders_df.

# Create a sample DataFrame
data = [
    (1001, "John Doe", 500.0),
    (1002, "Jane Smith", 750.0),
    (1003, "Bob Johnson", 300.0),
    (1004, "Sarah Lee", 400.0),
    (1005, "Tom Wilson", 600.0),
]

columns = ["customer_id", "customer_name", "revenue"]
orders_df = spark.createDataFrame(data, columns)

Next, create a temporary view called orders from the orders_df DataFrame using the createOrReplaceTempView method.

# Create a temporary view
orders_df.createOrReplaceTempView("orders")

With the temporary view created, we can perform various operations on it using SQL queries.

# Perform operations on the temporary view
total_revenue = spark.sql("SELECT SUM(revenue) AS total_revenue FROM orders")
order_count = spark.sql("SELECT COUNT(*) AS order_count FROM orders")

# Display the results
print("Total Revenue:")
total_revenue.show()

print("\nNumber of Orders:")
order_count.show()
Total Revenue:
                                                                                
+-------------+
|total_revenue|
+-------------+
|       2550.0|
+-------------+


Number of Orders:
+-----------+
|order_count|
+-----------+
|          5|
+-----------+