Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,38 @@ Py_hash_t PyCustomFloat_Hash(PyObject* self) {
return HashImpl(&_Py_HashDouble, self, static_cast<double>(x));
}

template <typename T>
PyObject* PyCustomFloat_Format(PyObject* self, PyObject* args) {
PyObject* format_spec = nullptr;
if (!PyArg_ParseTuple(args, "U", &format_spec)) {
return nullptr;
}

T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
// Round to 6 significant digits to match PyCustomFloat_Str/Repr, which
// use std::ostringstream with its default precision of 6. This avoids
// exposing false precision from the float64 expansion.
char buf[14]; // max %.6g output: "-9.98378e+38" + '\0' = 14
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm finding this part strange. For any floating point type, numpy appears to do the following: https://github.com/numpy/numpy/blob/f105cf2d7c20c9829b431c2ae0cdb1f07efaccf2/numpy/_core/src/multiarray/scalartypes.c.src#L633

i.e., convert the type to a python float using nb_format and then call PyObject_Format on the result.

Is there a reason for us to deviate here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think?

std::snprintf(buf, sizeof(buf), "%.6g", static_cast<double>(x));
double d = std::strtod(buf, nullptr);

PyObject* float_obj = PyFloat_FromDouble(d);
if (!float_obj) {
return nullptr;
}

PyObject* result = PyObject_Format(float_obj, format_spec);
Py_DECREF(float_obj);
return result;
}

template <typename T>
PyMethodDef CustomFloatType_methods[] = {
{"__format__", PyCustomFloat_Format<T>, METH_VARARGS,
"Format a custom float value."},
{nullptr, nullptr, 0, nullptr},
};

template <typename T>
PyType_Slot CustomFloatType<T>::type_slots[] = {
{Py_tp_new, reinterpret_cast<void*>(PyCustomFloat_New<T>)},
Expand All @@ -362,6 +394,7 @@ PyType_Slot CustomFloatType<T>::type_slots[] = {
{Py_tp_doc,
reinterpret_cast<void*>(const_cast<char*>(TypeDescriptor<T>::kTpDoc))},
{Py_tp_richcompare, reinterpret_cast<void*>(PyCustomFloat_RichCompare<T>)},
{Py_tp_methods, reinterpret_cast<void*>(CustomFloatType_methods<T>)},
{Py_nb_add, reinterpret_cast<void*>(PyCustomFloat_Add<T>)},
{Py_nb_subtract, reinterpret_cast<void*>(PyCustomFloat_Subtract<T>)},
{Py_nb_multiply, reinterpret_cast<void*>(PyCustomFloat_Multiply<T>)},
Expand Down
Loading