There are many different tools in the world, each of which solves a range of problems. Many of them are judged by how well and correct they solve this or that problem, but there are tools that you just like, you want to use them. They are properly designed and fit well in your hand, you do not need to dig into the documentation and understand how to do this or that simple action. About one of these tools for me I will be writing this series of posts.

I will describe the optimization methods and tips that help me solve certain technical problems and achieve high efficiency using Apache Spark. This is my updated collection.

Many of the optimizations that I will describe will not affect the JVM languages ​​so much, but without these methods, many Python applications may simply not work.

Whole series:


Clusters will not be fully utilized unless you set the level of parallelism for each operation high enough. The general recommendation for Spark is to have 4x of partitions to the number of cores in cluster available for application, and for upper bound — the task should take 100ms+ time to execute. If it is taking less time than your partitioned data is too small and your application might be spending more time in distributing the tasks.

Too small and too many partitions have certain disadvantages. It is therefore recommended that you partition wisely depending on the configuration and requirements of the cluster.

Too few partitions — not utilizing all cores available in the cluster.

Too many partitions — excessive overhead in managing many small tasks as well as data movement.

When in doubt, it's almost always better to be wrong on the side of a larger number of tasks (and thus partitions).

Even though those recommendations are still reasonable, it is still very case-sensitive, so Spark configuration must be fine-tuned according to a given scenario.

Factors affecting partitioning

When you come to such details of working with Spark, you should understand the following parts of your Spark pipeline, which will eventually affect the choice of partitioning the data:

  • Business logic
  • Data
  • Environment

Next, I'll go through these levels and give some tips on each of them.

Business logic

Let's start with the most varied point — business logic. Here it is difficult to give any specifics as the specifics for each particular pipeline may be different.

Reduce the working dataset size

The most effective advice is usually the dumbest — to speed up your Spark pipeline, reduce the amount of data the Spark cluster needs to process. There are several ways to do this, and of course, it all depends on what really needs to be done with the data.

It is possible to filter your data by skipping the partitions if they do not meet your condition (assuming of course that the data was partitioned). A properly selected condition can significantly speed up reading and retrieval of the necessary data. In some cases (e.g. s3) avoids unnecessary partition discovery, in some cases, it may help to use built-in data format mechanisms (e.g. partition pruning).

Repartition before multiple joins

Join is one of the most expensive operations that are usually widely used in Spark, all to blame as always infamous shuffle. We can talk about shuffle for more than one post, here we will discuss side related to partitions.

Shuffle

In order to join data, Spark needs data with the same condition on the same partition. The default implementation of the join in Spark since version 2.3 is sort-merge join.

Sort merge join is executed in three basic steps:

  1. It is necessary that the data on each partition has the same key values, so the partitions have to be co-located (in this context it is the same as co-partitioned). This is done by shuffling the data.
  2. Sort the data within each partition in parallel.
  3. Join the sorted and partitioned data. This is basically the merging of a dataset by iterating over the elements and joining the rows having the same value for the join key.

Although this approach always works, it may be more expensive than necessary as it requires a shuffle. Shuffle can be avoided if:

  1. Both dataframes have a common Partitioner.
  2. One of the dataframes is small enough to fit into the memory, in which case we can use a broadcast hash join.

You can repartition dataframe after the load if you know that you will join several times to it. Always persist after repartitioning.

users = spark.read.load('/path/to/users').repartition('userId')
joined1 = users.join(addresses, 'userId')
joined1.show() # <-- 1st shuffle for repartition
joined2 = users.join(salary, 'userId')
joined2.show() # <-- skips shuffle for users since it's already been repartitioned

This way, you will shuffle the data once and then reuse the shuffled data for the next join. You can also use bucketing for that purpose.

Repartition after flatMap

The result of flatMap operation is usually RDD with many more rows, but the number of partitions remains the same. This means that the memory load per partition can become too large and you can see all the goodies of disk spills and GC breaks. In this case, it is better to repartition the output of flatMap based on the predicted memory expansion.

Get rid of disk spills

From the Tuning Spark docs:

Sometimes, you will get an OutOfMemoryError, not because your RDDs don’t fit in memory, but because the working set of one of your tasks, such as one of the reduce tasks in groupByKey, was too large. Spark’s shuffle operations (sortByKey, groupByKey, reduceByKey, join, etc) build a hash table within each task to perform the grouping, which can often be large...

If spark.shuffle.spill is true(which is the default) Spark will be using ExternalAppendOnlyMap in shuffle process for storing intermediate data. This structure can spill the data on disk when there isn't enough memory available, which increases the memory pressure on the executor resulting in the additional overhead of disk I/O and increased garbage collection. If you use Pyspark, the memory pressure will also increase the chance of Python running out of memory.

Disk spills

To check if disk spilling occurred, you can search for the similar entries in logs:

INFO ExternalSorter: Task 1 force spilling in-memory map to disk it will release 232.1 MB memory

The fixes can be the following:

  1. reducing the size of the data by, for example, selecting only required columns or moving filtering operations before the wide transformations.
  2. increasing the level of parallelism so that each task’s input data is smaller. Or you can manually repartition() your prior stage.
  3. Increase the shuffle buffer by increasing the memory of your executor processes (spark.executor.memory).
  4. If the available memory resources are sufficient, you can increase the size of spark.shuffle.file.buffer, so as to reduce the number of times the buffers overflow during the shuffle write process, which can reduce the number of disks I/O times.

Data

Data

source

In data management, the data is an essential component of any work. And it always makes its adjustments. Without knowing your data you are trying to tweak your Spark pipeline to some common denominator, which usually doesn't work out well, both in terms of speed and in terms of utilization of available resources.

Data Skew

Ideally, when Spark performs for example a join, the join keys will be evenly distributed among the partitions. However, real data is rarely ideal, ahem... at least somewhat good. Often we come across skewed data which leads to performance degradation of parallel processing or even OOM crashes.

Data Skew

This is not a problem specific to Spark, but rather a data problem — the performance of distributed systems depends heavily on how distributed the data is. One way to ensure more or less correct distribution is to explicitly repartition the data. As it is a very expensive operation we don't want to execute it where it is not needed. We can set the validation of the number of RDD/DataFrame partitions right before performing any heavy operation. You can use the following code as an example:

def repartition(
    df: DataFrame, 
    min_partitions: int, 
    max_partitions: int
) -> DataFrame:
    num_partitions = df.rdd.getNumPartitions()
    if(num_partitions < min_partitions):
        df = df.repartition(min_partitions)
    elif(num_partitions > max_partitions):
        df = df.coalesce(max_partitions)
    return df

Often data is broken down into partitions based on the key, e.g. day of the week, country, etc. If the values are not evenly distributed on this key, then more data will be placed in one partition than in another.

Let's take a look at the following snippet

import pandas as pd
import numpy as np
from pyspark.sql import functions as F

# set smaller number of partitions so they can fit the screen
spark.conf.set('spark.sql.shuffle.partitions', 8)
# disable broadcast join to see the shuffle
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

length = 100
names = np.random.choice(['Bob', 'James', 'Marek', 'Johannes', None], length)
amounts = np.random.randint(0, 1000000, length)
# generate skewed data
country = np.random.choice(
    ['United Kingdom', 'Poland', 'USA', 'Germany', 'Russia'], 
    length,
    p = [0.05, 0.05, 0.8, 0.05, 0.05]
)
data = pd.DataFrame({'name': names, 'amount': amounts, 'country': country})

transactions = spark.createDataFrame(data).repartition('country')
countries = spark.createDataFrame(pd.DataFrame({
    'id': [11, 12, 13, 14, 15], 
    'country': ['United Kingdom', 'Poland', 'USA', 'Germany', 'Russia']
}))

df = transactions.join(countries, 'country')

# check the partitions data
for i, part in enumerate(df.rdd.glom().collect()):
    print({i: part})

Here we see that one partition out of 8 gets a lot of data and the rest gets very little or nothing. It smells skewed.

We can diagnose skewness by looking at the Spark UI and checking the time spent on each task. Some tasks are time-consuming, while others are not. Also, in the Spark UI, an indicator of data skew can be the huge difference between the minimum task time and the maximum task time.

To resolve the problem of data skew, we can either:

  1. Repartition data on a more evenly distributed key if you have such.
  2. Broadcast the smaller dataframe if possible
  3. Split data into skewed and non-skewed data and work with them in parallel by redistributing skewed data (differential replication)
  4. Use an additional random key for better distribution of the data(salting).
  5. Iterative broadcast join
  6. More complicated methods that I have never used in my life.

Differential replication

The rare keys don't need to be replicated as much as skewed ones. The idea here is to identify the frequent keys before replication, then use a different replication policy for them.

replication_high = 7
high = F.broadcast(spark.range(replication_high).withColumnRenamed('id', 'replica_id'))
replication_low = 2
low = F.broadcast(spark.range(replication_low).withColumnRenamed('id', 'replica_id'))

# determine which keys are highly over-represented, broadcast them
skewed_keys = F.broadcast(
	transactions.freqItems(['country'], 0.6)
	.select(F.explode('country_freqItems').alias('country_freqItems'))
)
# replicate uniform data, one copy of each row per bucket
countries_skewed_keys = (
    countries
    .join(
        skewed_keys, 
        countries.country == skewed_keys.country_freqItems, 
        how='inner'
    )
    .crossJoin(high)
    .withColumn('composite_key', F.concat('country', F.lit('@'), 'replica_id'))
)
countries_rest = (
    countries
    .join(
        skewed_keys, 
        countries.country == skewed_keys.country_freqItems, 
        how='leftanti'
    )
    .crossJoin(low)
    .withColumn('composite_key', F.concat('country', F.lit('@'), 'replica_id'))
    .withColumn('country_freqItems', F.lit(None))
)
# this is now the entire uniform dataset replicated differently
countries_replicated = countries_skewed_keys.union(countries_rest)

transactions_tagged = (
    transactions
    .join(
        skewed_keys, 
        transactions.country == skewed_keys.country_freqItems, 
        how='left'
    )
    .withColumn('replica_id',
        F.when(
            F.isnull(F.col('country_freqItems')), 
            (F.rand() * replication_low).cast('int'),
        )
        .otherwise((F.rand() * replication_high).cast('int'))
    )
    .withColumn('composite_key', F.concat('country', F.lit('@'), 'replica_id'))
)

# now we can join on the composite key
df = transactions_tagged.join(countries_replicated, 'composite_key')

for i, part in enumerate(df.rdd.glom().collect()):
    print({i: part})

This lets you replicate very frequent keys more often, and the non-skewed data is replicated only if needed.

Salting

Here the idea is to salt columns used in the join operation(add randomization to them) with some random number. You can see an example in the following snippet:

salt = np.random.randint(1, int(spark.conf.get('spark.sql.shuffle.partitions')) - 1)
salted_countries = countries.withColumn(
    'salt', 
    F.explode(F.array([F.lit(i) for i in range(salt)])))
salted_transactions = transactions.withColumn('salt', F.round(F.rand() * salt))
df = salted_transactions.join(salted_countries, ['country', 'salt'] )
    .drop('salt')

# check the partitions data
for i, part in enumerate(df.rdd.glom().collect()):
    print({i: part})

In any way, in the end, we will see a relatively smoother distribution of data between partitions.

Iterative broadcast join

This is a different method, here we trying to do broadcast join on the chunks of the second dataframe assuming it can't be easily be broadcasted. Something like the following:

num_of_passes = 2
cur_pass = 0
df = None
countries = countries.withColumn('pass', F.round(F.rand() * num_of_passes))
while cur_pass <= num_of_passes:
    out = transactions.join(
        F.broadcast(countries.filter(F.col('pass') == F.lit(cur_pass))),
        'country'
    )
    df = out if not df else df.union(out)
    cur_pass += 1

countries = countries.drop('pass')
df = df.drop('pass')

I've never used this method in real life so don't judge me here.

Coalesce after filtering

We have already mentioned that good filtering can give a huge boost in performance. But even if you managed, for example, to turn 2 billion lines of data (2 Tb) broken down into 15,000 partitions into 2 million lines of data, the number of partitions has not changed. It's just that most of these partitions will be empty and will essentially take up resources.

We have already seen the operation coalesce which allows to reduce the number of partitions and here it is just the right moment to apply it.

df = huge_data.sample(withReplacement = False, fraction = 0.001)
df = df.coalesce(4)

It should be noted that this applies not only to filtering but also to aggregation.

Environment

At this stage, you should not only understand your business logic and everything that is going on inside your pipeline but also understand the environment in which it gonna work. By this, I mean cluster core count, memory per executor, network topology (for example, regions of the nodes), limitations for your cluster setup, third-party resources, and finally cost.

Repartition before writing to storage

Spark DataFrameWriter provides partitionBy method which can be used to partition data on write. It repartition the data into separate files on write using a provided set of columns.

df.write.partitionBy('key').json('/path/to/foo.json')

This enables predicate pushdown to read for queries based on key. It limits the number of files and partitions that Spark reads when querying.

df = spark.read.schema(schema).json('/path/to/foo.json')
df.where(df.key == 'bar')

Also when you're saving dataframe to disk, pay particular attention to partition sizes. On write Spark produce one file per task (i.e. one file per partition) and will read at least one file per task while reading. The problem here is that if the cluster setup, in which dataframe was saved, had more total memory and thus could process large partitions sizes without any problems, then a following smaller cluster may have problems with reading that saved dataframe.

But partitionBy is not an equivalent to repartition in some aggregations like:

counts = df.groupBy('key').sum()

It will still require Exchange aka shuffle. partitionBy method does not trigger any shuffle.

That’s cool, right?

Yes, but there is a small detail here. partitionBy method does not trigger any shuffle but it may generate a hell lot of the files.

Imagine we have 10 partitions, and we want to partition data by let's say days. Each spark task will produce 365 files in HDFS (1 per day) which leads to 365×10=3650 files produced by the job in total. But let's say we have 200 partitions right after the shuffle stage, now we will get 365×200=73k files. Not very efficient and may cause aka small file problems in HDFS.

Use all available cluster cores

If you have 200 cores in your cluster and only have 10 partitions to read, you can only use 10 cores to read the data. All other 190 cores will be idle.

Another problem that can occur on partitioning is that there are too few partitions to properly cover the number of available executors. Imagine that you have 2 executors and 3 partitions. Executor 1 has an extra partition, so it takes twice as long to complete as executor 2. As a result executor 2 stays idle.

The simplest solution to the above two problems is to increase the number of partitions used for processing. This will reduce the effect of a partition skew, and also allows better utilization of the cluster resources.

Use stage barrier between shuffle stage and the write operation

The only way to change the number of files after the shuffle is to create a stage barrier. It can be done either by writing the dataframe to the temporary storage or more efficient way to do this by using localCheckpoint.

df.localCheckpoint(...) \
    .repartition(n) \
    .write(...)

This is very helpful as it breaks the stage barrier so coalesce or repartition will not go up your execution pipeline or the parallel workflow that uses the same dataframe does not need to re-process the current dataframe. Remember Spark is lazy execute, localCheckpoint() will trigger execution to materialize the dataframe.

Partitioning with JDBC sources

Traditional SQL databases can not process a huge amount of data on different nodes as a spark. In order to allow a Spark to read data from a database via JDBC in parallel, you must specify the level of parallel reads/writes which is controlled by the following option .option('numPartitions', parallelismLevel). The specified number controls a maximal number of concurrent JDBC connections. By default, you read data to a single partition which usually doesn't fully utilize your SQL database and obviously Spark.

Conclusion

  1. I would like to point out that balance is very important in distributed systems in different places.
  2. Sometimes the OOM is a better problem than never-ending Spark job, at least you understand the problem.
  3. In case of doubt, make a mistake on the side of more tasks (and thus partitions).

Additional materials

Apache Spark Core—Deep Dive—Proper Optimization Daniel Tomes Databricks


Buy me a coffee