Building an image classification app with Databricks Connect “V2” and Dash
Intelligent data application adoption is rapidly growing amongst various data platforms. And with this growing adoption, engineers are looking for efficient and flexible APIs to work with their data in a way they’re used to operating pipelines, but inside such data applications.
In this blog post, I’ll demonstrate how the upgraded Databricks Connect protocol (based on Spark Connect) enables developers to write concise and expressive data applications on top of the Databricks Lakehouse Platform with the Dash framework.
⛓️ The Tale of Protocols
There are several ways to perform operations with the data stored in the Databricks Lakehouse from the data application.
One of them is the one I’ve described in one of my previous blog posts — it’s a straightforward and well-known approach that uses DBSQL and the SQLAlchemy ORM mechanism.
However, this time I would like to highlight the (relatively) new functionality of Databricks, called Databricks Connect “V2”. Although for many of the Databricks users, this component may sound familiar, its new version provides a more flexible and faster approach within a thinner version of a client that can be potentially used from applications in various programming languages.
Why would a Data Engineer prefer using Databricks Connect over a standard ORM-like solution?
There are several reasons for that, and one of the most important for me is that most of the Data Engineers have very good knowledge of Apache Spark and its APIs, whilst SQLAlchemy is usually closer to the experience of Software Engineers.
Being a Data Engineer myself in the past, between the two APIs below I clearly know which one I would like to choose most:
👁️ Solution overview
The architecture for such an app is relatively simple — there are two main components, namely:
- the frontend app (written in the robust Dash framework in almost pure Python)
- the backend (which turns out to be a Databricks Cluster)
Connectivity and data exchange between the backend and frontend is achieved via DB Connect “V2” protocol and the databricks-connect
Python package.
The image metadata (e.g. classes) is stored in a Delta Table, while the images themselves are stored in a cloud storage as files.
The main benefits of this architecture are the following:
- Images can be stored and retrieved natively, without a need to put them into tables as blobs or binaries
- The metadata of the image can be easily changed and modified due to Delta Lake
UPDATE
functionality - We don’t need to use a specific ORM and can perform the backend operations by using Spark API methods itself
- Our frontend app can be fully written in Python, without a need to switch language contexts
💻 Setting up the development environment
Prerequisites for this step:
- [Local machine] JVM (I’m using JVM 11 in this example)
- [Local machine] Poetry
- [Local machine] Databricks profile configured via databricks-cli
- [Databricks] Databricks All-purpose Cluster with DBR 13.X or higher and Unity Catalog support
Let’s start by organizing the project and the dev environment. Our app is a poetry
-based Python package, and we can easily kick off the project with all dependencies as follows:
# create project dir
mkdir db-connect-v2-image-classification
# move into project dir
cd db-connect-v2-image-classification
# initialize poetry
poetry init -n \
--dependency="databricks-connect=13.*" \
--dependency="hydra-core=1.3.*" \ # for configurations
--dependency="databricks-sdk" \ # for cluster ops
--dependency="dash" # frontend framework
# add dev dependencies
poetry add -G dev black ruff isort
# poetry references this file, let's keep it empty for this moment
touch README.md
# add package directory
mkdir db_connect_v2_image_classification
touch db_connect_v2_image_classification/__init__.py
# install project and it's deps
poetry install
With these steps, we have a fully prepared Python app skeleton. Let’s step into the project and add the first bit that is responsible for preparing the metadata table:
code . # open VS Code in the project directory
Add a .gitignore
file with the following contents:
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# Distribution / packaging
*.egg-info/
build
dist
# Unit test / coverage reports
.coverage
coverage.xml
junit/*
htmlcov/*
# Caches
.pytest_cache/
# VSCode
.vscode/
# Idea
.idea/
*.iml
# MacOS
.DS_Store
And then init the git project:
git init .
Hint — you can immediately create a repo in GitHub if you have GitHub CLI installed and if you’re using Bash-like solution:
gh repo create "${PWD##*/}" --private --source=. --remote=origin
Now we have a proper project skeleton, and it’s pretty convenient to work with it. However our app will manage several configurations — both the connection-related ones, as well as ETL and data related ones.
🪄 Adding configurations
One of the most convenient ways to handle configurations is to use the hydra-core
library.
Let’s add the config-related logic:
# db_connect_v2_image_classification/configs.py
from dataclasses import dataclass, field
from typing import Optional
from databricks.sdk import WorkspaceClient
@dataclass
class ImageTableConfig:
catalog: str = "main"
database: str = "default"
table: str = "db_connect_v2_example_images_metadata"
# note - repr is used both inside print statements and when table name is provided to Spark APIs
def __repr__(self) -> str:
return f"{self.catalog}.{self.database}.{self.table}"
@dataclass
class AppConfig:
cluster_name: str
cluster_id: str = field(init=False)
image_table: ImageTableConfig = field(default_factory=ImageTableConfig)
profile: Optional[str] = "DEFAULT"
debug: Optional[bool] = True
# defines is we want to overwite the table, fail when it exists or append data into it
table_saving_mode: Optional[str] = "error"
def __post_init__(self):
_w = WorkspaceClient(profile=self.profile)
found = list(filter(lambda c: c.cluster_name == self.cluster_name, _w.clusters.list()))
assert found is not None, f"Not found any clusters with name {self.cluster_name}"
assert len(found) == 1, f"There are more than one cluster with name {self.cluster_name}"
self.cluster_id = found[0].cluster_id
The code above introduces data classes and a logical structure for the application config. Since dataclasses have a convenient __post_init__
method, we can automatically resolve the provided cluster_name
argument into a cluster_id
by using the latest Databricks SDK for Python.
🌯 Wrapping the entrypoints
With the configuration structure defined, it’s also important to add a bit of a glue code so hydra
would work with the configurations in a way it would be convenient for our app:
# db_connect_v2_image_classification/main_wrapper.py
from typing import Callable
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf, SCMode
from db_connect_v2_image_classification.configs import AppConfig
def main_wrapper():
# prepares the config store for any entrypoints that will use this decorator
# since we're using the same AppConfig both for table creation.
# for frontend app, we'll simply re-use this function.
cs = ConfigStore.instance()
cs.store(name="base_config", node=AppConfig)
# a decorator that will wrap any function under @main_wrapper()
def decorator(func: Callable[[AppConfig], None]):
# hydra.main to enable flexible config provisioning
@hydra.main(version_base=None, config_name="config")
def _wrapped(cfg: DictConfig):
# important - we're using dataclasses with additional functionality inside them
# (e.g. methods or __post_init__).
# by default hydra returns an untyped DictConfig.
# with this we convert the config back into a Python dataclass with all methods.
_cfg = OmegaConf.to_container(cfg, structured_config_mode=SCMode.INSTANTIATE)
func(_cfg)
return _wrapped
return decorator
Whoosh, here things get a bit of complex (classics with Python decorators). Let’s schematically describe this logic. What we want is a flexible wrapper for an entrypoint function:
This wrapper should be easy to apply to various entrypoint functions (we’ll have two of them — one for table creation, another one for frontend app launch).
What is the easiest way to add logic to functions in Python? Exactly — decorators. Our code would look like this:
A profound reader will also notice that we’ve added a cfg
argument to the entrypoints. This is done so we could easily work with the configuration without reading it separately in separate parts of the application.
One more thing is the way how hydra
(and omegaconf
which is running under the hood of it) handles the structured configs:
@hydra.main(version_base=None, config_name="config")
def _wrapped(cfg: DictConfig):
# work with cfg object
It actually doesn’t instantiate the dataclass
itself, simply providing a dict config after structure validation. To get back the real object with proper methods, a small trick is required:
@hydra.main(version_base=None, config_name="config")
def _wrapped(cfg: DictConfig):
# important - we're using dataclasses with additional functionality inside them
# (e.g. methods or __post_init__).
# by default hydra returns an untyped DictConfig.
# with this we convert the config back into a Python dataclass with all methods.
_cfg = OmegaConf.to_container(cfg, structured_config_mode=SCMode.INSTANTIATE)
func(_cfg)
With all this ✨arcane knowledge✨ we can now wrap this into a nice decorator (code provided above), and use it without worries through the whole application.
⬆️ Populating the image metadata table
It’s time to populate the image metadata table. The code is provided below:
# db_connect_v2_image_classification/create_table.py
from databricks.connect.session import DatabricksSession as SparkSession
from databricks.sdk.core import Config
from pyspark.sql.functions import element_at, split
from db_connect_v2_image_classification.configs import AppConfig
from db_connect_v2_image_classification.main_wrapper import main_wrapper
class CreateTableTask:
def __init__(self, cfg: AppConfig):
self.cfg = cfg
self.spark = SparkSession.builder.sdkConfig(
Config(profile=self.cfg.profile, cluster_id=self.cfg.cluster_id)
).getOrCreate()
def launch(self):
print(f"Loading image metadata into the table {self.cfg.image_table}")
print(f"Using cluster {self.cfg.cluster_name} with id {self.cfg.cluster_id}")
metadata_df = self._load_image_metadata()
metadata_df.write.saveAsTable(f"{self.cfg.image_table}", format="delta", mode=self.cfg.table_saving_mode)
print(f"Image metadata has been successfully saved into the table {self.cfg.image_table}")
def _load_image_metadata(self):
images = (
self.spark.read.format("image")
.load("/databricks-datasets/flower_photos/*/*.jpg")
.select("image.origin")
.withColumn("class", element_at(split("origin", "/"), 4))
.withColumn(
"image_id",
element_at(split(element_at(split("origin", "/"), -1), ".jpg"), 1),
)
.select("image_id", "class", "origin")
)
return images
@main_wrapper()
def entrypoint(cfg: AppConfig):
task = CreateTableTask(cfg)
task.launch()
It’s a fairly simple piece of Spark code it terms of ETL — we load data from a default dbfs:/databricks-datasets
folder (available in every Databricks workspace by default), and simply put the metadata into a separate table:
The interesting part is in the way how we initialize the Spark Session. Since we’re using DB Connect “V2”, it’s a fairly simple operation (snippet from the code above):
class CreateTableTask:
def __init__(self, cfg: AppConfig):
self.cfg = cfg
self.spark = SparkSession.builder.sdkConfig(
Config(profile=self.cfg.profile, cluster_id=self.cfg.cluster_id)
).getOrCreate()
This initializes a connection where DB Connect “V2” client is running on the local machine, and is capable of sending commands to the remote Databricks cluster.
🖼️ Adding a frontend app
For those who have had a chance to read my previous blogpost, this part would be pretty straightforward. This part of code initializes from an entrypoint function, passes the config and spins up the Dash app:
# file: db_connect_v2_image_classification/frontend/app.py
import logging
from random import choice
from dash import Dash, dcc, html
from db_connect_v2_image_classification.configs import AppConfig
from db_connect_v2_image_classification.frontend.css_utils import external_scripts, external_stylesheets
from db_connect_v2_image_classification.frontend.callbacks import prepare_callbacks
from db_connect_v2_image_classification.frontend.components import data_container, guideline, header, nav_container
from db_connect_v2_image_classification.frontend.crud import DataOperator
from db_connect_v2_image_classification.main_wrapper import main_wrapper
logger = logging.getLogger("frontend_app")
logger.setLevel(logging.INFO)
def prepare_layout(op: DataOperator) -> html.Div:
logger.info("Preparing the layout")
layout = html.Div(
children=[
header,
guideline,
data_container(op.get_all_classes()),
nav_container,
dcc.Store(id="current_index", data=choice(op.get_all_indexes())),
],
className="p-5",
style={"height": "100vh"},
)
logger.info("layout prepared")
return layout
@main_wrapper()
def entrypoint(cfg: AppConfig):
app = Dash(
__name__,
title="Image Classification App",
external_scripts=external_scripts,
external_stylesheets=external_stylesheets,
)
logger.info("Preparing the data operator")
op = DataOperator(cfg)
logger.info("Data operator loaded successfully")
app.layout = prepare_layout(op)
prepare_callbacks(app, op)
app.run(debug=cfg.debug, host="0.0.0.0")
if __name__ == "__main__":
entrypoint()
Line-by-line, the following happens:
- Initialize an empty app object with some extra CSS (I’ve decided to add Tailwind and Daisy UI for some ready-to-use components).
- Prepare the
DataOperator
instance. I’ve moved all backend-related actions (i.e. read/write ops) into this class for modularity. - Prepare the layout of the application
- Prepare the callback actions
- Serve the app
Let’s discuss some of the parts a bit deeper.
🔛 Backend interactions
As said above, the main backbone for R/W operations in this app is DB Connect “V2”. This allows users to write ETL-related code without wrapping it into an ORM-like solutions. For instance, here is the code relevant to the manipulations that could happen in the UI:
# file: db_connect_v2_image_classification/frontend/crud.py
import logging
from functools import lru_cache
from typing import List
import numpy as np
from databricks.connect.session import DatabricksSession as SparkSession
from databricks.sdk.core import Config
from pyspark.sql import DataFrame
from db_connect_v2_image_classification.configs import AppConfig
class DataOperator:
def __init__(self, cfg: AppConfig):
self.logger = logging.getLogger("frontend_app")
self.cfg = cfg
self.logger.info("Initializing DB Connect")
self.spark = SparkSession.builder.sdkConfig(
Config(profile=self.cfg.profile, cluster_id=self.cfg.cluster_id)
).getOrCreate()
self.logger.info(
f"DB Connect initialized a connection to cluster {self.cfg.cluster_name} with id {self.cfg.cluster_id}"
)
@property
def _source_table(self) -> DataFrame:
return self.spark.table(f"{self.cfg.image_table}")
@lru_cache(maxsize=10_000)
def get_all_indexes(self) -> List[str]:
self.logger.info("Loading all image indexes")
indexes = self._source_table.select("image_id").distinct().toPandas()["image_id"].to_list()
self.logger.info("Image indexes loaded")
return indexes
@lru_cache(maxsize=10_000)
def get_all_classes(self) -> List[str]:
return self._source_table.select("class").distinct().toPandas()["class"].to_list()
def get_image_class(self, image_id: str) -> str:
return self._source_table.select("class").where(f"image_id = '{image_id}'").toPandas().loc[0, "class"]
def get_image_payload(self, image_id: str) -> np.ndarray:
image_origin = self._source_table.select("origin").where(f"image_id = '{image_id}'").toPandas().loc[0, "origin"]
image_info = self.spark.read.format("image").load(image_origin).toPandas().T.squeeze()
img_payload = np.frombuffer(image_info["data"], dtype=np.uint8).reshape(
image_info["height"], image_info["width"], image_info["nChannels"]
)[:, :, ::-1]
return img_payload
def update_image_class(self, image_id: str, new_class: str):
self.logger.info(f"Updating the image class for image {image_id}")
command = f"UPDATE {self.cfg.image_table} SET class='{new_class}' WHERE image_id='{image_id}'"
df = self.spark.sql(command)
df.collect()
self.logger.info("Update finished")
In the initialization phase we connect to a Databricks cluster via DB Connect, then we can add read-related methods in pure Spark APIs.
However, some fiddling is required to properly put the binary image data into plotly.imshow
:
def get_image_payload(self, image_id: str) -> np.ndarray:
image_origin = self._source_table.select("origin").where(f"image_id = '{image_id}'").toPandas().loc[0, "origin"]
image_info = self.spark.read.format("image").load(image_origin).toPandas().T.squeeze()
img_payload = np.frombuffer(image_info["data"], dtype=np.uint8).reshape(
image_info["height"], image_info["width"], image_info["nChannels"]
)[:, :, ::-1]
return img_payload
Line-by-line on what’s happening:
- By running a select on the metadata table, function retrieves the image path (origin) in the cloud storage
- The image information is loaded into a single-element of
pd.Series
- The binary data together with height/width and channels information can be used to build a numpy ndarray
- As a last step, this array needs to be reversed to fit the RGB order (this is where we apply
[:,:,::-1]
operation
A nice cherry on the cake is that it’s still possible to use old good SQL, e.g. for the UPDATE
operations:
def update_image_class(self, image_id: str, new_class: str):
self.logger.info(f"Updating the image class for image {image_id}")
command = f"UPDATE {self.cfg.image_table} SET class='{new_class}' WHERE image_id='{image_id}'"
df = self.spark.sql(command)
df.collect()
self.logger.info("Update finished")
With this class, it’s now fully possible to cover any UI operation. Let’s bind these methods to the UI buttons:
# file: db_connect_v2_image_classification/frontend/callbacks.py
from random import choice
import plotly.express as px
from dash import Dash, Input, Output, State
from db_connect_v2_image_classification.frontend.crud import DataOperator
def prepare_callbacks(app: Dash, op: DataOperator):
@app.callback(
Output("current_index", "data"),
Output("image_id", "children"),
Output("class_selector", "value"),
Output("img_display", "figure"),
Input("random_btn", "n_clicks"),
)
def get_next_random(_):
next_index = choice(op.get_all_indexes())
img_class = op.get_image_class(next_index)
img_payload = op.get_image_payload(next_index)
figure = px.imshow(img_payload)
figure.update_layout(coloraxis_showscale=False)
figure.update_xaxes(showticklabels=False)
figure.update_yaxes(showticklabels=False)
return next_index, f"Image id: {next_index}", img_class, figure
@app.callback(
Output("output_mock", "children"),
Input("confirm_btn", "n_clicks"),
State("class_selector", "value"),
State("current_index", "data"),
)
def save_selected_class(_, value, current_index):
if value:
op.update_image_class(current_index, value)
return value
It turns out that the app has a pretty simply logical nature — we only have a random image picker (which will update necessary outputs on each “Next” click), as well as a class saver (which will be triggered on a “Confirm” click.
🔌 Adding UI components
Since the app itself is pretty simple, I’ve only used a couple of nice components and capabilities of Tailwind CSS and Daisy UI:
# file: db_connect_v2_image_classification/frontend/components.py
from dash import dcc, html
blogpost_link = (
"https://www.databricks.com/blog/2022/07/07/introducing-spark-connect-the-power-of-apache-spark-everywhere.html"
)
header = dcc.Markdown(
"## Image classification app, built with [Dash](https://plotly.com/dash/) "
f"and [DB Connect V2]({blogpost_link}) 🔥",
className="text-3xl mb-2",
)
guideline = dcc.Markdown(
"Please click below to navigate between images and correct their class labels when required.",
)
random_btn = html.Button(
"🔀 Next image",
id="random_btn",
n_clicks=0,
className="btn btn-primary btn-lg mt-5",
)
nav_container = html.Div(children=[random_btn], className="flex justify-center")
confirm_button = dcc.Loading(
id="submit-loading",
children=[
html.Button(
"Confirm",
id="confirm_btn",
n_clicks=0,
className="btn btn-success btn-block my-4",
),
html.Div(id="output_mock", style={"display": "none"}),
],
)
def class_selector(class_choices):
return html.Div(
children=[
html.Div(
children=[
dcc.Loading(
children=[
html.P("Please choose the class below:"),
dcc.Dropdown(
class_choices,
id="class_selector",
className="dropdown-content text-black",
placeholder="Select the class",
multi=False,
clearable=False,
),
]
)
],
),
confirm_button,
],
)
def data_container(class_choices):
return html.Div(
children=[
html.Div(
className="flex justify-center pt-2",
children=[
html.Div(
className="card w-96 bg-base-100 shadow-xl p-4",
children=[
dcc.Loading(
[
dcc.Graph(id="img_display"),
dcc.Markdown(id="image_id", className="text-sm text-sky-500"),
]
),
html.Div(className="divider"),
class_selector(class_choices),
],
)
],
),
]
)
The one I’ve liked the most is the ready-to-use card
class with all relevant presets. Take a look at how with just a couple of lines of CSS classes the app could become vivid and UI-friendly:
⚙️ Combining it all together
Finally, it’s time for a quick demo of the application:
With this app users can easily classify the images, without a need to use an external tool. At the same time, all data is stored in the cloud storage, whilst metadata is saved and updated in a Delta Table.
✅ Summary
With the flexibility of the Dash framework, it’s possible to build modern-looking and flexible frontend applications.
At the same time, DB Connect works in the backbone of the application, providing convenient Spark APIs to the client-side for proper operations with any kind of data (e.g. images).
The image data itself can be stored in the cloud storage, whilst the metadata information can be easily retrieved from and updated in the Delta Table.
All these technologies together provide a powerful toolset for building efficient and modern data applications on top of the Databricks Lakehouse, in pure Python and with the convenience of the well-known APIs.
The source code from this blog post is provided here. Feel free to copy it and fiddle around with your use cases 🙌.
Have you tried using DB Connect “V2” and Dash already? Have an opinion? Feel free to share it in the comments! Also, hit subscribe if you liked the post — it keeps the author motivated to write more.