feat: add torch_compat module for native torch support#550
feat: add torch_compat module for native torch support#550dionhaefner wants to merge 7 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #550 +/- ##
==========================================
+ Coverage 77.16% 77.34% +0.18%
==========================================
Files 32 34 +2
Lines 4418 4525 +107
Branches 728 745 +17
==========================================
+ Hits 3409 3500 +91
- Misses 714 719 +5
- Partials 295 306 +11 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Benchmark ResultsBenchmarks use a no-op Tesseract to measure pure framework overhead. 🚀 0 faster, ✅ No significant performance changes detected. Full results
|
MatteoSalvador
left a comment
There was a problem hiding this comment.
Thanks @dionhaefner, this is really nice! I have a couple of considerations apart from the nits I left:
- I understand that
tesseract-torchwould be overkill for this, but wouldn't it be confusing to have a dedicated repo fortesseract-jaxusers while the torch-equivalent is deep intotesseract-core? - Since
autogradfunctions support both custom JVPs and VJPs at the same time, I was wondering if those endpoints are picked automatically according to the specific external calls (e.g.backward()vs.forward_ad) and how error handling is done if torch-AD calls forward or reverse mode on a Tesseract that does not have JVP or VJP implemented.
| @@ -0,0 +1,185 @@ | |||
| (pytorch-integration)= | |||
There was a problem hiding this comment.
| (pytorch-integration)= |
|
|
||
| ```python | ||
| result = apply_tesseract(tesseract, { | ||
| "x": x_tensor, # differentiable — tracked by autograd |
There was a problem hiding this comment.
nit: maybe I would declare x_tensor explicitly beforehand?
| result = apply_tesseract(tesseract, { | ||
| "x": x_tensor, # differentiable — tracked by autograd | ||
| "A": np.eye(3, dtype=np.float32), # static — not tracked | ||
| "b": torch.zeros(3), # differentiable |
There was a problem hiding this comment.
Is b not tracked by autograd because it's not stored in a variable beforehand?
| ``` | ||
| forward backward (reverse-mode) | ||
| ┌──────────┐ ──────► ┌──────────────┐ ──────► ┌──────────┐ | ||
| │ PyTorch │ │ Tesseract │ │ PyTorch │ |
There was a problem hiding this comment.
super nit: Is it possible to align those pipes on the right border to make cubes for nicer rendering?
|
|
||
| Provides :func:`apply_tesseract`, which wraps any Tesseract as a differentiable | ||
| PyTorch operation supporting both reverse-mode (``.backward()``) and forward-mode | ||
| (``torch.autograd.forward_ad``) automatic differentiation. |
There was a problem hiding this comment.
Is still forward_ad in beta mode or not for torch 2.11? If that's the case, I would signal that functionalities and interfaces may change (although very unlikely).
| when inputs require grad) and non-differentiable outputs as-is | ||
| (NumPy arrays or scalars). | ||
|
|
||
| Example:: |
There was a problem hiding this comment.
| Example:: | |
| Example: |
Relevant issue or PR
n/a
Description of changes
Adds a single module
tesseract_core.torch_compatinstead of a standalone Tesseract-Torch project (which would really be overkill here). It's basically a singletorch.autograd.Functionclass with custombackwardandjvpimplementations, plus anapply_tesseractfunction for API compatibility with Tesseract-JAX.The reason why torch is simpler is that (1) it doesn't require abstract eval, (2) autograd functions support both custom JVPs and VJPs at the same time, (3) torch isn't picky about static vs. traced args.
I'm not a torch expert by any means so this will need another set of eyes for sure.
Testing done
Basic tests pass on CI, NN training w/ Tesseract in the loop works on my machine.