From f1ccc47847f4e15aab67e506d1440936c31c96d9 Mon Sep 17 00:00:00 2001 From: Zijiang YANG Date: Mon, 23 Mar 2026 17:48:12 +0800 Subject: [PATCH 1/2] Fix custom float dtype (e.g. bfloat16) f-string formatting truncating exponent notation. --- ml_dtypes/_src/custom_float.h | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index bf2568a7..81ee4407 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -353,6 +353,40 @@ Py_hash_t PyCustomFloat_Hash(PyObject* self) { return HashImpl(&_Py_HashDouble, self, static_cast(x)); } +template +PyObject* PyCustomFloat_Format(PyObject* self, PyObject* args) { + PyObject* format_spec = nullptr; + if (!PyArg_ParseTuple(args, "U", &format_spec)) { + return nullptr; + } + + T x = reinterpret_cast*>(self)->value; + float f = static_cast(x); + + // Round to 6 significant digits to match PyCustomFloat_Str/Repr, which + // use std::ostringstream with its default precision of 6. This avoids + // exposing false precision from the float64 expansion. + char buf[14]; // max %.6g output: "-9.98378e+38" + '\0' = 14 + std::snprintf(buf, sizeof(buf), "%.6g", static_cast(f)); + double d = std::strtod(buf, nullptr); + + PyObject* float_obj = PyFloat_FromDouble(d); + if (!float_obj) { + return nullptr; + } + + PyObject* result = PyObject_Format(float_obj, format_spec); + Py_DECREF(float_obj); + return result; +} + +template +PyMethodDef CustomFloatType_methods[] = { + {"__format__", PyCustomFloat_Format, METH_VARARGS, + "Format a custom float value."}, + {nullptr, nullptr, 0, nullptr}, +}; + template PyType_Slot CustomFloatType::type_slots[] = { {Py_tp_new, reinterpret_cast(PyCustomFloat_New)}, @@ -362,6 +396,7 @@ PyType_Slot CustomFloatType::type_slots[] = { {Py_tp_doc, reinterpret_cast(const_cast(TypeDescriptor::kTpDoc))}, {Py_tp_richcompare, reinterpret_cast(PyCustomFloat_RichCompare)}, + {Py_tp_methods, reinterpret_cast(CustomFloatType_methods)}, {Py_nb_add, reinterpret_cast(PyCustomFloat_Add)}, {Py_nb_subtract, reinterpret_cast(PyCustomFloat_Subtract)}, {Py_nb_multiply, reinterpret_cast(PyCustomFloat_Multiply)}, From 5b9d2a6ef01ef25d0b7a0f85dd6a9fd677779362 Mon Sep 17 00:00:00 2001 From: Zijiang YANG Date: Mon, 23 Mar 2026 18:17:10 +0800 Subject: [PATCH 2/2] Remove redundant type conversion. --- ml_dtypes/_src/custom_float.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index 81ee4407..68ed496d 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -361,13 +361,11 @@ PyObject* PyCustomFloat_Format(PyObject* self, PyObject* args) { } T x = reinterpret_cast*>(self)->value; - float f = static_cast(x); - // Round to 6 significant digits to match PyCustomFloat_Str/Repr, which // use std::ostringstream with its default precision of 6. This avoids // exposing false precision from the float64 expansion. char buf[14]; // max %.6g output: "-9.98378e+38" + '\0' = 14 - std::snprintf(buf, sizeof(buf), "%.6g", static_cast(f)); + std::snprintf(buf, sizeof(buf), "%.6g", static_cast(x)); double d = std::strtod(buf, nullptr); PyObject* float_obj = PyFloat_FromDouble(d);