Skip to content

feat(jit): implement @pl.jit preprocessor over @pl.program pipeline#915

Open
huxinyuan1215 wants to merge 1 commit intohw-native-sys:mainfrom
huxinyuan1215:feat/jit-preprocessor
Open

feat(jit): implement @pl.jit preprocessor over @pl.program pipeline#915
huxinyuan1215 wants to merge 1 commit intohw-native-sys:mainfrom
huxinyuan1215:feat/jit-preprocessor

Conversation

@huxinyuan1215
Copy link
Copy Markdown
Contributor

Summary

Implements issue #878: a JIT preprocessor that sits above the existing @pl.program pipeline. Users write a generic kernel once with @pl.jit; at call-time, JIT specializes it to a concrete @pl.program by filling in actual tensor shapes and dtypes, then runs the normal compiler pipeline.

  • python/pypto/jit/decorator.py: @jit and @jit.incore decorators — collect runtime tensor metadata, discover @jit.incore dependencies, drive specialization, maintain an L1 compilation cache
  • python/pypto/jit/specializer.py: AST rewriter — fills concrete shapes/dtypes into type annotations, inlines scalar constants, handles pl.dynamic DynVars, rewrites dep calls to self.method() style for Style B
  • python/pypto/jit/cache.py: L1 cache keyed on (source_hash, shapes, dtypes, scalar_values, dynamic_dims) to skip recompilation for repeated calls with the same specialization
  • python/pypto/language/__init__.py: expose jit and JITFunction in the pl.* namespace
  • python/pypto/language/typing/tensor.py: add Tensor.bind_dynamic() no-op used by the specializer to statically detect dynamic dimensions

Testing

  • tests/ut/jit/test_cache.py: cache key stability and hit/miss behavior
  • tests/ut/jit/test_decorator.py: dep discovery, dynamic-dim scanning, call-arg extraction, cache integration, accurate OSError on missing source
  • tests/ut/jit/test_specializer.py: annotation filling, Out params, dynvar emission, shape/dtype substitution, generated program structure
  • tests/ut/jit/test_roundtrip.py: 23 structural-equality round-trip tests — each examples/kernels/ program expressible as @pl.jit is re-written in JIT style and verified with ir.assert_structural_equal against the hand-written @pl.program

All 90 new tests pass.

Related Issues

Fixes #878

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a JIT compilation subsystem under python/pypto/jit/: decorator, AST-based specializer, compilation cache, language typing hook for dynamic dims, and extensive unit/integration tests to validate specialization, caching, dependency discovery, and IR round-trips.

Changes

Cohort / File(s) Summary
JIT Core
python/pypto/jit/__init__.py
Package initializer exporting JITFunction and jit.
Cache
python/pypto/jit/cache.py
Defines TensorCacheInfo, ScalarCacheInfo, CacheKey, compute_source_hash (SHA-256 truncated), and make_cache_key assembling cache keys with dynamic-dim handling.
Decorator / API
python/pypto/jit/decorator.py
Implements JITFunction and jit API: lazy dependency discovery (AST), arg classification (tensors vs scalars), cache-key construction, cache lookup/store, specialization orchestration, and parsing to ir.Program.
Specializer
python/pypto/jit/specializer.py
AST transformer and Specializer that emits pl.program/pl.function source: removes bind_dynamic, substitutes concrete shapes/dtypes or dynvar refs, emits module-level dynvars, rewrites dependency calls, and produces parseable DSL source.
Language typing
python/pypto/language/__init__.py, python/pypto/language/typing/tensor.py
Re-exports jit/JITFunction from pypto.language. Adds Tensor.bind_dynamic(dim, var) as a runtime no-op used by the specializer.
Tests — Cache & Conftest
tests/ut/jit/conftest.py, tests/ut/jit/test_cache.py, tests/ut/jit/__init__.py
Pytest setup (project root path) and unit tests for compute_source_hash and make_cache_key, verifying order sensitivity, dynamic-dim masking, scalar inclusion, and hashability.
Tests — Decorator & Round-trip
tests/ut/jit/test_decorator.py, tests/ut/jit/test_roundtrip.py
Decorator behavior, cache hit/miss semantics (torch-backed tests skipped if torch missing), dependency discovery (Style B), end-to-end compilation, and structural equality checks vs hand-written @pl.program examples.
Tests — Specializer
tests/ut/jit/test_specializer.py
Unit tests for AST analyses and _BodyTransformer rewrites: dynamic-dim collection, dynvar extraction, param classification, shape/dtype literal substitution, dependency-call rewriting, and generated source parseability/IR equality.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant JIT as JITFunction
    participant Cache as Cache Layer
    participant Specializer as Specializer
    participant Parser as Parser
    participant Program as ir.Program

    User->>JIT: call jit_kernel(torch_tensor, scalar)
    JIT->>JIT: extract shapes/dtypes, classify args
    JIT->>JIT: compute source_hash + cache key (mask dyn dims)
    JIT->>Cache: lookup(cache_key)

    alt cache hit
        Cache-->>JIT: cached Program
        JIT-->>User: return Program
    else cache miss
        JIT->>JIT: discover incore deps via AST
        JIT->>Specializer: build contexts (deps first, entry last)
        Specializer->>Specializer: rewrite AST (remove bind_dynamic, substitute shapes/dtypes, emit dynvars)
        Specializer-->>JIT: specialized DSL source
        JIT->>Parser: parse(DSL) -> Program
        Parser-->>Program: ir.Program
        JIT->>Cache: store(cache_key, Program)
        JIT-->>User: return Program
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • lyfne123

Poem

🐰
I nibble AST leaves in moonlight bright,
Bind dynamic hops, then shapes take flight,
Sources hashed, the cache chest hums soft—
Kernels compile, the carrots lift aloft! 🥕✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 38.30% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: implementing a JIT preprocessor (@pl.jit) for the PyPTO @pl.program pipeline, which is the primary objective of this pull request.
Description check ✅ Passed The description comprehensively covers the implementation details, testing approach, and related issues, accurately reflecting the changeset's scope and objectives.
Linked Issues check ✅ Passed The pull request successfully implements all core requirements from issue #878: @pl.jit decorator, shape/dtype extraction, compilation cache with dynamic dims, Tensor.bind_dynamic API, dep discovery, and extensive test coverage (90 tests) validating decorator behavior, caching, specializer correctness, and structural equality.
Out of Scope Changes check ✅ Passed All changes are directly aligned with issue #878 requirements: new JIT modules (decorator, specializer, cache), Tensor.bind_dynamic addition, @pl.jit/@pl.jit.incore exports, and comprehensive test suite; no unrelated modifications detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a JIT compilation module for PyPTO, providing a @pl.jit decorator that enables automatic specialization and compilation of kernel functions based on tensor shapes and dtypes. The implementation includes a compilation cache, an AST specializer to transform JIT-style code into @pl.program source, and necessary updates to the pl.Tensor API. My review suggests optimizing the _is_tensor function to avoid redundant calls to _get_torch and updating _is_tensor_annotation to consistently handle both native and pl.Tensor types.

Comment on lines +109 to +112
torch = _get_torch()
if torch is None:
return False
return isinstance(obj, torch.Tensor)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _is_tensor function calls _get_torch() which performs a try-except import. This is redundant because _get_torch() is already called inside _extract_tensor_meta and _torch_dtype_to_pypto. Since _get_torch caches the result, it is better to just call _get_torch() once and check the result, or simply use isinstance directly if torch is already imported in the module scope if possible, or rely on the cached result.

is_scalar_ann = _is_scalar_annotation(outer)
if is_scalar_ann:
scalar_dtype_strs[name] = _ast_to_str(inner)
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _is_tensor_annotation function only checks for the native Tensor type. It should be updated to also check for the custom pl.Tensor type to be consistent with how types are resolved in the DSL (similar to how pl.Scalar and pl.Tuple are handled). This ensures that type resolution correctly identifies tensor annotations regardless of whether the native or custom type is used.

References
  1. When checking for nested tuple types during type resolution, ensure the check covers both native tuple (which may resolve to a list) and the custom pl.Tuple type (which resolves to ir.TupleType).

@huxinyuan1215 huxinyuan1215 force-pushed the feat/jit-preprocessor branch from 572087d to 3de82d2 Compare April 8, 2026 15:24
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
python/pypto/jit/decorator.py (2)

580-586: Consider adding strict=True for closure variable extraction.

The co_freevars and __closure__ tuples are guaranteed by Python to have matching lengths, but adding strict=True would make this invariant explicit and catch any future edge cases.

♻️ Optional: add strict=True
-    for name, cell in zip(co_freevars, closure):
+    for name, cell in zip(co_freevars, closure, strict=True):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/decorator.py` around lines 580 - 586, The loop extracting
closure variables (using co_freevars = getattr(getattr(func, "__code__", None),
"co_freevars", ()) and closure = getattr(func, "__closure__", None) or () and
for name, cell in zip(co_freevars, closure): ...) should use zip(...,
strict=True) to assert the freevar/closure length invariant; update the for-loop
to for name, cell in zip(co_freevars, closure, strict=True): so mismatched
lengths raise immediately and still catch ValueError from cell.cell_contents as
before (ensure runtime Python supports zip(strict=True)).

520-530: Potential silent truncation when dependency has more parameters than positional call args.

If the entry function calls a dependency with keyword arguments or fewer positional args than the dependency expects, zip(dep_param_names, call_args) will silently truncate, leaving some dependency parameters without metadata. This could cause downstream issues in specialization.

Consider adding strict=True or handling the length mismatch explicitly:

♻️ Proposed fix to handle length mismatch
         if call_args is not None:
-            for dep_param, entry_arg in zip(dep_param_names, call_args):
+            for dep_param, entry_arg in zip(dep_param_names, call_args, strict=False):
+                # Note: strict=False is intentional - call_args may be shorter if
+                # some args are passed as kwargs. Unmatched params fall through
+                # to name-based matching below.
                 if entry_arg is None:
                     continue
                 if entry_arg in all_tensor_meta:

Alternatively, if partial positional matching followed by name-based fallback is intended, document this behavior.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/decorator.py` around lines 520 - 530, The loop over
zip(dep_param_names, call_args) can silently truncate when a dependency expects
more parameters than provided as positional args (or when kwargs are used);
update the logic in the block handling dep_param_names/call_args to detect
length mismatches and handle them explicitly: either iterate over
dep_param_names and index into call_args with bounds checks (filling remaining
params from keyword mapping if available) or raise/log an explicit error when
required dep parameters lack metadata, and ensure dep_tensor_meta,
dep_scalar_values, and dep_scalar_dtypes are filled from
entry_scalar_values/entry_scalar_dtypes or all_tensor_meta by parameter name
fallback rather than relying solely on zip truncation. Ensure you touch the code
referencing dep_param_names, call_args, dep_tensor_meta, all_tensor_meta,
entry_scalar_values, and entry_scalar_dtypes.
tests/ut/jit/test_specializer.py (1)

172-180: Add one specialize() assertion for scalar substitution.

These tests prove _classify_params() recognizes pl.INDEX / pl.FP32, but nothing here feeds concrete scalar values into specialize() and checks that something like BLOCK_M or alpha becomes a literal in the generated source. One focused case would cover the biggest remaining gap in the new JIT contract.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ut/jit/test_specializer.py` around lines 172 - 180, The test currently
verifies _classify_params() recognizes pl.INDEX/pl.FP32 but doesn't assert that
specialize() substitutes scalar params into the generated source; update
test_scalar_bare_dtype to call specialize() on the parsed function (use concrete
values e.g. BLOCK_M=16 and alpha=0.5), get the generated source from the
specialization result, and assert that the source contains the literal "16" for
BLOCK_M and "0.5" (or equivalent literal) for alpha so the scalar substitution
is validated; keep references to _parse_func(func_def), _classify_params(), and
specialize() when locating code to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/pypto/jit/specializer.py`:
- Around line 165-186: _collect_dynvar_names currently returns only Python
identifiers which lets specialize recreate pl.dynamic(...) using the local name;
instead capture and return the original literal used in the call (e.g., map
variable name -> literal string when the Call has a Constant/Str arg) so that
specialize can re-emit pl.dynamic(original_literal) instead of
pl.dynamic(variable_name); update _collect_dynvar_names to return a dict[str,
str] (or similar) and change specialize (and the other similar site that handles
dynvars) to consult that mapping when reconstructing pl.dynamic calls.
- Around line 279-295: The constructor stores scalar_values on self._scalars but
nothing uses it; update the _BodyTransformer to actually specialize compile-time
scalars by replacing references to parameter names with their concrete values
when walking the AST (e.g., implement handling in visit_Name / visit_Constant or
add a helper like _specialize_scalar_usage called from existing visit_*
methods). Ensure you read self._scalars (keys are parameter names) and replace
occurrences of pl.INDEX / pl.FP32-style compile-time params with literal values
so the transformed body reflects the scalar and the cache key then matches the
actual IR.

In `@tests/ut/jit/test_roundtrip.py`:
- Around line 66-122: The tests create tile_add/tile_mul JIT functions that
load/store 128x128 tensors but compare against reference IR fixtures
TileAddProgram and TileMulProgram which are 32x32, causing structural
mismatches; update the test functions (test_tile_add_128x128 and
test_tile_mul_128x128) to match the referenced programs by either (A) changing
the loads, stores, and test tensor shapes to 32x32 (adjust tile_a/tile_b load
sizes and created tensors a, b, c) or (B) importing or referencing the correct
128x128 fixtures if those exist; likewise scan the other failing tests that
compare against FusedAddScaleProgram and TileAssembleAccMatProgram and align
their JIT bodies (e.g., apply the 0.5 scale in the fused case and accept a
precomputed acc tile signature in the assemble case) so the orchestrator/tile_*
functions (and their argument shapes) exactly match the structure of the
referenced program symbols before calling ir.assert_structural_equal().

---

Nitpick comments:
In `@python/pypto/jit/decorator.py`:
- Around line 580-586: The loop extracting closure variables (using co_freevars
= getattr(getattr(func, "__code__", None), "co_freevars", ()) and closure =
getattr(func, "__closure__", None) or () and for name, cell in zip(co_freevars,
closure): ...) should use zip(..., strict=True) to assert the freevar/closure
length invariant; update the for-loop to for name, cell in zip(co_freevars,
closure, strict=True): so mismatched lengths raise immediately and still catch
ValueError from cell.cell_contents as before (ensure runtime Python supports
zip(strict=True)).
- Around line 520-530: The loop over zip(dep_param_names, call_args) can
silently truncate when a dependency expects more parameters than provided as
positional args (or when kwargs are used); update the logic in the block
handling dep_param_names/call_args to detect length mismatches and handle them
explicitly: either iterate over dep_param_names and index into call_args with
bounds checks (filling remaining params from keyword mapping if available) or
raise/log an explicit error when required dep parameters lack metadata, and
ensure dep_tensor_meta, dep_scalar_values, and dep_scalar_dtypes are filled from
entry_scalar_values/entry_scalar_dtypes or all_tensor_meta by parameter name
fallback rather than relying solely on zip truncation. Ensure you touch the code
referencing dep_param_names, call_args, dep_tensor_meta, all_tensor_meta,
entry_scalar_values, and entry_scalar_dtypes.

In `@tests/ut/jit/test_specializer.py`:
- Around line 172-180: The test currently verifies _classify_params() recognizes
pl.INDEX/pl.FP32 but doesn't assert that specialize() substitutes scalar params
into the generated source; update test_scalar_bare_dtype to call specialize() on
the parsed function (use concrete values e.g. BLOCK_M=16 and alpha=0.5), get the
generated source from the specialization result, and assert that the source
contains the literal "16" for BLOCK_M and "0.5" (or equivalent literal) for
alpha so the scalar substitution is validated; keep references to
_parse_func(func_def), _classify_params(), and specialize() when locating code
to change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 8fecc048-92c4-4d96-aa29-285e3f44a771

📥 Commits

Reviewing files that changed from the base of the PR and between 08bdedd and 572087d.

📒 Files selected for processing (12)
  • python/pypto/jit/__init__.py
  • python/pypto/jit/cache.py
  • python/pypto/jit/decorator.py
  • python/pypto/jit/specializer.py
  • python/pypto/language/__init__.py
  • python/pypto/language/typing/tensor.py
  • tests/ut/jit/__init__.py
  • tests/ut/jit/conftest.py
  • tests/ut/jit/test_cache.py
  • tests/ut/jit/test_decorator.py
  • tests/ut/jit/test_roundtrip.py
  • tests/ut/jit/test_specializer.py

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
python/pypto/jit/specializer.py (2)

279-295: ⚠️ Potential issue | 🟠 Major

Scalar values affect the cache key, but not the generated body.

scalar_values is stored on _BodyTransformer, but no visitor ever reads self._scalars. Different pl.INDEX / pl.FP32 call values therefore generate different cache entries while producing the same IR, and scalar-dependent DSL constructs never get specialized by value. Please inline scalar loads and add a regression in tests/ut/jit/test_specializer.py.

Minimal fix
 class _BodyTransformer(ast.NodeTransformer):
@@
     def _shape_dim_node(self, param_name: str, dim_idx: int) -> ast.expr:
         meta = self._meta[param_name]
         if (param_name, dim_idx) in self._dynamic_dims:
             dv = _dynvar_name_for_dim(param_name, dim_idx, self._dv_names)
             return ast.Name(id=dv, ctx=ast.Load())
         return ast.Constant(value=meta.shape[dim_idx])
+
+    def visit_Name(self, node: ast.Name) -> ast.expr:
+        if isinstance(node.ctx, ast.Load) and node.id in self._scalars:
+            return ast.copy_location(ast.Constant(value=self._scalars[node.id]), node)
+        return node
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/specializer.py` around lines 279 - 295, The transformer
currently stores scalar_values on _BodyTransformer as self._scalars but never
uses them, causing cache divergence; update _BodyTransformer to inline scalar
constants by replacing loads of scalar DSL variables with their literal values
during IR generation (e.g., implement logic in the visitor used to walk
expressions — such as visit_Name/visit_Attribute or the method that handles DSL
scalar loads — to check self._scalars and return a Constant node for keys
present), remove reliance on scalar_values from any cache key generation if
present, and add a regression test in tests/ut/jit/test_specializer.py that
creates two specializations differing only by scalar_values and asserts they
produce identical generated bodies/IR.

615-623: ⚠️ Potential issue | 🟠 Major

DynVar emission renames user-defined dynamic symbols.

_iter_dynvar_names() only yields the Python binding name, so line 622 reconstructs pl.dynamic() as rows = pl.dynamic("rows") even when the source was rows = pl.dynamic("M"). That silently changes the symbol name and can also merge unrelated dynvars when different functions reuse the same local variable name. Please carry the original literal through specialization and emit that here instead.

Also applies to: 642-649

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/specializer.py` around lines 615 - 623, The emitted
module-level DynVar declarations are using the Python binding name from
_iter_dynvar_names() which loses the original literal passed to pl.dynamic()
(causing e.g. pl.dynamic("rows") instead of pl.dynamic("M")); update the
pipeline so the original literal is preserved through specialization (e.g.,
store the literal string on the context or the DynVar metadata when you discover
pl.dynamic) and change the emission logic to iterate those preserved literals
(not the local binding names) when building lines (the code around
_iter_dynvar_names and the emission loop that appends f'{dv_varname} =
pl.dynamic("{dv_varname}")'); apply the same fix to the second emission site
mentioned (the 642-649 block) so you emit the exact original string passed to
pl.dynamic() and avoid merging different DynVars that happen to share a local
variable name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/pypto/jit/decorator.py`:
- Around line 369-375: The current _get_source_hash and dependency handling only
includes direct deps, so nested `@jit.incore` dependencies (e.g., dep_a calling
dep_b) are not specialized or hashed; update _get_deps to perform a recursive,
deduplicating walk of dependency contexts and change call-site emission (where
contexts and dep_names are built) to include all transitive deps so emitted
contexts list contains every unique dep; ensure _get_source_hash collects
inspect.getsource for self._func plus the full deduped transitive set from
_get_deps and then calls compute_source_hash on that complete list (apply same
recursion fix to the other occurrences referenced around lines 472-490 and
540-549).
- Around line 239-257: The bug is that _extract_call_args_for_dep (and the same
logic at the other occurrence) only inspects node.args so keyword-only calls
like sub(a=a, c=out) yield an empty list which is treated as "found" and causes
_build_dep_context to take the positional branch; update
_extract_call_args_for_dep to also inspect node.keywords and build a mapping:
for each keyword in node.keywords, append keyword.arg (or None for None arg) at
the corresponding logical position or, simpler, if node.args is empty but
node.keywords is non-empty, return None so the caller falls back; apply the same
change to the duplicate block at the other location (lines referenced in the
comment) and ensure callers (_build_dep_context) correctly interpret a None
return as "no mapping".
- Around line 423-440: The entry-level dynamic_dims computed by
_scan_dynamic_dims must be augmented with any dependency-level bind_dynamic
markers before building the cache key and creating the entry specialization:
after calling _scan_dynamic_dims(self._func, param_names) merge in dynamic-dim
mappings produced/propagated by dependent `@jit.incore` functions (i.e. any
dep-derived dynamic dim map you already compute or can obtain from the
dependency analysis) so that dynamic_dims reflects both entry and dependency
bindings; then pass the merged dynamic_dims into make_cache_key(...) and into
the entry specialization/_compile(...) (and into the SpecializeContext creation)
so dependency-level DynVar bindings affect caching and compilation. Apply the
same merge at the other locations noted (the blocks around lines 479-489 and
537-548).

---

Duplicate comments:
In `@python/pypto/jit/specializer.py`:
- Around line 279-295: The transformer currently stores scalar_values on
_BodyTransformer as self._scalars but never uses them, causing cache divergence;
update _BodyTransformer to inline scalar constants by replacing loads of scalar
DSL variables with their literal values during IR generation (e.g., implement
logic in the visitor used to walk expressions — such as
visit_Name/visit_Attribute or the method that handles DSL scalar loads — to
check self._scalars and return a Constant node for keys present), remove
reliance on scalar_values from any cache key generation if present, and add a
regression test in tests/ut/jit/test_specializer.py that creates two
specializations differing only by scalar_values and asserts they produce
identical generated bodies/IR.
- Around line 615-623: The emitted module-level DynVar declarations are using
the Python binding name from _iter_dynvar_names() which loses the original
literal passed to pl.dynamic() (causing e.g. pl.dynamic("rows") instead of
pl.dynamic("M")); update the pipeline so the original literal is preserved
through specialization (e.g., store the literal string on the context or the
DynVar metadata when you discover pl.dynamic) and change the emission logic to
iterate those preserved literals (not the local binding names) when building
lines (the code around _iter_dynvar_names and the emission loop that appends
f'{dv_varname} = pl.dynamic("{dv_varname}")'); apply the same fix to the second
emission site mentioned (the 642-649 block) so you emit the exact original
string passed to pl.dynamic() and avoid merging different DynVars that happen to
share a local variable name.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 905c2132-c18d-427d-b9e6-adf01bc749df

📥 Commits

Reviewing files that changed from the base of the PR and between 572087d and 3de82d2.

📒 Files selected for processing (12)
  • python/pypto/jit/__init__.py
  • python/pypto/jit/cache.py
  • python/pypto/jit/decorator.py
  • python/pypto/jit/specializer.py
  • python/pypto/language/__init__.py
  • python/pypto/language/typing/tensor.py
  • tests/ut/jit/__init__.py
  • tests/ut/jit/conftest.py
  • tests/ut/jit/test_cache.py
  • tests/ut/jit/test_decorator.py
  • tests/ut/jit/test_roundtrip.py
  • tests/ut/jit/test_specializer.py
✅ Files skipped from review due to trivial changes (5)
  • tests/ut/jit/init.py
  • tests/ut/jit/conftest.py
  • python/pypto/language/typing/tensor.py
  • python/pypto/jit/init.py
  • python/pypto/jit/cache.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/ut/jit/test_cache.py
  • tests/ut/jit/test_decorator.py

Comment on lines +423 to +440
# Scan dynamic dims from this function's AST
dynamic_dims = _scan_dynamic_dims(self._func, param_names)

# Build cache key (based on entry function's params only)
key = make_cache_key(
source_hash=self._get_source_hash(),
param_names=param_names,
tensor_shapes={n: m.shape for n, m in tensor_meta.items()},
tensor_dtypes={n: m.dtype for n, m in tensor_meta.items()},
dynamic_dims=dynamic_dims,
scalar_values=scalar_values,
)

if key in self._cache:
return self._cache[key]

# Cache miss: specialize, parse, cache
program = self._compile(tensor_meta, scalar_values, scalar_dtypes, dynamic_dims, pl)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Dependency-level bind_dynamic() markers never reach the entry specialization.

dynamic_dims is computed from the entry AST only, then reused for both the cache key and the entry SpecializeContext. If an entry param is forwarded into a @jit.incore dep that binds it dynamic, the dep sees a DynVar but the public entry signature and cache key stay concrete, so Style B dynamic dims compile/cache as fixed-shape programs. Merge the mapped dep dynamic dims back onto the entry-param set before building the key and the entry context.

Also applies to: 479-489, 537-548

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/decorator.py` around lines 423 - 440, The entry-level
dynamic_dims computed by _scan_dynamic_dims must be augmented with any
dependency-level bind_dynamic markers before building the cache key and creating
the entry specialization: after calling _scan_dynamic_dims(self._func,
param_names) merge in dynamic-dim mappings produced/propagated by dependent
`@jit.incore` functions (i.e. any dep-derived dynamic dim map you already compute
or can obtain from the dependency analysis) so that dynamic_dims reflects both
entry and dependency bindings; then pass the merged dynamic_dims into
make_cache_key(...) and into the entry specialization/_compile(...) (and into
the SpecializeContext creation) so dependency-level DynVar bindings affect
caching and compilation. Apply the same merge at the other locations noted (the
blocks around lines 479-489 and 537-548).

@huxinyuan1215 huxinyuan1215 force-pushed the feat/jit-preprocessor branch 5 times, most recently from 9b6b353 to df06df4 Compare April 9, 2026 02:49
Add @pl.jit decorator that specializes JIT-decorated functions into
@pl.program source at call-time based on concrete tensor shapes/dtypes.

Core components:
- python/pypto/jit/decorator.py: @jit and @jit.incore decorators, dep
  discovery, tensor-meta collection, L1 specialization cache
- python/pypto/jit/specializer.py: AST rewriter that fills in concrete
  shapes/dtypes, inlines scalar constants, handles pl.dynamic DynVars,
  rewrites dep calls to self.method() style
- python/pypto/jit/cache.py: L1 cache keyed on (source_hash, shapes,
  dtypes, scalar_values, dynamic_dims)
- python/pypto/language/__init__.py: expose jit and JITFunction in pl.*
- python/pypto/language/typing/tensor.py: add Tensor.bind_dynamic() no-op
  for marking dynamic dimensions in JIT kernels

Tests:
- tests/ut/jit/test_cache.py: cache key stability and hit/miss behavior
- tests/ut/jit/test_decorator.py: dep discovery, dynamic-dim scanning,
  call-arg extraction, cache integration, OSError on missing source
- tests/ut/jit/test_specializer.py: annotation filling, Out params,
  dynvar emission, shape/dtype substitution, structural generation
- tests/ut/jit/test_roundtrip.py: 23 structural-equality tests covering
  all expressible examples/kernels/ programs (01_elementwise through
  08_assemble), verifying JIT output matches hand-written @pl.program IR
@huxinyuan1215 huxinyuan1215 force-pushed the feat/jit-preprocessor branch from df06df4 to 44b5260 Compare April 9, 2026 02:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[RFC] JIT Compilation Interface for PyPTO

1 participant