From 86aff804ec0bc0ac7d3f7ed6a2234c9d76640184 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Thu, 2 Apr 2026 15:16:28 +0100 Subject: [PATCH 1/3] feat: sketch for a Julia template and MWE --- benchmarks/tesseract_noop/tesseract_api.py | 5 + examples/bandedblock_cholmod/apply.jl | 63 +++++ .../bandedblock_cholmod/enzyme_wrappers.jl | 59 ++++ examples/bandedblock_cholmod/tesseract_api.py | 246 +++++++++++++++++ .../bandedblock_cholmod/tesseract_config.yaml | 12 + .../tesseract_requirements.txt | 1 + .../test_cases/test_abstract_eval.json | 240 ++++++++++++++++ .../test_cases/test_apply.json | 245 +++++++++++++++++ .../test_cases/test_jvp_wrt_b.json | 258 ++++++++++++++++++ .../test_cases/test_vjp_wrt_b.json | 258 ++++++++++++++++++ .../runtime/testing/finite_differences.py | 12 +- tesseract_core/sdk/engine.py | 7 + tesseract_core/sdk/templates/julia/apply.jl | 21 ++ .../sdk/templates/julia/enzyme_wrappers.jl | 59 ++++ .../sdk/templates/julia/tesseract_api.py | 203 ++++++++++++++ .../sdk/templates/julia/tesseract_config.yaml | 35 +++ .../julia/tesseract_requirements.txt | 9 + 17 files changed, 1731 insertions(+), 2 deletions(-) create mode 100644 examples/bandedblock_cholmod/apply.jl create mode 100644 examples/bandedblock_cholmod/enzyme_wrappers.jl create mode 100644 examples/bandedblock_cholmod/tesseract_api.py create mode 100644 examples/bandedblock_cholmod/tesseract_config.yaml create mode 100644 examples/bandedblock_cholmod/tesseract_requirements.txt create mode 100644 examples/bandedblock_cholmod/test_cases/test_abstract_eval.json create mode 100644 examples/bandedblock_cholmod/test_cases/test_apply.json create mode 100644 examples/bandedblock_cholmod/test_cases/test_jvp_wrt_b.json create mode 100644 examples/bandedblock_cholmod/test_cases/test_vjp_wrt_b.json create mode 100644 tesseract_core/sdk/templates/julia/apply.jl create mode 100644 tesseract_core/sdk/templates/julia/enzyme_wrappers.jl create mode 100644 tesseract_core/sdk/templates/julia/tesseract_api.py create mode 100644 tesseract_core/sdk/templates/julia/tesseract_config.yaml create mode 100644 tesseract_core/sdk/templates/julia/tesseract_requirements.txt diff --git a/benchmarks/tesseract_noop/tesseract_api.py b/benchmarks/tesseract_noop/tesseract_api.py index ed8c5fff..d3084341 100644 --- a/benchmarks/tesseract_noop/tesseract_api.py +++ b/benchmarks/tesseract_noop/tesseract_api.py @@ -32,3 +32,8 @@ def apply(inputs: InputSchema) -> OutputSchema: HTTP transport, and deserialization. """ return OutputSchema(result=inputs.data) + + +def abstract_eval(abstract_inputs): + """Return output shapes from input shapes.""" + return {"result": abstract_inputs.data} diff --git a/examples/bandedblock_cholmod/apply.jl b/examples/bandedblock_cholmod/apply.jl new file mode 100644 index 00000000..03ac73e9 --- /dev/null +++ b/examples/bandedblock_cholmod/apply.jl @@ -0,0 +1,63 @@ +# Sparse CHOLMOD solver for SPD block systems with tridiagonal blocks. +# +# Assembles a sparse SparseMatrixCSC from tridiagonal block diagonals +# (identified by their Tesseract paths) and solves via CHOLMOD. +# +# Contract: +# apply_jl(diff_args, non_diff_args, diff_paths, non_diff_paths) -> Vector{Float64} + +using LinearAlgebra, SparseArrays, LinearSolve + +function apply_jl(diff_args, non_diff_args, diff_paths, non_diff_paths) + # Parse block_sizes from non_diff inputs + block_sizes = Int[] + for (i, path) in enumerate(non_diff_paths) + if startswith(path, "block_sizes.") + push!(block_sizes, Int(non_diff_args[i])) + end + end + + N = sum(block_sizes) + offsets = cumsum([0; block_sizes[1:end-1]]) .+ 1 + + # Build sparse matrix from tridiagonal block diagonals + I_idx = Int[]; J_idx = Int[]; V = Float64[] + b = zeros(N) + + for (k, path) in enumerate(diff_paths) + arr = diff_args[k] + path == "b" && (b .= arr; continue) + + m = match(r"^blocks\.\[(\d+)\]\.\[(\d+)\]\.(sub|main|sup)$", path) + m === nothing && continue + + bi = parse(Int, m[1]) + 1 # 1-indexed + bj = parse(Int, m[2]) + 1 + comp = m[3] + r0 = offsets[bi] + c0 = offsets[bj] + + if comp == "main" + for idx in 1:length(arr) + push!(I_idx, r0+idx-1); push!(J_idx, c0+idx-1); push!(V, arr[idx]) + end + elseif comp == "sub" + for idx in 1:length(arr) + push!(I_idx, r0+idx); push!(J_idx, c0+idx-1); push!(V, arr[idx]) + end + elseif comp == "sup" + for idx in 1:length(arr) + push!(I_idx, r0+idx-1); push!(J_idx, c0+idx); push!(V, arr[idx]) + end + end + end + + A_raw = sparse(I_idx, J_idx, V, N, N) + # Materialize symmetry into a plain SparseMatrixCSC (avoids Symmetric wrapper + # which has a known bug with Enzyme reverse-mode in LinearSolve) + A_sym = sparse(Symmetric(A_raw)) + + prob = LinearProblem(A_sym, b) + sol = solve(prob, CHOLMODFactorization()) + return copy(sol.u) +end diff --git a/examples/bandedblock_cholmod/enzyme_wrappers.jl b/examples/bandedblock_cholmod/enzyme_wrappers.jl new file mode 100644 index 00000000..43aa1d0b --- /dev/null +++ b/examples/bandedblock_cholmod/enzyme_wrappers.jl @@ -0,0 +1,59 @@ +# Generic Enzyme AD wrappers for Tesseract Julia recipes. +# +# These work with any apply_jl that follows the contract: +# apply_jl(diff_args, non_diff_args, diff_paths, non_diff_paths) -> Vector{Float64} +# +# IMPORTANT: All Python objects (PyArray, PyList, PyString) must be converted +# to native Julia types before Enzyme sees them. Enzyme traces at the LLVM +# level and cannot differentiate through PythonCall's conversion internals. + +using LinearAlgebra +using Enzyme +using PythonCall: pyconvert + +function _to_jl_vecs(pyargs) + return [Vector{Float64}(a) for a in pyargs] +end + +function _to_jl_strings(pyargs) + return String[pyconvert(String, s) for s in pyargs] +end + +function _to_jl_any(pyargs) + return Any[pyconvert(Any, a) for a in pyargs] +end + +""" + enzyme_jvp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, tangents) + +Forward-mode AD. Returns the JVP output vector. +""" +function enzyme_jvp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, tangents) + jl_args = _to_jl_vecs(diff_args) + jl_tangents = _to_jl_vecs(tangents) + jl_non_diff = _to_jl_any(non_diff_args) + jl_diff_paths = _to_jl_strings(diff_paths) + jl_non_diff_paths = _to_jl_strings(non_diff_paths) + closure(d...) = apply_fn(collect(d), jl_non_diff, jl_diff_paths, jl_non_diff_paths) + dups = [Enzyme.Duplicated(jl_args[i], jl_tangents[i]) for i in eachindex(jl_args)] + return Enzyme.autodiff(set_runtime_activity(Enzyme.Forward), Enzyme.Const(closure), dups...)[1] +end + +""" + enzyme_vjp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, cotangent) + +Reverse-mode AD. Returns a list of gradients, one per element of diff_args. +""" +function enzyme_vjp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, cotangent) + jl_args = _to_jl_vecs(diff_args) + jl_cotangent = Vector{Float64}(cotangent) + jl_non_diff = _to_jl_any(non_diff_args) + jl_diff_paths = _to_jl_strings(diff_paths) + jl_non_diff_paths = _to_jl_strings(non_diff_paths) + closure(d...) = apply_fn(collect(d), jl_non_diff, jl_diff_paths, jl_non_diff_paths) + shadows = [zero(a) for a in jl_args] + dups = [Enzyme.Duplicated(jl_args[i], shadows[i]) for i in eachindex(jl_args)] + scalar_f(d...) = dot(jl_cotangent, closure(d...)) + Enzyme.autodiff(set_runtime_activity(Enzyme.Reverse), Enzyme.Const(scalar_f), Enzyme.Active, dups...) + return shadows +end diff --git a/examples/bandedblock_cholmod/tesseract_api.py b/examples/bandedblock_cholmod/tesseract_api.py new file mode 100644 index 00000000..ebc8b5b9 --- /dev/null +++ b/examples/bandedblock_cholmod/tesseract_api.py @@ -0,0 +1,246 @@ +# Copyright 2025 Pasteur Labs. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tesseract wrapping a sparse CHOLMOD solver for SPD block systems with Enzyme AD. + +Solves A * x = b where A is a symmetric positive definite matrix with block +structure. Each nonzero block is tridiagonal (or diagonal), stored compactly +as up to 3 diagonal vectors. Zero blocks are represented as None. + +The matrix is assembled as a sparse SparseMatrixCSC and solved via +SuiteSparse CHOLMOD — a sparse Cholesky factorization not available in JAX. +Enzyme provides both forward-mode (JVP) and reverse-mode (VJP) automatic +differentiation through the LinearSolve.jl Enzyme extension, which +implements the implicit function theorem adjoint without differentiating +through CHOLMOD internals. + +Example arrow structure (SPD, blocks [0][1]=[1][0]^T, [0][2]=[2][0]^T): + + ┌──────────────────────┐ + │ A00 │ A01 │ A02 │ + ├───────┼───────┼──────┤ + │ A10 │ A11 │ 0 │ + ├───────┼───────┼──────┤ + │ A20 │ 0 │ A22 │ + └──────────────────────┘ +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +from juliacall import Main as jl +from pydantic import BaseModel, Field + +from tesseract_core.runtime import Array, Differentiable, Float64 +from tesseract_core.runtime.schema_generation import ( + DICT_INDEX_SENTINEL, + SEQ_INDEX_SENTINEL, + get_all_model_path_patterns, +) +from tesseract_core.runtime.schema_types import is_differentiable +from tesseract_core.runtime.testing.finite_differences import expand_path_pattern +from tesseract_core.runtime.tree_transforms import get_at_path + +# --------------------------------------------------------------------------- +# Julia setup — load solver and Enzyme wrappers from .jl files +# --------------------------------------------------------------------------- + +_here = Path(__file__).parent +jl.include(str(_here / "apply.jl")) +jl.include(str(_here / "enzyme_wrappers.jl")) + + +# --------------------------------------------------------------------------- +# Schemata +# --------------------------------------------------------------------------- + + +class TridiagBlock(BaseModel): + """A tridiagonal (or diagonal) block stored as diagonal vectors.""" + + sub: Differentiable[Array[(None,), Float64]] | None = Field( + default=None, description="Sub-diagonal, length n-1. None for diagonal blocks." + ) + main: Differentiable[Array[(None,), Float64]] = Field( + description="Main diagonal, length n." + ) + sup: Differentiable[Array[(None,), Float64]] | None = Field( + default=None, + description="Super-diagonal, length n-1. None for diagonal blocks.", + ) + + +class InputSchema(BaseModel): + blocks: list[list[TridiagBlock | None]] = Field( + description="Block structure as nested list. None = zero block. " + "Example for 3x3 arrow: " + "[[A00, A01, A02], [A10, A11, None], [A20, None, A22]]", + ) + b: Differentiable[Array[(None,), Float64]] = Field( + description="Right-hand side vector, length sum(block_sizes).", + ) + block_sizes: list[int] = Field( + description="Size of each block group.", + ) + + +class OutputSchema(BaseModel): + x: Differentiable[Array[(None,), Float64]] = Field( + description="Solution vector, length sum(block_sizes).", + ) + + +# --------------------------------------------------------------------------- +# Schema-driven path extraction (generic, no need to modify) +# --------------------------------------------------------------------------- + + +def _path_tuple_to_pattern(path_tuple: tuple) -> str: + """Convert a path tuple with sentinels to a pattern string with [] and {}.""" + parts = [] + for part in path_tuple: + if part is SEQ_INDEX_SENTINEL: + parts.append("[]") + elif part is DICT_INDEX_SENTINEL: + parts.append("{}") + else: + parts.append(str(part)) + return ".".join(parts) + + +_ALL_PATTERNS = [ + _path_tuple_to_pattern(p) for p in get_all_model_path_patterns(InputSchema) +] +_DIFF_PATTERNS = [ + _path_tuple_to_pattern(p) + for p in get_all_model_path_patterns(InputSchema, filter_fn=is_differentiable) +] + + +def _expand_inputs( + inputs_dict: dict, +) -> tuple[list[np.ndarray], list[str], list[Any], list[str]]: + """Expand schema paths against concrete inputs, split into diff and non-diff. + + Returns: + diff_args: list of differentiable arrays (always numpy float64) + diff_paths: corresponding Tesseract paths + non_diff_args: list of non-differentiable leaf values + non_diff_paths: corresponding Tesseract paths + """ + all_paths = [] + for pattern in _ALL_PATTERNS: + all_paths.extend(expand_path_pattern(pattern, inputs_dict)) + + diff_path_set = set() + for pattern in _DIFF_PATTERNS: + diff_path_set.update(expand_path_pattern(pattern, inputs_dict)) + + diff_args, diff_paths = [], [] + non_diff_args, non_diff_paths = [], [] + + for path in all_paths: + value = get_at_path(inputs_dict, path) + if isinstance(value, (dict, list)): + continue + + if path in diff_path_set: + diff_args.append(np.asarray(value, dtype=np.float64)) + diff_paths.append(path) + else: + non_diff_args.append(value) + non_diff_paths.append(path) + + return diff_args, diff_paths, non_diff_args, non_diff_paths + + +# --------------------------------------------------------------------------- +# Required endpoints — modify evaluate and abstract_eval for your solver +# --------------------------------------------------------------------------- + + +def evaluate(inputs_dict: dict) -> dict: + """Call the Julia CHOLMOD solver.""" + diff_args, diff_paths, non_diff_args, non_diff_paths = _expand_inputs(inputs_dict) + result = jl.apply_jl(diff_args, non_diff_args, diff_paths, non_diff_paths) + return {"x": np.asarray(result)} + + +def apply(inputs: InputSchema) -> OutputSchema: + return evaluate(inputs.model_dump()) + + +def abstract_eval(abstract_inputs): + """Output x has length sum(block_sizes).""" + inp = abstract_inputs.model_dump() + total_size = sum(inp["block_sizes"]) + return { + "x": {"shape": [total_size], "dtype": "float64"}, + } + + +# --------------------------------------------------------------------------- +# Gradient endpoints (generic, no need to modify) +# --------------------------------------------------------------------------- + + +def jacobian_vector_product( + inputs: InputSchema, + jvp_inputs: set[str], + jvp_outputs: set[str], + tangent_vector: dict[str, Any], +): + inputs_dict = inputs.model_dump() + diff_args, diff_paths, non_diff_args, non_diff_paths = _expand_inputs(inputs_dict) + + tangent_args = [] + for k, path in enumerate(diff_paths): + if path in jvp_inputs: + tangent_args.append(np.asarray(tangent_vector[path], dtype=np.float64)) + else: + tangent_args.append(np.zeros_like(diff_args[k])) + + jvp_out = np.asarray( + jl.enzyme_jvp( + jl.apply_jl, + diff_args, + non_diff_args, + diff_paths, + non_diff_paths, + tangent_args, + ) + ) + return {p: jvp_out for p in jvp_outputs} + + +def vector_jacobian_product( + inputs: InputSchema, + vjp_inputs: set[str], + vjp_outputs: set[str], + cotangent_vector: dict[str, Any], +): + inputs_dict = inputs.model_dump() + diff_args, diff_paths, non_diff_args, non_diff_paths = _expand_inputs(inputs_dict) + + cotangent_vector = {key: cotangent_vector[key] for key in vjp_outputs} + combined_cotangent = sum( + np.asarray(v, dtype=np.float64) for v in cotangent_vector.values() + ) + + grads = jl.enzyme_vjp( + jl.apply_jl, + diff_args, + non_diff_args, + diff_paths, + non_diff_paths, + combined_cotangent, + ) + + result = {} + for k, path in enumerate(diff_paths): + if path in vjp_inputs: + result[path] = np.asarray(grads[k]) + return result diff --git a/examples/bandedblock_cholmod/tesseract_config.yaml b/examples/bandedblock_cholmod/tesseract_config.yaml new file mode 100644 index 00000000..38ac3b0e --- /dev/null +++ b/examples/bandedblock_cholmod/tesseract_config.yaml @@ -0,0 +1,12 @@ +name: "bandedblock-cholmod" +version: "0.1.0" +description: "Differentiable sparse CHOLMOD solver for SPD block systems with tridiagonal blocks, using Julia Enzyme AD" + +build_config: + target_platform: "native" + + custom_build_steps: + - | + RUN apt-get update && apt-get install -y curl && \ + curl -fsSL https://install.julialang.org | sh -s -- -y && \ + /root/.juliaup/bin/julia -e 'using Pkg; Pkg.add(["Enzyme", "LinearSolve"])' diff --git a/examples/bandedblock_cholmod/tesseract_requirements.txt b/examples/bandedblock_cholmod/tesseract_requirements.txt new file mode 100644 index 00000000..1ada8bed --- /dev/null +++ b/examples/bandedblock_cholmod/tesseract_requirements.txt @@ -0,0 +1 @@ +juliacall diff --git a/examples/bandedblock_cholmod/test_cases/test_abstract_eval.json b/examples/bandedblock_cholmod/test_cases/test_abstract_eval.json new file mode 100644 index 00000000..223c7339 --- /dev/null +++ b/examples/bandedblock_cholmod/test_cases/test_abstract_eval.json @@ -0,0 +1,240 @@ +{ + "endpoint": "abstract_eval", + "expected_outputs": { + "x": { + "shape": [9], + "dtype": "float64" + } + }, + "expected_exception": null, + "expected_exception_regex": null, + "atol": 1e-8, + "rtol": 1e-5, + "payload": { + "inputs": { + "blocks": [ + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAEEAAAAAAAAAQQAAAAAAAABBA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZuT+amZmZmZm5Pw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZuT+amZmZmZm5Pw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA4D8AAAAAAADgPwAAAAAAAOA/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMz0z8zMzMzMzPTPzMzMzMzM9M/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + } + } + ], + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA4D8AAAAAAADgPwAAAAAAAOA/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAFEAAAAAAAAAUQAAAAAAAABRA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZyT+amZmZmZnJPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZyT+amZmZmZnJPw==", + "encoding": "base64" + } + } + }, + null + ], + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMz0z8zMzMzMzPTPzMzMzMzM9M/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + } + }, + null, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAGEAAAAAAAAAYQAAAAAAAABhA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMzwz8zMzMzMzPDPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMzwz8zMzMzMzPDPw==", + "encoding": "base64" + } + } + } + ] + ], + "b": { + "object_type": "array", + "shape": [9], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA8D8AAAAAAAAAQAAAAAAAAAhAAAAAAAAAEEAAAAAAAAAUQAAAAAAAABhAAAAAAAAAHEAAAAAAAAAgQAAAAAAAACJA", + "encoding": "base64" + } + }, + "block_sizes": [3, 3, 3] + } + } +} diff --git a/examples/bandedblock_cholmod/test_cases/test_apply.json b/examples/bandedblock_cholmod/test_cases/test_apply.json new file mode 100644 index 00000000..081e6b90 --- /dev/null +++ b/examples/bandedblock_cholmod/test_cases/test_apply.json @@ -0,0 +1,245 @@ +{ + "endpoint": "apply", + "expected_outputs": { + "x": { + "object_type": "array", + "shape": [9], + "dtype": "float64", + "data": { + "buffer": "eOYpetp1pj9HKspYPX3OP9QuzujzdN4//ufHfoI86D+EKdelA63sPyUhWgiz0/E/n7Yanlkc8j/82zzGKhL0P1Xvt6Q0Gfc/", + "encoding": "base64" + } + } + }, + "expected_exception": null, + "expected_exception_regex": null, + "atol": 1e-10, + "rtol": 1e-8, + "payload": { + "inputs": { + "blocks": [ + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAEEAAAAAAAAAQQAAAAAAAABBA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZuT+amZmZmZm5Pw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZuT+amZmZmZm5Pw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA4D8AAAAAAADgPwAAAAAAAOA/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMz0z8zMzMzMzPTPzMzMzMzM9M/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + } + } + ], + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA4D8AAAAAAADgPwAAAAAAAOA/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAFEAAAAAAAAAUQAAAAAAAABRA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZyT+amZmZmZnJPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZyT+amZmZmZnJPw==", + "encoding": "base64" + } + } + }, + null + ], + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMz0z8zMzMzMzPTPzMzMzMzM9M/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + } + }, + null, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAGEAAAAAAAAAYQAAAAAAAABhA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMzwz8zMzMzMzPDPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMzwz8zMzMzMzPDPw==", + "encoding": "base64" + } + } + } + ] + ], + "b": { + "object_type": "array", + "shape": [9], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA8D8AAAAAAAAAQAAAAAAAAAhAAAAAAAAAEEAAAAAAAAAUQAAAAAAAABhAAAAAAAAAHEAAAAAAAAAgQAAAAAAAACJA", + "encoding": "base64" + } + }, + "block_sizes": [3, 3, 3] + } + } +} diff --git a/examples/bandedblock_cholmod/test_cases/test_jvp_wrt_b.json b/examples/bandedblock_cholmod/test_cases/test_jvp_wrt_b.json new file mode 100644 index 00000000..53e057c6 --- /dev/null +++ b/examples/bandedblock_cholmod/test_cases/test_jvp_wrt_b.json @@ -0,0 +1,258 @@ +{ + "endpoint": "jacobian_vector_product", + "expected_outputs": { + "x": { + "object_type": "array", + "shape": [9], + "dtype": "float64", + "data": { + "buffer": "/BTkhBdG0D8hquHMia13v1Y4cdiQdSM/SnpzJZHwmb83a8fR71RPv9jVw2fFSRU/lywGntjxib8KNEhsEt5Fv+qKEdVMAgQ/", + "encoding": "base64" + } + } + }, + "expected_exception": null, + "expected_exception_regex": null, + "atol": 1e-6, + "rtol": 0.0001, + "payload": { + "inputs": { + "blocks": [ + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAEEAAAAAAAAAQQAAAAAAAABBA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZuT+amZmZmZm5Pw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZuT+amZmZmZm5Pw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA4D8AAAAAAADgPwAAAAAAAOA/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMz0z8zMzMzMzPTPzMzMzMzM9M/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + } + } + ], + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA4D8AAAAAAADgPwAAAAAAAOA/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAFEAAAAAAAAAUQAAAAAAAABRA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZyT+amZmZmZnJPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZyT+amZmZmZnJPw==", + "encoding": "base64" + } + } + }, + null + ], + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMz0z8zMzMzMzPTPzMzMzMzM9M/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + } + }, + null, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAGEAAAAAAAAAYQAAAAAAAABhA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMzwz8zMzMzMzPDPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMzwz8zMzMzMzPDPw==", + "encoding": "base64" + } + } + } + ] + ], + "b": { + "object_type": "array", + "shape": [9], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA8D8AAAAAAAAAQAAAAAAAAAhAAAAAAAAAEEAAAAAAAAAUQAAAAAAAABhAAAAAAAAAHEAAAAAAAAAgQAAAAAAAACJA", + "encoding": "base64" + } + }, + "block_sizes": [3, 3, 3] + }, + "jvp_inputs": ["b"], + "jvp_outputs": ["x"], + "tangent_vector": { + "b": { + "object_type": "array", + "shape": [9], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA8D8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", + "encoding": "base64" + } + } + } + } +} diff --git a/examples/bandedblock_cholmod/test_cases/test_vjp_wrt_b.json b/examples/bandedblock_cholmod/test_cases/test_vjp_wrt_b.json new file mode 100644 index 00000000..356a000e --- /dev/null +++ b/examples/bandedblock_cholmod/test_cases/test_vjp_wrt_b.json @@ -0,0 +1,258 @@ +{ + "endpoint": "vector_jacobian_product", + "expected_outputs": { + "b": { + "object_type": "array", + "shape": [9], + "dtype": "float64", + "data": { + "buffer": "1qBEzCbFyj/kg597h9LJP9agRMwmxco/hJURA9nVxT/+8rTVS7zEP4SVEQPZ1cU/V6VO10Vlwz8wvcj/As7CP1mlTtdFZcM/", + "encoding": "base64" + } + } + }, + "expected_exception": null, + "expected_exception_regex": null, + "atol": 1e-6, + "rtol": 0.0001, + "payload": { + "inputs": { + "blocks": [ + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAEEAAAAAAAAAQQAAAAAAAABBA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZuT+amZmZmZm5Pw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZuT+amZmZmZm5Pw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA4D8AAAAAAADgPwAAAAAAAOA/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMz0z8zMzMzMzPTPzMzMzMzM9M/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + } + } + ], + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA4D8AAAAAAADgPwAAAAAAAOA/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZqT+amZmZmZmpPw==", + "encoding": "base64" + } + } + }, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAFEAAAAAAAAAUQAAAAAAAABRA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZyT+amZmZmZnJPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "mpmZmZmZyT+amZmZmZnJPw==", + "encoding": "base64" + } + } + }, + null + ], + [ + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMz0z8zMzMzMzPTPzMzMzMzM9M/", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "uB6F61G4nj+4HoXrUbiePw==", + "encoding": "base64" + } + } + }, + null, + { + "main": { + "object_type": "array", + "shape": [3], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAAGEAAAAAAAAAYQAAAAAAAABhA", + "encoding": "base64" + } + }, + "sub": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMzwz8zMzMzMzPDPw==", + "encoding": "base64" + } + }, + "sup": { + "object_type": "array", + "shape": [2], + "dtype": "float64", + "data": { + "buffer": "MzMzMzMzwz8zMzMzMzPDPw==", + "encoding": "base64" + } + } + } + ] + ], + "b": { + "object_type": "array", + "shape": [9], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA8D8AAAAAAAAAQAAAAAAAAAhAAAAAAAAAEEAAAAAAAAAUQAAAAAAAABhAAAAAAAAAHEAAAAAAAAAgQAAAAAAAACJA", + "encoding": "base64" + } + }, + "block_sizes": [3, 3, 3] + }, + "vjp_inputs": ["b"], + "vjp_outputs": ["x"], + "cotangent_vector": { + "x": { + "object_type": "array", + "shape": [9], + "dtype": "float64", + "data": { + "buffer": "AAAAAAAA8D8AAAAAAADwPwAAAAAAAPA/AAAAAAAA8D8AAAAAAADwPwAAAAAAAPA/AAAAAAAA8D8AAAAAAADwPwAAAAAAAPA/", + "encoding": "base64" + } + } + } + } +} diff --git a/tesseract_core/runtime/testing/finite_differences.py b/tesseract_core/runtime/testing/finite_differences.py index 24238384..61905950 100644 --- a/tesseract_core/runtime/testing/finite_differences.py +++ b/tesseract_core/runtime/testing/finite_differences.py @@ -58,6 +58,11 @@ def _handle_part( parts: Sequence[str], current_inputs: Any, current_path: list[str] ) -> list[str]: """Recursively expand each part separately.""" + if current_inputs is None: + # None means this branch doesn't exist (e.g. None entry in a + # list, or an Optional field that is absent). No paths here. + return [] + if not parts: return [".".join(current_path)] @@ -79,9 +84,12 @@ def _handle_part( ) paths.extend(subpaths) else: - subpaths = _handle_part( - parts[1:], current_inputs[part], [*current_path, part] + value = ( + current_inputs.get(part) + if isinstance(current_inputs, dict) + else getattr(current_inputs, part, None) ) + subpaths = _handle_part(parts[1:], value, [*current_path, part]) paths.extend(subpaths) return paths diff --git a/tesseract_core/sdk/engine.py b/tesseract_core/sdk/engine.py index fd2b32c2..3e26dd93 100644 --- a/tesseract_core/sdk/engine.py +++ b/tesseract_core/sdk/engine.py @@ -422,6 +422,13 @@ def init_api( "tesseract_requirements.txt", target_dir, template_vars, recipe=Path(recipe) ) + # Julia recipe: copy Julia source files + if recipe == "julia": + _write_template_file("apply.jl", target_dir, template_vars, recipe=Path(recipe)) + _write_template_file( + "enzyme_wrappers.jl", target_dir, template_vars, recipe=Path(recipe) + ) + return target_dir / "tesseract_api.py" diff --git a/tesseract_core/sdk/templates/julia/apply.jl b/tesseract_core/sdk/templates/julia/apply.jl new file mode 100644 index 00000000..7636a4e8 --- /dev/null +++ b/tesseract_core/sdk/templates/julia/apply.jl @@ -0,0 +1,21 @@ +# Core computation for {{name}}. +# +# Contract: apply_jl receives all inputs split into diff and non-diff, +# with paths describing each value's role. Replace the body with your solver. +# +# Arguments: +# diff_args::Vector{Vector{Float64}} — differentiable arrays +# non_diff_args::Vector{Any} — static values (ints, strings, etc.) +# diff_paths::Vector{String} — Tesseract path for each diff arg +# non_diff_paths::Vector{String} — Tesseract path for each non-diff arg + +function apply_jl( + diff_args::Vector{Vector{Float64}}, + non_diff_args::Vector{Any}, + diff_paths::Vector{String}, + non_diff_paths::Vector{String}, +)::Vector{Float64} + # Example: square the first (and only) differentiable input. + # Replace with your solver. + return diff_args[1] .^ 2 +end diff --git a/tesseract_core/sdk/templates/julia/enzyme_wrappers.jl b/tesseract_core/sdk/templates/julia/enzyme_wrappers.jl new file mode 100644 index 00000000..43aa1d0b --- /dev/null +++ b/tesseract_core/sdk/templates/julia/enzyme_wrappers.jl @@ -0,0 +1,59 @@ +# Generic Enzyme AD wrappers for Tesseract Julia recipes. +# +# These work with any apply_jl that follows the contract: +# apply_jl(diff_args, non_diff_args, diff_paths, non_diff_paths) -> Vector{Float64} +# +# IMPORTANT: All Python objects (PyArray, PyList, PyString) must be converted +# to native Julia types before Enzyme sees them. Enzyme traces at the LLVM +# level and cannot differentiate through PythonCall's conversion internals. + +using LinearAlgebra +using Enzyme +using PythonCall: pyconvert + +function _to_jl_vecs(pyargs) + return [Vector{Float64}(a) for a in pyargs] +end + +function _to_jl_strings(pyargs) + return String[pyconvert(String, s) for s in pyargs] +end + +function _to_jl_any(pyargs) + return Any[pyconvert(Any, a) for a in pyargs] +end + +""" + enzyme_jvp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, tangents) + +Forward-mode AD. Returns the JVP output vector. +""" +function enzyme_jvp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, tangents) + jl_args = _to_jl_vecs(diff_args) + jl_tangents = _to_jl_vecs(tangents) + jl_non_diff = _to_jl_any(non_diff_args) + jl_diff_paths = _to_jl_strings(diff_paths) + jl_non_diff_paths = _to_jl_strings(non_diff_paths) + closure(d...) = apply_fn(collect(d), jl_non_diff, jl_diff_paths, jl_non_diff_paths) + dups = [Enzyme.Duplicated(jl_args[i], jl_tangents[i]) for i in eachindex(jl_args)] + return Enzyme.autodiff(set_runtime_activity(Enzyme.Forward), Enzyme.Const(closure), dups...)[1] +end + +""" + enzyme_vjp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, cotangent) + +Reverse-mode AD. Returns a list of gradients, one per element of diff_args. +""" +function enzyme_vjp(apply_fn, diff_args, non_diff_args, diff_paths, non_diff_paths, cotangent) + jl_args = _to_jl_vecs(diff_args) + jl_cotangent = Vector{Float64}(cotangent) + jl_non_diff = _to_jl_any(non_diff_args) + jl_diff_paths = _to_jl_strings(diff_paths) + jl_non_diff_paths = _to_jl_strings(non_diff_paths) + closure(d...) = apply_fn(collect(d), jl_non_diff, jl_diff_paths, jl_non_diff_paths) + shadows = [zero(a) for a in jl_args] + dups = [Enzyme.Duplicated(jl_args[i], shadows[i]) for i in eachindex(jl_args)] + scalar_f(d...) = dot(jl_cotangent, closure(d...)) + Enzyme.autodiff(set_runtime_activity(Enzyme.Reverse), Enzyme.Const(scalar_f), Enzyme.Active, dups...) + return shadows +end diff --git a/tesseract_core/sdk/templates/julia/tesseract_api.py b/tesseract_core/sdk/templates/julia/tesseract_api.py new file mode 100644 index 00000000..969124bf --- /dev/null +++ b/tesseract_core/sdk/templates/julia/tesseract_api.py @@ -0,0 +1,203 @@ +# Copyright 2025 Pasteur Labs. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Tesseract API module for {{name}} +# Generated by tesseract {{version}} on {{timestamp}} + +from pathlib import Path +from typing import Any + +import numpy as np +from juliacall import Main as jl +from pydantic import BaseModel + +from tesseract_core.runtime import Array, Differentiable, Float64 +from tesseract_core.runtime.schema_generation import ( + DICT_INDEX_SENTINEL, + SEQ_INDEX_SENTINEL, + get_all_model_path_patterns, +) +from tesseract_core.runtime.schema_types import is_differentiable +from tesseract_core.runtime.testing.finite_differences import expand_path_pattern +from tesseract_core.runtime.tree_transforms import get_at_path + +# --------------------------------------------------------------------------- +# Julia setup — load solver and Enzyme wrappers from .jl files +# --------------------------------------------------------------------------- + +_here = Path(__file__).parent +jl.include(str(_here / "apply.jl")) +jl.include(str(_here / "enzyme_wrappers.jl")) + +# --------------------------------------------------------------------------- +# Schemata +# --------------------------------------------------------------------------- + + +class InputSchema(BaseModel): + x: Differentiable[Array[(None,), Float64]] + + +class OutputSchema(BaseModel): + y: Differentiable[Array[(None,), Float64]] + + +# --------------------------------------------------------------------------- +# Schema-driven path extraction (generic, no need to modify) +# --------------------------------------------------------------------------- + + +def _path_tuple_to_pattern(path_tuple: tuple) -> str: + """Convert a path tuple with sentinels to a pattern string with [] and {}.""" + parts = [] + for part in path_tuple: + if part is SEQ_INDEX_SENTINEL: + parts.append("[]") + elif part is DICT_INDEX_SENTINEL: + parts.append("{}") + else: + parts.append(str(part)) + return ".".join(parts) + + +# Pre-compute path patterns from schema at module load time +_ALL_PATTERNS = [ + _path_tuple_to_pattern(p) for p in get_all_model_path_patterns(InputSchema) +] +_DIFF_PATTERNS = [ + _path_tuple_to_pattern(p) + for p in get_all_model_path_patterns(InputSchema, filter_fn=is_differentiable) +] + + +def _expand_inputs( + inputs_dict: dict, +) -> tuple[list[np.ndarray], list[str], list[Any], list[str]]: + """Expand schema paths against concrete inputs, split into diff and non-diff. + + Returns: + diff_args: list of differentiable arrays (always numpy float64) + diff_paths: corresponding Tesseract paths + non_diff_args: list of non-differentiable leaf values + non_diff_paths: corresponding Tesseract paths + """ + # Expand all concrete paths + all_paths = [] + for pattern in _ALL_PATTERNS: + all_paths.extend(expand_path_pattern(pattern, inputs_dict)) + + # Expand diff paths + diff_path_set = set() + for pattern in _DIFF_PATTERNS: + diff_path_set.update(expand_path_pattern(pattern, inputs_dict)) + + # Split by differentiability, skip intermediate nodes (dicts/lists) + diff_args, diff_paths = [], [] + non_diff_args, non_diff_paths = [], [] + + for path in all_paths: + value = get_at_path(inputs_dict, path) + if isinstance(value, (dict, list)): + continue # intermediate node, not a leaf + + if path in diff_path_set: + diff_args.append(np.asarray(value, dtype=np.float64)) + diff_paths.append(path) + else: + non_diff_args.append(value) + non_diff_paths.append(path) + + return diff_args, diff_paths, non_diff_args, non_diff_paths + + +# --------------------------------------------------------------------------- +# Required endpoints — modify evaluate and abstract_eval for your solver +# --------------------------------------------------------------------------- + + +def evaluate(inputs_dict: dict) -> dict: + """Call the Julia function.""" + diff_args, diff_paths, non_diff_args, non_diff_paths = _expand_inputs(inputs_dict) + result = jl.apply_jl(diff_args, non_diff_args, diff_paths, non_diff_paths) + return {"y": np.asarray(result)} + + +def apply(inputs: InputSchema) -> OutputSchema: + return evaluate(inputs.model_dump()) + + +def abstract_eval(abstract_inputs): + """Calculate output shapes from input shapes. + + Must be implemented explicitly since JAX cannot trace through Julia code. + """ + inp = abstract_inputs.model_dump() + return { + "y": {"shape": inp["x"]["shape"], "dtype": inp["x"]["dtype"]}, + } + + +# --------------------------------------------------------------------------- +# Gradient endpoints (generic, no need to modify) +# --------------------------------------------------------------------------- + + +def jacobian_vector_product( + inputs: InputSchema, + jvp_inputs: set[str], + jvp_outputs: set[str], + tangent_vector: dict[str, Any], +): + inputs_dict = inputs.model_dump() + diff_args, diff_paths, non_diff_args, non_diff_paths = _expand_inputs(inputs_dict) + + # Build tangent for each diff arg: provided tangent or zero + tangent_args = [] + for k, path in enumerate(diff_paths): + if path in jvp_inputs: + tangent_args.append(np.asarray(tangent_vector[path], dtype=np.float64)) + else: + tangent_args.append(np.zeros_like(diff_args[k])) + + jvp_out = np.asarray( + jl.enzyme_jvp( + jl.apply_jl, + diff_args, + non_diff_args, + diff_paths, + non_diff_paths, + tangent_args, + ) + ) + return {p: jvp_out for p in jvp_outputs} + + +def vector_jacobian_product( + inputs: InputSchema, + vjp_inputs: set[str], + vjp_outputs: set[str], + cotangent_vector: dict[str, Any], +): + inputs_dict = inputs.model_dump() + diff_args, diff_paths, non_diff_args, non_diff_paths = _expand_inputs(inputs_dict) + + cotangent_vector = {key: cotangent_vector[key] for key in vjp_outputs} + combined_cotangent = sum( + np.asarray(v, dtype=np.float64) for v in cotangent_vector.values() + ) + + grads = jl.enzyme_vjp( + jl.apply_jl, + diff_args, + non_diff_args, + diff_paths, + non_diff_paths, + combined_cotangent, + ) + + # Map gradients back to requested paths + result = {} + for k, path in enumerate(diff_paths): + if path in vjp_inputs: + result[path] = np.asarray(grads[k]) + return result diff --git a/tesseract_core/sdk/templates/julia/tesseract_config.yaml b/tesseract_core/sdk/templates/julia/tesseract_config.yaml new file mode 100644 index 00000000..ce55b295 --- /dev/null +++ b/tesseract_core/sdk/templates/julia/tesseract_config.yaml @@ -0,0 +1,35 @@ +# Tesseract configuration file +# Generated by tesseract {{version}} on {{timestamp}} + +name: "{{name}}" +version: "0.1.0" +description: "" + +# Arbitrary user-defined metadata +# metadata: +# key: value + +build_config: + # Base image to use for the container, must be Ubuntu or Debian-based + # base_image: "debian:bookworm-slim" + + # Platform to build the container for. In general, images can only be executed + # on the platform they were built for. + target_platform: "native" + + # Additional packages to install in the container (via apt-get) + # Julia is installed via custom_build_steps below + # extra_packages: + # - package_name + + # Data to copy into the container, relative to the project root + # package_data: + # - [path/to/source, path/to/destination] + + # Install Julia and precompile packages in the container + custom_build_steps: + - | + RUN apt-get update && apt-get install -y curl && \ + curl -fsSL https://install.julialang.org | sh -s -- -y && \ + /root/.juliaup/bin/julia -e 'using Pkg; Pkg.add(["Enzyme"])' + # TODO: Add your Julia packages to the Pkg.add list above diff --git a/tesseract_core/sdk/templates/julia/tesseract_requirements.txt b/tesseract_core/sdk/templates/julia/tesseract_requirements.txt new file mode 100644 index 00000000..43017393 --- /dev/null +++ b/tesseract_core/sdk/templates/julia/tesseract_requirements.txt @@ -0,0 +1,9 @@ +# Tesseract requirements file +# Generated by tesseract {{version}} on {{timestamp}} + +# juliacall provides the Python <-> Julia bridge (zero-copy array sharing) +juliacall + +# This may contain private dependencies via SSH URLs: +# git+ssh://git@github.com/username/repo.git@branch +# (use `tesseract build --forward-ssh-agent` to grant the builder access to your SSH keys) From 3245d22e57c89cf5586c274b9bc742a3fd98ad81 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Thu, 2 Apr 2026 16:49:45 +0100 Subject: [PATCH 2/3] add test cases to config --- tests/endtoend_tests/test_examples.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/endtoend_tests/test_examples.py b/tests/endtoend_tests/test_examples.py index c8db19a4..505cff96 100644 --- a/tests/endtoend_tests/test_examples.py +++ b/tests/endtoend_tests/test_examples.py @@ -105,6 +105,7 @@ class Config: "filereference": Config(input_path="test_cases/testdata", output_path="output"), "metrics": Config(test_with_random_inputs=True), "qp_solve": Config(), + "bandedblock_cholmod": Config(), "tesseractreference": Config(), # Can't test requests standalone; needs target Tesseract. Covered in separate test. "userhandling": Config(), } From 328a66169f149a4222830f604eaf342229ec1d49 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Thu, 2 Apr 2026 23:42:58 +0100 Subject: [PATCH 3/3] fix offline build --- .../bandedblock_cholmod/julia/Project.toml | 6 ++++ .../bandedblock_cholmod/{ => julia}/apply.jl | 0 .../{ => julia}/enzyme_wrappers.jl | 0 examples/bandedblock_cholmod/tesseract_api.py | 4 +-- .../bandedblock_cholmod/tesseract_config.yaml | 24 +++++++++---- tesseract_core/sdk/engine.py | 12 ++++--- .../sdk/templates/julia/julia/Project.toml | 4 +++ .../sdk/templates/julia/{ => julia}/apply.jl | 0 .../julia/{ => julia}/enzyme_wrappers.jl | 0 .../sdk/templates/julia/tesseract_api.py | 4 +-- .../sdk/templates/julia/tesseract_config.yaml | 34 ++++++++----------- 11 files changed, 53 insertions(+), 35 deletions(-) create mode 100644 examples/bandedblock_cholmod/julia/Project.toml rename examples/bandedblock_cholmod/{ => julia}/apply.jl (100%) rename examples/bandedblock_cholmod/{ => julia}/enzyme_wrappers.jl (100%) create mode 100644 tesseract_core/sdk/templates/julia/julia/Project.toml rename tesseract_core/sdk/templates/julia/{ => julia}/apply.jl (100%) rename tesseract_core/sdk/templates/julia/{ => julia}/enzyme_wrappers.jl (100%) diff --git a/examples/bandedblock_cholmod/julia/Project.toml b/examples/bandedblock_cholmod/julia/Project.toml new file mode 100644 index 00000000..a5deae57 --- /dev/null +++ b/examples/bandedblock_cholmod/julia/Project.toml @@ -0,0 +1,6 @@ +[deps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/examples/bandedblock_cholmod/apply.jl b/examples/bandedblock_cholmod/julia/apply.jl similarity index 100% rename from examples/bandedblock_cholmod/apply.jl rename to examples/bandedblock_cholmod/julia/apply.jl diff --git a/examples/bandedblock_cholmod/enzyme_wrappers.jl b/examples/bandedblock_cholmod/julia/enzyme_wrappers.jl similarity index 100% rename from examples/bandedblock_cholmod/enzyme_wrappers.jl rename to examples/bandedblock_cholmod/julia/enzyme_wrappers.jl diff --git a/examples/bandedblock_cholmod/tesseract_api.py b/examples/bandedblock_cholmod/tesseract_api.py index ebc8b5b9..dff814ae 100644 --- a/examples/bandedblock_cholmod/tesseract_api.py +++ b/examples/bandedblock_cholmod/tesseract_api.py @@ -49,8 +49,8 @@ # --------------------------------------------------------------------------- _here = Path(__file__).parent -jl.include(str(_here / "apply.jl")) -jl.include(str(_here / "enzyme_wrappers.jl")) +jl.include(str(_here / "julia" / "enzyme_wrappers.jl")) +jl.include(str(_here / "julia" / "apply.jl")) # --------------------------------------------------------------------------- diff --git a/examples/bandedblock_cholmod/tesseract_config.yaml b/examples/bandedblock_cholmod/tesseract_config.yaml index 38ac3b0e..798ec462 100644 --- a/examples/bandedblock_cholmod/tesseract_config.yaml +++ b/examples/bandedblock_cholmod/tesseract_config.yaml @@ -1,12 +1,24 @@ name: "bandedblock-cholmod" version: "0.1.0" -description: "Differentiable sparse CHOLMOD solver for SPD block systems with tridiagonal blocks, using Julia Enzyme AD" +description: | + Differentiable sparse CHOLMOD solver for SPD block systems with + tridiagonal blocks. Uses Julia's LinearSolve.jl with Enzyme AD. + + Demonstrates wrapping a Julia sparse solver as a differentiable Tesseract + that Python consumers can call — with gradients — without installing Julia. build_config: - target_platform: "native" + base_image: "julia:1.11-bookworm" + + # Copy Julia project and solver source into the container + package_data: + - ["julia/", "julia/"] custom_build_steps: - - | - RUN apt-get update && apt-get install -y curl && \ - curl -fsSL https://install.julialang.org | sh -s -- -y && \ - /root/.juliaup/bin/julia -e 'using Pkg; Pkg.add(["Enzyme", "LinearSolve"])' + # Use a shared Julia depot so packages are accessible to any runtime user. + - "ENV JULIA_DEPOT_PATH=/tesseract/.julia" + # Ensure precompiled cache is portable across CPUs of the same architecture. + - 'RUN if [ "$(uname -m)" = "x86_64" ]; then export JULIA_CPU_TARGET="generic;sandybridge,-xsaveopt,clone_all;haswell,-rdrnd,base(1)"; else export JULIA_CPU_TARGET="generic"; fi && mkdir -p /tesseract/.julia && julia --project=/tesseract/julia -e "import Pkg; Pkg.instantiate(); Pkg.precompile()"' + - "RUN chmod -R 777 /tesseract/.julia" + # Tell JuliaCall to use the system Julia and skip its own package management. + - "ENV PYTHON_JULIACALL_EXE=/usr/local/julia/bin/julia PYTHON_JULIACALL_PROJECT=/tesseract/julia PYTHON_JULIAPKG_OFFLINE=yes JULIA_PKG_PRECOMPILE_AUTO=0" diff --git a/tesseract_core/sdk/engine.py b/tesseract_core/sdk/engine.py index 3e26dd93..c004e29b 100644 --- a/tesseract_core/sdk/engine.py +++ b/tesseract_core/sdk/engine.py @@ -422,12 +422,14 @@ def init_api( "tesseract_requirements.txt", target_dir, template_vars, recipe=Path(recipe) ) - # Julia recipe: copy Julia source files + # Julia recipe: copy Julia project directory (solver + Enzyme wrappers + Project.toml) if recipe == "julia": - _write_template_file("apply.jl", target_dir, template_vars, recipe=Path(recipe)) - _write_template_file( - "enzyme_wrappers.jl", target_dir, template_vars, recipe=Path(recipe) - ) + julia_dir = target_dir / "julia" + julia_dir.mkdir(parents=True, exist_ok=True) + for jl_file in ("apply.jl", "enzyme_wrappers.jl", "Project.toml"): + _write_template_file( + f"julia/{jl_file}", target_dir, template_vars, recipe=Path(recipe) + ) return target_dir / "tesseract_api.py" diff --git a/tesseract_core/sdk/templates/julia/julia/Project.toml b/tesseract_core/sdk/templates/julia/julia/Project.toml new file mode 100644 index 00000000..d8c00a40 --- /dev/null +++ b/tesseract_core/sdk/templates/julia/julia/Project.toml @@ -0,0 +1,4 @@ +[deps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" diff --git a/tesseract_core/sdk/templates/julia/apply.jl b/tesseract_core/sdk/templates/julia/julia/apply.jl similarity index 100% rename from tesseract_core/sdk/templates/julia/apply.jl rename to tesseract_core/sdk/templates/julia/julia/apply.jl diff --git a/tesseract_core/sdk/templates/julia/enzyme_wrappers.jl b/tesseract_core/sdk/templates/julia/julia/enzyme_wrappers.jl similarity index 100% rename from tesseract_core/sdk/templates/julia/enzyme_wrappers.jl rename to tesseract_core/sdk/templates/julia/julia/enzyme_wrappers.jl diff --git a/tesseract_core/sdk/templates/julia/tesseract_api.py b/tesseract_core/sdk/templates/julia/tesseract_api.py index 969124bf..1b4addec 100644 --- a/tesseract_core/sdk/templates/julia/tesseract_api.py +++ b/tesseract_core/sdk/templates/julia/tesseract_api.py @@ -26,8 +26,8 @@ # --------------------------------------------------------------------------- _here = Path(__file__).parent -jl.include(str(_here / "apply.jl")) -jl.include(str(_here / "enzyme_wrappers.jl")) +jl.include(str(_here / "julia" / "enzyme_wrappers.jl")) +jl.include(str(_here / "julia" / "apply.jl")) # --------------------------------------------------------------------------- # Schemata diff --git a/tesseract_core/sdk/templates/julia/tesseract_config.yaml b/tesseract_core/sdk/templates/julia/tesseract_config.yaml index ce55b295..7e9027e8 100644 --- a/tesseract_core/sdk/templates/julia/tesseract_config.yaml +++ b/tesseract_core/sdk/templates/julia/tesseract_config.yaml @@ -5,31 +5,25 @@ name: "{{name}}" version: "0.1.0" description: "" -# Arbitrary user-defined metadata -# metadata: -# key: value - build_config: - # Base image to use for the container, must be Ubuntu or Debian-based - # base_image: "debian:bookworm-slim" + # Use the official Julia base image (includes Julia + stdlib) + base_image: "julia:1.11-bookworm" # Platform to build the container for. In general, images can only be executed # on the platform they were built for. target_platform: "native" - # Additional packages to install in the container (via apt-get) - # Julia is installed via custom_build_steps below - # extra_packages: - # - package_name - - # Data to copy into the container, relative to the project root - # package_data: - # - [path/to/source, path/to/destination] + # Copy Julia project and solver source into the container + package_data: + - ["julia/", "julia/"] - # Install Julia and precompile packages in the container custom_build_steps: - - | - RUN apt-get update && apt-get install -y curl && \ - curl -fsSL https://install.julialang.org | sh -s -- -y && \ - /root/.juliaup/bin/julia -e 'using Pkg; Pkg.add(["Enzyme"])' - # TODO: Add your Julia packages to the Pkg.add list above + # Use a shared Julia depot so packages are accessible to any runtime user. + - "ENV JULIA_DEPOT_PATH=/tesseract/.julia" + # Ensure precompiled cache is portable across CPUs of the same architecture. + - 'RUN if [ "$(uname -m)" = "x86_64" ]; then export JULIA_CPU_TARGET="generic;sandybridge,-xsaveopt,clone_all;haswell,-rdrnd,base(1)"; else export JULIA_CPU_TARGET="generic"; fi && mkdir -p /tesseract/.julia && julia --project=/tesseract/julia -e "import Pkg; Pkg.instantiate(); Pkg.precompile()"' + - "RUN chmod -R 777 /tesseract/.julia" + # Tell JuliaCall to use the system Julia and skip its own package management. + # JULIA_PKG_PRECOMPILE_AUTO=0 prevents Julia from recompiling packages at + # runtime when the container UID differs from the build UID. + - "ENV PYTHON_JULIACALL_EXE=/usr/local/julia/bin/julia PYTHON_JULIACALL_PROJECT=/tesseract/julia PYTHON_JULIAPKG_OFFLINE=yes JULIA_PKG_PRECOMPILE_AUTO=0"