Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions benchmarks/tesseract_noop/tesseract_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ def apply(inputs: InputSchema) -> OutputSchema:
HTTP transport, and deserialization.
"""
return OutputSchema(result=inputs.data)


def abstract_eval(abstract_inputs):
"""Return output shapes from input shapes."""
return {"result": abstract_inputs.data}
6 changes: 6 additions & 0 deletions examples/bandedblock_cholmod/julia/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
63 changes: 63 additions & 0 deletions examples/bandedblock_cholmod/julia/apply.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Sparse CHOLMOD solver for SPD block systems with tridiagonal blocks.
#
# Assembles a sparse SparseMatrixCSC from tridiagonal block diagonals
# (identified by their Tesseract paths) and solves via CHOLMOD.
#
# Contract:
# apply_jl(diff_args, non_diff_args, diff_paths, non_diff_paths) -> Vector{Float64}

using LinearAlgebra, SparseArrays, LinearSolve

function apply_jl(diff_args, non_diff_args, diff_paths, non_diff_paths)
# Parse block_sizes from non_diff inputs
block_sizes = Int[]
for (i, path) in enumerate(non_diff_paths)
if startswith(path, "block_sizes.")
push!(block_sizes, Int(non_diff_args[i]))
end
end

N = sum(block_sizes)
offsets = cumsum([0; block_sizes[1:end-1]]) .+ 1

# Build sparse matrix from tridiagonal block diagonals
I_idx = Int[]; J_idx = Int[]; V = Float64[]
b = zeros(N)

for (k, path) in enumerate(diff_paths)
arr = diff_args[k]
path == "b" && (b .= arr; continue)

m = match(r"^blocks\.\[(\d+)\]\.\[(\d+)\]\.(sub|main|sup)$", path)
m === nothing && continue

bi = parse(Int, m[1]) + 1 # 1-indexed
bj = parse(Int, m[2]) + 1
comp = m[3]
r0 = offsets[bi]
c0 = offsets[bj]

if comp == "main"
for idx in 1:length(arr)
push!(I_idx, r0+idx-1); push!(J_idx, c0+idx-1); push!(V, arr[idx])
end
elseif comp == "sub"
for idx in 1:length(arr)
push!(I_idx, r0+idx); push!(J_idx, c0+idx-1); push!(V, arr[idx])
end
elseif comp == "sup"
for idx in 1:length(arr)
push!(I_idx, r0+idx-1); push!(J_idx, c0+idx); push!(V, arr[idx])
end
end
end

A_raw = sparse(I_idx, J_idx, V, N, N)
# Materialize symmetry into a plain SparseMatrixCSC (avoids Symmetric wrapper
# which has a known bug with Enzyme reverse-mode in LinearSolve)
A_sym = sparse(Symmetric(A_raw))

prob = LinearProblem(A_sym, b)
sol = solve(prob, CHOLMODFactorization())
return copy(sol.u)
end
59 changes: 59 additions & 0 deletions examples/bandedblock_cholmod/julia/enzyme_wrappers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Generic Enzyme AD wrappers for Tesseract Julia recipes.
#
# These work with any apply_jl that follows the contract:
# apply_jl(diff_args, non_diff_args, diff_paths, non_diff_paths) -> Vector{Float64}
#
# IMPORTANT: All Python objects (PyArray, PyList, PyString) must be converted
# to native Julia types before Enzyme sees them. Enzyme traces at the LLVM
# level and cannot differentiate through PythonCall's conversion internals.

using LinearAlgebra
using Enzyme
using PythonCall: pyconvert

function _to_jl_vecs(pyargs)
return [Vector{Float64}(a) for a in pyargs]
end

function _to_jl_strings(pyargs)
return String[pyconvert(String, s) for s in pyargs]
end

function _to_jl_any(pyargs)
return Any[pyconvert(Any, a) for a in pyargs]
end

"""
enzyme_jvp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, tangents)

Forward-mode AD. Returns the JVP output vector.
"""
function enzyme_jvp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, tangents)
jl_args = _to_jl_vecs(diff_args)
jl_tangents = _to_jl_vecs(tangents)
jl_non_diff = _to_jl_any(non_diff_args)
jl_diff_paths = _to_jl_strings(diff_paths)
jl_non_diff_paths = _to_jl_strings(non_diff_paths)
closure(d...) = apply_fn(collect(d), jl_non_diff, jl_diff_paths, jl_non_diff_paths)
dups = [Enzyme.Duplicated(jl_args[i], jl_tangents[i]) for i in eachindex(jl_args)]
return Enzyme.autodiff(set_runtime_activity(Enzyme.Forward), Enzyme.Const(closure), dups...)[1]
end

"""
enzyme_vjp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, cotangent)

Reverse-mode AD. Returns a list of gradients, one per element of diff_args.
"""
function enzyme_vjp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, cotangent)
jl_args = _to_jl_vecs(diff_args)
jl_cotangent = Vector{Float64}(cotangent)
jl_non_diff = _to_jl_any(non_diff_args)
jl_diff_paths = _to_jl_strings(diff_paths)
jl_non_diff_paths = _to_jl_strings(non_diff_paths)
closure(d...) = apply_fn(collect(d), jl_non_diff, jl_diff_paths, jl_non_diff_paths)
shadows = [zero(a) for a in jl_args]
dups = [Enzyme.Duplicated(jl_args[i], shadows[i]) for i in eachindex(jl_args)]
scalar_f(d...) = dot(jl_cotangent, closure(d...))
Enzyme.autodiff(set_runtime_activity(Enzyme.Reverse), Enzyme.Const(scalar_f), Enzyme.Active, dups...)
return shadows
end
246 changes: 246 additions & 0 deletions examples/bandedblock_cholmod/tesseract_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Tesseract wrapping a sparse CHOLMOD solver for SPD block systems with Enzyme AD.

Solves A * x = b where A is a symmetric positive definite matrix with block
structure. Each nonzero block is tridiagonal (or diagonal), stored compactly
as up to 3 diagonal vectors. Zero blocks are represented as None.

The matrix is assembled as a sparse SparseMatrixCSC and solved via
SuiteSparse CHOLMOD — a sparse Cholesky factorization not available in JAX.
Enzyme provides both forward-mode (JVP) and reverse-mode (VJP) automatic
differentiation through the LinearSolve.jl Enzyme extension, which
implements the implicit function theorem adjoint without differentiating
through CHOLMOD internals.

Example arrow structure (SPD, blocks [0][1]=[1][0]^T, [0][2]=[2][0]^T):

┌──────────────────────┐
│ A00 │ A01 │ A02 │
├───────┼───────┼──────┤
│ A10 │ A11 │ 0 │
├───────┼───────┼──────┤
│ A20 │ 0 │ A22 │
└──────────────────────┘
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np
from juliacall import Main as jl
from pydantic import BaseModel, Field

from tesseract_core.runtime import Array, Differentiable, Float64
from tesseract_core.runtime.schema_generation import (
DICT_INDEX_SENTINEL,
SEQ_INDEX_SENTINEL,
get_all_model_path_patterns,
)
from tesseract_core.runtime.schema_types import is_differentiable
from tesseract_core.runtime.testing.finite_differences import expand_path_pattern
from tesseract_core.runtime.tree_transforms import get_at_path

# ---------------------------------------------------------------------------
# Julia setup — load solver and Enzyme wrappers from .jl files
# ---------------------------------------------------------------------------

_here = Path(__file__).parent
jl.include(str(_here / "julia" / "enzyme_wrappers.jl"))
jl.include(str(_here / "julia" / "apply.jl"))


# ---------------------------------------------------------------------------
# Schemata
# ---------------------------------------------------------------------------


class TridiagBlock(BaseModel):
"""A tridiagonal (or diagonal) block stored as diagonal vectors."""

sub: Differentiable[Array[(None,), Float64]] | None = Field(
default=None, description="Sub-diagonal, length n-1. None for diagonal blocks."
)
main: Differentiable[Array[(None,), Float64]] = Field(
description="Main diagonal, length n."
)
sup: Differentiable[Array[(None,), Float64]] | None = Field(
default=None,
description="Super-diagonal, length n-1. None for diagonal blocks.",
)


class InputSchema(BaseModel):
blocks: list[list[TridiagBlock | None]] = Field(
description="Block structure as nested list. None = zero block. "
"Example for 3x3 arrow: "
"[[A00, A01, A02], [A10, A11, None], [A20, None, A22]]",
)
b: Differentiable[Array[(None,), Float64]] = Field(
description="Right-hand side vector, length sum(block_sizes).",
)
block_sizes: list[int] = Field(
description="Size of each block group.",
)


class OutputSchema(BaseModel):
x: Differentiable[Array[(None,), Float64]] = Field(
description="Solution vector, length sum(block_sizes).",
)


# ---------------------------------------------------------------------------
# Schema-driven path extraction (generic, no need to modify)
# ---------------------------------------------------------------------------


def _path_tuple_to_pattern(path_tuple: tuple) -> str:
"""Convert a path tuple with sentinels to a pattern string with [] and {}."""
parts = []
for part in path_tuple:
if part is SEQ_INDEX_SENTINEL:
parts.append("[]")
elif part is DICT_INDEX_SENTINEL:
parts.append("{}")
else:
parts.append(str(part))
return ".".join(parts)


_ALL_PATTERNS = [
_path_tuple_to_pattern(p) for p in get_all_model_path_patterns(InputSchema)
]
_DIFF_PATTERNS = [
_path_tuple_to_pattern(p)
for p in get_all_model_path_patterns(InputSchema, filter_fn=is_differentiable)
]


def _expand_inputs(
inputs_dict: dict,
) -> tuple[list[np.ndarray], list[str], list[Any], list[str]]:
"""Expand schema paths against concrete inputs, split into diff and non-diff.

Returns:
diff_args: list of differentiable arrays (always numpy float64)
diff_paths: corresponding Tesseract paths
non_diff_args: list of non-differentiable leaf values
non_diff_paths: corresponding Tesseract paths
"""
all_paths = []
for pattern in _ALL_PATTERNS:
all_paths.extend(expand_path_pattern(pattern, inputs_dict))

diff_path_set = set()
for pattern in _DIFF_PATTERNS:
diff_path_set.update(expand_path_pattern(pattern, inputs_dict))

diff_args, diff_paths = [], []
non_diff_args, non_diff_paths = [], []

for path in all_paths:
value = get_at_path(inputs_dict, path)
if isinstance(value, (dict, list)):
continue

if path in diff_path_set:
diff_args.append(np.asarray(value, dtype=np.float64))
diff_paths.append(path)
else:
non_diff_args.append(value)
non_diff_paths.append(path)

return diff_args, diff_paths, non_diff_args, non_diff_paths


# ---------------------------------------------------------------------------
# Required endpoints — modify evaluate and abstract_eval for your solver
# ---------------------------------------------------------------------------


def evaluate(inputs_dict: dict) -> dict:
"""Call the Julia CHOLMOD solver."""
diff_args, diff_paths, non_diff_args, non_diff_paths = _expand_inputs(inputs_dict)
result = jl.apply_jl(diff_args, non_diff_args, diff_paths, non_diff_paths)
return {"x": np.asarray(result)}


def apply(inputs: InputSchema) -> OutputSchema:
return evaluate(inputs.model_dump())


def abstract_eval(abstract_inputs):
"""Output x has length sum(block_sizes)."""
inp = abstract_inputs.model_dump()
total_size = sum(inp["block_sizes"])
return {
"x": {"shape": [total_size], "dtype": "float64"},
}


# ---------------------------------------------------------------------------
# Gradient endpoints (generic, no need to modify)
# ---------------------------------------------------------------------------


def jacobian_vector_product(
inputs: InputSchema,
jvp_inputs: set[str],
jvp_outputs: set[str],
tangent_vector: dict[str, Any],
):
inputs_dict = inputs.model_dump()
diff_args, diff_paths, non_diff_args, non_diff_paths = _expand_inputs(inputs_dict)

tangent_args = []
for k, path in enumerate(diff_paths):
if path in jvp_inputs:
tangent_args.append(np.asarray(tangent_vector[path], dtype=np.float64))
else:
tangent_args.append(np.zeros_like(diff_args[k]))

jvp_out = np.asarray(
jl.enzyme_jvp(
jl.apply_jl,
diff_args,
non_diff_args,
diff_paths,
non_diff_paths,
tangent_args,
)
)
return {p: jvp_out for p in jvp_outputs}


def vector_jacobian_product(
inputs: InputSchema,
vjp_inputs: set[str],
vjp_outputs: set[str],
cotangent_vector: dict[str, Any],
):
inputs_dict = inputs.model_dump()
diff_args, diff_paths, non_diff_args, non_diff_paths = _expand_inputs(inputs_dict)

cotangent_vector = {key: cotangent_vector[key] for key in vjp_outputs}
combined_cotangent = sum(
np.asarray(v, dtype=np.float64) for v in cotangent_vector.values()
)

grads = jl.enzyme_vjp(
jl.apply_jl,
diff_args,
non_diff_args,
diff_paths,
non_diff_paths,
combined_cotangent,
)

result = {}
for k, path in enumerate(diff_paths):
if path in vjp_inputs:
result[path] = np.asarray(grads[k])
return result
Loading
Loading