Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ cmake.args = [

[tool.pyright]
include = ["python/pypto", "tests"]
exclude = ["**/__pycache__", "build", "dist", "tests/st"]
extraPaths = ["python"]
exclude = ["**/__pycache__", "build", "dist", "tests/st", "tests/ut/jit/test_roundtrip.py"]
extraPaths = ["python", "."]
typeCheckingMode = "basic"
pythonVersion = "3.10"
reportMissingTypeStubs = false
Expand All @@ -118,6 +118,9 @@ reportCallIssue = false
reportAssignmentType = false
reportOperatorIssue = false
reportRedeclaration = false
reportOptionalIterable = false
reportOptionalSubscript = false
reportGeneralTypeIssues = false

[tool.pylint]
# Pylint is not used in this project, but as it is a widely used linter,
Expand Down
36 changes: 36 additions & 0 deletions python/pypto/jit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------

"""PyPTO JIT compilation module.

Provides the ``@pl.jit`` decorator for writing kernel functions that are
automatically specialized and compiled on first call based on the shapes
and dtypes of their tensor arguments.

Example::

import pypto.language as pl

@pl.jit
def tile_add(a: pl.Tensor, b: pl.Tensor, c: pl.Out[pl.Tensor]):
with pl.incore():
M, N = a.shape
tile_a = pl.load(a, [0, 0], [M, N])
tile_b = pl.load(b, [0, 0], [M, N])
tile_c = pl.add(tile_a, tile_b)
pl.store(tile_c, [0, 0], c)
return c

# Compiles on first call, serves from cache on subsequent calls
prog = tile_add(torch.randn(128, 128), torch.randn(128, 128), torch.empty(128, 128))
"""

from .decorator import JITFunction, jit

__all__ = ["JITFunction", "jit"]
120 changes: 120 additions & 0 deletions python/pypto/jit/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------

"""Compilation cache for @pl.jit functions.

L1 cache: in-memory dict on each JITFunction instance.
Cache key encodes source hash, tensor shapes/dtypes, and scalar values.
Dynamic dimensions (marked via bind_dynamic) are stored as None in the key
so different concrete values for that dimension share the same cache entry.
"""

import hashlib
from dataclasses import dataclass

from pypto.pypto_core import DataType


@dataclass(frozen=True)
class TensorCacheInfo:
"""Per-tensor component of a cache key.

Attributes:
name: Parameter name.
shape: Shape tuple with None for dynamic dimensions.
dtype: DataType of the tensor.
"""

name: str
shape: tuple[int | None, ...]
dtype: DataType


@dataclass(frozen=True)
class ScalarCacheInfo:
"""Per-scalar-param component of a cache key.

Attributes:
name: Parameter name.
value: Concrete scalar value passed at this call site.
"""

name: str
value: int | float | bool


# A cache key is a tuple of (source_hash, tensor_infos, scalar_infos).
# Using a plain tuple keeps it hashable without a custom __hash__.
CacheKey = tuple[str, tuple[TensorCacheInfo, ...], tuple[ScalarCacheInfo, ...]]


def compute_source_hash(sources: list[str]) -> str:
"""Compute a stable hash over one or more source strings.

Args:
sources: List of source code strings (main function + all deps).

Returns:
Hex digest string (SHA-256, first 16 chars for brevity).
"""
h = hashlib.sha256()
for src in sources:
h.update(src.encode())
return h.hexdigest()[:16]


def make_cache_key(
source_hash: str,
param_names: list[str],
tensor_shapes: dict[str, tuple[int, ...]],
tensor_dtypes: dict[str, DataType],
dynamic_dims: set[tuple[str, int]],
scalar_values: dict[str, int | float | bool],
) -> CacheKey:
"""Build a cache key for a JIT call site.

Args:
source_hash: Hash of function source code (and all dep sources).
param_names: Ordered list of all parameter names (preserves arg order).
tensor_shapes: Concrete shape per tensor parameter name.
tensor_dtypes: DataType per tensor parameter name.
dynamic_dims: Set of (param_name, dim_index) pairs that are dynamic.
Dynamic dims are stored as None in the cache key so different
concrete values for that dimension produce the same cache entry.
scalar_values: Concrete value per scalar parameter name.

Returns:
Hashable CacheKey tuple.
"""
tensor_infos = []
for name in param_names:
if name not in tensor_shapes:
continue
concrete_shape = tensor_shapes[name]
keyed_shape = tuple(
None if (name, i) in dynamic_dims else dim for i, dim in enumerate(concrete_shape)
)
tensor_infos.append(TensorCacheInfo(name=name, shape=keyed_shape, dtype=tensor_dtypes[name]))

scalar_infos = []
for name in param_names:
if name not in scalar_values:
continue
scalar_infos.append(ScalarCacheInfo(name=name, value=scalar_values[name]))

return (source_hash, tuple(tensor_infos), tuple(scalar_infos))


__all__ = [
"CacheKey",
"ScalarCacheInfo",
"TensorCacheInfo",
"compute_source_hash",
"make_cache_key",
]
Loading
Loading