diff --git a/.github/actions/fuzz_tests/action.yml b/.github/actions/fuzz_tests/action.yml index 95c7b09..7b736ac 100644 --- a/.github/actions/fuzz_tests/action.yml +++ b/.github/actions/fuzz_tests/action.yml @@ -8,7 +8,7 @@ inputs: fuzz_time: description: 'Maximum time in seconds to run fuzzing' required: false - default: '180' + default: '300' cargo_fuzz_version: description: 'Version of cargo-fuzz to install' required: false @@ -26,5 +26,5 @@ runs: - name: Run Fuzz Tests shell: bash working-directory: fuzz - run: cargo fuzz run ${{ inputs.fuzz_target }} --release -- -max_total_time=${{ inputs.fuzz_time }} + run: cargo fuzz run ${{ inputs.fuzz_target }} --release -- -max_total_time=${{ inputs.fuzz_time }} -ignore_ooms=1 -rss_limit_mb=0 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f690705..79a6cad 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,8 @@ jobs: matrix: fuzz_target: - fuzz_json_de - # Add more fuzz targets here as needed + - fuzz_cbor_decode + - fuzz_cbor_roundtrip steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/Cargo.toml b/Cargo.toml index 9332c82..4ef9042 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,10 +30,15 @@ serde = { workspace = true } serde_json = { workspace = true } ctor = { version = "0.1.16", optional = true } paste = "1.0.15" -half = "2.0.0" +ciborium = { git = "https://github.com/AvivDavid23/ciborium", branch = "main" } +ciborium-ll = { git = "https://github.com/AvivDavid23/ciborium", branch = "main" } +bytemuck = "1" +half = { version = "2.0.0", features = ["bytemuck"] } +thiserror = "2.0.18" +zstd = "0.13" [dev-dependencies] mockalloc = "0.1.2" ctor = "0.1.16" rand = "0.8.4" - +zstd = "0.13" diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 09cd0b9..74da9c7 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -7,11 +7,16 @@ edition = "2021" [package.metadata] cargo-fuzz = true +[lib] +name = "ijson_fuzz" +path = "src/lib.rs" + [dependencies] libfuzzer-sys = "0.4" arbitrary = { version = "1.3", features = ["derive"] } -serde = { workspace = true } +serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +arbitrary-json = "=0.1.1" [dependencies.ijson] path = ".." @@ -22,3 +27,17 @@ path = "fuzz_targets/fuzz_json_de.rs" test = false doc = false bench = false + +[[bin]] +name = "fuzz_cbor_decode" +path = "fuzz_targets/fuzz_cbor_decode.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "fuzz_cbor_roundtrip" +path = "fuzz_targets/fuzz_cbor_roundtrip.rs" +test = false +doc = false +bench = false diff --git a/fuzz/fuzz_targets/fuzz_cbor_decode.rs b/fuzz/fuzz_targets/fuzz_cbor_decode.rs new file mode 100644 index 0000000..c58951d --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_cbor_decode.rs @@ -0,0 +1,8 @@ +#![no_main] + +use ijson::cbor::decode; +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: &[u8]| { + let _ = decode(data); +}); diff --git a/fuzz/fuzz_targets/fuzz_cbor_roundtrip.rs b/fuzz/fuzz_targets/fuzz_cbor_roundtrip.rs new file mode 100644 index 0000000..7da09a5 --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_cbor_roundtrip.rs @@ -0,0 +1,22 @@ +#![no_main] + +use arbitrary_json::ArbitraryValue; +use ijson::{cbor, IValue}; +use libfuzzer_sys::fuzz_target; +use serde::Deserialize; + +fuzz_target!(|value: ArbitraryValue| { + let json_string = value.to_string(); + let mut deserializer = serde_json::Deserializer::from_str(&json_string); + let Ok(original) = IValue::deserialize(&mut deserializer) else { + return; + }; + + let encoded = cbor::encode(&original); + let decoded = cbor::decode(&encoded).expect("encode->decode round-trip must not fail"); + + assert_eq!( + original, decoded, + "round-trip mismatch for input: {json_string}" + ); +}); diff --git a/fuzz/fuzz_targets/fuzz_json_de.rs b/fuzz/fuzz_targets/fuzz_json_de.rs index 1592e59..b5a35cf 100644 --- a/fuzz/fuzz_targets/fuzz_json_de.rs +++ b/fuzz/fuzz_targets/fuzz_json_de.rs @@ -1,54 +1,12 @@ #![no_main] -use arbitrary::Arbitrary; +use arbitrary_json::ArbitraryValue; use ijson::IValue; use libfuzzer_sys::fuzz_target; use serde::Deserialize; -use std::collections::HashMap; -#[derive(Arbitrary, Debug)] -enum JsonValue { - Null, - Bool(bool), - Number(f64), - String(String), - Array(Vec), - Object(HashMap), -} - -impl JsonValue { - fn to_json_string(&self) -> String { - match self { - JsonValue::Null => "null".to_string(), - JsonValue::Bool(b) => b.to_string(), - JsonValue::Number(n) => { - if n.is_finite() { - n.to_string() - } else { - "0".to_string() - } - } - JsonValue::String(s) => format!("\"{}\"", s), - JsonValue::Array(arr) => { - let items: Vec = arr.iter().map(|v| v.to_json_string()).collect(); - format!("[{}]", items.join(",")) - } - JsonValue::Object(obj) => { - let items: Vec = obj - .iter() - .map(|(k, v)| { - let key = k.clone(); - format!("\"{}\":{}", key, v.to_json_string()) - }) - .collect(); - format!("{{{}}}", items.join(",")) - } - } - } -} - -fuzz_target!(|value: JsonValue| { - let json_string = value.to_json_string(); +fuzz_target!(|value: ArbitraryValue| { + let json_string = value.to_string(); let mut deserializer = serde_json::Deserializer::from_str(&json_string); let _ = IValue::deserialize(&mut deserializer); }); diff --git a/fuzz/src/lib.rs b/fuzz/src/lib.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/fuzz/src/lib.rs @@ -0,0 +1 @@ + diff --git a/src/array.rs b/src/array.rs index 2cc0d70..864f85e 100644 --- a/src/array.rs +++ b/src/array.rs @@ -9,8 +9,9 @@ use std::iter::FromIterator; use std::ops::{Index, IndexMut}; use std::slice::{from_raw_parts, from_raw_parts_mut, SliceIndex}; +use crate::error::IJsonError; use crate::{ - alloc::AllocError, + error::AllocError, thin::{ThinMut, ThinMutExt, ThinRef, ThinRefExt}, value::TypeTag, Defrag, DefragAllocator, IValue, @@ -54,6 +55,55 @@ impl Default for ArrayTag { } } +/// Enum representing different types of floating-point types +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum FloatType { + /// F16 + F16 = 1, + /// BF16 + BF16, + /// F32 + F32, + /// F64 + F64, +} + +impl fmt::Display for FloatType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FloatType::F16 => write!(f, "F16"), + FloatType::BF16 => write!(f, "BF16"), + FloatType::F32 => write!(f, "F32"), + FloatType::F64 => write!(f, "F64"), + } + } +} + +impl TryFrom for FloatType { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(FloatType::F16), + 2 => Ok(FloatType::BF16), + 3 => Ok(FloatType::F32), + 4 => Ok(FloatType::F64), + _ => Err(()), + } + } +} + +impl From for ArrayTag { + fn from(fp_type: FloatType) -> Self { + match fp_type { + FloatType::F16 => ArrayTag::F16, + FloatType::BF16 => ArrayTag::BF16, + FloatType::F32 => ArrayTag::F32, + FloatType::F64 => ArrayTag::F64, + } + } +} + impl ArrayTag { fn from_type() -> Self { use ArrayTag::*; @@ -401,11 +451,11 @@ impl Header { const TAG_MASK: u64 = 0xF; const TAG_SHIFT: u64 = 60; - const fn new(len: usize, cap: usize, tag: ArrayTag) -> Result { + const fn new(len: usize, cap: usize, tag: ArrayTag) -> Result { // assert!(len <= Self::LEN_MASK as usize, "Length exceeds 30-bit limit"); // assert!(cap <= Self::CAP_MASK as usize, "Capacity exceeds 30-bit limit"); if len > Self::LEN_MASK as usize || cap > Self::CAP_MASK as usize { - return Err(AllocError); + return Err(IJsonError::Alloc(AllocError)); } let packed = ((len as u64) & Self::LEN_MASK) << Self::LEN_SHIFT @@ -561,6 +611,26 @@ trait HeaderMut<'a>: ThinMutExt<'a, Header> { self.set_len(index + 1); } + // Safety: Space must already be allocated for the item, + // and the item must be a number. The array type must be a floating-point type. + unsafe fn push_lossy(&mut self, item: IValue) { + use ArrayTag::*; + let index = self.len(); + + macro_rules! push_lossy_impl { + ($(($tag:ident, $ty:ty)),*) => { + match self.type_tag() { + $($tag => self.reborrow().raw_array_ptr_mut().cast::<$ty>().add(index).write( + paste::paste!(item.[]()).unwrap()),)* + _ => unreachable!(), + } + } + } + + push_lossy_impl!((F16, f16), (BF16, bf16), (F32, f32), (F64, f64)); + self.set_len(index + 1); + } + fn pop(&mut self) -> Option { if self.len() == 0 { None @@ -670,7 +740,7 @@ impl IArray { .pad_to_align()) } - fn alloc(cap: usize, tag: ArrayTag) -> Result<*mut Header, AllocError> { + fn alloc(cap: usize, tag: ArrayTag) -> Result<*mut Header, IJsonError> { unsafe { let ptr = alloc(Self::layout(cap, tag).map_err(|_| AllocError)?).cast::
(); ptr.write(Header::new(0, cap, tag)?); @@ -678,7 +748,7 @@ impl IArray { } } - fn realloc(ptr: *mut Header, new_cap: usize) -> Result<*mut Header, AllocError> { + fn realloc(ptr: *mut Header, new_cap: usize) -> Result<*mut Header, IJsonError> { unsafe { let tag = (*ptr).type_tag(); let old_layout = Self::layout((*ptr).cap(), tag).map_err(|_| AllocError)?; @@ -706,13 +776,13 @@ impl IArray { /// Constructs a new `IArray` with the specified capacity. At least that many items /// can be added to the array without reallocating. #[must_use] - pub fn with_capacity(cap: usize) -> Result { + pub fn with_capacity(cap: usize) -> Result { Self::with_capacity_and_tag(cap, ArrayTag::Heterogeneous) } /// Constructs a new `IArray` with the specified capacity and array type. #[must_use] - fn with_capacity_and_tag(cap: usize, tag: ArrayTag) -> Result { + fn with_capacity_and_tag(cap: usize, tag: ArrayTag) -> Result { if cap == 0 { Ok(Self::new()) } else { @@ -743,7 +813,7 @@ impl IArray { /// Converts this array to a new type, promoting all existing elements. /// This is used for automatic type promotion when incompatible types are added. - fn promote_to_type(&mut self, new_tag: ArrayTag) -> Result<(), AllocError> { + fn promote_to_type(&mut self, new_tag: ArrayTag) -> Result<(), IJsonError> { if self.is_static() || self.header().type_tag() == new_tag { return Ok(()); } @@ -898,7 +968,7 @@ impl IArray { self.header_mut().as_mut_slice_unchecked::() } - fn resize_internal(&mut self, cap: usize) -> Result<(), AllocError> { + fn resize_internal(&mut self, cap: usize) -> Result<(), IJsonError> { if self.is_static() || cap == 0 { let tag = if self.is_static() { ArrayTag::Heterogeneous @@ -916,7 +986,7 @@ impl IArray { } /// Reserves space for at least this many additional items. - pub fn reserve(&mut self, additional: usize) -> Result<(), AllocError> { + pub fn reserve(&mut self, additional: usize) -> Result<(), IJsonError> { let hd = self.header(); let current_capacity = hd.cap(); let desired_capacity = hd.len().checked_add(additional).ok_or(AllocError)?; @@ -956,7 +1026,7 @@ impl IArray { /// on or after this index will be shifted down to accomodate this. For large /// arrays, insertions near the front will be slow as it will require shifting /// a large number of items. - pub fn insert(&mut self, index: usize, item: impl Into) -> Result<(), AllocError> { + pub fn insert(&mut self, index: usize, item: impl Into) -> Result<(), IJsonError> { let item = item.into(); let current_tag = self.header().type_tag(); let len = self.len(); @@ -1080,8 +1150,50 @@ impl IArray { } } + /// Pushes a new item onto the back of the array with a specific floating-point type, potentially losing precision. + /// If the item is not a number, it is pushed as is. + pub(crate) fn push_with_fp_type( + &mut self, + item: impl Into, + fp_type: FloatType, + ) -> Result<(), IJsonError> { + let item = item.into(); + let desired_tag = fp_type.into(); + let current_tag = self.header().type_tag(); + let len = self.len(); + if !item.is_number() || (current_tag != desired_tag && len > 0) { + return self.push(item); + } + let can_fit = || match fp_type { + FloatType::F16 => item.to_f16_lossy().map_or(false, |v| v.is_finite()), + FloatType::BF16 => item.to_bf16_lossy().map_or(false, |v| v.is_finite()), + FloatType::F32 => item.to_f32_lossy().map_or(false, |v| v.is_finite()), + FloatType::F64 => item.to_f64_lossy().map_or(false, |v| v.is_finite()), + }; + + if !can_fit() { + return Err(IJsonError::OutOfRange(fp_type)); + } + + // We can fit the item into the array, so we can push it directly + + if len == 0 { + if self.is_static() { + *self = IArray::with_capacity_and_tag(4, desired_tag)?; + } else { + self.promote_to_type(desired_tag)?; + } + } + + self.reserve(1)?; + unsafe { + self.header_mut().push_lossy(item); + } + Ok(()) + } + /// Pushes a new item onto the back of the array. - pub fn push(&mut self, item: impl Into) -> Result<(), AllocError> { + pub fn push(&mut self, item: impl Into) -> Result<(), IJsonError> { let item = item.into(); let current_tag = self.header().type_tag(); let len = self.len(); @@ -1425,11 +1537,11 @@ pub trait TryExtend { /// Returns an `AllocError` if allocation fails. /// # Errors /// Returns an `AllocError` if memory allocation fails during the extension. - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError>; + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError>; } impl + private::Sealed> TryExtend for IArray { - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError> { + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError> { let iter = iter.into_iter(); self.reserve(iter.size_hint().0)?; for v in iter { @@ -1442,7 +1554,7 @@ impl + private::Sealed> TryExtend for IArray { macro_rules! extend_impl_int { ($($ty:ty),*) => { $(impl TryExtend<$ty> for IArray { - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError> { + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError> { let expected_tag = ArrayTag::from_type::<$ty>(); let iter = iter.into_iter(); let size_hint = iter.size_hint().0; @@ -1494,7 +1606,7 @@ macro_rules! extend_impl_int { macro_rules! extend_impl_float { ($($ty:ty),*) => { $(impl TryExtend<$ty> for IArray { - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError> { + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError> { let expected_tag = ArrayTag::from_type::<$ty>(); let iter = iter.into_iter(); let size_hint = iter.size_hint().0; @@ -1564,13 +1676,13 @@ pub trait TryFromIterator { /// Returns an `AllocError` if allocation fails. /// # Errors /// Returns `AllocError` if memory allocation fails during the construction. - fn try_from_iter>(iter: U) -> Result + fn try_from_iter>(iter: U) -> Result where Self: Sized; } impl + private::Sealed> TryFromIterator for IArray { - fn try_from_iter>(iter: T) -> Result { + fn try_from_iter>(iter: T) -> Result { let mut res = IArray::new(); res.try_extend(iter)?; Ok(res) @@ -1580,7 +1692,7 @@ impl + private::Sealed> TryFromIterator for IArray { macro_rules! from_iter_impl { ($($ty:ty),*) => { $(impl TryFromIterator<$ty> for IArray { - fn try_from_iter>(iter: T) -> Result { + fn try_from_iter>(iter: T) -> Result { let iter = iter.into_iter(); let mut res = IArray::with_capacity_and_tag(iter.size_hint().0, ArrayTag::from_type::<$ty>())?; res.try_extend(iter)?; @@ -1599,13 +1711,13 @@ pub trait TryCollect: Iterator + Sized { /// Returns an `AllocError` if allocation fails. /// # Errors /// Returns `AllocError` if memory allocation fails during the collection. - fn try_collect(self) -> Result + fn try_collect(self) -> Result where B: TryFromIterator; } impl> TryCollect for I { - fn try_collect(self) -> Result + fn try_collect(self) -> Result where B: TryFromIterator, { @@ -1614,7 +1726,7 @@ impl> TryCollect for I { } impl + private::Sealed> TryFrom> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: Vec) -> Result { let mut res = IArray::with_capacity(other.len())?; res.try_extend(other.into_iter().map(Into::into))?; @@ -1623,7 +1735,7 @@ impl + private::Sealed> TryFrom> for IArray { } impl + Clone + private::Sealed> TryFrom<&[T]> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: &[T]) -> Result { let mut res = IArray::with_capacity(other.len())?; res.try_extend(other.iter().cloned().map(Into::into))?; @@ -1634,7 +1746,7 @@ impl + Clone + private::Sealed> TryFrom<&[T]> for IArray { macro_rules! from_slice_impl { ($($ty:ty),*) => {$( impl TryFrom> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: Vec<$ty>) -> Result { let mut res = IArray::with_capacity_and_tag(other.len(), ArrayTag::from_type::<$ty>())?; TryExtend::<$ty>::try_extend(&mut res, other.into_iter().map(Into::into))?; @@ -1642,7 +1754,7 @@ macro_rules! from_slice_impl { } } impl TryFrom<&[$ty]> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: &[$ty]) -> Result { let mut res = IArray::with_capacity_and_tag(other.len(), ArrayTag::from_type::<$ty>())?; TryExtend::<$ty>::try_extend(&mut res, other.iter().cloned().map(Into::into))?; @@ -3207,4 +3319,28 @@ mod tests { } } } + + #[test] + fn test_push_with_fp_type_creates_typed_array() { + let mut arr = IArray::new(); + arr.push_with_fp_type(IValue::from(1.5), FloatType::F16) + .unwrap(); + arr.push_with_fp_type(IValue::from(2.5), FloatType::F16) + .unwrap(); + + assert_eq!(arr.len(), 2); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + } + + #[test] + fn test_push_with_fp_type_overflow_rejected() { + let mut arr = IArray::new(); + arr.push_with_fp_type(IValue::from(1.5), FloatType::F16) + .unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + arr.push_with_fp_type(IValue::from(100000.0), FloatType::F16) + .unwrap_err(); + assert_eq!(arr.len(), 1); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + } } diff --git a/src/cbor.rs b/src/cbor.rs new file mode 100644 index 0000000..b67fda0 --- /dev/null +++ b/src/cbor.rs @@ -0,0 +1,489 @@ +//! CBOR encode/decode for [`IValue`], preserving typed array tags via RFC 8746. +//! +//! Typed homogeneous arrays are encoded as `Tag(rfc8746_tag, Bytes(raw_le_bytes))`. +//! BF16 arrays use private tag `0x10000` (no standard RFC 8746 equivalent). +//! +//! Use [`encode`] / [`decode`] for raw CBOR, or [`encode_compressed`] / +//! [`decode_compressed`] for zstd-compressed CBOR. + +use std::fmt; + +use bytemuck::Pod; +use ciborium::value::{Integer, Value}; +use ciborium_ll::tag; +use half::{bf16, f16}; + +/// Converts a typed slice to/from little-endian bytes using each type's own +/// `to_le_bytes` / `from_le_bytes`, which are correct on every host endianness. +trait LeBytes: Pod + Copy { + fn slice_to_le_bytes(s: &[Self]) -> Vec; + fn slice_from_le_bytes(bytes: &[u8]) -> Result, CborDecodeError>; +} + +macro_rules! impl_le_bytes { + ($($t:ty => $n:literal),* $(,)?) => {$( + impl LeBytes for $t { + fn slice_to_le_bytes(s: &[Self]) -> Vec { + s.iter().flat_map(|v| v.to_le_bytes()).collect() + } + fn slice_from_le_bytes(bytes: &[u8]) -> Result, CborDecodeError> { + if bytes.len() % $n != 0 { + return Err(CborDecodeError::CastError); + } + Ok(bytes + .chunks_exact($n) + .map(|c| { + // SAFETY: `chunks_exact($n)` guarantees every chunk + // is exactly $n bytes, matching [u8; $n]. + let arr: [u8; $n] = c.try_into().unwrap(); + Self::from_le_bytes(arr) + }) + .collect()) + } + } + )*}; +} + +impl_le_bytes!( + i8 => 1, u8 => 1, + i16 => 2, u16 => 2, f16 => 2, bf16 => 2, + i32 => 4, u32 => 4, f32 => 4, + i64 => 8, u64 => 8, f64 => 8, +); + +use crate::array::ArraySliceRef; +use crate::{DestructuredRef, IArray, INumber, IObject, IString, IValue}; + +use tag::{ + TYPED_F16_LE as TAG_F16_LE, TYPED_F32_LE as TAG_F32_LE, TYPED_F64_LE as TAG_F64_LE, + TYPED_I16_LE as TAG_I16_LE, TYPED_I32_LE as TAG_I32_LE, TYPED_I64_LE as TAG_I64_LE, + TYPED_I8 as TAG_I8, TYPED_U16_LE as TAG_U16_LE, TYPED_U32_LE as TAG_U32_LE, + TYPED_U64_LE as TAG_U64_LE, TYPED_U8 as TAG_U8, +}; + +/// Private CBOR tag for BF16 arrays (no RFC 8746 standard tag exists for BF16). +const TAG_BF16_LE: u64 = 0x10000; + +const MAX_DEPTH: u32 = 128; + +/// Error returned when CBOR decoding fails. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CborDecodeError { + /// The CBOR stream was malformed or could not be parsed. + DecodeError, + /// An unrecognised CBOR tag was encountered where a typed array was expected. + UnknownTag(u64), + /// A CBOR map key was not a text string. + InvalidValue, + /// An array allocation failed. + AllocError, + /// Nesting depth exceeded the limit. + DepthLimitExceeded, + /// Failed to reinterpret a byte slice. + CastError, + /// Zstd decompression failed. + DecompressError, +} + +impl fmt::Display for CborDecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CborDecodeError::DecodeError => write!(f, "CBOR decode error"), + CborDecodeError::UnknownTag(t) => write!(f, "unknown CBOR tag: {t}"), + CborDecodeError::InvalidValue => write!(f, "unexpected CBOR value type"), + CborDecodeError::AllocError => write!(f, "memory allocation failed"), + CborDecodeError::DepthLimitExceeded => write!(f, "nesting depth limit exceeded"), + CborDecodeError::CastError => write!(f, "failed to cast byte slice"), + CborDecodeError::DecompressError => write!(f, "zstd decompression failed"), + } + } +} + +impl std::error::Error for CborDecodeError {} + +// ── Encode ──────────────────────────────────────────────────────────────────── + +/// Encodes an [`IValue`] tree into CBOR bytes, preserving typed array tags via +/// RFC 8746. +pub fn encode(value: &IValue) -> Vec { + let cbor = ivalue_to_cbor(value); + let mut out = Vec::new(); + ciborium::into_writer(&cbor, &mut out).expect("write to Vec never fails"); + out +} + +/// Encodes an [`IValue`] tree as CBOR and then compresses it with zstd (level 3). +/// +/// Use [`decode_compressed`] to decode the output. +pub fn encode_compressed(value: &IValue) -> Vec { + let raw = encode(value); + zstd::bulk::Compressor::default() + .compress(&raw) + .expect("zstd compress") +} + +fn ivalue_to_cbor(value: &IValue) -> Value { + match value.destructure_ref() { + DestructuredRef::Null => Value::Null, + DestructuredRef::Bool(b) => Value::Bool(b), + DestructuredRef::Number(n) => number_to_cbor(n), + DestructuredRef::String(s) => Value::Text(s.as_str().to_owned()), + DestructuredRef::Array(a) => array_to_cbor(a), + DestructuredRef::Object(o) => object_to_cbor(o), + } +} + +fn number_to_cbor(n: &INumber) -> Value { + if n.has_decimal_point() { + Value::Float(n.to_f64().unwrap()) + } else if let Some(i) = n.to_i64() { + Value::Integer(Integer::from(i)) + } else { + Value::Integer(Integer::from(n.to_u64().unwrap())) + } +} + +fn array_to_cbor(a: &IArray) -> Value { + match a.as_slice() { + ArraySliceRef::Heterogeneous(s) => Value::Array(s.iter().map(ivalue_to_cbor).collect()), + ArraySliceRef::I8(s) => typed_le_tag(TAG_I8, s), + ArraySliceRef::U8(s) => typed_le_tag(TAG_U8, s), + ArraySliceRef::I16(s) => typed_le_tag(TAG_I16_LE, s), + ArraySliceRef::U16(s) => typed_le_tag(TAG_U16_LE, s), + ArraySliceRef::F16(s) => typed_le_tag(TAG_F16_LE, s), + ArraySliceRef::BF16(s) => typed_le_tag(TAG_BF16_LE, s), + ArraySliceRef::I32(s) => typed_le_tag(TAG_I32_LE, s), + ArraySliceRef::U32(s) => typed_le_tag(TAG_U32_LE, s), + ArraySliceRef::F32(s) => typed_le_tag(TAG_F32_LE, s), + ArraySliceRef::I64(s) => typed_le_tag(TAG_I64_LE, s), + ArraySliceRef::U64(s) => typed_le_tag(TAG_U64_LE, s), + ArraySliceRef::F64(s) => typed_le_tag(TAG_F64_LE, s), + } +} + +fn object_to_cbor(o: &IObject) -> Value { + Value::Map( + o.iter() + .map(|(k, v)| (Value::Text(k.as_str().to_owned()), ivalue_to_cbor(v))) + .collect(), + ) +} + +fn typed_le_tag(tag: u64, s: &[T]) -> Value { + Value::Tag(tag, Box::new(Value::Bytes(T::slice_to_le_bytes(s)))) +} + +// ── Decode ──────────────────────────────────────────────────────────────────── + +/// Decodes an [`IValue`] tree from CBOR bytes produced by [`encode`]. +pub fn decode(bytes: &[u8]) -> Result { + let cbor: Value = ciborium::from_reader(bytes).map_err(|_| CborDecodeError::DecodeError)?; + cbor_to_ivalue(cbor, 0) +} + +/// Decodes an [`IValue`] tree from bytes produced by [`encode_compressed`]. +pub fn decode_compressed(bytes: &[u8]) -> Result { + let raw = zstd::decode_all(bytes).map_err(|_| CborDecodeError::DecompressError)?; + decode(&raw) +} + +fn cbor_to_ivalue(val: Value, depth: u32) -> Result { + if depth >= MAX_DEPTH { + return Err(CborDecodeError::DepthLimitExceeded); + } + match val { + Value::Null => Ok(IValue::NULL), + Value::Bool(b) => Ok(b.into()), + Value::Float(f) => Ok(INumber::try_from(f).map(Into::into).unwrap_or(IValue::NULL)), + Value::Integer(i) => { + if let Ok(v) = i64::try_from(i.clone()) { + Ok(IValue::from(v)) + } else if let Ok(v) = u64::try_from(i) { + Ok(IValue::from(v)) + } else { + Err(CborDecodeError::InvalidValue) + } + } + Value::Text(s) => Ok(IString::from(s.as_str()).into()), + Value::Array(arr) => { + let hint = arr.len().min(1024); + let mut out = IArray::with_capacity(hint).map_err(|_| CborDecodeError::AllocError)?; + for v in arr { + let iv = cbor_to_ivalue(v, depth + 1)?; + out.push(iv).map_err(|_| CborDecodeError::AllocError)?; + } + Ok(out.into()) + } + Value::Map(entries) => { + let mut obj = IObject::with_capacity(entries.len()); + for (k, v) in entries { + let key = match k { + Value::Text(s) => s, + _ => return Err(CborDecodeError::InvalidValue), + }; + let val = cbor_to_ivalue(v, depth + 1)?; + obj.insert(&key, val); + } + Ok(obj.into()) + } + Value::Tag(tag, inner) => decode_typed_array(tag, *inner), + Value::Bytes(_) => Err(CborDecodeError::InvalidValue), + _ => Err(CborDecodeError::InvalidValue), + } +} + +fn decode_typed_array(tag: u64, inner: Value) -> Result { + let bytes = match inner { + Value::Bytes(b) => b, + _ => return Err(CborDecodeError::InvalidValue), + }; + match tag { + TAG_U8 => decode_le_array::(&bytes), + TAG_I8 => decode_le_array::(&bytes), + TAG_U16_LE => decode_le_array::(&bytes), + TAG_I16_LE => decode_le_array::(&bytes), + TAG_F16_LE => decode_le_array::(&bytes), + TAG_BF16_LE => decode_le_array::(&bytes), + TAG_U32_LE => decode_le_array::(&bytes), + TAG_I32_LE => decode_le_array::(&bytes), + TAG_F32_LE => decode_le_array::(&bytes), + TAG_U64_LE => decode_le_array::(&bytes), + TAG_I64_LE => decode_le_array::(&bytes), + TAG_F64_LE => decode_le_array::(&bytes), + other => Err(CborDecodeError::UnknownTag(other)), + } +} + +fn decode_le_array(bytes: &[u8]) -> Result +where + T: LeBytes, + IArray: TryFrom>, +{ + IArray::try_from(T::slice_from_le_bytes(bytes)?) + .map(Into::into) + .map_err(|_| CborDecodeError::AllocError) +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::ArraySliceRef; + use crate::IValueDeserSeed; + use serde::de::DeserializeSeed; + + fn round_trip(value: &IValue) -> IValue { + let bytes = encode(value); + decode(&bytes).expect("decode should succeed") + } + + #[test] + fn test_null() { + assert_eq!(round_trip(&IValue::NULL), IValue::NULL); + } + + #[test] + fn test_bool() { + let t: IValue = true.into(); + let f: IValue = false.into(); + assert_eq!(round_trip(&t), t); + assert_eq!(round_trip(&f), f); + } + + #[test] + fn test_numbers() { + let cases: Vec = vec![ + 0i64.into(), + 42i64.into(), + (-1i64).into(), + i64::MAX.into(), + u64::MAX.into(), + 1.5f64.into(), + (-3.14f64).into(), + ]; + for v in &cases { + assert_eq!(round_trip(v), *v); + } + } + + #[test] + fn test_string() { + let v: IValue = IString::from("hello world").into(); + assert_eq!(round_trip(&v), v); + } + + #[test] + fn test_heterogeneous_array() { + let mut arr = IArray::new(); + arr.push(IValue::NULL).unwrap(); + arr.push(IValue::from(true)).unwrap(); + arr.push(IValue::from(42i64)).unwrap(); + arr.push(IValue::from(IString::from("hi"))).unwrap(); + let v: IValue = arr.into(); + let result = round_trip(&v); + let result_arr = result.as_array().unwrap(); + assert!(matches!( + result_arr.as_slice(), + ArraySliceRef::Heterogeneous(_) + )); + assert_eq!(result_arr.len(), 4); + } + + #[test] + fn test_f32_array_preserves_tag() { + let seed = IValueDeserSeed::new(Some(crate::FPHAConfig::new_with_type( + crate::FloatType::F32, + ))); + let json = r#"[1.5, 2.5, 3.5]"#; + let mut de = serde_json::Deserializer::from_str(json); + let v = seed.deserialize(&mut de).unwrap(); + assert!(matches!( + v.as_array().unwrap().as_slice(), + ArraySliceRef::F32(_) + )); + + let result = round_trip(&v); + let arr = result.as_array().unwrap(); + assert!( + matches!(arr.as_slice(), ArraySliceRef::F32(_)), + "F32 tag should survive CBOR encode/decode" + ); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_f16_array_preserves_tag() { + let seed = IValueDeserSeed::new(Some(crate::FPHAConfig::new_with_type( + crate::FloatType::F16, + ))); + let json = r#"[0.5, 1.0, 1.5]"#; + let mut de = serde_json::Deserializer::from_str(json); + let v = seed.deserialize(&mut de).unwrap(); + assert!(matches!( + v.as_array().unwrap().as_slice(), + ArraySliceRef::F16(_) + )); + + let result = round_trip(&v); + let arr = result.as_array().unwrap(); + assert!( + matches!(arr.as_slice(), ArraySliceRef::F16(_)), + "F16 tag should survive CBOR encode/decode" + ); + } + + #[test] + fn test_bf16_array_preserves_tag() { + let seed = IValueDeserSeed::new(Some(crate::FPHAConfig::new_with_type( + crate::FloatType::BF16, + ))); + let json = r#"[1.0, 2.0, 3.0]"#; + let mut de = serde_json::Deserializer::from_str(json); + let v = seed.deserialize(&mut de).unwrap(); + + let result = round_trip(&v); + let arr = result.as_array().unwrap(); + assert!( + matches!(arr.as_slice(), ArraySliceRef::BF16(_)), + "BF16 tag should survive CBOR encode/decode" + ); + } + + #[test] + fn test_f64_array_preserves_tag() { + let seed = IValueDeserSeed::new(Some(crate::FPHAConfig::new_with_type( + crate::FloatType::F64, + ))); + let json = r#"[1.0, 2.0, 3.0]"#; + let mut de = serde_json::Deserializer::from_str(json); + let v = seed.deserialize(&mut de).unwrap(); + + let result = round_trip(&v); + let arr = result.as_array().unwrap(); + assert!( + matches!(arr.as_slice(), ArraySliceRef::F64(_)), + "F64 tag should survive CBOR encode/decode" + ); + } + + #[test] + fn test_nested_object_with_typed_arrays() { + let seed = IValueDeserSeed::new(Some(crate::FPHAConfig::new_with_type( + crate::FloatType::F32, + ))); + let json = r#"{"a": [1.0, 2.0], "b": "text", "c": [3.0, 4.0]}"#; + let mut de = serde_json::Deserializer::from_str(json); + let v = seed.deserialize(&mut de).unwrap(); + + let result = round_trip(&v); + let obj = result.as_object().unwrap(); + let a = obj.get("a").unwrap().as_array().unwrap(); + let c = obj.get("c").unwrap().as_array().unwrap(); + assert!(matches!(a.as_slice(), ArraySliceRef::F32(_))); + assert!(matches!(c.as_slice(), ArraySliceRef::F32(_))); + assert_eq!(obj.get("b").unwrap().as_string().unwrap().as_str(), "text"); + } + + #[test] + fn test_compressed_round_trip() { + let seed = IValueDeserSeed::new(Some(crate::FPHAConfig::new_with_type( + crate::FloatType::F32, + ))); + let json = r#"[1.5, 2.5, 3.5, 4.5, 5.5]"#; + let mut de = serde_json::Deserializer::from_str(json); + let v = seed.deserialize(&mut de).unwrap(); + + let bytes = encode_compressed(&v); + let result = decode_compressed(&bytes).expect("decode_compressed should succeed"); + assert!(matches!( + result.as_array().unwrap().as_slice(), + ArraySliceRef::F32(_) + )); + } + + #[test] + fn test_small_integers_compact() { + // Small integers should be encoded more compactly in CBOR than custom binary. + let v: IValue = 42i64.into(); + let cbor_bytes = encode(&v); + // 42 fits in a single CBOR byte (major type 0, value 24 triggers 1-byte header + 1-byte value) + // Either way it's much smaller than the custom binary's fixed 9 bytes. + assert!( + cbor_bytes.len() < 9, + "expected CBOR to be smaller than 9-byte fixed encoding" + ); + } + + #[test] + fn test_misaligned_typed_array_is_rejected() { + // Hand-craft valid CBOR: Tag(85, Bytes[13 bytes]). + // 13 bytes is not divisible by 4 (F32 element size), so decode must + // return an error rather than silently truncating to 3 elements. + // + // CBOR layout: + // 0xD8 0x55 — tag 85 (RFC 8746 F32-LE), 2-byte form (tag >= 24) + // 0x4D — byte string, length 13 (0x40 | 13) + // [0u8; 13] — 13 zero bytes (misaligned: 13 % 4 != 0) + let mut bytes = vec![0xD8, 0x55, 0x4D]; + bytes.extend_from_slice(&[0u8; 13]); + assert_eq!( + decode(&bytes), + Err(CborDecodeError::CastError), + "F32 typed array with 13 bytes (not a multiple of 4) must be rejected" + ); + } + + #[test] + fn test_misaligned_f16_array_is_rejected() { + // Tag 84 (F16-LE), byte string of 5 bytes (not divisible by 2). + // 0xD8 0x54 — tag 84; 0x45 — byte string length 5 + let mut bytes = vec![0xD8, 0x54, 0x45]; + bytes.extend_from_slice(&[0u8; 5]); + assert_eq!( + decode(&bytes), + Err(CborDecodeError::CastError), + "F16 typed array with 5 bytes (not a multiple of 2) must be rejected" + ); + } +} diff --git a/src/de.rs b/src/de.rs index 4eabe26..e1c9068 100644 --- a/src/de.rs +++ b/src/de.rs @@ -8,14 +8,54 @@ use serde::de::{ use serde::{forward_to_deserialize_any, Deserialize, Deserializer}; use serde_json::error::Error; -use crate::{DestructuredRef, IArray, INumber, IObject, IString, IValue}; +use crate::{DestructuredRef, FloatType, IArray, INumber, IObject, IString, IValue}; + +#[derive(Debug, Clone, Copy)] +/// Configuration for floating point homogeneous arrays. +pub struct FPHAConfig { + /// Floating point type for homogeneous arrays. + pub fpha_type: FloatType, +} + +impl FPHAConfig { + /// Creates a new [`FPHAConfig`] with the given floating point type. + pub fn new_with_type(fpha_type: FloatType) -> Self { + Self { fpha_type } + } +} + +/// Seed for deserializing an [`IValue`]. +#[derive(Debug, Default)] +pub struct IValueDeserSeed { + /// Optional FPHA configuration for homogeneous arrays. + pub fpha_config: Option, +} + +impl IValueDeserSeed { + /// Creates a new [`IValueDeserSeed`] with the given floating point type enforcment type for homogeneous arrays. + pub fn new(fpha_config: Option) -> Self { + IValueDeserSeed { fpha_config } + } +} + +impl<'de> DeserializeSeed<'de> for IValueDeserSeed { + type Value = IValue; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Pass hint to a custom visitor + deserializer.deserialize_any(ValueVisitor::new(self.fpha_config)) + } +} impl<'de> Deserialize<'de> for IValue { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { - deserializer.deserialize_any(ValueVisitor) + deserializer.deserialize_any(ValueVisitor::new(None)) } } @@ -42,7 +82,7 @@ impl<'de> Deserialize<'de> for IArray { where D: Deserializer<'de>, { - deserializer.deserialize_seq(ArrayVisitor) + deserializer.deserialize_seq(ArrayVisitor { fpha_config: None }) } } @@ -51,11 +91,19 @@ impl<'de> Deserialize<'de> for IObject { where D: Deserializer<'de>, { - deserializer.deserialize_map(ObjectVisitor) + deserializer.deserialize_map(ObjectVisitor { fpha_config: None }) } } -struct ValueVisitor; +struct ValueVisitor { + fpha_config: Option, +} + +impl ValueVisitor { + fn new(fpha_config: Option) -> Self { + ValueVisitor { fpha_config } + } +} impl<'de> Visitor<'de> for ValueVisitor { type Value = IValue; @@ -104,7 +152,7 @@ impl<'de> Visitor<'de> for ValueVisitor { where D: Deserializer<'de>, { - Deserialize::deserialize(deserializer) + IValueDeserSeed::new(self.fpha_config).deserialize(deserializer) } #[inline] @@ -117,14 +165,22 @@ impl<'de> Visitor<'de> for ValueVisitor { where V: SeqAccess<'de>, { - ArrayVisitor.visit_seq(visitor).map(Into::into) + ArrayVisitor { + fpha_config: self.fpha_config, + } + .visit_seq(visitor) + .map(Into::into) } fn visit_map(self, visitor: V) -> Result where V: MapAccess<'de>, { - ObjectVisitor.visit_map(visitor).map(Into::into) + ObjectVisitor { + fpha_config: self.fpha_config, + } + .visit_map(visitor) + .map(Into::into) } } @@ -192,7 +248,9 @@ impl<'de> Visitor<'de> for StringVisitor { } } -struct ArrayVisitor; +struct ArrayVisitor { + fpha_config: Option, +} impl<'de> Visitor<'de> for ArrayVisitor { type Value = IArray; @@ -208,15 +266,20 @@ impl<'de> Visitor<'de> for ArrayVisitor { { let mut arr = IArray::with_capacity(visitor.size_hint().unwrap_or(0)) .map_err(|_| SError::custom("Failed to allocate array"))?; - while let Some(v) = visitor.next_element::()? { - arr.push(v) - .map_err(|_| SError::custom("Failed to push to array"))?; + while let Some(v) = visitor.next_element_seed(IValueDeserSeed::new(self.fpha_config))? { + match self.fpha_config { + Some(FPHAConfig { fpha_type }) => arr.push_with_fp_type(v, fpha_type), + None => arr.push(v).map_err(Into::into), + } + .map_err(|e| SError::custom(e.to_string()))?; } Ok(arr) } } -struct ObjectVisitor; +struct ObjectVisitor { + fpha_config: Option, +} impl<'de> Visitor<'de> for ObjectVisitor { type Value = IObject; @@ -230,7 +293,8 @@ impl<'de> Visitor<'de> for ObjectVisitor { V: MapAccess<'de>, { let mut obj = IObject::with_capacity(visitor.size_hint().unwrap_or(0)); - while let Some((k, v)) = visitor.next_entry::()? { + while let Some(k) = visitor.next_key::()? { + let v = visitor.next_value_seed(IValueDeserSeed::new(self.fpha_config))?; obj.insert(k, v); } Ok(obj) @@ -999,3 +1063,177 @@ where { T::deserialize(value) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::ArraySliceRef; + use serde::de::DeserializeSeed; + + #[test] + fn test_deserialize_with_f64_fp() { + let json = r#"[1.5, 2.5, 3.5]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F64))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_with_f32_fp() { + let json = r#"[1.5, 2.5, 3.5]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_with_f16_fp() { + let json = r#"[0.5, 1.0, 1.5]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_with_bf16_fp() { + let json = r#"[0.5, 1.0, 2.0]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::BF16))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::BF16(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_mixed_array_with_fp() { + let json = r#"[1, "string", 3.5]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::Heterogeneous(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_integer_array_with_fp() { + let json = r#"[1, 2, 3]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_f16_value_overflow_rejected() { + let json = r#"[0.5, 100000.0, 1.5]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } + + #[test] + fn test_deserialize_bf16_value_overflow_rejected() { + let json = r#"[1e39, 2e39]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::BF16))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } + + #[test] + fn test_deserialize_f32_value_overflow_rejected() { + let json = r#"[1e39, 2e39]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } + + #[test] + fn test_fpha_outer_array_of_objects_succeeds() { + // The classic embedding use-case: outer array holds objects, not numbers. + // Before the fix, push_with_fp_type would error on the object element. + let json = r#"[{"embedding": [1.0, 2.0]}, {"embedding": [3.0, 4.0]}]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert_eq!(arr.len(), 2); + assert!(matches!(arr.as_slice(), ArraySliceRef::Heterogeneous(_))); + + // Inner arrays should still be typed f16 + assert!(matches!( + arr[0] + .as_object() + .unwrap() + .get("embedding") + .unwrap() + .as_array() + .unwrap() + .as_slice(), + ArraySliceRef::F16(_) + )); + } + + #[test] + fn test_fpha_outer_array_of_nested_arrays_succeeds() { + // Outer array holds inner float arrays; outer must become heterogeneous. + let json = r#"[[1.0, 2.0], [3.0, 4.0]]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert_eq!(arr.len(), 2); + assert!(matches!(arr.as_slice(), ArraySliceRef::Heterogeneous(_))); + // Inner arrays should still be typed f16 + assert!(matches!( + arr[0].as_array().unwrap().as_slice(), + ArraySliceRef::F16(_) + )); + } + + #[test] + fn test_ser_deser_roundtrip_preserves_type() { + let json = r#"[0.2, 1.0, 1.2]"#; + + for fp_type in [FloatType::F16, FloatType::BF16, FloatType::F32] { + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type))); + let mut de = serde_json::Deserializer::from_str(json); + let original = seed.deserialize(&mut de).unwrap(); + + let serialized = serde_json::to_string(&original).unwrap(); + + let reload_seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type))); + let mut de = serde_json::Deserializer::from_str(&serialized); + let roundtripped = reload_seed.deserialize(&mut de).unwrap(); + + let arr = roundtripped.as_array().unwrap(); + assert_eq!(arr.len(), 3); + let roundtrip_tag = arr.as_slice().type_tag(); + assert_eq!( + roundtrip_tag, + fp_type.into(), + "roundtrip should preserve {fp_type}" + ); + } + } +} diff --git a/src/alloc.rs b/src/error.rs similarity index 56% rename from src/alloc.rs rename to src/error.rs index af0c87d..d5b9f36 100644 --- a/src/alloc.rs +++ b/src/error.rs @@ -2,6 +2,9 @@ use std::error::Error; use std::fmt; +use thiserror::Error; + +use crate::FloatType; /// Error type for fallible allocation /// This error is returned when an allocation fails. @@ -16,3 +19,14 @@ impl fmt::Display for AllocError { f.write_str("memory allocation failed") } } + +/// Error type for ijson +#[derive(Error, Debug)] +pub enum IJsonError { + /// Memory allocation failed + #[error("memory allocation failed")] + Alloc(#[from] AllocError), + /// Value out of range for the specified floating-point type + #[error("value out of range for {0}")] + OutOfRange(FloatType), +} diff --git a/src/lib.rs b/src/lib.rs index afcff7b..ecebc69 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,11 +36,11 @@ pub mod unsafe_string; #[cfg(not(feature = "thread_safe"))] pub use unsafe_string::IString; -pub mod alloc; +pub mod error; mod thin; mod value; -pub use array::IArray; +pub use array::{FloatType, IArray}; pub use number::INumber; pub use object::IObject; use std::alloc::Layout; @@ -49,9 +49,12 @@ pub use value::{ BoolMut, Destructured, DestructuredMut, DestructuredRef, IValue, ValueIndex, ValueType, }; +/// CBOR encode/decode for [`IValue`] using RFC 8746 typed array tags. +pub mod cbor; mod de; mod ser; -pub use de::from_value; +pub use cbor::{decode, decode_compressed, encode, encode_compressed, CborDecodeError}; +pub use de::{from_value, FPHAConfig, IValueDeserSeed}; pub use ser::to_value; /// Trait to implement defrag allocator