Building an image classification app with Databricks Connect “V2” and Dash

Ivan Trusov
13 min readMay 18, 2023

--

The Large Figure Paintings, nr 5 by Hilma af Klimt, 1907. Image source - wikiart.org

Intelligent data application adoption is rapidly growing amongst various data platforms. And with this growing adoption, engineers are looking for efficient and flexible APIs to work with their data in a way they’re used to operating pipelines, but inside such data applications.

In this blog post, I’ll demonstrate how the upgraded Databricks Connect protocol (based on Spark Connect) enables developers to write concise and expressive data applications on top of the Databricks Lakehouse Platform with the Dash framework.

⛓️ The Tale of Protocols

There are several ways to perform operations with the data stored in the Databricks Lakehouse from the data application.

One of them is the one I’ve described in one of my previous blog posts — it’s a straightforward and well-known approach that uses DBSQL and the SQLAlchemy ORM mechanism.

However, this time I would like to highlight the (relatively) new functionality of Databricks, called Databricks Connect “V2. Although for many of the Databricks users, this component may sound familiar, its new version provides a more flexible and faster approach within a thinner version of a client that can be potentially used from applications in various programming languages.

Why would a Data Engineer prefer using Databricks Connect over a standard ORM-like solution?

There are several reasons for that, and one of the most important for me is that most of the Data Engineers have very good knowledge of Apache Spark and its APIs, whilst SQLAlchemy is usually closer to the experience of Software Engineers.

Being a Data Engineer myself in the past, between the two APIs below I clearly know which one I would like to choose most:

self-made via excalidraw.com

👁️ Solution overview

The architecture for such an app is relatively simple — there are two main components, namely:

  • the frontend app (written in the robust Dash framework in almost pure Python)
  • the backend (which turns out to be a Databricks Cluster)

Connectivity and data exchange between the backend and frontend is achieved via DB Connect “V2” protocol and the databricks-connect Python package.

The image metadata (e.g. classes) is stored in a Delta Table, while the images themselves are stored in a cloud storage as files.

self-made in excalidraw.com

The main benefits of this architecture are the following:

  • Images can be stored and retrieved natively, without a need to put them into tables as blobs or binaries
  • The metadata of the image can be easily changed and modified due to Delta Lake UPDATE functionality
  • We don’t need to use a specific ORM and can perform the backend operations by using Spark API methods itself
  • Our frontend app can be fully written in Python, without a need to switch language contexts

💻 Setting up the development environment

Prerequisites for this step:

Let’s start by organizing the project and the dev environment. Our app is a poetry -based Python package, and we can easily kick off the project with all dependencies as follows:

# create project dir
mkdir db-connect-v2-image-classification

# move into project dir
cd db-connect-v2-image-classification

# initialize poetry
poetry init -n \
--dependency="databricks-connect=13.*" \
--dependency="hydra-core=1.3.*" \ # for configurations
--dependency="databricks-sdk" \ # for cluster ops
--dependency="dash" # frontend framework

# add dev dependencies
poetry add -G dev black ruff isort

# poetry references this file, let's keep it empty for this moment
touch README.md

# add package directory
mkdir db_connect_v2_image_classification
touch db_connect_v2_image_classification/__init__.py
# install project and it's deps
poetry install

With these steps, we have a fully prepared Python app skeleton. Let’s step into the project and add the first bit that is responsible for preparing the metadata table:

code . # open VS Code in the project directory

Add a .gitignore file with the following contents:

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# Distribution / packaging
*.egg-info/
build
dist

# Unit test / coverage reports
.coverage
coverage.xml
junit/*
htmlcov/*

# Caches
.pytest_cache/

# VSCode
.vscode/

# Idea
.idea/
*.iml

# MacOS
.DS_Store

And then init the git project:

git init .

Hint — you can immediately create a repo in GitHub if you have GitHub CLI installed and if you’re using Bash-like solution:

gh repo create "${PWD##*/}" --private --source=. --remote=origin

Now we have a proper project skeleton, and it’s pretty convenient to work with it. However our app will manage several configurations — both the connection-related ones, as well as ETL and data related ones.

🪄 Adding configurations

One of the most convenient ways to handle configurations is to use the hydra-core library.

Let’s add the config-related logic:

# db_connect_v2_image_classification/configs.py

from dataclasses import dataclass, field
from typing import Optional

from databricks.sdk import WorkspaceClient


@dataclass
class ImageTableConfig:
catalog: str = "main"
database: str = "default"
table: str = "db_connect_v2_example_images_metadata"

# note - repr is used both inside print statements and when table name is provided to Spark APIs
def __repr__(self) -> str:
return f"{self.catalog}.{self.database}.{self.table}"


@dataclass
class AppConfig:
cluster_name: str
cluster_id: str = field(init=False)
image_table: ImageTableConfig = field(default_factory=ImageTableConfig)
profile: Optional[str] = "DEFAULT"
debug: Optional[bool] = True

# defines is we want to overwite the table, fail when it exists or append data into it
table_saving_mode: Optional[str] = "error"

def __post_init__(self):
_w = WorkspaceClient(profile=self.profile)
found = list(filter(lambda c: c.cluster_name == self.cluster_name, _w.clusters.list()))
assert found is not None, f"Not found any clusters with name {self.cluster_name}"
assert len(found) == 1, f"There are more than one cluster with name {self.cluster_name}"
self.cluster_id = found[0].cluster_id

The code above introduces data classes and a logical structure for the application config. Since dataclasses have a convenient __post_init__ method, we can automatically resolve the provided cluster_name argument into a cluster_id by using the latest Databricks SDK for Python.

🌯 Wrapping the entrypoints

With the configuration structure defined, it’s also important to add a bit of a glue code so hydrawould work with the configurations in a way it would be convenient for our app:

# db_connect_v2_image_classification/main_wrapper.py

from typing import Callable

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf, SCMode

from db_connect_v2_image_classification.configs import AppConfig


def main_wrapper():
# prepares the config store for any entrypoints that will use this decorator
# since we're using the same AppConfig both for table creation.
# for frontend app, we'll simply re-use this function.
cs = ConfigStore.instance()
cs.store(name="base_config", node=AppConfig)

# a decorator that will wrap any function under @main_wrapper()
def decorator(func: Callable[[AppConfig], None]):
# hydra.main to enable flexible config provisioning
@hydra.main(version_base=None, config_name="config")
def _wrapped(cfg: DictConfig):
# important - we're using dataclasses with additional functionality inside them
# (e.g. methods or __post_init__).
# by default hydra returns an untyped DictConfig.
# with this we convert the config back into a Python dataclass with all methods.
_cfg = OmegaConf.to_container(cfg, structured_config_mode=SCMode.INSTANTIATE)
func(_cfg)

return _wrapped

return decorator

Whoosh, here things get a bit of complex (classics with Python decorators). Let’s schematically describe this logic. What we want is a flexible wrapper for an entrypoint function:

self-made in excalidraw.com

This wrapper should be easy to apply to various entrypoint functions (we’ll have two of them — one for table creation, another one for frontend app launch).

What is the easiest way to add logic to functions in Python? Exactly — decorators. Our code would look like this:

self-made in excalidraw.com

A profound reader will also notice that we’ve added a cfg argument to the entrypoints. This is done so we could easily work with the configuration without reading it separately in separate parts of the application.

One more thing is the way how hydra (and omegaconf which is running under the hood of it) handles the structured configs:

@hydra.main(version_base=None, config_name="config")
def _wrapped(cfg: DictConfig):
# work with cfg object

It actually doesn’t instantiate the dataclass itself, simply providing a dict config after structure validation. To get back the real object with proper methods, a small trick is required:

@hydra.main(version_base=None, config_name="config")
def _wrapped(cfg: DictConfig):
# important - we're using dataclasses with additional functionality inside them
# (e.g. methods or __post_init__).
# by default hydra returns an untyped DictConfig.
# with this we convert the config back into a Python dataclass with all methods.
_cfg = OmegaConf.to_container(cfg, structured_config_mode=SCMode.INSTANTIATE)
func(_cfg)

With all this ✨arcane knowledge✨ we can now wrap this into a nice decorator (code provided above), and use it without worries through the whole application.

⬆️ Populating the image metadata table

It’s time to populate the image metadata table. The code is provided below:

# db_connect_v2_image_classification/create_table.py

from databricks.connect.session import DatabricksSession as SparkSession
from databricks.sdk.core import Config
from pyspark.sql.functions import element_at, split

from db_connect_v2_image_classification.configs import AppConfig
from db_connect_v2_image_classification.main_wrapper import main_wrapper


class CreateTableTask:
def __init__(self, cfg: AppConfig):
self.cfg = cfg
self.spark = SparkSession.builder.sdkConfig(
Config(profile=self.cfg.profile, cluster_id=self.cfg.cluster_id)
).getOrCreate()

def launch(self):
print(f"Loading image metadata into the table {self.cfg.image_table}")
print(f"Using cluster {self.cfg.cluster_name} with id {self.cfg.cluster_id}")

metadata_df = self._load_image_metadata()
metadata_df.write.saveAsTable(f"{self.cfg.image_table}", format="delta", mode=self.cfg.table_saving_mode)
print(f"Image metadata has been successfully saved into the table {self.cfg.image_table}")

def _load_image_metadata(self):
images = (
self.spark.read.format("image")
.load("/databricks-datasets/flower_photos/*/*.jpg")
.select("image.origin")
.withColumn("class", element_at(split("origin", "/"), 4))
.withColumn(
"image_id",
element_at(split(element_at(split("origin", "/"), -1), ".jpg"), 1),
)
.select("image_id", "class", "origin")
)
return images


@main_wrapper()
def entrypoint(cfg: AppConfig):
task = CreateTableTask(cfg)
task.launch()

It’s a fairly simple piece of Spark code it terms of ETL — we load data from a default dbfs:/databricks-datasets folder (available in every Databricks workspace by default), and simply put the metadata into a separate table:

screenshot of a generated table

The interesting part is in the way how we initialize the Spark Session. Since we’re using DB Connect “V2”, it’s a fairly simple operation (snippet from the code above):

class CreateTableTask:
def __init__(self, cfg: AppConfig):
self.cfg = cfg
self.spark = SparkSession.builder.sdkConfig(
Config(profile=self.cfg.profile, cluster_id=self.cfg.cluster_id)
).getOrCreate()

This initializes a connection where DB Connect “V2” client is running on the local machine, and is capable of sending commands to the remote Databricks cluster.

🖼️ Adding a frontend app

For those who have had a chance to read my previous blogpost, this part would be pretty straightforward. This part of code initializes from an entrypoint function, passes the config and spins up the Dash app:

# file: db_connect_v2_image_classification/frontend/app.py

import logging
from random import choice

from dash import Dash, dcc, html

from db_connect_v2_image_classification.configs import AppConfig
from db_connect_v2_image_classification.frontend.css_utils import external_scripts, external_stylesheets
from db_connect_v2_image_classification.frontend.callbacks import prepare_callbacks
from db_connect_v2_image_classification.frontend.components import data_container, guideline, header, nav_container
from db_connect_v2_image_classification.frontend.crud import DataOperator
from db_connect_v2_image_classification.main_wrapper import main_wrapper

logger = logging.getLogger("frontend_app")
logger.setLevel(logging.INFO)


def prepare_layout(op: DataOperator) -> html.Div:
logger.info("Preparing the layout")
layout = html.Div(
children=[
header,
guideline,
data_container(op.get_all_classes()),
nav_container,
dcc.Store(id="current_index", data=choice(op.get_all_indexes())),
],
className="p-5",
style={"height": "100vh"},
)
logger.info("layout prepared")
return layout


@main_wrapper()
def entrypoint(cfg: AppConfig):
app = Dash(
__name__,
title="Image Classification App",
external_scripts=external_scripts,
external_stylesheets=external_stylesheets,
)
logger.info("Preparing the data operator")
op = DataOperator(cfg)
logger.info("Data operator loaded successfully")
app.layout = prepare_layout(op)
prepare_callbacks(app, op)
app.run(debug=cfg.debug, host="0.0.0.0")


if __name__ == "__main__":
entrypoint()

Line-by-line, the following happens:

  1. Initialize an empty app object with some extra CSS (I’ve decided to add Tailwind and Daisy UI for some ready-to-use components).
  2. Prepare the DataOperator instance. I’ve moved all backend-related actions (i.e. read/write ops) into this class for modularity.
  3. Prepare the layout of the application
  4. Prepare the callback actions
  5. Serve the app

Let’s discuss some of the parts a bit deeper.

🔛 Backend interactions

As said above, the main backbone for R/W operations in this app is DB Connect “V2”. This allows users to write ETL-related code without wrapping it into an ORM-like solutions. For instance, here is the code relevant to the manipulations that could happen in the UI:

# file: db_connect_v2_image_classification/frontend/crud.py

import logging
from functools import lru_cache
from typing import List

import numpy as np
from databricks.connect.session import DatabricksSession as SparkSession
from databricks.sdk.core import Config
from pyspark.sql import DataFrame

from db_connect_v2_image_classification.configs import AppConfig


class DataOperator:
def __init__(self, cfg: AppConfig):
self.logger = logging.getLogger("frontend_app")
self.cfg = cfg
self.logger.info("Initializing DB Connect")
self.spark = SparkSession.builder.sdkConfig(
Config(profile=self.cfg.profile, cluster_id=self.cfg.cluster_id)
).getOrCreate()
self.logger.info(
f"DB Connect initialized a connection to cluster {self.cfg.cluster_name} with id {self.cfg.cluster_id}"
)

@property
def _source_table(self) -> DataFrame:
return self.spark.table(f"{self.cfg.image_table}")

@lru_cache(maxsize=10_000)
def get_all_indexes(self) -> List[str]:
self.logger.info("Loading all image indexes")
indexes = self._source_table.select("image_id").distinct().toPandas()["image_id"].to_list()
self.logger.info("Image indexes loaded")
return indexes

@lru_cache(maxsize=10_000)
def get_all_classes(self) -> List[str]:
return self._source_table.select("class").distinct().toPandas()["class"].to_list()

def get_image_class(self, image_id: str) -> str:
return self._source_table.select("class").where(f"image_id = '{image_id}'").toPandas().loc[0, "class"]

def get_image_payload(self, image_id: str) -> np.ndarray:
image_origin = self._source_table.select("origin").where(f"image_id = '{image_id}'").toPandas().loc[0, "origin"]
image_info = self.spark.read.format("image").load(image_origin).toPandas().T.squeeze()

img_payload = np.frombuffer(image_info["data"], dtype=np.uint8).reshape(
image_info["height"], image_info["width"], image_info["nChannels"]
)[:, :, ::-1]

return img_payload

def update_image_class(self, image_id: str, new_class: str):
self.logger.info(f"Updating the image class for image {image_id}")
command = f"UPDATE {self.cfg.image_table} SET class='{new_class}' WHERE image_id='{image_id}'"
df = self.spark.sql(command)
df.collect()
self.logger.info("Update finished")

In the initialization phase we connect to a Databricks cluster via DB Connect, then we can add read-related methods in pure Spark APIs.

However, some fiddling is required to properly put the binary image data into plotly.imshow :

def get_image_payload(self, image_id: str) -> np.ndarray:
image_origin = self._source_table.select("origin").where(f"image_id = '{image_id}'").toPandas().loc[0, "origin"]
image_info = self.spark.read.format("image").load(image_origin).toPandas().T.squeeze()

img_payload = np.frombuffer(image_info["data"], dtype=np.uint8).reshape(
image_info["height"], image_info["width"], image_info["nChannels"]
)[:, :, ::-1]

return img_payload

Line-by-line on what’s happening:

  1. By running a select on the metadata table, function retrieves the image path (origin) in the cloud storage
  2. The image information is loaded into a single-element of pd.Series
  3. The binary data together with height/width and channels information can be used to build a numpy ndarray
  4. As a last step, this array needs to be reversed to fit the RGB order (this is where we apply [:,:,::-1] operation

A nice cherry on the cake is that it’s still possible to use old good SQL, e.g. for the UPDATE operations:

def update_image_class(self, image_id: str, new_class: str):
self.logger.info(f"Updating the image class for image {image_id}")
command = f"UPDATE {self.cfg.image_table} SET class='{new_class}' WHERE image_id='{image_id}'"
df = self.spark.sql(command)
df.collect()
self.logger.info("Update finished")

With this class, it’s now fully possible to cover any UI operation. Let’s bind these methods to the UI buttons:

# file: db_connect_v2_image_classification/frontend/callbacks.py

from random import choice

import plotly.express as px
from dash import Dash, Input, Output, State

from db_connect_v2_image_classification.frontend.crud import DataOperator


def prepare_callbacks(app: Dash, op: DataOperator):
@app.callback(
Output("current_index", "data"),
Output("image_id", "children"),
Output("class_selector", "value"),
Output("img_display", "figure"),
Input("random_btn", "n_clicks"),
)
def get_next_random(_):
next_index = choice(op.get_all_indexes())
img_class = op.get_image_class(next_index)
img_payload = op.get_image_payload(next_index)
figure = px.imshow(img_payload)
figure.update_layout(coloraxis_showscale=False)
figure.update_xaxes(showticklabels=False)
figure.update_yaxes(showticklabels=False)
return next_index, f"Image id: {next_index}", img_class, figure

@app.callback(
Output("output_mock", "children"),
Input("confirm_btn", "n_clicks"),
State("class_selector", "value"),
State("current_index", "data"),
)
def save_selected_class(_, value, current_index):
if value:
op.update_image_class(current_index, value)
return value

It turns out that the app has a pretty simply logical nature — we only have a random image picker (which will update necessary outputs on each “Next” click), as well as a class saver (which will be triggered on a “Confirm” click.

🔌 Adding UI components

Since the app itself is pretty simple, I’ve only used a couple of nice components and capabilities of Tailwind CSS and Daisy UI:

# file: db_connect_v2_image_classification/frontend/components.py

from dash import dcc, html

blogpost_link = (
"https://www.databricks.com/blog/2022/07/07/introducing-spark-connect-the-power-of-apache-spark-everywhere.html"
)

header = dcc.Markdown(
"## Image classification app, built with [Dash](https://plotly.com/dash/) "
f"and [DB Connect V2]({blogpost_link}) 🔥",
className="text-3xl mb-2",
)


guideline = dcc.Markdown(
"Please click below to navigate between images and correct their class labels when required.",
)

random_btn = html.Button(
"🔀 Next image",
id="random_btn",
n_clicks=0,
className="btn btn-primary btn-lg mt-5",
)

nav_container = html.Div(children=[random_btn], className="flex justify-center")

confirm_button = dcc.Loading(
id="submit-loading",
children=[
html.Button(
"Confirm",
id="confirm_btn",
n_clicks=0,
className="btn btn-success btn-block my-4",
),
html.Div(id="output_mock", style={"display": "none"}),
],
)


def class_selector(class_choices):
return html.Div(
children=[
html.Div(
children=[
dcc.Loading(
children=[
html.P("Please choose the class below:"),
dcc.Dropdown(
class_choices,
id="class_selector",
className="dropdown-content text-black",
placeholder="Select the class",
multi=False,
clearable=False,
),
]
)
],
),
confirm_button,
],
)


def data_container(class_choices):
return html.Div(
children=[
html.Div(
className="flex justify-center pt-2",
children=[
html.Div(
className="card w-96 bg-base-100 shadow-xl p-4",
children=[
dcc.Loading(
[
dcc.Graph(id="img_display"),
dcc.Markdown(id="image_id", className="text-sm text-sky-500"),
]
),
html.Div(className="divider"),
class_selector(class_choices),
],
)
],
),
]
)

The one I’ve liked the most is the ready-to-use card class with all relevant presets. Take a look at how with just a couple of lines of CSS classes the app could become vivid and UI-friendly:

⚙️ Combining it all together

Finally, it’s time for a quick demo of the application:

screen-recorded with gifcap.dev

With this app users can easily classify the images, without a need to use an external tool. At the same time, all data is stored in the cloud storage, whilst metadata is saved and updated in a Delta Table.

Summary

With the flexibility of the Dash framework, it’s possible to build modern-looking and flexible frontend applications.

At the same time, DB Connect works in the backbone of the application, providing convenient Spark APIs to the client-side for proper operations with any kind of data (e.g. images).

The image data itself can be stored in the cloud storage, whilst the metadata information can be easily retrieved from and updated in the Delta Table.

All these technologies together provide a powerful toolset for building efficient and modern data applications on top of the Databricks Lakehouse, in pure Python and with the convenience of the well-known APIs.

The source code from this blog post is provided here. Feel free to copy it and fiddle around with your use cases 🙌.

Have you tried using DB Connect “V2” and Dash already? Have an opinion? Feel free to share it in the comments! Also, hit subscribe if you liked the post — it keeps the author motivated to write more.

--

--

Ivan Trusov

Senior Specialist Solutions Architect @ Databricks. All opinions are my own.