This example shows how to test your application code (in src/) using the test framework (tests in tests/).
my_databricks_project/
├── src/
│ ├── __init__.py
│ ├── data_processing/
│ │ ├── __init__.py
│ │ ├── transformations.py # Your application code
│ │ ├── validators.py # Your application code
│ │ └── aggregations.py # Your application code
│ └── utils/
│ ├── __init__.py
│ └── helpers.py # Your application code
├── tests/
│ ├── test_transformations.py # Tests for transformations
│ ├── test_validators.py # Tests for validators
│ └── test_aggregations.py # Tests for aggregations
├── config/
│ └── test_config.yml
└── pyproject.toml
"""Data transformation functions."""
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
def clean_customer_data(df: DataFrame) -> DataFrame:
"""
Clean customer data by standardizing formats.
Args:
df: Input dataframe with columns: id, name, email, phone
Returns:
Cleaned dataframe
"""
return df.select(
"id",
F.upper(F.trim(df.name)).alias("name"),
F.lower(F.trim(df.email)).alias("email"),
F.regexp_replace(df.phone, r"[^\d]", "").alias("phone")
)
def calculate_customer_lifetime_value(df: DataFrame) -> DataFrame:
"""
Calculate customer lifetime value from transactions.
Args:
df: Dataframe with columns: customer_id, amount, date
Returns:
Dataframe with customer_id, total_value, transaction_count, first_purchase, last_purchase
"""
return df.groupBy("customer_id").agg(
F.sum("amount").alias("total_value"),
F.count("*").alias("transaction_count"),
F.min("date").alias("first_purchase"),
F.max("date").alias("last_purchase")
)
def enrich_with_demographics(customers: DataFrame, demographics: DataFrame) -> DataFrame:
"""
Enrich customer data with demographic information.
Args:
customers: Customer dataframe
demographics: Demographics dataframe
Returns:
Enriched customer dataframe
"""
return customers.join(
demographics,
customers.id == demographics.customer_id,
"left"
).drop(demographics.customer_id)"""Data validation functions."""
from pyspark.sql import DataFrame
from typing import List, Dict, Any
def validate_schema(df: DataFrame, required_columns: List[str]) -> Dict[str, Any]:
"""
Validate that dataframe has required columns.
Args:
df: Dataframe to validate
required_columns: List of required column names
Returns:
Dict with validation results
"""
actual_columns = set(df.columns)
required_set = set(required_columns)
missing = required_set - actual_columns
extra = actual_columns - required_set
return {
"valid": len(missing) == 0,
"missing_columns": list(missing),
"extra_columns": list(extra),
"message": "Schema valid" if len(missing) == 0 else f"Missing columns: {missing}"
}
def validate_no_nulls(df: DataFrame, columns: List[str]) -> Dict[str, Any]:
"""
Validate that specified columns have no null values.
Args:
df: Dataframe to validate
columns: Columns to check for nulls
Returns:
Dict with validation results
"""
null_counts = {}
for col in columns:
count = df.filter(f"{col} IS NULL").count()
if count > 0:
null_counts[col] = count
return {
"valid": len(null_counts) == 0,
"null_counts": null_counts,
"message": "No nulls found" if len(null_counts) == 0 else f"Nulls found: {null_counts}"
}
def validate_email_format(df: DataFrame, email_column: str = "email") -> Dict[str, Any]:
"""
Validate email format.
Args:
df: Dataframe to validate
email_column: Name of email column
Returns:
Dict with validation results
"""
invalid_count = df.filter(
f"{email_column} NOT RLIKE '^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{{2,}}$'"
).count()
return {
"valid": invalid_count == 0,
"invalid_count": invalid_count,
"message": "All emails valid" if invalid_count == 0 else f"Found {invalid_count} invalid emails"
}"""Data aggregation functions."""
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.window import Window
def daily_sales_summary(df: DataFrame) -> DataFrame:
"""
Create daily sales summary.
Args:
df: Transactions with columns: date, amount, customer_id
Returns:
Daily summary with total_sales, transaction_count, unique_customers
"""
return df.groupBy("date").agg(
F.sum("amount").alias("total_sales"),
F.count("*").alias("transaction_count"),
F.countDistinct("customer_id").alias("unique_customers")
)
def top_customers_by_revenue(df: DataFrame, top_n: int = 10) -> DataFrame:
"""
Get top N customers by revenue.
Args:
df: Transactions with columns: customer_id, amount
top_n: Number of top customers to return
Returns:
Top N customers with total revenue
"""
customer_revenue = df.groupBy("customer_id").agg(
F.sum("amount").alias("total_revenue")
)
return customer_revenue.orderBy(F.desc("total_revenue")).limit(top_n)
def calculate_running_total(df: DataFrame, partition_by: str, order_by: str, value_col: str) -> DataFrame:
"""
Calculate running total.
Args:
df: Input dataframe
partition_by: Column to partition by
order_by: Column to order by
value_col: Column to sum
Returns:
Dataframe with running_total column added
"""
window_spec = Window.partitionBy(partition_by).orderBy(order_by).rowsBetween(
Window.unboundedPreceding, Window.currentRow
)
return df.withColumn("running_total", F.sum(value_col).over(window_spec))"""Tests for data_processing.transformations module."""
# Add src/ to Python path so we can import our application code
import sys
from pathlib import Path
src_path = Path(__file__).parent.parent / "src"
sys.path.insert(0, str(src_path))
from dbx_test import NotebookTestFixture
from data_processing.transformations import (
clean_customer_data,
calculate_customer_lifetime_value,
enrich_with_demographics
)
class TestCleanCustomerData(NotebookTestFixture):
"""Test the clean_customer_data function."""
def run_setup(self):
"""Create test data."""
# Create test customers with messy data
self.raw_customers = spark.createDataFrame([
(1, " alice smith ", "Alice.Smith@EXAMPLE.COM ", "(555) 123-4567"),
(2, "bob jones", " BOB@example.com", "555-234-5678"),
(3, "CHARLIE BROWN", "charlie@example.com", "5553456789"),
], ["id", "name", "email", "phone"])
# Apply transformation
self.cleaned = clean_customer_data(self.raw_customers)
def test_names_are_uppercase_and_trimmed(self):
"""Names should be uppercase with no leading/trailing spaces."""
for row in self.cleaned.collect():
name = row["name"]
assert name == name.upper(), f"Name not uppercase: {name}"
assert name == name.strip(), f"Name has whitespace: '{name}'"
def test_emails_are_lowercase_and_trimmed(self):
"""Emails should be lowercase with no leading/trailing spaces."""
for row in self.cleaned.collect():
email = row["email"]
assert email == email.lower(), f"Email not lowercase: {email}"
assert email == email.strip(), f"Email has whitespace: '{email}'"
def test_phones_have_only_digits(self):
"""Phone numbers should contain only digits."""
for row in self.cleaned.collect():
phone = row["phone"]
assert phone.isdigit(), f"Phone contains non-digits: {phone}"
assert len(phone) == 10, f"Phone should be 10 digits, got {len(phone)}"
def test_all_rows_preserved(self):
"""Should not lose any rows during cleaning."""
assert self.raw_customers.count() == self.cleaned.count()
class TestCalculateCustomerLifetimeValue(NotebookTestFixture):
"""Test the calculate_customer_lifetime_value function."""
def run_setup(self):
"""Create test transactions."""
self.transactions = spark.createDataFrame([
(1, 100.0, "2024-01-01"),
(1, 150.0, "2024-01-15"),
(1, 200.0, "2024-02-01"),
(2, 50.0, "2024-01-10"),
(2, 75.0, "2024-01-20"),
], ["customer_id", "amount", "date"])
# Calculate CLV
self.clv = calculate_customer_lifetime_value(self.transactions)
def test_correct_total_value(self):
"""Total value should sum all transactions per customer."""
customer_1 = self.clv.filter("customer_id = 1").collect()[0]
customer_2 = self.clv.filter("customer_id = 2").collect()[0]
assert customer_1["total_value"] == 450.0, "Customer 1 total incorrect"
assert customer_2["total_value"] == 125.0, "Customer 2 total incorrect"
def test_correct_transaction_count(self):
"""Transaction count should be accurate."""
customer_1 = self.clv.filter("customer_id = 1").collect()[0]
customer_2 = self.clv.filter("customer_id = 2").collect()[0]
assert customer_1["transaction_count"] == 3, "Customer 1 count incorrect"
assert customer_2["transaction_count"] == 2, "Customer 2 count incorrect"
def test_date_ranges(self):
"""First and last purchase dates should be correct."""
customer_1 = self.clv.filter("customer_id = 1").collect()[0]
assert customer_1["first_purchase"] == "2024-01-01"
assert customer_1["last_purchase"] == "2024-02-01"
class TestEnrichWithDemographics(NotebookTestFixture):
"""Test the enrich_with_demographics function."""
def run_setup(self):
"""Create test customers and demographics."""
self.customers = spark.createDataFrame([
(1, "Alice", "alice@example.com"),
(2, "Bob", "bob@example.com"),
(3, "Charlie", "charlie@example.com"),
], ["id", "name", "email"])
self.demographics = spark.createDataFrame([
(1, 25, "CA", "USA"),
(2, 30, "NY", "USA"),
# Note: Customer 3 has no demographics
], ["customer_id", "age", "state", "country"])
# Enrich
self.enriched = enrich_with_demographics(self.customers, self.demographics)
def test_all_customers_preserved(self):
"""Left join should preserve all customers."""
assert self.enriched.count() == 3, "Should have all 3 customers"
def test_demographics_added(self):
"""Demographics columns should be added."""
columns = set(self.enriched.columns)
assert "age" in columns
assert "state" in columns
assert "country" in columns
def test_customer_without_demographics_has_nulls(self):
"""Customers without demographics should have null values."""
charlie = self.enriched.filter("id = 3").collect()[0]
assert charlie["age"] is None
assert charlie["state"] is None
def test_no_duplicate_customer_id_column(self):
"""Should not have duplicate customer_id column."""
columns = self.enriched.columns
customer_id_count = columns.count("customer_id")
assert customer_id_count == 0, "customer_id column should be dropped""""Tests for data_processing.validators module."""
import sys
from pathlib import Path
src_path = Path(__file__).parent.parent / "src"
sys.path.insert(0, str(src_path))
from dbx_test import NotebookTestFixture
from data_processing.validators import (
validate_schema,
validate_no_nulls,
validate_email_format
)
class TestValidateSchema(NotebookTestFixture):
"""Test the validate_schema function."""
def run_setup(self):
"""Create test dataframe."""
self.df = spark.createDataFrame([
(1, "Alice", "alice@example.com"),
], ["id", "name", "email"])
def test_valid_schema(self):
"""Should pass when all required columns present."""
result = validate_schema(self.df, ["id", "name", "email"])
assert result["valid"] is True
assert len(result["missing_columns"]) == 0
def test_missing_columns_detected(self):
"""Should detect missing columns."""
result = validate_schema(self.df, ["id", "name", "email", "phone"])
assert result["valid"] is False
assert "phone" in result["missing_columns"]
def test_extra_columns_reported(self):
"""Should report extra columns."""
result = validate_schema(self.df, ["id", "name"])
assert "email" in result["extra_columns"]
class TestValidateNoNulls(NotebookTestFixture):
"""Test the validate_no_nulls function."""
def run_setup(self):
"""Create test dataframe with some nulls."""
self.df = spark.createDataFrame([
(1, "Alice", "alice@example.com"),
(2, None, "bob@example.com"),
(3, "Charlie", None),
], ["id", "name", "email"])
def test_detects_nulls(self):
"""Should detect null values."""
result = validate_no_nulls(self.df, ["name", "email"])
assert result["valid"] is False
assert "name" in result["null_counts"]
assert "email" in result["null_counts"]
def test_no_nulls_in_id_column(self):
"""Should pass when column has no nulls."""
result = validate_no_nulls(self.df, ["id"])
assert result["valid"] is True
class TestValidateEmailFormat(NotebookTestFixture):
"""Test the validate_email_format function."""
def run_setup(self):
"""Create test dataframe with valid and invalid emails."""
self.valid_df = spark.createDataFrame([
(1, "alice@example.com"),
(2, "bob.jones@company.co.uk"),
(3, "charlie+test@domain.org"),
], ["id", "email"])
self.invalid_df = spark.createDataFrame([
(1, "not-an-email"),
(2, "@example.com"),
(3, "missing@domain"),
], ["id", "email"])
def test_valid_emails_pass(self):
"""Should validate correct email formats."""
result = validate_email_format(self.valid_df)
assert result["valid"] is True
assert result["invalid_count"] == 0
def test_invalid_emails_detected(self):
"""Should detect invalid email formats."""
result = validate_email_format(self.invalid_df)
assert result["valid"] is False
assert result["invalid_count"] == 3"""Tests for data_processing.aggregations module."""
import sys
from pathlib import Path
src_path = Path(__file__).parent.parent / "src"
sys.path.insert(0, str(src_path))
from dbx_test import NotebookTestFixture
from data_processing.aggregations import (
daily_sales_summary,
top_customers_by_revenue,
calculate_running_total
)
class TestDailySalesSummary(NotebookTestFixture):
"""Test the daily_sales_summary function."""
def run_setup(self):
"""Create test transactions."""
self.transactions = spark.createDataFrame([
("2024-01-01", 100.0, 1),
("2024-01-01", 150.0, 2),
("2024-01-01", 200.0, 1), # Customer 1 buys twice
("2024-01-02", 50.0, 3),
("2024-01-02", 75.0, 3), # Customer 3 buys twice
], ["date", "amount", "customer_id"])
# Calculate summary
self.summary = daily_sales_summary(self.transactions)
def test_correct_daily_totals(self):
"""Daily totals should be correct."""
day1 = self.summary.filter("date = '2024-01-01'").collect()[0]
day2 = self.summary.filter("date = '2024-01-02'").collect()[0]
assert day1["total_sales"] == 450.0, "Day 1 total incorrect"
assert day2["total_sales"] == 125.0, "Day 2 total incorrect"
def test_correct_transaction_counts(self):
"""Transaction counts should be accurate."""
day1 = self.summary.filter("date = '2024-01-01'").collect()[0]
assert day1["transaction_count"] == 3, "Day 1 count incorrect"
def test_correct_unique_customers(self):
"""Unique customer counts should be accurate."""
day1 = self.summary.filter("date = '2024-01-01'").collect()[0]
day2 = self.summary.filter("date = '2024-01-02'").collect()[0]
assert day1["unique_customers"] == 2, "Day 1 should have 2 unique customers"
assert day2["unique_customers"] == 1, "Day 2 should have 1 unique customer"
class TestTopCustomersByRevenue(NotebookTestFixture):
"""Test the top_customers_by_revenue function."""
def run_setup(self):
"""Create test transactions."""
self.transactions = spark.createDataFrame([
(1, 100.0),
(1, 150.0), # Customer 1: 250.0 total
(2, 300.0), # Customer 2: 300.0 total
(3, 50.0),
(3, 75.0), # Customer 3: 125.0 total
(4, 500.0), # Customer 4: 500.0 total (highest)
], ["customer_id", "amount"])
# Get top 3
self.top3 = top_customers_by_revenue(self.transactions, top_n=3)
def test_returns_correct_number(self):
"""Should return requested number of customers."""
assert self.top3.count() == 3, "Should return top 3"
def test_correct_order(self):
"""Should be ordered by revenue descending."""
results = self.top3.collect()
assert results[0]["customer_id"] == 4, "Customer 4 should be #1"
assert results[1]["customer_id"] == 2, "Customer 2 should be #2"
assert results[2]["customer_id"] == 1, "Customer 1 should be #3"
def test_correct_revenue_values(self):
"""Revenue values should be correct."""
customer_4 = self.top3.filter("customer_id = 4").collect()[0]
assert customer_4["total_revenue"] == 500.0
class TestCalculateRunningTotal(NotebookTestFixture):
"""Test the calculate_running_total function."""
def run_setup(self):
"""Create test data."""
self.df = spark.createDataFrame([
("A", "2024-01-01", 10.0),
("A", "2024-01-02", 20.0),
("A", "2024-01-03", 30.0),
("B", "2024-01-01", 5.0),
("B", "2024-01-02", 10.0),
], ["category", "date", "value"])
# Calculate running total
self.result = calculate_running_total(
self.df,
partition_by="category",
order_by="date",
value_col="value"
)
def test_running_total_correct(self):
"""Running totals should be cumulative."""
category_a = self.result.filter("category = 'A'").orderBy("date").collect()
assert category_a[0]["running_total"] == 10.0, "First value should be 10"
assert category_a[1]["running_total"] == 30.0, "Second should be 10+20=30"
assert category_a[2]["running_total"] == 60.0, "Third should be 10+20+30=60"
def test_partitioning_works(self):
"""Each partition should have its own running total."""
category_b = self.result.filter("category = 'B'").orderBy("date").collect()
assert category_b[0]["running_total"] == 5.0, "Category B should start at 5"
assert category_b[1]["running_total"] == 15.0, "Category B second should be 5+10=15"# Run all tests
dbx_test run --local --tests-dir tests
# Run specific test file
dbx_test run --local --pattern "*test_transformations*"
# Run remotely on Databricks
dbx_test run --remote --tests-dir tests --profile devCreate a notebook that imports and tests your code:
# Cell 1: Setup path
import sys
from pathlib import Path
# Add src to path
src_path = "/Workspace/Repos/my-repo/my_databricks_project/src"
sys.path.insert(0, src_path)
# Cell 2: Install framework
%pip install /dbfs/FileStore/wheels/dbx_test-0.1.0-py3-none-any.whl
# Cell 3: Import your code
from data_processing.transformations import clean_customer_data
# Cell 4: Write and run tests
from dbx_test import NotebookTestFixture, run_notebook_tests
class TestTransformations(NotebookTestFixture):
def run_setup(self):
self.raw_data = spark.createDataFrame([
(1, " alice ", "ALICE@EXAMPLE.COM"),
], ["id", "name", "email"])
self.cleaned = clean_customer_data(self.raw_data)
def test_name_uppercase(self):
row = self.cleaned.collect()[0]
assert row["name"] == "ALICE"
run_notebook_tests()[project]
name = "my-databricks-project"
version = "0.1.0"
[tool.setuptools]
package-dir = {"" = "src"}
[tool.setuptools.packages.find]
where = ["src"]
include = ["data_processing*", "utils*"]$ dbx_test run --local --tests-dir tests
Discovered 3 test notebook(s):
- tests/test_transformations.py (3 classes, 10 tests)
- tests/test_validators.py (3 classes, 7 tests)
- tests/test_aggregations.py (3 classes, 9 tests)
Running tests locally...
============================================================
Running TestCleanCustomerData
============================================================
Running test_names_are_uppercase_and_trimmed...
✓ PASSED
Running test_emails_are_lowercase_and_trimmed...
✓ PASSED
Running test_phones_have_only_digits...
✓ PASSED
Running test_all_rows_preserved...
✓ PASSED
============================================================
Running TestCalculateCustomerLifetimeValue
============================================================
Running test_correct_total_value...
✓ PASSED
Running test_correct_transaction_count...
✓ PASSED
Running test_date_ranges...
✓ PASSED
[... more tests ...]
============================================================
SUMMARY
============================================================
Total Tests: 26
✓ Passed: 26
✗ Failed: 0
✗ Errors: 0
🎉 All tests passed!
============================================================- Path Management: Always add
src/to Python path at the start of test files - Modular Code: Keep functions small and testable
- Type Hints: Use type hints for better IDE support and documentation
- Test Data: Create realistic test data that covers edge cases
- Separate Concerns: One test class per function/feature
- Clean Up: Drop temp tables/views in
run_cleanup()
Create a test helper to avoid repeating path setup:
# tests/test_helpers.py
import sys
from pathlib import Path
def setup_src_path():
"""Add src/ to Python path."""
src_path = Path(__file__).parent.parent / "src"
if str(src_path) not in sys.path:
sys.path.insert(0, str(src_path))
# Use in tests
from test_helpers import setup_src_path
setup_src_path()# tests/fixtures.py
from dbx_test import NotebookTestFixture
class BaseTestWithSpark(NotebookTestFixture):
"""Base test class with common setup."""
def run_setup(self):
# Common setup for all tests
self.test_db = "test_database"
spark.sql(f"CREATE DATABASE IF NOT EXISTS {self.test_db}")
def run_cleanup(self):
spark.sql(f"DROP DATABASE IF EXISTS {self.test_db} CASCADE")
# Use in tests
from tests.fixtures import BaseTestWithSpark
class TestMyFeature(BaseTestWithSpark):
# Inherits setup/cleanup
passThis example shows the complete workflow for testing application code! 🎉