Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
abs,
add,
addmm,
avg_pool2d,
bitwise_and,
bitwise_not,
bitwise_or,
bmm,
clamp,
conv2d,
cos,
div,
dropout,
Expand All @@ -20,6 +22,7 @@
layer_norm,
le,
lt,
max_pool2d,
mm,
mul,
ne,
Expand All @@ -42,11 +45,13 @@
"abs",
"add",
"addmm",
"avg_pool2d",
"bitwise_and",
"bitwise_not",
"bitwise_or",
"bmm",
"clamp",
"conv2d",
"cos",
"div",
"dropout",
Expand All @@ -60,6 +65,7 @@
"layer_norm",
"le",
"lt",
"max_pool2d",
"mm",
"mul",
"ne",
Expand Down
42 changes: 42 additions & 0 deletions src/ntops/kernels/avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.pooling import arrangement


def application(input, output):
output = ntl.sum(input, axis=-1) / input.shape[-1] # noqa: F841


def premake(
kernel_size_h=None,
kernel_size_w=None,
stride_h=None,
stride_w=None,
padding_h=None,
padding_w=None,
dilation_h=None,
dilation_w=None,
ceil_mode=None,
dtype=None,
block_size=None,
):
arrangement_ = functools.partial(
arrangement,
kernel_size_h=kernel_size_h,
kernel_size_w=kernel_size_w,
stride_h=stride_h,
stride_w=stride_w,
padding_h=padding_h,
padding_w=padding_w,
dilation_h=dilation_h,
dilation_w=dilation_w,
ceil_mode=ceil_mode,
block_size=block_size,
)

tensors = (Tensor(4, dtype=dtype), Tensor(4, dtype=dtype))

return arrangement_, application, tensors
144 changes: 144 additions & 0 deletions src/ntops/kernels/conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import copy
import functools

import ninetoothed.language as ntl
from ninetoothed import Symbol, Tensor

from ntops.kernels import mm


def arrangement(
input,
weight,
bias,
output,
input_precision,
stride_h=None,
stride_w=None,
padding_h=None,
padding_w=None,
dilation_h=None,
dilation_w=None,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
if stride_h is None:
stride_h = Symbol("stride_h", constexpr=True)

if stride_w is None:
stride_w = Symbol("stride_w", constexpr=True)

if padding_h is None:
padding_h = Symbol("padding_h", constexpr=True)

if padding_w is None:
padding_w = Symbol("padding_w", constexpr=True)

if dilation_h is None:
dilation_h = Symbol("dilation_h", constexpr=True)

if dilation_w is None:
dilation_w = Symbol("dilation_w", constexpr=True)

if block_size_m is None:
block_size_m = mm.BLOCK_SIZE_M

if block_size_n is None:
block_size_n = mm.BLOCK_SIZE_N

if block_size_k is None:
block_size_k = mm.BLOCK_SIZE_K

mm_arrangement = functools.partial(
mm.arrangement,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)

input_arranged = input.pad(
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
)
input_arranged = input_arranged.tile(
(1, *weight.shape[1:]),
strides=(-1, -1, stride_h, stride_w),
dilation=(1, 1, dilation_h, dilation_w),
floor_mode=True,
)
input_arranged = input_arranged.squeeze(1)
input_arranged.dtype = input_arranged.dtype.squeeze(0)
input_arranged = input_arranged.ravel()
input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1)

weight_arranged = weight.flatten(start_dim=1)
weight_arranged = weight_arranged.permute((1, 0))

bias_arranged = bias[None, :, None, None].expand(
(output.shape[0], -1, output.shape[2], output.shape[3])
)
bias_arranged = bias_arranged.permute((0, 2, 3, 1)).flatten(end_dim=3)

output_arranged = output.permute((0, 2, 3, 1)).flatten(end_dim=3)

_, _, bias_arranged, _ = mm_arrangement(
copy.deepcopy(input_arranged),
copy.deepcopy(weight_arranged),
bias_arranged,
copy.deepcopy(input_precision),
)

input_arranged, weight_arranged, output_arranged, input_precision_arranged = (
mm_arrangement(
input_arranged, weight_arranged, output_arranged, input_precision
)
)

return (
input_arranged,
weight_arranged,
bias_arranged,
output_arranged,
input_precision_arranged,
)


def application(input, weight, bias, output, input_precision):
mm_output = ntl.zeros(output.shape, dtype=ntl.float32)
mm.application(input, weight, mm_output, input_precision)
output = mm_output + bias


def premake(
input_precision=None,
stride_h=None,
stride_w=None,
padding_h=None,
padding_w=None,
dilation_h=None,
dilation_w=None,
dtype=None,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
arrangement_ = functools.partial(
arrangement,
stride_h=stride_h,
stride_w=stride_w,
padding_h=padding_h,
padding_w=padding_w,
dilation_h=dilation_h,
dilation_w=dilation_w,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)

input, weight, output = (Tensor(4, dtype=dtype) for _ in range(3))
bias = Tensor(1, dtype=dtype)
input_precision = Tensor(0, dtype=dtype, constexpr=True, value=input_precision)

tensors = (input, weight, bias, output, input_precision)

return arrangement_, application, tensors
42 changes: 42 additions & 0 deletions src/ntops/kernels/max_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.pooling import arrangement


def application(input, output):
output = ntl.max(input, axis=-1) # noqa: F841


def premake(
kernel_size_h=None,
kernel_size_w=None,
stride_h=None,
stride_w=None,
padding_h=None,
padding_w=None,
dilation_h=None,
dilation_w=None,
ceil_mode=None,
dtype=None,
block_size=None,
):
arrangement_ = functools.partial(
arrangement,
kernel_size_h=kernel_size_h,
kernel_size_w=kernel_size_w,
stride_h=stride_h,
stride_w=stride_w,
padding_h=padding_h,
padding_w=padding_w,
dilation_h=dilation_h,
dilation_w=dilation_w,
ceil_mode=ceil_mode,
block_size=block_size,
)

tensors = (Tensor(4, dtype=dtype, other=float("-inf")), Tensor(4, dtype=dtype))

return arrangement_, application, tensors
68 changes: 68 additions & 0 deletions src/ntops/kernels/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import ninetoothed
from ninetoothed import Symbol


def arrangement(
input,
output,
kernel_size_h=None,
kernel_size_w=None,
stride_h=None,
stride_w=None,
padding_h=None,
padding_w=None,
dilation_h=None,
dilation_w=None,
ceil_mode=None,
block_size=None,
):
if kernel_size_h is None:
kernel_size_h = Symbol("kernel_size_h", constexpr=True, upper_bound=16)

if kernel_size_w is None:
kernel_size_w = Symbol("kernel_size_w", constexpr=True, upper_bound=16)

if stride_h is None:
stride_h = Symbol("stride_h", constexpr=True)

if stride_w is None:
stride_w = Symbol("stride_w", constexpr=True)

if padding_h is None:
padding_h = Symbol("padding_h", constexpr=True)

if padding_w is None:
padding_w = Symbol("padding_w", constexpr=True)

if dilation_h is None:
dilation_h = Symbol("dilation_h", constexpr=True)

if dilation_w is None:
dilation_w = Symbol("dilation_w", constexpr=True)

if ceil_mode is None:
ceil_mode = False

if block_size is None:
block_size = ninetoothed.block_size()

input_arranged = input.pad(
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
)
input_arranged = input_arranged.tile(
(1, 1, kernel_size_h, kernel_size_w),
strides=(-1, -1, stride_h, stride_w),
dilation=(1, 1, dilation_h, dilation_w),
floor_mode=not ceil_mode,
)
input_arranged = input_arranged.ravel()
input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
input_arranged = input_arranged.tile((block_size, -1))

output_arranged = output.tile((1, 1, 1, 1))
output_arranged = output_arranged.ravel()
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
output_arranged = output_arranged.tile((block_size, -1))
output_arranged.dtype = output_arranged.dtype.squeeze(1)

return input_arranged, output_arranged
6 changes: 6 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from ntops.torch.abs import abs
from ntops.torch.add import add
from ntops.torch.addmm import addmm
from ntops.torch.avg_pool2d import avg_pool2d
from ntops.torch.bitwise_and import bitwise_and
from ntops.torch.bitwise_not import bitwise_not
from ntops.torch.bitwise_or import bitwise_or
from ntops.torch.bmm import bmm
from ntops.torch.clamp import clamp
from ntops.torch.conv2d import conv2d
from ntops.torch.cos import cos
from ntops.torch.div import div
from ntops.torch.dropout import dropout
Expand All @@ -20,6 +22,7 @@
from ntops.torch.le import le
from ntops.torch.lt import lt
from ntops.torch.matmul import matmul
from ntops.torch.max_pool2d import max_pool2d
from ntops.torch.mm import mm
from ntops.torch.mul import mul
from ntops.torch.ne import ne
Expand All @@ -41,11 +44,13 @@
"abs",
"add",
"addmm",
"avg_pool2d",
"bitwise_and",
"bitwise_not",
"bitwise_or",
"bmm",
"clamp",
"conv2d",
"cos",
"div",
"dropout",
Expand All @@ -60,6 +65,7 @@
"le",
"lt",
"matmul",
"max_pool2d",
"mm",
"mul",
"ne",
Expand Down
Loading