Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -224,25 +224,17 @@ follow_imports = 'silent'
module = 'gt4py.cartesian.*'

[[tool.mypy.overrides]]
ignore_errors = true
disable_error_code = "call-arg"
module = 'gt4py.cartesian.frontend.nodes'

[[tool.mypy.overrides]]
ignore_errors = true
module = 'gt4py.cartesian.frontend.node_util'

[[tool.mypy.overrides]]
ignore_errors = true
disable_error_code = "call-arg"
module = 'gt4py.cartesian.frontend.gtscript_frontend'

[[tool.mypy.overrides]]
ignore_errors = true
disable_error_code = "call-arg"
module = 'gt4py.cartesian.frontend.defir_to_gtir'

[[tool.mypy.overrides]]
ignore_errors = true
module = 'gt4py.cartesian.frontend.meta'

[[tool.mypy.overrides]]
module = 'gt4py.eve.extended_typing'
warn_unused_ignores = false
Expand Down
5 changes: 4 additions & 1 deletion src/gt4py/cartesian/frontend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def generate(
@classmethod
@abc.abstractmethod
def prepare_stencil_definition(
cls, definition: AnyStencilFunc, externals: dict[str, Any]
cls,
definition: AnyStencilFunc,
externals: dict[str, Any],
options: BuildOptions | None = None,
) -> AnnotatedStencilFunc:
"""
Annotate the stencil function if not already done so.
Expand Down
30 changes: 14 additions & 16 deletions src/gt4py/cartesian/frontend/defir_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import functools
import itertools
import numbers
from typing import Any, Final, List, Optional, Tuple, Union, cast
from typing import Any, Final, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -66,9 +66,7 @@ def _convert_dtype(data_type) -> common.DataType:
if dtype == common.DataType.DEFAULT:
# TODO: this will be a frontend choice later
# in non-GTC parts, this is set in the backend
dtype = cast(
common.DataType, common.DataType.FLOAT64
) # see https://github.com/GridTools/gtc/issues/100
dtype = common.DataType.FLOAT64
return dtype


Expand Down Expand Up @@ -161,7 +159,7 @@ def _nested_list_dim(self, a: List) -> List[int]:

def visit_Assign(
self, node: Assign, *, fields_decls: dict[str, FieldDecl], **kwargs
) -> Union[gtir.ParAssignStmt, List[gtir.ParAssignStmt]]:
) -> Assign | list[Assign]:
if self._is_vector_assignment(node, fields_decls):
assert isinstance(node.target, FieldRef) or isinstance(node.target, VarRef)
target_dims = fields_decls[node.target.name].data_dims
Expand Down Expand Up @@ -249,20 +247,20 @@ def visit_FieldRef(self, node: FieldRef, *, fields_decls: dict[str, FieldDecl],

def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: dict[str, FieldDecl], **kwargs):
if node.op == UnaryOperator.TRANSPOSED:
node = self.visit(node.arg, fields_decls=fields_decls, **kwargs)
assert isinstance(node, list) and all(
isinstance(row, list) and len(row) == len(node[0]) for row in node
argument = self.visit(node.arg, fields_decls=fields_decls, **kwargs)
assert isinstance(argument, list) and all(
isinstance(row, list) and len(row) == len(argument[0]) for row in argument
)
# transpose list
node = [list(x) for x in zip(*node)]
return node
argument = [list(x) for x in zip(*argument)]
return argument

return self.generic_visit(node, **kwargs)

def visit_BinOpExpr(self, node: BinOpExpr, *, fields_decls: dict[str, FieldDecl], **kwargs):
lhs = self.visit(node.lhs, fields_decls=fields_decls, **kwargs)
rhs = self.visit(node.rhs, fields_decls=fields_decls, **kwargs)
result: Union[List[BinOpExpr], BinOpExpr] = []
result: list[BinOpExpr] = []

if node.op == BinaryOperator.MATMULT:
for j in range(len(lhs)):
Expand Down Expand Up @@ -587,20 +585,20 @@ def visit_While(self, node: While) -> gtir.While:
def visit_VarRef(self, node: VarRef, **kwargs) -> gtir.ScalarAccess:
return gtir.ScalarAccess(name=node.name, loc=location_to_source_location(node.loc))

def visit_AxisInterval(self, node: AxisInterval) -> Tuple[gtir.AxisBound, gtir.AxisBound]:
def visit_AxisInterval(self, node: AxisInterval) -> tuple[common.AxisBound, common.AxisBound]:
return self.visit(node.start), self.visit(node.end)

def visit_AxisBound(self, node: AxisBound) -> gtir.AxisBound:
def visit_AxisBound(self, node: AxisBound) -> common.AxisBound:
# TODO(havogt) add support VarRef
return gtir.AxisBound(
return common.AxisBound(
level=self.GT4PY_LEVELMARKER_TO_GTIR_LEVELMARKER[node.level], offset=node.offset
)

def visit_RuntimeAxisBound(self, node: RuntimeAxisBound) -> gtir.RuntimeAxisBound:
def visit_RuntimeAxisBound(self, node: RuntimeAxisBound) -> common.RuntimeAxisBound:
utils.warn_experimental_feature(
feature="Runtime Interval Bounds", ADR="experimental/runtime-intervals.md"
)
return gtir.RuntimeAxisBound(
return common.RuntimeAxisBound(
level=self.GT4PY_LEVELMARKER_TO_GTIR_LEVELMARKER[node.level],
offset=self.visit(node.offset),
)
Expand Down
Loading
Loading