Skip to content

feat: add torch_compat module for native torch support#550

Open
dionhaefner wants to merge 7 commits intomainfrom
dion/torch-support
Open

feat: add torch_compat module for native torch support#550
dionhaefner wants to merge 7 commits intomainfrom
dion/torch-support

Conversation

@dionhaefner
Copy link
Copy Markdown
Contributor

@dionhaefner dionhaefner commented Apr 1, 2026

Relevant issue or PR

n/a

Description of changes

Adds a single module tesseract_core.torch_compat instead of a standalone Tesseract-Torch project (which would really be overkill here). It's basically a single torch.autograd.Function class with custom backward and jvp implementations, plus an apply_tesseract function for API compatibility with Tesseract-JAX.

The reason why torch is simpler is that (1) it doesn't require abstract eval, (2) autograd functions support both custom JVPs and VJPs at the same time, (3) torch isn't picky about static vs. traced args.

I'm not a torch expert by any means so this will need another set of eyes for sure.

Testing done

Basic tests pass on CI, NN training w/ Tesseract in the loop works on my machine.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 1, 2026

Codecov Report

❌ Patch coverage is 85.04673% with 16 lines in your changes missing coverage. Please review.
✅ Project coverage is 77.34%. Comparing base (61a8140) to head (e11ce4b).

Files with missing lines Patch % Lines
tesseract_core/torch_compat/function.py 84.76% 5 Missing and 11 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #550      +/-   ##
==========================================
+ Coverage   77.16%   77.34%   +0.18%     
==========================================
  Files          32       34       +2     
  Lines        4418     4525     +107     
  Branches      728      745      +17     
==========================================
+ Hits         3409     3500      +91     
- Misses        714      719       +5     
- Partials      295      306      +11     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@PasteurBot
Copy link
Copy Markdown
Contributor

PasteurBot commented Apr 1, 2026

Benchmark Results

Benchmarks use a no-op Tesseract to measure pure framework overhead.

🚀 0 faster, ⚠️ 0 slower, ✅ 36 unchanged

✅ No significant performance changes detected.

Full results
Benchmark Baseline Current Change Status
api/apply_1,000 0.778ms 0.784ms +0.8%
api/apply_100,000 0.783ms 0.775ms -1.0%
api/apply_10,000,000 0.789ms 0.775ms -1.8%
cli/apply_1,000 1638.038ms 1644.055ms +0.4%
cli/apply_100,000 1647.515ms 1644.475ms -0.2%
cli/apply_10,000,000 1699.322ms 1687.266ms -0.7%
decoding/base64_1,000 0.038ms 0.037ms -2.1%
decoding/base64_100,000 0.893ms 0.891ms -0.2%
decoding/base64_10,000,000 98.387ms 98.930ms +0.6%
decoding/binref_1,000 0.202ms 0.201ms -0.6%
decoding/binref_100,000 0.243ms 0.240ms -1.2%
decoding/binref_10,000,000 10.388ms 10.446ms +0.6%
decoding/json_1,000 0.109ms 0.108ms -0.9%
decoding/json_100,000 8.978ms 8.874ms -1.2%
decoding/json_10,000,000 1070.470ms 1069.608ms -0.1%
encoding/base64_1,000 0.041ms 0.041ms +1.0%
encoding/base64_100,000 0.146ms 0.148ms +1.1%
encoding/base64_10,000,000 24.531ms 24.783ms +1.0%
encoding/binref_1,000 0.306ms 0.304ms -0.4%
encoding/binref_100,000 0.482ms 0.484ms +0.5%
encoding/binref_10,000,000 18.281ms 18.286ms +0.0%
encoding/json_1,000 0.153ms 0.157ms +2.3%
encoding/json_100,000 13.319ms 13.678ms +2.7%
encoding/json_10,000,000 1421.158ms 1453.543ms +2.3%
http/apply_1,000 3.681ms 3.693ms +0.3%
http/apply_100,000 9.001ms 9.085ms +0.9%
http/apply_10,000,000 754.185ms 774.807ms +2.7%
roundtrip/base64_1,000 0.089ms 0.089ms +0.2%
roundtrip/base64_100,000 1.048ms 1.050ms +0.1%
roundtrip/base64_10,000,000 123.297ms 123.055ms -0.2%
roundtrip/binref_1,000 0.518ms 0.520ms +0.4%
roundtrip/binref_100,000 0.727ms 0.718ms -1.2%
roundtrip/binref_10,000,000 29.214ms 29.136ms -0.3%
roundtrip/json_1,000 0.297ms 0.277ms -6.8%
roundtrip/json_100,000 22.144ms 20.276ms -8.4%
roundtrip/json_10,000,000 2716.775ms 2524.034ms -7.1%
  • Runner: Linux 6.17.0-1008-azure x86_64

@dionhaefner dionhaefner marked this pull request as ready for review April 2, 2026 07:50
Copy link
Copy Markdown
Contributor

@MatteoSalvador MatteoSalvador left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dionhaefner, this is really nice! I have a couple of considerations apart from the nits I left:

  1. I understand that tesseract-torch would be overkill for this, but wouldn't it be confusing to have a dedicated repo for tesseract-jax users while the torch-equivalent is deep into tesseract-core?
  2. Since autograd functions support both custom JVPs and VJPs at the same time, I was wondering if those endpoints are picked automatically according to the specific external calls (e.g. backward() vs. forward_ad) and how error handling is done if torch-AD calls forward or reverse mode on a Tesseract that does not have JVP or VJP implemented.

@@ -0,0 +1,185 @@
(pytorch-integration)=
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(pytorch-integration)=


```python
result = apply_tesseract(tesseract, {
"x": x_tensor, # differentiable — tracked by autograd
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe I would declare x_tensor explicitly beforehand?

result = apply_tesseract(tesseract, {
"x": x_tensor, # differentiable — tracked by autograd
"A": np.eye(3, dtype=np.float32), # static — not tracked
"b": torch.zeros(3), # differentiable
Copy link
Copy Markdown
Contributor

@MatteoSalvador MatteoSalvador Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is b not tracked by autograd because it's not stored in a variable beforehand?

```
forward backward (reverse-mode)
┌──────────┐ ──────► ┌──────────────┐ ──────► ┌──────────┐
│ PyTorch │ │ Tesseract │ │ PyTorch │
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: Is it possible to align those pipes on the right border to make cubes for nicer rendering?


Provides :func:`apply_tesseract`, which wraps any Tesseract as a differentiable
PyTorch operation supporting both reverse-mode (``.backward()``) and forward-mode
(``torch.autograd.forward_ad``) automatic differentiation.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is still forward_ad in beta mode or not for torch 2.11? If that's the case, I would signal that functionalities and interfaces may change (although very unlikely).

when inputs require grad) and non-differentiable outputs as-is
(NumPy arrays or scalars).

Example::
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Example::
Example:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants