Tags Projects About License

Spark Tips. Partition Tuning

Spark Tips. Partition Tuning

There are many different tools in the world, each of which solves a range of problems. Many of them are judged by how well and correct they solve this or that problem, but there are tools that you just like, you want to use them. They are properly designed and fit well in your hand, you do not need to dig into the documentation and understand how to do this or that simple action. About one of these tools for me I will be writing this series of posts.

I will describe the optimization methods and tips that help me solve certain technical problems and achieve high efficiency using Apache Spark. This is my updated collection.

Many of the optimizations that I will describe will not affect the JVM languages so much, but without these methods, many Python applications may simply not work.

Whole series:

Clusters will not be fully utilized unless you set the level of parallelism for each operation high enough. The general recommendation for Spark is to have 4x of partitions to the number of cores in cluster available for application, and for upper bound — the task should take 100ms+ time to execute. If it is taking less time than your partitioned data is too small and your application might be spending more time in distributing the tasks.

Too small and too many partitions have certain disadvantages. It is therefore recommended that you partition wisely depending on the configuration and requirements of the cluster.

😢 Too few partitions — not utilizing all cores available in the cluster.

😢 Too many partitions — excessive overhead in managing many small tasks as well as data movement.

When in doubt, it's almost always better to be wrong on the side of a larger number of tasks (and thus partitions).

Even though those recommendations are still reasonable, it is still very case-sensitive, so Spark configuration must be fine-tuned according to a given scenario.

Factors affecting partitioning

When you come to such details of working with Spark, you should understand the following parts of your Spark pipeline, which will eventually affect the choice of partitioning the data:

  • Business logic
  • Data
  • Environment

Next, I'll go through these levels and give some tips on each of them.

Business logic

Let's start with the most varied point — business logic. Here it is difficult to give any specific recommendation as the specifics for each particular pipeline may be different.

Reduce the working dataset size

The most effective advice is usually the dumbest — to speed up Spark application reduce the amount of data the Spark cluster needs to process. There are several ways to do this, and of course, it all depends on what really needs to be done with the data.

We can filter the source data by skipping some partitions if they do not satisfy our condition (assuming, of course, that the data have been partitioned). The right condition can significantly speed up reading and retrieving the needed data. In some cases (e.g. in s3) we can avoid unnecessary partition discovery, in some cases using built-in data formatting mechanisms (e.g. partition pruning) may help.

Repartition before multiple joins

join is one of the most expensive operations that are usually widely used in Spark, all to blame as always infamous shuffle. We can talk about shuffle for more than one post, here we will discuss side related to partitions.

Do not shuffle

In order to join data, Spark needs data with the same condition on the same partition. The default implementation of the join in Spark since version 2.3 is sort-merge join.

Sort merge join is executed in three basic steps:

  1. It is necessary that the data on each partition has the same key values, so the partitions have to be co-located (in this context it is the same as co-partitioned). This is done by shuffling the data.
  2. Sort the data within each partition in parallel.
  3. Join the sorted and partitioned data. This is basically the merging of a dataset by iterating over the elements and joining the rows having the same value for the join key.

Although this approach always works, it may be more expensive than necessary as it requires a shuffle. Shuffle can be avoided if:

  1. Both dataframes have a common Partitioner.
  2. One of the dataframes is small enough to fit into the memory, in which case we can use a broadcast hash join.

For example, if we know that a dataframe will be joined several times, we can avoid the additional shuffling operation by performing it ourselves. Please don't forget to cache the dataframe after repartitioning in such cases.

users = spark.read.load('/path/to/users')
users = users.repartition('userId').cache() # do not forget to cache!
joined1 = users.join(addresses, 'userId')
joined1.show() # 1st shuffle for repartition
joined2 = users.join(salary, 'userId')
joined2.show() # skips shuffle for users since it's already been repartitioned

This way, we will shuffle the data once and then reuse the shuffled data for the subsequent joins. We can also use bucketing for that purpose.

Repartition after flatMap

The result of flatMap operation is usually RDD with increased number of rows, but the number of partitions remains the same. This means that the memory load on each partition may become too large, and you may see all the delights of disk spillage and GC breaks. In this case it is better to repartition the flatMap output based on the predicted memory expansion.

Get rid of disk spills

From the Tuning Spark docs:

Sometimes, you will get an OutOfMemoryError, not because your RDDs don’t fit in memory, but because the working set of one of your tasks, such as one of the reduce tasks in groupByKey, was too large. Spark’s shuffle operations (sortByKey, groupByKey, reduceByKey, join, etc) build a hash table within each task to perform the grouping, which can often be large...

If spark.shuffle.spill is true(which is the default) Spark will be using ExternalAppendOnlyMap in shuffle process for storing intermediate data. This structure can spill the data on disk when there isn't enough memory available, which increases the memory pressure on the executor resulting in the additional overhead of disk I/O and increased garbage collection. If we use Pyspark, the memory pressure will also increase the chance of Python running out of memory.

Disk spill

To check if disk spilling occurred, we can search for the similar entries in logs:

INFO ExternalSorter: Task 1 force spilling in-memory map to disk it will release 232.1 MB memory

The fixes can be the following:

👍 reducing the size of the data by, for example, selecting only required columns or moving filtering operations before the wide transformations.

👍 increasing the level of parallelism so that each task’s input data is smaller. Or we can manually repartition() prior stage.

👍 Increase the shuffle buffer by increasing the memory of our executor processes (spark.executor.memory).

👍 If the available memory resources are sufficient, we can increase the size of spark.shuffle.file.buffer, so as to reduce the number of times the buffers overflow during the shuffle write process, which can reduce the number of disks I/O times.

More configuration optimizations can be found with this tool.


It's all because of data


In Data Management, the data is an essential component of any work. And it always makes its adjustments. Without knowing your data you are trying to tweak your Spark application to some common denominator, which usually doesn't work out well, both in terms of speed and in terms of utilization of available resources.

Data Skew

Ideally, when Spark performs, for example, a join, the join keys should be evenly distributed among partitions. However, real-world data is rarely ideal, ahem... best case it at least somewhat good. Often practitioners encounter skewed data, resulting in poor performance or even OOM failures.

Data skew

This is not a problem specific to Spark, but rather a data problem — the performance of distributed systems depends heavily on how distributed the data is.

Often data is partitioned based on a key, such as day of week, country, etc. If the values for that key are not evenly distributed, then one partition will contain more data than another.

Let's take a look at the following snippet:

import pandas as pd
import numpy as np
from pyspark.sql import functions as F

# set smaller number of partitions so they can fit the screen
spark.conf.set('spark.sql.shuffle.partitions', 8)
# disable broadcast join to see the shuffle
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
length = 100
names = np.random.choice(['Bob', 'James', 'Marek', 'Johannes', None], length)
amounts = np.random.randint(0, 1000000, length)

# generate skewed data
country = np.random.choice(
	['United Kingdom', 'Poland', 'USA', 'Germany', 'Russia'],
	p = [0.05, 0.05, 0.8, 0.05, 0.05]
data = pd.DataFrame({'name': names, 'amount': amounts, 'country': country})
transactions = spark.createDataFrame(data).repartition('country')
countries = spark.createDataFrame(pd.DataFrame({
	'id': [11, 12, 13, 14, 15],
	'country': ['United Kingdom', 'Poland', 'USA', 'Germany', 'Russia']
df = transactions.join(countries, 'country')

# check the partitions data
for i, part in enumerate(df.rdd.glom().collect()):
	print({i: part})

Here we see that one partition out of 8 gets a lot of data and the rest gets very little or nothing.

It smells skewed.

We can diagnose skewness by looking at the Spark UI and checking the time spent on each task. Some tasks are time-consuming, while others are not. Also, in the Spark UI, an indicator of data skew can be the huge difference between the minimum task time and the maximum task time.

To resolve the problem of data skew, we can either:

👍 Redistribute data to more evenly distributed keys or simply increasing the number of partitions

👍 Broadcast the smaller dataframe if possible

👍 Differential replication

👍 Use an additional random key for better distribution of the data (salting)

👍 Iterative broadcast join


One way to ensure more or less correct distribution is to explicitly repartition the data. As it is a very expensive operation we don't want to execute it where it is not needed. We can set the validation of the number of RDD/DataFrame partitions right before performing any heavy operation.

You can use the following code as an example:

def repartition(
    df: DataFrame,
    min_partitions: int,
    max_partitions: int
) -> DataFrame:
    num_partitions = df.rdd.getNumPartitions()
    if(num_partitions < min_partitions):
        df = df.repartition(min_partitions)
    elif(num_partitions > max_partitions):
        df = df.coalesce(max_partitions)
    return df

Broadcast smaller dataframe

Broadcasting in Spark is the process of loading data onto each of the cluster nodes as a dataframe.

The broadcast join operation is achieved by joining a smaller dataframe to a larger dataframe, where the smaller data frame is broadcast and the join operation is performed.

df = transactions.join(broadcast(countries), 'country')

Broadcasting avoids data shuffling and relatively less data network operation.

Differential replication

If the simple broadcasting is not possible (the data is too huge) we can split the data into skewed and non-skewed dataframes and work with them in parallel by redistributing skewed data. This process is called differential replication.

The rare keys don't need to be replicated as much as skewed ones. The idea here is to identify the frequent keys before replication, then use a different replication policy for them.

replication_high = 7
high = F.broadcast(
		.withColumnRenamed('id', 'replica_id')
replication_low = 2
low = F.broadcast(
		.withColumnRenamed('id', 'replica_id')

# determine which keys are highly over-represented, broadcast them
skewed_keys = F.broadcast(
    transactions.freqItems(['country'], 0.6)

# replicate uniform data, one copy of each row per bucket
countries_skewed_keys = (
            countries.country == skewed_keys.country_freqItems,
        .withColumn('composite_key', F.concat('country', F.lit('@'), 'replica_id'))
countries_rest = (
            countries.country == skewed_keys.country_freqItems,
        .withColumn('composite_key', F.concat('country', F.lit('@'), 'replica_id'))
        .withColumn('country_freqItems', F.lit(None))

# this is now the entire uniform dataset replicated differently
countries_replicated = countries_skewed_keys.union(countries_rest)
transactions_tagged = (
            transactions.country == skewed_keys.country_freqItems,
                (F.rand() * replication_low).cast('int'))
            .otherwise((F.rand() * replication_high).cast('int')))
        .withColumn('composite_key', F.concat('country', F.lit('@'), 'replica_id'))

# now we can join on the composite key
df = transactions_tagged.join(countries_replicated, 'composite_key')

for i, part in enumerate(df.rdd.glom().collect()): 
    print({i: part})

This allows for more frequent replication of very common keys, and data without a skew is replicated only when necessary.


The idea behind this technique is to add a random value (salt) to the columns used in the join operation to make them more spread out.

salt = np.random.randint(1, int(spark.conf.get('spark.sql.shuffle.partitions')) - 1)
salted_countries = countries.withColumn(
    F.explode(F.array([F.lit(i) for i in range(salt)])))
salted_transactions = transactions.withColumn('salt', F.round(F.rand() * salt))
df = salted_transactions.join(salted_countries, ['country', 'salt'] )

# check the partitions data
for i, part in enumerate(df.rdd.glom().collect()): 
    print({i: part})

In any way, in the end, we will see a relatively smoother distribution of data between partitions.

Iterative broadcast join

Here we are trying to do broadcast join on the chunks of the second dataframe assuming it can't be easily be broadcasted. Something like the following:

num_of_passes = 2
cur_pass = 0
df = None
countries = countries.withColumn('pass', F.round(F.rand() * num_of_passes))
while cur_pass <= num_of_passes:
    out = transactions.join(
        F.broadcast(countries.filter(F.col('pass') == F.lit(cur_pass))),
    df = out if not df else df.union(out)
    cur_pass += 1

countries = countries.drop('pass') 
df = df.drop('pass')

I've never used this method in real production projects so don't judge my implementation here.

Coalesce after filtering

We have already mentioned that good filtering can give a huge boost in performance. But even if you managed, for example, to turn 2 billion lines of data (2 Tb) broken down into 15,000 partitions into 2 million lines of data, the number of partitions has not changed. It's just that most of these partitions will be empty and will essentially take up resources.

We have already seen the operation coalesce which allows to reduce the number of partitions and here it is just the right moment to apply it.

df = huge_data.sample(withReplacement = False, fraction = 0.001)
df = df.coalesce(4)

It should be noted that this applies not only to filtering but also to aggregation.


At this stage, you should not only understand your business logic and everything that is going on inside your application but also understand the environment in which it must run. By this I mean resource manager, node sizes, cluster core count, memory per executor, network topology (for example, regions of the nodes), limitations for your cluster setup, third-party resources, and finally cost.

Repartition before writing to storage

Spark DataFrameWriter provides partitionBy method which can be used to partition data on write. It repartition the data into separate files on write using a provided set of columns.


This enables upstream jobs to use predicate pushdown to read data. It limits the number of files and partitions that Spark reads when querying.

df = spark.read.schema(schema).json('/path/to/foo.json')
df.where(df.key == 'bar')

Also when you're saving dataframe to disk, pay particular attention to partition sizes. On write Spark produce one file per task (i.e. one file per partition) and will read at least one file per task while reading. The problem here is that if the cluster, in which dataframe was saved, had more total memory and therefore could process larger partition sizes without any problems, the upstream smaller cluster may have problems reading that saved data.

But partitionBy is not an equivalent to repartition in some aggregations like:

counts = df.groupBy('key').sum()

It will still require Exchange aka shuffle. partitionBy method does not trigger any shuffle.

That’s cool, right?

Yes, but there is a small detail here. partitionBy method does not trigger any shuffle but it may generate a hell lot of the files.

Imagine we have 10 partitions, and we want to partition data by let's say days. Each spark task will produce 365 files in HDFS (1 per day) which leads to 365×10=3650 files produced by the job in total. But let's say we have 200 partitions right after the shuffle stage, now we will get 365×200=73k files. Not very efficient and may cause small file problems in HDFS.

Use all available cluster cores

If you have 200 cores in your cluster and only have 10 partitions to read, you can only use 10 cores to read the data. All other 190 cores will be idle.

Another problem that can occur on partitioning is that there are too few partitions to properly cover the number of available executors. Imagine that you have 2 executors and 3 partitions. Executor 1 has an extra partition, so it takes twice as long to complete as executor 2. As a result executor 2 stays idle.

The simplest solution to the above two problems is to increase the number of partitions used for processing. This will reduce the effect of a partition skew, and also allows better utilization of the cluster resources.

Use stage barrier between shuffle stage and the write operation

The only way to change the number of files after the shuffle is to create a stage barrier. It can be done either by writing the dataframe to the temporary storage or more efficient way to do this by using localCheckpoint.

df.localCheckpoint(...) \
    .repartition(n) \

This is very helpful as it breaks the stage barrier so coalesce or repartition will not go up your execution pipeline or the parallel workflow that uses the same dataframe does not need to re-process the current dataframe. Remember Spark is lazy execute, localCheckpoint() will trigger execution to materialize the dataframe.

Partitioning with JDBC sources

Traditional SQL databases can not process a huge amount of data on different nodes as a spark. In order to allow a Spark to read data from a database via JDBC in parallel, you must specify the level of parallel reads/writes which is controlled by the following option .option('numPartitions', parallelismLevel). The specified number controls a maximal number of concurrent JDBC connections. By default, you read data to a single partition which usually doesn't fully utilize your SQL database and obviously Spark.


🤓 I would like to point out that balance is very important in distributed systems in different places.

🤓 Sometimes the OOM is a better problem than a never-ending Spark job, at least you understand the problem.

🤓 In case of doubt, make a mistake on the side of more tasks (and thus partitions).

Additional materials:

Previous post
Buy me a coffee
Next post

More? Well, there you go:

Spark Partitions

Spark Tips. Optimizing JDBC data source reads

The 5-minute guide to using bucketing in Pyspark