diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index bf2568a7..68ed496d 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -353,6 +353,38 @@ Py_hash_t PyCustomFloat_Hash(PyObject* self) { return HashImpl(&_Py_HashDouble, self, static_cast(x)); } +template +PyObject* PyCustomFloat_Format(PyObject* self, PyObject* args) { + PyObject* format_spec = nullptr; + if (!PyArg_ParseTuple(args, "U", &format_spec)) { + return nullptr; + } + + T x = reinterpret_cast*>(self)->value; + // Round to 6 significant digits to match PyCustomFloat_Str/Repr, which + // use std::ostringstream with its default precision of 6. This avoids + // exposing false precision from the float64 expansion. + char buf[14]; // max %.6g output: "-9.98378e+38" + '\0' = 14 + std::snprintf(buf, sizeof(buf), "%.6g", static_cast(x)); + double d = std::strtod(buf, nullptr); + + PyObject* float_obj = PyFloat_FromDouble(d); + if (!float_obj) { + return nullptr; + } + + PyObject* result = PyObject_Format(float_obj, format_spec); + Py_DECREF(float_obj); + return result; +} + +template +PyMethodDef CustomFloatType_methods[] = { + {"__format__", PyCustomFloat_Format, METH_VARARGS, + "Format a custom float value."}, + {nullptr, nullptr, 0, nullptr}, +}; + template PyType_Slot CustomFloatType::type_slots[] = { {Py_tp_new, reinterpret_cast(PyCustomFloat_New)}, @@ -362,6 +394,7 @@ PyType_Slot CustomFloatType::type_slots[] = { {Py_tp_doc, reinterpret_cast(const_cast(TypeDescriptor::kTpDoc))}, {Py_tp_richcompare, reinterpret_cast(PyCustomFloat_RichCompare)}, + {Py_tp_methods, reinterpret_cast(CustomFloatType_methods)}, {Py_nb_add, reinterpret_cast(PyCustomFloat_Add)}, {Py_nb_subtract, reinterpret_cast(PyCustomFloat_Subtract)}, {Py_nb_multiply, reinterpret_cast(PyCustomFloat_Multiply)},