Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,53 @@ def to_python(self, obj_id: int):
return ctypes.cast(obj_id, ctypes.py_object).value


class Float32sr(typeclass):
"""
32-bit floating-point type with stochastic rounding.

Stochastic rounding randomly rounds to the nearest representable value
with probability proportional to the distance, reducing systematic bias
in repeated computations.

Limitations of current implementation: library functions like blas fallback
to rount-to-nearest float32 for compatibility reasons; targets CPU only.
"""

def __init__(self):
self.type = numpy.float32
self.bytes = 4
self.dtype = self
self.typename = "float"
self.stochastically_rounded = True

def to_json(self):
return 'float32sr'

@staticmethod
def from_json(json_obj, context=None):
from dace.symbolic import pystr_to_symbolic # must be included!
return float32sr()

@property
def ctype(self):
return "dace::float32sr"

@property
def ctype_unaligned(self):
return self.ctype

def as_ctypes(self):
""" Returns the ctypes version of the typeclass. """
return _FFI_CTYPES[self.type]

def as_numpy_dtype(self):
return numpy.dtype(self.type)

@property
def base_type(self):
return self


class compiletime:
"""
Data descriptor type hint signalling that argument evaluation is
Expand Down Expand Up @@ -1180,6 +1227,7 @@ class uint32(_DaCeArray, npt.NDArray[numpy.uint32]): ...
class uint64(_DaCeArray, npt.NDArray[numpy.uint64]): ...
class float16(_DaCeArray, npt.NDArray[numpy.float16]): ...
class float32(_DaCeArray, npt.NDArray[numpy.float32]): ...
class float32sr(_DaCeArray, npt.NDArray[numpy.float32]): ...
class float64(_DaCeArray, npt.NDArray[numpy.float64]): ...
class complex64(_DaCeArray, npt.NDArray[numpy.complex64]): ...
class complex128(_DaCeArray, npt.NDArray[numpy.complex128]): ...
Expand All @@ -1201,6 +1249,7 @@ class MPI_Request(_DaCeArray, npt.NDArray[numpy.void]): ...
uint64 = typeclass(numpy.uint64)
float16 = typeclass(numpy.float16)
float32 = typeclass(numpy.float32)
float32sr = Float32sr()
float64 = typeclass(numpy.float64)
complex64 = typeclass(numpy.complex64)
complex128 = typeclass(numpy.complex128)
Expand Down Expand Up @@ -1258,6 +1307,7 @@ def dtype_to_typeclass(dtype=None):
int64: "dace::int64",
float16: "dace::float16",
float32: "dace::float32",
float32sr: "dace::float32sr",
float64: "dace::float64",
complex64: "dace::complex64",
complex128: "dace::complex128"
Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/replacements/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op
type1 = arr1.dtype.type
type2 = arr2.dtype.type
restype = dtypes.dtype_to_typeclass(np.result_type(type1, type2).type)
if arr1.dtype == dtypes.float32sr and arr2.dtype == dtypes.float32sr:
restype = dtypes.float32sr

op3, arr3 = sdfg.add_transient(visitor.get_target_name(), output_shape, restype, arr1.storage, find_new_name=True)

Expand Down
58 changes: 54 additions & 4 deletions dace/frontend/python/replacements/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,51 @@ def _is_op_boolean(op: str):
return False


def _handle_casting_for_stochastically_rounded_types(input_datatypes: List[dtypes.typeclass], restype: dtypes.typeclass,
cast_types: List) -> Tuple[dtypes.typeclass, List]:
"""
Adjusts result type and casts for stochastically rounded inputs.

If all inputs are stochastically rounded types, promotes the result type
to its SR equivalent (e.g., float32 -> float32sr) and updates cast_types.

Args:
input_datatypes: List of input data types.
restype: The computed result type.
cast_types: List of cast strings.

Returns:
Tuple of (possibly promoted result type, updated cast_types).
"""
float_to_sr = {
dace.float32: dace.float32sr,
}

for i, dtype in enumerate(input_datatypes):
if hasattr(dtype, "stochastically_rounded"):
if cast_types[i] == "dace.float32":
cast_types[i] = None

# check if stoc rounded inputs
stochastically_rounded = True
for i, dtype in enumerate(input_datatypes):
if not hasattr(dtype, "stochastically_rounded"):
stochastically_rounded = False
break

if stochastically_rounded:
# make the result SR
if restype in float_to_sr:
restype = float_to_sr[restype]

# cast the intermediate types
for i, dtype in enumerate(cast_types):
if dtype in float_to_sr:
cast_types[i] = float_to_sr[dtype]

return restype, cast_types


def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic]],
operator: str = None) -> Tuple[Union[List[dtypes.typeclass], dtypes.typeclass, str], ...]:

Expand Down Expand Up @@ -144,12 +189,15 @@ def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic
raise TypeError("Type {t} of argument {a} is not supported".format(t=type(arg), a=arg))

complex_types = {dtypes.complex64, dtypes.complex128, np.complex64, np.complex128}
float_types = {dtypes.float16, dtypes.float32, dtypes.float64, np.float16, np.float32, np.float64}
float_types = {dtypes.float16, dtypes.float32, dtypes.float32sr, dtypes.float64, np.float16, np.float32, np.float64}
signed_types = {dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, np.int8, np.int16, np.int32, np.int64}
# unsigned_types = {np.uint8, np.uint16, np.uint32, np.uint64}

coarse_types = []
for dtype in datatypes:
if hasattr(dtype, "srtype"): # unwrap stochastically rounded vars
dtype = dtype.srtype

if dtype in complex_types:
coarse_types.append(3) # complex
elif dtype in float_types:
Expand Down Expand Up @@ -336,18 +384,20 @@ def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic
else: # Operators with 3 or more arguments
restype = np_result_type(dtypes_for_result)
coarse_result_type = None
if result_type in complex_types:
if restype in complex_types:
coarse_result_type = 3 # complex
elif result_type in float_types:
elif restype in float_types:
coarse_result_type = 2 # float
elif result_type in signed_types:
elif restype in signed_types:
coarse_result_type = 1 # signed integer, bool
else:
coarse_result_type = 0 # unsigned integer
for i, t in enumerate(coarse_types):
if t != coarse_result_type:
casting[i] = cast_str(restype)

restype, casting = _handle_casting_for_stochastically_rounded_types(datatypes, restype, casting)

return restype, casting


Expand Down
15 changes: 13 additions & 2 deletions dace/libraries/blas/nodes/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
(desc_x, stride_x), (desc_y, stride_y), desc_res, sz = node.validate(parent_sdfg, parent_state)
dtype = desc_x.dtype.base_type
veclen = desc_x.dtype.veclen
cast = "(float *)" if dtype == dace.float32sr else ""

try:
func, _, _ = blas_helpers.cublas_type_metadata(dtype)
Expand All @@ -82,7 +83,8 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
n = n or node.n or sz
if veclen != 1:
n /= veclen
code = f"_result = cblas_{func}({n}, _x, {stride_x}, _y, {stride_y});"

code = f"_result = cblas_{func}({n}, {cast} _x, {stride_x}, {cast} _y, {stride_y});"
# The return type is scalar in cblas_?dot signature
tasklet = dace.sdfg.nodes.Tasklet(node.name,
node.in_connectors, {'_result': dtype},
Expand Down Expand Up @@ -203,7 +205,16 @@ def validate(self, sdfg, state):
if desc_x.dtype != desc_y.dtype:
raise TypeError(f"Data types of input operands must be equal: {desc_x.dtype}, {desc_y.dtype}")
if desc_x.dtype.base_type != desc_res.dtype.base_type:
raise TypeError(f"Data types of input and output must be equal: {desc_x.dtype}, {desc_res.dtype}")
arg_types = (desc_x.dtype.base_type, desc_res.dtype.base_type)
if dace.float32 in arg_types and dace.float32sr in arg_types:
"""
When using stochastic rounding, a legitimate (i.e not a bug) mismatch between the input and output
arguments may arise where one argument is a float32sr and the other is a float32 (round-to-nearest).
The underlying data type is the same so this should not cause the validation to fail.
"""
pass
else:
raise TypeError(f"Data types of input and output must be equal: {desc_x.dtype}, {desc_res.dtype}")

# Squeeze input memlets
squeezed1 = copy.deepcopy(in_memlets[0].subset)
Expand Down
7 changes: 5 additions & 2 deletions dace/libraries/blas/nodes/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def expansion(node, state, sdfg):
node.validate(sdfg, state)
(_, adesc, _, _, _, _), (_, bdesc, _, _, _, _), _ = _get_matmul_operands(node, state, sdfg)
dtype = adesc.dtype.base_type

func = to_blastype(dtype.type).lower() + 'gemm'
alpha = f'{dtype.ctype}({node.alpha})'
beta = f'{dtype.ctype}({node.beta})'
Expand All @@ -178,6 +179,8 @@ def expansion(node, state, sdfg):
check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc)

opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, dtype.ctype, func)
opt['cast'] = "(float *)" if dtype == dace.float32sr else ""
opt['c'] = '_c'

# Adaptations for BLAS API
opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans'
Expand All @@ -193,8 +196,8 @@ def expansion(node, state, sdfg):
opt['beta'] = '&__beta'

code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, "
"{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, "
"_c, {ldc});").format_map(opt)
"{M}, {N}, {K}, {alpha},{cast} {x}, {lda}, {cast} {y}, {ldb}, {beta}, "
"{cast} {c}, {ldc});").format_map(opt)

tasklet = dace.sdfg.nodes.Tasklet(
node.name,
Expand Down
1 change: 1 addition & 0 deletions dace/libraries/blas/nodes/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
name_out="_y")
dtype_a = outer_array_a.dtype.type
dtype = outer_array_x.dtype.base_type

veclen = outer_array_x.dtype.veclen
alpha = f'{dtype.ctype}({node.alpha})'
beta = f'{dtype.ctype}({node.beta})'
Expand Down
3 changes: 2 additions & 1 deletion dace/libraries/lapack/nodes/potrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ class ExpandPotrfOpenBLAS(ExpandTransformation):
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
(desc_x, stride_x, rows_x, cols_x), desc_result = node.validate(parent_sdfg, parent_state)
dtype = desc_x.dtype.base_type
cast = "(float *)" if dtype == dace.float32sr else ""
lapack_dtype = blas_helpers.to_blastype(dtype.type).lower()
if desc_x.dtype.veclen > 1:
raise (NotImplementedError)

n = n or node.n
uplo = "'L'" if node.lower else "'U'"
code = f"_res = LAPACKE_{lapack_dtype}potrf(LAPACK_ROW_MAJOR, {uplo}, {rows_x}, _xin, {stride_x});"
code = f"_res = LAPACKE_{lapack_dtype}potrf(LAPACK_ROW_MAJOR, {uplo}, {rows_x}, {cast} _xin, {stride_x});"
tasklet = dace.sdfg.nodes.Tasklet(node.name,
node.in_connectors,
node.out_connectors,
Expand Down
3 changes: 2 additions & 1 deletion dace/libraries/linalg/nodes/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def _make_sdfg(node, parent_state, parent_sdfg, implementation):

inp_desc, inp_shape, out_desc, out_shape = node.validate(parent_sdfg, parent_state)
dtype = inp_desc.dtype
cast = "(dace::float32sr)" if dtype == dace.float32sr else ""
storage = inp_desc.storage

sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))
Expand All @@ -36,7 +37,7 @@ def _make_sdfg(node, parent_state, parent_sdfg, implementation):
_, me, mx = state.add_mapped_tasklet('_uzero_',
dict(__i="0:%s" % out_shape[0], __j="0:%s" % out_shape[1]),
dict(_inp=Memlet.simple('_b', '__i, __j')),
'_out = (__i < __j) ? 0 : _inp;',
f'_out = (__i < __j) ? {cast}(0) : _inp;',
dict(_out=Memlet.simple('_b', '__i, __j')),
language=dace.dtypes.Language.CPP,
external_edges=True)
Expand Down
1 change: 1 addition & 0 deletions dace/runtime/include/dace/dace.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "perf/reporting.h"
#include "comm.h"
#include "serialization.h"
#include "stochastic_rounding.h"

#if defined(__CUDACC__) || defined(__HIPCC__)
#include "cuda/cudacommon.cuh"
Expand Down
67 changes: 67 additions & 0 deletions dace/runtime/include/dace/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#define __DACE_REDUCTION_H

#include <cstdint>
#include <dace/stochastic_rounding.h>

#include "types.h"
#include "vector.h"
Expand Down Expand Up @@ -121,6 +122,40 @@ namespace dace {
}
};

template <>
struct wcr_custom<dace::float32sr> {
template <typename WCR>
static DACE_HDFI dace::float32sr reduce_atomic(WCR wcr, dace::float32sr *ptr, const dace::float32sr& value) {
#ifdef DACE_USE_GPU_ATOMICS
// Stochastic rounding version of atomic float reduction
int *iptr = reinterpret_cast<int *>(ptr);
int old = *iptr, assumed;
do {
assumed = old;
float old_val = __int_as_float(assumed);
float new_val = static_cast<float>(wcr(static_cast<dace::float32sr>(old_val), value));
old = atomicCAS(iptr, assumed, __float_as_int(new_val));
} while (assumed != old);
return static_cast<dace::float32sr>(__int_as_float(old));
#else
dace::float32sr old;
#pragma omp critical
{
old = *ptr;
*ptr = wcr(old, value);
}
return old;
#endif
}

template <typename WCR>
static DACE_HDFI dace::float32sr reduce(WCR wcr, dace::float32sr *ptr, const dace::float32sr& value) {
dace::float32sr old = *ptr;
*ptr = wcr(old, value);
return old;
}
};

template <>
struct wcr_custom<double> {
template <typename WCR>
Expand Down Expand Up @@ -313,6 +348,33 @@ namespace dace {
DACE_HDFI float operator()(const float &a, const float &b) const { return ::max(a, b); }
};


template <>
struct _wcr_fixed<ReductionType::Min, dace::float32sr> {

static DACE_HDFI dace::float32sr reduce_atomic(dace::float32sr *ptr, const dace::float32sr& value) {
float *fptr = reinterpret_cast<float *>(ptr);
return static_cast<dace::float32sr>(
_wcr_fixed<ReductionType::Min, float>::reduce_atomic(fptr, static_cast<float>(value)));
}


DACE_HDFI dace::float32sr operator()(const dace::float32sr &a, const dace::float32sr &b) const { return ::min(a, b); }
};

template <>
struct _wcr_fixed<ReductionType::Max, dace::float32sr> {

static DACE_HDFI dace::float32sr reduce_atomic(dace::float32sr *ptr, const dace::float32sr& value) {
float *fptr = reinterpret_cast<float *>(ptr);
return static_cast<dace::float32sr>(
_wcr_fixed<ReductionType::Max, float>::reduce_atomic(fptr, static_cast<float>(value)));
}

DACE_HDFI dace::float32sr operator()(const dace::float32sr &a, const dace::float32sr &b) const { return ::max(a, b); }
};


template <>
struct _wcr_fixed<ReductionType::Min, double> {

Expand Down Expand Up @@ -555,6 +617,11 @@ namespace dace {
return _wcr_fixed<REDTYPE, T>::reduce_atomic(ptr, value);
}

static DACE_HDFI float reduce_atomic(dace::float32sr *ptr, const float& value)
{
return _wcr_fixed<REDTYPE, float>::reduce_atomic(reinterpret_cast<float*>(ptr), value);
}

DACE_HDFI T operator()(const T &a, const T &b) const
{
return _wcr_fixed<REDTYPE, T>()(a, b);
Expand Down
Loading