State Storage in Spark Structured Streaming
The state is one of the most important parts of many streaming data pipelines. Based on the use case, you might need to have an ability to keep the state of different objects, while your streaming application is running.
In this tutorial, we are going to deep-dive into techniques of working with state storage in Apache Spark version 2.4.4 via Structured Streaming APIs in Scala.
If you prefer the “just-show-me-the-code” approach, it’s provided here.
Please note, that APIs used in this post are mostly experimental, which means that in future Spark releases they might be changed a lot.
Spark Structured Streaming
Spark Structured Streaming combines the power of Spark abstractions, such as Data Frames, typed Datasets, as well as a long list of extremely convenient functions for data handling, with a very concise API.
The main idea behind the Spark Structured Streaming is unbounded table abstraction. Generally speaking, the data stream is represented as a DataFrame, and new data rows are constantly appending to it. Data engineers should define how to handle new rows.
Stateless vs Stateful
In the process of data handling in streams, any application implicitly uses one of these approaches to handling the data:
- Stateless means that the logic of handling the new data is independent of the previous data. As an example, you could imagine the process of writing user web clicks data directly to the S3 storage, without any additional analysis of this data during the stream processing.
- Stateful — in contrast to stateless, it means that you need somehow combine the data with old records or previous batches. As an example, you could imagine that the web clicks data is too large to store all of it, so you need to combine it on the fly with previous records, calculate the number of user visits and update the values if needed.
Luckily for many data engineers, Spark Structured Streaming provides the possibility to implement stateful operations in a simple and concise API.
State operations in Structured Streaming
So, how do we implement state operations in Structured Streaming? To understand the logic, let’s take a look at the following schema:
This means, that we use the state store to put and share the information about the state of arbitrary objects between different batches. Let’s implement this using Spark Structured Streaming API.
First, let’s define our data structures:
import java.sql.Timestamp
import java.time.Instantcase class PageVisit(
id: Int,
url: String,
timestamp: Timestamp = Timestamp.from(Instant.now())
)case class UserStatistics(
userId: Int,
visits: Seq[PageVisit],
totalVisits: Int
)
The PageVisit
is a case class, responsible for handling the data from the external system. UserStatistics
is our internal case class, that will be used to store and manipulate data about page visits.
Also, we need to use implicit converters of these classes, so let’s declare them as well:
import org.apache.spark.sql.{Encoder, Encoders}implicit val pageVisitEncoder: Encoder[PageVisit] = Encoders.product[PageVisit]
implicit val userStatisticsEncoder: Encoder[UserStatistics] = Encoders.product[UserStatistics]
Let’s assume, that data comes from as a stream, and we simply read it via Spark Structured Streaming API. For local testing and debugging purposes, it’s extremely convenient to use a simple memory stream:
val visitsStream = MemoryStream[PageVisit]val pageVisitsTypedStream: Dataset[PageVisit] = visitsStream.toDS()val initialBatch = Seq(
generateEvent(1),
generateEvent(1),
generateEvent(1),
generateEvent(1),
generateEvent(2),
)visitsStream.addData(initialBatch)
Next, we need to transform the page visits into user statistics:
val noTimeout = GroupStateTimeout.NoTimeout
val userStatisticsStream = pageVisitsTypedStream
.groupByKey(_.id)
.mapGroupsWithState(noTimeout)(updateUserStatistics)
Here we are using mapGroupsWithState
method of the KeyValueGroupedDataset
class to actually implement the aggregation transformation.
The signature of this method is not sophisticated but takes some time to understand. As parameters, you need to provide the GroupStateTimeout
(which is simple), and a callback to a function with three parameters:
func: (K, Iterator[V], GroupState[S]) => U
These parameters are:
K
is simply an ID of your grouped computation. In our case, it should beuserId
Iterator[V]
is an iterator over the new values, coming from the batch for this particular ID.GroupState[S]
is an object, that gives you the state API.
This function should return the calculated aggregation result per each group.
In our example, we are implementing this function as follows:
def updateUserStatistics(
id: Int,
newEvents: Iterator[PageVisit],
oldState: GroupState[UserStatistics]): UserStatistics = { var state: UserStatistics = if (oldState.exists) oldState.get else UserStatistics(id, Seq.empty, 0) for (event <- newEvents) {
state = state.copy(visits = state.visits ++ Seq(event), totalVisits = state.totalVisits + 1)
oldState.update(state)
}
state
}
Then, we need to define a StreamingQuery
object:
query = userStatisticsStream.writeStream
.outputMode(OutputMode.Update())
.option("checkpointLocation", checkpointLocation)
.foreachBatch(printBatch _)
.start()def printBatch(batchData: Dataset[UserStatistics], batchId: Long): Unit = {
log.info(s"Started working with batch id $batchId")
log.info(s"Successfully finished working with batch id $batchId, dataset size: ${batchData.count()}")
}
When you will start this stream, the following computations will be executed per each micro-batch:
- Read new data from the sink
- New records will be grouped by respective
userId
- If the state for provided
userId
exists, then get it - Else, just create a plain
UserStatistics
with correspondingid
and empty visits sequence. - Update the statistics with new data
- Return it to the stream
Please note, that such type of processing is currently supported only in update
output mode, therefore, only the updated records will be provided in the output.
For debugging purposes, it’s convenient to implement simple locking logic:
def processDataWithLock(query: StreamingQuery): Unit = {
query.processAllAvailable()
while (query.status.message != "Waiting for data to arrive") {
log.info(s"Waiting for the query to finish processing, current status is ${query.status.message}")
Thread.sleep(1)
}
log.info("Locking the thread for another 5 seconds for state operations cleanup")
Thread.sleep(5000)
}
And then the final code looks like this:
val additionalBatch = Seq(
generateEvent(1),
generateEvent(3),
generateEvent(3),
)visitsStream.addData(additionalBatch)processDataWithLock(query)
State Storage Persistence
Here comes another set of questions:
- Where the state is stored?
- How could we make it fault-tolerant?
- Could we read the state dataset from the external system?
To make the store fault-tolerant, you need to add the checkpointLocation
option to your output configuration. The only available in 2.4.4 version implementation of the StateStore is HDFSBackedStateStore, which provides an S3-compatible interface to store everything that has a state and related to the streaming application. The structure of the checkpointLocation
in case if you use the state operations, will look as follows:
.
|-- commits/
|-- offsets/
|-- sources/
|-- state/
`-- metadata
What will happen if you won’t specify the checkpoint location? Surprisingly simple, but in any case the state will be stored on the disk. If the checkpoint directory is not defined, then stream-related data (commits/offsets) and state will be provided in Spark temporary directory.
Now, let’s take a look at the generated state files.
State data representation
As we just discovered, the state is kept on the file system, in checkpointLocation
. The state
the directory will have the following structure:
state
`-- 0
|-- 0
| |-- 1.delta
| |-- 2.delta
| |-- 3.delta
`-- 1
|-- 1.delta
|-- 10.delta
|-- 11.delta
|-- 12.delta
|-- 13.delta
|-- ...
`-- 26.delta
The directory structure reflects logic, defined in StateStoreId
class:
/**
* Checkpoint directory to be used by a single state store, identified uniquely by the tuple
* (operatorId, partitionId, storeName). All implementations of [[StateStoreProvider]] should
* use this path for saving state data, as this ensures that distinct stores will write to
* different locations.
*/
def storeCheckpointLocation(): Path = {
if (storeName == StateStoreId.DEFAULT_STORE_NAME) {
// For reading state store data that was generated before store names were used (Spark <= 2.2)
new Path(checkpointRootLocation, s"$operatorId/$partitionId")
} else {
new Path(checkpointRootLocation, s"$operatorId/$partitionId/$storeName")
}
}
We are not overriding the StateStoreId.DEFAULT_STORE_NAME
parameter, hence our structure reflects this pattern:
new Path(checkpointRootLocation, s"$operatorId/$partitionId")
Root location is quite straightforward — it’s simply the checkpointLocation
parameter we are using in our options for the output. Operator ID and partition ID are set in accordance with your data stream.
Now comes the question about reading the state data somehow externally. In Spark 2.4.4 there is no simple way to read the state directory, so I’ve implemented a class called StateDatasetProvider
for this:
case class StateDatasetProvider(spark: SparkSession, checkpointLocation: String, query: StreamingQuery)
The data in state
directory is provided in some kind of key-value structure, serialized with parquet. To read this data, we need to define key and value schemas:
keySchema = new StructType().add(StructField("id", IntegerType))
valueSchema = new StructType().add(StructField("groupState", userStatisticsEncoder.schema))
Schema of the key is a StructType
with the key, we used for aggregation. In our case, it’s an integer ID column.
For value schema, we need to add StructField
with groupState
, which holds our statistics data structure.
Then, we need to read the data from our store. Hence most of the methods and classes in org.apache.spark.sql.execution.streaming.state
are private, we need to make a small trick:
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}val storeConf = StateStoreConf(spark.sessionState.conf)
val hadoopConf = spark.sessionState.newHadoopConf()
val stateStoreId = StateStoreId(
checkpointLocation + "/state",
operatorId = 0,
partitionId = 0
)val storeProviderId = StateStoreProviderId(
stateStoreId,
query.runId
)
val store: StateStore = StateStore.get(
storeProviderId,
keySchema,
valueSchema,
None,
query.lastProgress.batchId,
storeConf,
hadoopConf
)
Finally, we could create a dataset on top of store data:
val dataset: Dataset[UserStatistics] = store.iterator().map { rowPair =>
val statisticsEncoder = ExpressionEncoder[UserGroupState].resolveAndBind()
statisticsEncoder.fromRow(rowPair.value).groupState
}.toSeq.toDS()
In our case, we don’t read the key value, because it’s already provided in the value.
Summary
- Spark Structured Streaming provides a set of instruments for stateful stream management. One of these methods is
mapGroupsWithState
, which provides API for state management via your custom implementation of a callback function. - In Spark 2.4.4 the only default option to persist the state is S3-compatible directory. The path is defined via
checkpointLocation
option. - It's possible (to some extent) to read the data from the state, but the read path is not optimized.
Sources and related posts