There are many different tools in the world, each of which solves a range of problems. Many of them are judged by how well and correct they solve this or that problem, but there are tools that you just like, you want to use them. They are properly designed and fit well in your hand, you do not need to dig into the documentation and understand how to do this or that simple action. About one of these tools for me I will be writing this series of posts.

I will describe the optimization methods and tips that help me solve certain technical problems and achieve high efficiency using Apache Spark. This is my updated collection.

Many of the optimizations that I will describe will not affect the JVM languages ​​so much, but without these methods, many Python applications may simply not work.

Whole series:


Don't collect data on driver

If your RDD/DataFrame is so large that all its elements will not fit into the driver machine memory, do not do the following:

data = df.collect()

Collect action will try to move all data in RDD/DataFrame to the machine with the driver and where it may run out of memory and crash.

Spark driver meme.

Instead, you can make sure that the number of items returned is sampled by calling take or takeSample, or perhaps by filtering your RDD/DataFrame.

Similarly, be careful with other actions if you are not sure that your dataset is small enough to fit into the driver memory:

  • countByKey
  • countByValue
  • collectAsMap

Broadcasting large variables

From the docs:

Broadcast variables allow the programmer to keep a read-only variable cached on each machine rather than shipping a copy of it with tasks.

The use of broadcast variables available in SparkContext can significantly reduce the size of each serialized task, as well as the cost of running the task on the cluster. If your tasks use a large object from the driver program (e.g. a static search table, a large list), consider turning it into a broadcast variable.

If you don't, the same variable will be sent to the executor separately for each partition. Broadcast variables allow the programmer to cache a read-only variable, in a deserialized form on each machine, instead of sending a copy of the variable with tasks.

Spark prints the serialized size of each task on the application master, so you can check this out to see if your tasks are too large; in general, tasks over 20KB in size are probably worth optimizing.

Example of usage:

import contextlib

@contextlib.contextmanager
def persisted_broadcast(sc, variable):
    variable_b = sc.broadcast(variable)
    yield variable_b
    variable_b.unpersist()
    
with persisted_broadcast(sc, my_data) as broadcasted_data:
    pass

It should be noted that Spark has a ContextCleaner, which is run at periodic intervals to remove broadcast variables if they are not used.

Use the best suitable file format

To increase productivity, be wise in choosing file formats. Depending on the specific application or individual functionality of your Spark jobs, the formats may vary.

Many formats have their own specifics, e.g. Avro has easy serialization/deserialization, which allows for efficient integration of ingestion processes. Meanwhile, Parquet allows you to work effectively when selecting specific columns and can be effective for storing intermediate files. But the parquet files are immutable, modifications require overwriting the whole data set, however, Avro files can easily cope with frequent schema changes.

json meme.

When reading CSV and JSON files, you get better performance by specifying the schema, instead of using the inference mechanism - specifying the schema reduces errors and is recommended for production code.

Example code:

from pyspark.sql.types import (StructType, StructField, 
    DoubleType, IntegerType, StringType)

schema = StructType([   
    StructField('A', IntegerType(), nullable=False),    
    StructField('B', DoubleType(), nullable=False),    
    StructField('C', StringType(), nullable=False)
])

df = sc.read.csv('/some/input/file.csv', inferSchema=False)

Use the right compression for files

The types of files we deal with can be divided into two types

  • Splittable ( eg. lso, bzip2, snappy)
  • Non-splittable ( eg. gzip, zip, lz4)

For discussion purposes, "splittable files" means that they can be processed in parallel in a distributed manner rather than on a single machine (non-splittable).

Compression

Do not use large source files in zip/gzip format, they are not splittable. It is not possible to read such files in parallel with Spark. First, Spark needs to download the whole file on one executor, unpack it on just one core, and then redistribute the partitions to the cluster nodes. As you can imagine, this becomes a huge bottleneck in your distributed processing. If the files are stored on HDFS, you should unpack them before downloading them to Spark.

Bzip2 files have a similar problem. Even though they are splittable, they are so compressed that you get very few partitions and therefore they can be poorly distributed. Bzip2 is used if there are no limits on compression time and CPU load, for example for one-time packaging of large amounts of data.

After all, we see that uncompressed files are clearly outperforming compressed files. This is because uncompressed files are I/O bound, and compressed files are CPU bound, but I/O is good enough here. Apache Spark supports this quite well, but other libraries and data warehouses may not.

Avoid reduceByKey when the input and output value types are different

If for any reason you have RDD-based jobs, use wisely reduceByKey operations.

Consider the job of creating a set of strings for each key:

rdd.map(lambda p: (p[0], {p[1]})) \
    .reduceByKey(lambda x, y: x | y) \
    .collect()

Note that the input values are strings and the output values are sets. The map operation creates lots of temporary small objects. A better way to handle this scenario is to use aggregateByKey:

def merge_vals(xs, x):
    xs.add(x)
    return xs
	
def combine(xs, ys):
    return xs | ys

rdd.aggregateByKey(set(), merge_vals, combine).collect()

Don't use count when you don't need to return the exact number of rows

When you don't need to return the exact number of rows use:

df = sqlContext.read().json(...);
if not len(df.take(1)):
    ...
    
or
    
if df.rdd.isEmpty():
    ...

instead of

if not df.count():
    ...

In RDD you can use isEmpty()

Materials:

Tuning Spark docs

Additional links:

Uber Case Study: Choosing the Right HDFS File Format for Your Apache Spark Jobs


Buy me a coffee