feat(jit): implement @pl.jit preprocessor over @pl.program pipeline#915
feat(jit): implement @pl.jit preprocessor over @pl.program pipeline#915huxinyuan1215 wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a JIT compilation subsystem under Changes
Sequence DiagramsequenceDiagram
participant User as User Code
participant JIT as JITFunction
participant Cache as Cache Layer
participant Specializer as Specializer
participant Parser as Parser
participant Program as ir.Program
User->>JIT: call jit_kernel(torch_tensor, scalar)
JIT->>JIT: extract shapes/dtypes, classify args
JIT->>JIT: compute source_hash + cache key (mask dyn dims)
JIT->>Cache: lookup(cache_key)
alt cache hit
Cache-->>JIT: cached Program
JIT-->>User: return Program
else cache miss
JIT->>JIT: discover incore deps via AST
JIT->>Specializer: build contexts (deps first, entry last)
Specializer->>Specializer: rewrite AST (remove bind_dynamic, substitute shapes/dtypes, emit dynvars)
Specializer-->>JIT: specialized DSL source
JIT->>Parser: parse(DSL) -> Program
Parser-->>Program: ir.Program
JIT->>Cache: store(cache_key, Program)
JIT-->>User: return Program
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a JIT compilation module for PyPTO, providing a @pl.jit decorator that enables automatic specialization and compilation of kernel functions based on tensor shapes and dtypes. The implementation includes a compilation cache, an AST specializer to transform JIT-style code into @pl.program source, and necessary updates to the pl.Tensor API. My review suggests optimizing the _is_tensor function to avoid redundant calls to _get_torch and updating _is_tensor_annotation to consistently handle both native and pl.Tensor types.
| torch = _get_torch() | ||
| if torch is None: | ||
| return False | ||
| return isinstance(obj, torch.Tensor) |
There was a problem hiding this comment.
The _is_tensor function calls _get_torch() which performs a try-except import. This is redundant because _get_torch() is already called inside _extract_tensor_meta and _torch_dtype_to_pypto. Since _get_torch caches the result, it is better to just call _get_torch() once and check the result, or simply use isinstance directly if torch is already imported in the module scope if possible, or rely on the cached result.
| is_scalar_ann = _is_scalar_annotation(outer) | ||
| if is_scalar_ann: | ||
| scalar_dtype_strs[name] = _ast_to_str(inner) | ||
| continue |
There was a problem hiding this comment.
The _is_tensor_annotation function only checks for the native Tensor type. It should be updated to also check for the custom pl.Tensor type to be consistent with how types are resolved in the DSL (similar to how pl.Scalar and pl.Tuple are handled). This ensures that type resolution correctly identifies tensor annotations regardless of whether the native or custom type is used.
References
- When checking for nested tuple types during type resolution, ensure the check covers both native tuple (which may resolve to a list) and the custom pl.Tuple type (which resolves to ir.TupleType).
572087d to
3de82d2
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (3)
python/pypto/jit/decorator.py (2)
580-586: Consider addingstrict=Truefor closure variable extraction.The
co_freevarsand__closure__tuples are guaranteed by Python to have matching lengths, but addingstrict=Truewould make this invariant explicit and catch any future edge cases.♻️ Optional: add strict=True
- for name, cell in zip(co_freevars, closure): + for name, cell in zip(co_freevars, closure, strict=True):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@python/pypto/jit/decorator.py` around lines 580 - 586, The loop extracting closure variables (using co_freevars = getattr(getattr(func, "__code__", None), "co_freevars", ()) and closure = getattr(func, "__closure__", None) or () and for name, cell in zip(co_freevars, closure): ...) should use zip(..., strict=True) to assert the freevar/closure length invariant; update the for-loop to for name, cell in zip(co_freevars, closure, strict=True): so mismatched lengths raise immediately and still catch ValueError from cell.cell_contents as before (ensure runtime Python supports zip(strict=True)).
520-530: Potential silent truncation when dependency has more parameters than positional call args.If the entry function calls a dependency with keyword arguments or fewer positional args than the dependency expects,
zip(dep_param_names, call_args)will silently truncate, leaving some dependency parameters without metadata. This could cause downstream issues in specialization.Consider adding
strict=Trueor handling the length mismatch explicitly:♻️ Proposed fix to handle length mismatch
if call_args is not None: - for dep_param, entry_arg in zip(dep_param_names, call_args): + for dep_param, entry_arg in zip(dep_param_names, call_args, strict=False): + # Note: strict=False is intentional - call_args may be shorter if + # some args are passed as kwargs. Unmatched params fall through + # to name-based matching below. if entry_arg is None: continue if entry_arg in all_tensor_meta:Alternatively, if partial positional matching followed by name-based fallback is intended, document this behavior.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@python/pypto/jit/decorator.py` around lines 520 - 530, The loop over zip(dep_param_names, call_args) can silently truncate when a dependency expects more parameters than provided as positional args (or when kwargs are used); update the logic in the block handling dep_param_names/call_args to detect length mismatches and handle them explicitly: either iterate over dep_param_names and index into call_args with bounds checks (filling remaining params from keyword mapping if available) or raise/log an explicit error when required dep parameters lack metadata, and ensure dep_tensor_meta, dep_scalar_values, and dep_scalar_dtypes are filled from entry_scalar_values/entry_scalar_dtypes or all_tensor_meta by parameter name fallback rather than relying solely on zip truncation. Ensure you touch the code referencing dep_param_names, call_args, dep_tensor_meta, all_tensor_meta, entry_scalar_values, and entry_scalar_dtypes.tests/ut/jit/test_specializer.py (1)
172-180: Add onespecialize()assertion for scalar substitution.These tests prove
_classify_params()recognizespl.INDEX/pl.FP32, but nothing here feeds concrete scalar values intospecialize()and checks that something likeBLOCK_Moralphabecomes a literal in the generated source. One focused case would cover the biggest remaining gap in the new JIT contract.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ut/jit/test_specializer.py` around lines 172 - 180, The test currently verifies _classify_params() recognizes pl.INDEX/pl.FP32 but doesn't assert that specialize() substitutes scalar params into the generated source; update test_scalar_bare_dtype to call specialize() on the parsed function (use concrete values e.g. BLOCK_M=16 and alpha=0.5), get the generated source from the specialization result, and assert that the source contains the literal "16" for BLOCK_M and "0.5" (or equivalent literal) for alpha so the scalar substitution is validated; keep references to _parse_func(func_def), _classify_params(), and specialize() when locating code to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@python/pypto/jit/specializer.py`:
- Around line 165-186: _collect_dynvar_names currently returns only Python
identifiers which lets specialize recreate pl.dynamic(...) using the local name;
instead capture and return the original literal used in the call (e.g., map
variable name -> literal string when the Call has a Constant/Str arg) so that
specialize can re-emit pl.dynamic(original_literal) instead of
pl.dynamic(variable_name); update _collect_dynvar_names to return a dict[str,
str] (or similar) and change specialize (and the other similar site that handles
dynvars) to consult that mapping when reconstructing pl.dynamic calls.
- Around line 279-295: The constructor stores scalar_values on self._scalars but
nothing uses it; update the _BodyTransformer to actually specialize compile-time
scalars by replacing references to parameter names with their concrete values
when walking the AST (e.g., implement handling in visit_Name / visit_Constant or
add a helper like _specialize_scalar_usage called from existing visit_*
methods). Ensure you read self._scalars (keys are parameter names) and replace
occurrences of pl.INDEX / pl.FP32-style compile-time params with literal values
so the transformed body reflects the scalar and the cache key then matches the
actual IR.
In `@tests/ut/jit/test_roundtrip.py`:
- Around line 66-122: The tests create tile_add/tile_mul JIT functions that
load/store 128x128 tensors but compare against reference IR fixtures
TileAddProgram and TileMulProgram which are 32x32, causing structural
mismatches; update the test functions (test_tile_add_128x128 and
test_tile_mul_128x128) to match the referenced programs by either (A) changing
the loads, stores, and test tensor shapes to 32x32 (adjust tile_a/tile_b load
sizes and created tensors a, b, c) or (B) importing or referencing the correct
128x128 fixtures if those exist; likewise scan the other failing tests that
compare against FusedAddScaleProgram and TileAssembleAccMatProgram and align
their JIT bodies (e.g., apply the 0.5 scale in the fused case and accept a
precomputed acc tile signature in the assemble case) so the orchestrator/tile_*
functions (and their argument shapes) exactly match the structure of the
referenced program symbols before calling ir.assert_structural_equal().
---
Nitpick comments:
In `@python/pypto/jit/decorator.py`:
- Around line 580-586: The loop extracting closure variables (using co_freevars
= getattr(getattr(func, "__code__", None), "co_freevars", ()) and closure =
getattr(func, "__closure__", None) or () and for name, cell in zip(co_freevars,
closure): ...) should use zip(..., strict=True) to assert the freevar/closure
length invariant; update the for-loop to for name, cell in zip(co_freevars,
closure, strict=True): so mismatched lengths raise immediately and still catch
ValueError from cell.cell_contents as before (ensure runtime Python supports
zip(strict=True)).
- Around line 520-530: The loop over zip(dep_param_names, call_args) can
silently truncate when a dependency expects more parameters than provided as
positional args (or when kwargs are used); update the logic in the block
handling dep_param_names/call_args to detect length mismatches and handle them
explicitly: either iterate over dep_param_names and index into call_args with
bounds checks (filling remaining params from keyword mapping if available) or
raise/log an explicit error when required dep parameters lack metadata, and
ensure dep_tensor_meta, dep_scalar_values, and dep_scalar_dtypes are filled from
entry_scalar_values/entry_scalar_dtypes or all_tensor_meta by parameter name
fallback rather than relying solely on zip truncation. Ensure you touch the code
referencing dep_param_names, call_args, dep_tensor_meta, all_tensor_meta,
entry_scalar_values, and entry_scalar_dtypes.
In `@tests/ut/jit/test_specializer.py`:
- Around line 172-180: The test currently verifies _classify_params() recognizes
pl.INDEX/pl.FP32 but doesn't assert that specialize() substitutes scalar params
into the generated source; update test_scalar_bare_dtype to call specialize() on
the parsed function (use concrete values e.g. BLOCK_M=16 and alpha=0.5), get the
generated source from the specialization result, and assert that the source
contains the literal "16" for BLOCK_M and "0.5" (or equivalent literal) for
alpha so the scalar substitution is validated; keep references to
_parse_func(func_def), _classify_params(), and specialize() when locating code
to change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 8fecc048-92c4-4d96-aa29-285e3f44a771
📒 Files selected for processing (12)
python/pypto/jit/__init__.pypython/pypto/jit/cache.pypython/pypto/jit/decorator.pypython/pypto/jit/specializer.pypython/pypto/language/__init__.pypython/pypto/language/typing/tensor.pytests/ut/jit/__init__.pytests/ut/jit/conftest.pytests/ut/jit/test_cache.pytests/ut/jit/test_decorator.pytests/ut/jit/test_roundtrip.pytests/ut/jit/test_specializer.py
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (2)
python/pypto/jit/specializer.py (2)
279-295:⚠️ Potential issue | 🟠 MajorScalar values affect the cache key, but not the generated body.
scalar_valuesis stored on_BodyTransformer, but no visitor ever readsself._scalars. Differentpl.INDEX/pl.FP32call values therefore generate different cache entries while producing the same IR, and scalar-dependent DSL constructs never get specialized by value. Please inline scalar loads and add a regression intests/ut/jit/test_specializer.py.Minimal fix
class _BodyTransformer(ast.NodeTransformer): @@ def _shape_dim_node(self, param_name: str, dim_idx: int) -> ast.expr: meta = self._meta[param_name] if (param_name, dim_idx) in self._dynamic_dims: dv = _dynvar_name_for_dim(param_name, dim_idx, self._dv_names) return ast.Name(id=dv, ctx=ast.Load()) return ast.Constant(value=meta.shape[dim_idx]) + + def visit_Name(self, node: ast.Name) -> ast.expr: + if isinstance(node.ctx, ast.Load) and node.id in self._scalars: + return ast.copy_location(ast.Constant(value=self._scalars[node.id]), node) + return node🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@python/pypto/jit/specializer.py` around lines 279 - 295, The transformer currently stores scalar_values on _BodyTransformer as self._scalars but never uses them, causing cache divergence; update _BodyTransformer to inline scalar constants by replacing loads of scalar DSL variables with their literal values during IR generation (e.g., implement logic in the visitor used to walk expressions — such as visit_Name/visit_Attribute or the method that handles DSL scalar loads — to check self._scalars and return a Constant node for keys present), remove reliance on scalar_values from any cache key generation if present, and add a regression test in tests/ut/jit/test_specializer.py that creates two specializations differing only by scalar_values and asserts they produce identical generated bodies/IR.
615-623:⚠️ Potential issue | 🟠 MajorDynVar emission renames user-defined dynamic symbols.
_iter_dynvar_names()only yields the Python binding name, so line 622 reconstructspl.dynamic()asrows = pl.dynamic("rows")even when the source wasrows = pl.dynamic("M"). That silently changes the symbol name and can also merge unrelated dynvars when different functions reuse the same local variable name. Please carry the original literal through specialization and emit that here instead.Also applies to: 642-649
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@python/pypto/jit/specializer.py` around lines 615 - 623, The emitted module-level DynVar declarations are using the Python binding name from _iter_dynvar_names() which loses the original literal passed to pl.dynamic() (causing e.g. pl.dynamic("rows") instead of pl.dynamic("M")); update the pipeline so the original literal is preserved through specialization (e.g., store the literal string on the context or the DynVar metadata when you discover pl.dynamic) and change the emission logic to iterate those preserved literals (not the local binding names) when building lines (the code around _iter_dynvar_names and the emission loop that appends f'{dv_varname} = pl.dynamic("{dv_varname}")'); apply the same fix to the second emission site mentioned (the 642-649 block) so you emit the exact original string passed to pl.dynamic() and avoid merging different DynVars that happen to share a local variable name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@python/pypto/jit/decorator.py`:
- Around line 369-375: The current _get_source_hash and dependency handling only
includes direct deps, so nested `@jit.incore` dependencies (e.g., dep_a calling
dep_b) are not specialized or hashed; update _get_deps to perform a recursive,
deduplicating walk of dependency contexts and change call-site emission (where
contexts and dep_names are built) to include all transitive deps so emitted
contexts list contains every unique dep; ensure _get_source_hash collects
inspect.getsource for self._func plus the full deduped transitive set from
_get_deps and then calls compute_source_hash on that complete list (apply same
recursion fix to the other occurrences referenced around lines 472-490 and
540-549).
- Around line 239-257: The bug is that _extract_call_args_for_dep (and the same
logic at the other occurrence) only inspects node.args so keyword-only calls
like sub(a=a, c=out) yield an empty list which is treated as "found" and causes
_build_dep_context to take the positional branch; update
_extract_call_args_for_dep to also inspect node.keywords and build a mapping:
for each keyword in node.keywords, append keyword.arg (or None for None arg) at
the corresponding logical position or, simpler, if node.args is empty but
node.keywords is non-empty, return None so the caller falls back; apply the same
change to the duplicate block at the other location (lines referenced in the
comment) and ensure callers (_build_dep_context) correctly interpret a None
return as "no mapping".
- Around line 423-440: The entry-level dynamic_dims computed by
_scan_dynamic_dims must be augmented with any dependency-level bind_dynamic
markers before building the cache key and creating the entry specialization:
after calling _scan_dynamic_dims(self._func, param_names) merge in dynamic-dim
mappings produced/propagated by dependent `@jit.incore` functions (i.e. any
dep-derived dynamic dim map you already compute or can obtain from the
dependency analysis) so that dynamic_dims reflects both entry and dependency
bindings; then pass the merged dynamic_dims into make_cache_key(...) and into
the entry specialization/_compile(...) (and into the SpecializeContext creation)
so dependency-level DynVar bindings affect caching and compilation. Apply the
same merge at the other locations noted (the blocks around lines 479-489 and
537-548).
---
Duplicate comments:
In `@python/pypto/jit/specializer.py`:
- Around line 279-295: The transformer currently stores scalar_values on
_BodyTransformer as self._scalars but never uses them, causing cache divergence;
update _BodyTransformer to inline scalar constants by replacing loads of scalar
DSL variables with their literal values during IR generation (e.g., implement
logic in the visitor used to walk expressions — such as
visit_Name/visit_Attribute or the method that handles DSL scalar loads — to
check self._scalars and return a Constant node for keys present), remove
reliance on scalar_values from any cache key generation if present, and add a
regression test in tests/ut/jit/test_specializer.py that creates two
specializations differing only by scalar_values and asserts they produce
identical generated bodies/IR.
- Around line 615-623: The emitted module-level DynVar declarations are using
the Python binding name from _iter_dynvar_names() which loses the original
literal passed to pl.dynamic() (causing e.g. pl.dynamic("rows") instead of
pl.dynamic("M")); update the pipeline so the original literal is preserved
through specialization (e.g., store the literal string on the context or the
DynVar metadata when you discover pl.dynamic) and change the emission logic to
iterate those preserved literals (not the local binding names) when building
lines (the code around _iter_dynvar_names and the emission loop that appends
f'{dv_varname} = pl.dynamic("{dv_varname}")'); apply the same fix to the second
emission site mentioned (the 642-649 block) so you emit the exact original
string passed to pl.dynamic() and avoid merging different DynVars that happen to
share a local variable name.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 905c2132-c18d-427d-b9e6-adf01bc749df
📒 Files selected for processing (12)
python/pypto/jit/__init__.pypython/pypto/jit/cache.pypython/pypto/jit/decorator.pypython/pypto/jit/specializer.pypython/pypto/language/__init__.pypython/pypto/language/typing/tensor.pytests/ut/jit/__init__.pytests/ut/jit/conftest.pytests/ut/jit/test_cache.pytests/ut/jit/test_decorator.pytests/ut/jit/test_roundtrip.pytests/ut/jit/test_specializer.py
✅ Files skipped from review due to trivial changes (5)
- tests/ut/jit/init.py
- tests/ut/jit/conftest.py
- python/pypto/language/typing/tensor.py
- python/pypto/jit/init.py
- python/pypto/jit/cache.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/ut/jit/test_cache.py
- tests/ut/jit/test_decorator.py
| # Scan dynamic dims from this function's AST | ||
| dynamic_dims = _scan_dynamic_dims(self._func, param_names) | ||
|
|
||
| # Build cache key (based on entry function's params only) | ||
| key = make_cache_key( | ||
| source_hash=self._get_source_hash(), | ||
| param_names=param_names, | ||
| tensor_shapes={n: m.shape for n, m in tensor_meta.items()}, | ||
| tensor_dtypes={n: m.dtype for n, m in tensor_meta.items()}, | ||
| dynamic_dims=dynamic_dims, | ||
| scalar_values=scalar_values, | ||
| ) | ||
|
|
||
| if key in self._cache: | ||
| return self._cache[key] | ||
|
|
||
| # Cache miss: specialize, parse, cache | ||
| program = self._compile(tensor_meta, scalar_values, scalar_dtypes, dynamic_dims, pl) |
There was a problem hiding this comment.
Dependency-level bind_dynamic() markers never reach the entry specialization.
dynamic_dims is computed from the entry AST only, then reused for both the cache key and the entry SpecializeContext. If an entry param is forwarded into a @jit.incore dep that binds it dynamic, the dep sees a DynVar but the public entry signature and cache key stay concrete, so Style B dynamic dims compile/cache as fixed-shape programs. Merge the mapped dep dynamic dims back onto the entry-param set before building the key and the entry context.
Also applies to: 479-489, 537-548
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@python/pypto/jit/decorator.py` around lines 423 - 440, The entry-level
dynamic_dims computed by _scan_dynamic_dims must be augmented with any
dependency-level bind_dynamic markers before building the cache key and creating
the entry specialization: after calling _scan_dynamic_dims(self._func,
param_names) merge in dynamic-dim mappings produced/propagated by dependent
`@jit.incore` functions (i.e. any dep-derived dynamic dim map you already compute
or can obtain from the dependency analysis) so that dynamic_dims reflects both
entry and dependency bindings; then pass the merged dynamic_dims into
make_cache_key(...) and into the entry specialization/_compile(...) (and into
the SpecializeContext creation) so dependency-level DynVar bindings affect
caching and compilation. Apply the same merge at the other locations noted (the
blocks around lines 479-489 and 537-548).
9b6b353 to
df06df4
Compare
Add @pl.jit decorator that specializes JIT-decorated functions into @pl.program source at call-time based on concrete tensor shapes/dtypes. Core components: - python/pypto/jit/decorator.py: @jit and @jit.incore decorators, dep discovery, tensor-meta collection, L1 specialization cache - python/pypto/jit/specializer.py: AST rewriter that fills in concrete shapes/dtypes, inlines scalar constants, handles pl.dynamic DynVars, rewrites dep calls to self.method() style - python/pypto/jit/cache.py: L1 cache keyed on (source_hash, shapes, dtypes, scalar_values, dynamic_dims) - python/pypto/language/__init__.py: expose jit and JITFunction in pl.* - python/pypto/language/typing/tensor.py: add Tensor.bind_dynamic() no-op for marking dynamic dimensions in JIT kernels Tests: - tests/ut/jit/test_cache.py: cache key stability and hit/miss behavior - tests/ut/jit/test_decorator.py: dep discovery, dynamic-dim scanning, call-arg extraction, cache integration, OSError on missing source - tests/ut/jit/test_specializer.py: annotation filling, Out params, dynvar emission, shape/dtype substitution, structural generation - tests/ut/jit/test_roundtrip.py: 23 structural-equality tests covering all expressible examples/kernels/ programs (01_elementwise through 08_assemble), verifying JIT output matches hand-written @pl.program IR
df06df4 to
44b5260
Compare
Summary
Implements issue #878: a JIT preprocessor that sits above the existing
@pl.programpipeline. Users write a generic kernel once with@pl.jit; at call-time, JIT specializes it to a concrete@pl.programby filling in actual tensor shapes and dtypes, then runs the normal compiler pipeline.python/pypto/jit/decorator.py:@jitand@jit.incoredecorators — collect runtime tensor metadata, discover@jit.incoredependencies, drive specialization, maintain an L1 compilation cachepython/pypto/jit/specializer.py: AST rewriter — fills concrete shapes/dtypes into type annotations, inlines scalar constants, handlespl.dynamicDynVars, rewrites dep calls toself.method()style for Style Bpython/pypto/jit/cache.py: L1 cache keyed on(source_hash, shapes, dtypes, scalar_values, dynamic_dims)to skip recompilation for repeated calls with the same specializationpython/pypto/language/__init__.py: exposejitandJITFunctionin thepl.*namespacepython/pypto/language/typing/tensor.py: addTensor.bind_dynamic()no-op used by the specializer to statically detect dynamic dimensionsTesting
tests/ut/jit/test_cache.py: cache key stability and hit/miss behaviortests/ut/jit/test_decorator.py: dep discovery, dynamic-dim scanning, call-arg extraction, cache integration, accurate OSError on missing sourcetests/ut/jit/test_specializer.py: annotation filling, Out params, dynvar emission, shape/dtype substitution, generated program structuretests/ut/jit/test_roundtrip.py: 23 structural-equality round-trip tests — eachexamples/kernels/program expressible as@pl.jitis re-written in JIT style and verified withir.assert_structural_equalagainst the hand-written@pl.programAll 90 new tests pass.
Related Issues
Fixes #878