Home / Spark tips. Don't collect data on driver

Spark tips. Don't collect data on driver

Spark tips. Don't collect data on driver

There are many different tools in the world, each of which solves a range of problems. Many of them are judged by how well and correct they solve this or that problem, but there are tools that you just like, you want to use them. They are properly designed and fit well in your hand, you do not need to dig into the documentation and understand how to do this or that simple action. About one of these tools for me I will be writing this series of posts.

I will describe the optimization methods and tips that help me solve certain technical problems and achieve high efficiency. This is my updated collection.

Many of the optimizations that I will describe will not affect the JVM languages ​​so much, but without these methods, many Python applications may simply not work.

Whole series:

Don't collect data on driver

If your RDD/DataFrame is so large that all of its elements won't fit in memory on the drive machine, don't do this:

data = df.collect()

Collect will attempt to copy all the data in the RDD/DataFrame into the driver machine and may run out of memory and crash.

Spark driver meme

Instead, you can make sure the number of elements you return is reduced by calling take or takeSample, or perhaps filtering or sampling your RDD/DataFrame.

Similarly, be cautious of the other actions as well unless you are sure your dataset size is small enough to fit in driver memory:

  • countByKey
  • countByValue
  • collectAsMap

Broadcasting large variables

From the docs:

Broadcast variables allow the programmer to keep a read-only variable cached on each machine rather than shipping a copy of it with tasks.

Using the broadcast functionality available in SparkContext can greatly reduce the size of each serialized task, and the cost of launching a job on a cluster. If your tasks use any large object from the driver program (e.g. a static lookup table, big enumerable), consider turning them into broadcast variables.

If you don’t, this same variable will be sent separately to the executor for each partition. Broadcast variables allow the programmer to keep a read-only variable cached, in deserialized form, on each machine rather than shipping a copy of it with tasks.

Spark prints the serialized size of each task on the master, so you can look at that to decide whether your tasks are too large; in general tasks larger than about 20 KB are probably worth optimizing.

Example of usage:

import contextlib

def persisted_broadcast(sc, variable):
    variable_b = sc.broadcast(variable)
    yield variable_b
with persisted_broadcast(sc, my_data) as broadcasted_data:

It needs to be mentioned that Spark has ContextCleaner which runs on a periodic interval to delete the broadcast variable if it is not being used.

Use the best suitable file format

To increase productivity, be wise in choosing file formats. Depending on the specifics of the application or the individual functionality of your spark jobs, the formats may vary.

Many formats have their own specifics, for example, Avro has a lightweight serialization/deserialization that allows you to effectively build ingestion processes. Meanwhile, Parquet allows you to work efficiently when selecting individual columns and can be effective in storing intermediate files. But parquet files are immutable, modifications require a rewrite of the whole dataset, however, Avro files can easily handle frequently changing the schema.

json meme

When reading CSV and JSON files, you will get better performance by specifying the schema, instead of using the inference mechanism — specifying the schema reduces errors and is recommended for production code.

Example code:

from pyspark.sql.types import (StructType, StructField, 
    DoubleType, IntegerType, StringType)

schema = StructType([   
    StructField("A", IntegerType(), nullable=False),    
    StructField("B", DoubleType(), nullable=False),    
    StructField("C", StringType(), nullable=False)

df = sc.read.csv("/some/input/file.csv", inferSchema=False)

Use the right compression for files

The types of files we deal with can be divided into two types

  • Splittable ( eg. lso, bzip2, snappy)
  • Non- Splittable ( eg. gzip, zip, lz4)

For the purpose of the discussion, "splittable files" mean that they can be processed in parallel in a distributed fashion rather than in one machine(non-splittable).


Do not use big zip/gzip source files, they are not-splittable. There is no way to read such files in parallel by Spark. Spark needs to download the whole file first, unzip it by only one core and then repartition chunks to the cluster nodes. As you may think it becoming a huge bottleneck of your distributed processing. If you come across such cases, it is a good idea to move the files from s3 into HDFS and unzip it. If files stored in HDFS, you should unzip them before loading into Spark.

Bzip2 files suffer from a similar problem. Although they are splittable, they are so much compressed that you get very few partitions and hence can be poorly distributed. Bzip2 is used if there are no limitations on the compression time and CPU load, for example, for one-time packaging of a large amount of data.

In the end, we see that uncompressed files clearly outperform compressed files. This is because uncompressed files are I/O bound and compressed files are CPU bound, but I/Os are good enough here. Apache Spark supports it quite well but other libraries and data stores may not.

Avoid reduceByKey when the input and output value types are different

If for any reason you have RDD-based jobs, use wisely reduceByKey operations.

Consider the job of creating a set of strings for each key:

rdd.map(lambda p: (p[0], {p[1]})) \
    .reduceByKey(lambda x, y: x | y) \

Note that the input values are strings and the output values are sets. The map operation creates lots of temporary small objects. A better way to handle this scenario is to use aggregateByKey:

def merge_vals(xs, x):
    return xs
def combine(xs, ys):
    return xs | ys

rdd.aggregateByKey(set(), merge_vals, combine).collect()

Don't use count when you don't need to return the exact number of rows

When you don't need to return the exact number of rows use:

df = sqlContext.read().json(...);
if not len(df.take(1)):
if df.rdd.isEmpty():

instead of

if not df.count():

In RDD you can use isEmpty()


Tuning Spark docs

Additional links:

Uber Case Study: Choosing the Right HDFS File Format for Your Apache Spark Jobs

Support author