Skip to content

Commit 2b141ea

Browse files
committed
Support deserializing Python dataclass into structs / mappings
1 parent a444b75 commit 2b141ea

File tree

1 file changed

+210
-9
lines changed

1 file changed

+210
-9
lines changed

src/de.rs

Lines changed: 210 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use pyo3::{types::*, Bound};
1+
use pyo3::exceptions::PyKeyError;
2+
use pyo3::{intern, types::*, Bound};
23
use serde::de::{self, IntoDeserializer};
34
use serde::Deserialize;
45

@@ -7,7 +8,16 @@ use crate::error::{ErrorImpl, PythonizeError, Result};
78
#[cfg(feature = "arbitrary_precision")]
89
const TOKEN: &str = "$serde_json::private::Number";
910

10-
/// Attempt to convert a Python object to an instance of `T`
11+
/// Attempt to convert a Python object to an instance of `T`.
12+
///
13+
/// Generally this only supports Python types that match `serde`'s object model well:
14+
/// - integers (including arbitrary precision integers if the `arbitrary_precision` feature is enabled)
15+
/// - floats
16+
/// - strings
17+
/// - bytes
18+
/// - `collections.abc.Sequence` instances (as serde sequences)
19+
/// - `collections.abc.Mapping` instances (as serde maps)
20+
/// - dataclasses (as serde maps)
1121
pub fn depythonize<'a, 'py, T>(obj: &'a Bound<'py, PyAny>) -> Result<T>
1222
where
1323
T: Deserialize<'a>,
@@ -55,6 +65,14 @@ impl<'a, 'py> Depythonizer<'a, 'py> {
5565
PyMappingAccess::new(self.input.cast()?)
5666
}
5767

68+
fn dataclass_access(&self) -> Result<Option<PyDataclassAccess<'py>>> {
69+
if let Some(dc) = DataclassCandidate::try_new(self.input) {
70+
Some(PyDataclassAccess::new(dc)).transpose()
71+
} else {
72+
Ok(None)
73+
}
74+
}
75+
5876
fn deserialize_any_int<'de, V>(&self, int: &Bound<'_, PyInt>, visitor: V) -> Result<V::Value>
5977
where
6078
V: de::Visitor<'de>,
@@ -147,6 +165,8 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
147165
self.deserialize_tuple(obj.len()?, visitor)
148166
} else if obj.cast::<PyMapping>().is_ok() {
149167
self.deserialize_map(visitor)
168+
} else if let Some(dc) = DataclassCandidate::try_new(obj) {
169+
visitor.visit_map(PyDataclassAccess::new(dc)?)
150170
} else {
151171
Err(obj.get_type().qualname().map_or_else(
152172
|_| PythonizeError::unsupported_type("unknown"),
@@ -293,7 +313,11 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
293313
where
294314
V: de::Visitor<'de>,
295315
{
296-
visitor.visit_map(self.dict_access()?)
316+
if let Some(dc_access) = self.dataclass_access()? {
317+
visitor.visit_map(dc_access)
318+
} else {
319+
visitor.visit_map(self.dict_access()?)
320+
}
297321
}
298322

299323
fn deserialize_struct<V>(
@@ -470,6 +494,79 @@ impl<'de> de::MapAccess<'de> for PyMappingAccess<'_> {
470494
}
471495
}
472496

497+
/// Intermediate structure used to denote that `obj` is a dataclass with `fields`.
498+
struct DataclassCandidate<'a, 'py> {
499+
obj: &'a Bound<'py, PyAny>,
500+
fields: Bound<'py, PyAny>,
501+
}
502+
503+
impl<'a, 'py> DataclassCandidate<'a, 'py> {
504+
fn try_new(obj: &'a Bound<'py, PyAny>) -> Option<Self> {
505+
let fields = obj
506+
.getattr_opt(intern!(obj.py(), "__dataclass_fields__"))
507+
.ok()
508+
.flatten()?;
509+
Some(Self { obj, fields })
510+
}
511+
}
512+
513+
struct PyDataclassAccess<'py> {
514+
fields: Bound<'py, PyList>,
515+
dict: Bound<'py, PyDict>,
516+
field_idx: usize,
517+
val_idx: usize,
518+
len: usize,
519+
}
520+
521+
impl<'py> PyDataclassAccess<'py> {
522+
fn new(dc: DataclassCandidate<'_, 'py>) -> Result<Self> {
523+
let fields = dc.fields.cast::<PyDict>()?.keys();
524+
let dict = dc
525+
.obj
526+
.getattr(intern!(dc.obj.py(), "__dict__"))?
527+
.cast_into()?;
528+
let len = fields.len();
529+
Ok(Self {
530+
fields,
531+
dict,
532+
field_idx: 0,
533+
val_idx: 0,
534+
len,
535+
})
536+
}
537+
}
538+
539+
impl<'de> de::MapAccess<'de> for PyDataclassAccess<'_> {
540+
type Error = PythonizeError;
541+
542+
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
543+
where
544+
K: de::DeserializeSeed<'de>,
545+
{
546+
if self.field_idx < self.len {
547+
let item = self.fields.get_item(self.field_idx)?;
548+
self.field_idx += 1;
549+
seed.deserialize(&mut Depythonizer::from_object(&item))
550+
.map(Some)
551+
} else {
552+
Ok(None)
553+
}
554+
}
555+
556+
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
557+
where
558+
V: de::DeserializeSeed<'de>,
559+
{
560+
let key = self.fields.get_item(self.val_idx)?;
561+
let value = self
562+
.dict
563+
.get_item(&key)?
564+
.ok_or_else(|| PyKeyError::new_err(key.unbind()))?;
565+
self.val_idx += 1;
566+
seed.deserialize(&mut Depythonizer::from_object(&value))
567+
}
568+
}
569+
473570
struct PyEnumAccess<'a, 'py> {
474571
de: Depythonizer<'a, 'py>,
475572
variant: Bound<'py, PyString>,
@@ -558,7 +655,7 @@ impl<'de> de::MapAccess<'de> for NumberDeserializer {
558655

559656
#[cfg(test)]
560657
mod test {
561-
use std::ffi::CStr;
658+
use std::{collections::HashMap, ffi::CStr};
562659

563660
use super::*;
564661
use crate::error::ErrorImpl;
@@ -572,14 +669,21 @@ mod test {
572669
{
573670
Python::attach(|py| {
574671
let obj = py.eval(code, None, None).unwrap();
575-
let actual: T = depythonize(&obj).unwrap();
576-
assert_eq!(&actual, expected);
577-
578-
let actual_json: JsonValue = depythonize(&obj).unwrap();
579-
assert_eq!(&actual_json, expected_json);
672+
test_de_with_obj(&obj, expected, expected_json);
580673
});
581674
}
582675

676+
fn test_de_with_obj<T>(obj: &Bound<'_, PyAny>, expected: &T, expected_json: &JsonValue)
677+
where
678+
T: de::DeserializeOwned + PartialEq + std::fmt::Debug,
679+
{
680+
let actual: T = depythonize(obj).unwrap();
681+
assert_eq!(&actual, expected);
682+
683+
let actual_json: JsonValue = depythonize(obj).unwrap();
684+
assert_eq!(&actual_json, expected_json);
685+
}
686+
583687
#[test]
584688
fn test_empty_struct() {
585689
#[derive(Debug, Deserialize, PartialEq)]
@@ -930,4 +1034,101 @@ mod test {
9301034
));
9311035
});
9321036
}
1037+
1038+
#[test]
1039+
fn test_dataclass() {
1040+
let code = c"\
1041+
from dataclasses import dataclass
1042+
1043+
@dataclass
1044+
class Point:
1045+
x: int
1046+
y: int
1047+
1048+
point = Point(1, 2)";
1049+
1050+
#[derive(Debug, Deserialize, PartialEq)]
1051+
struct Point {
1052+
x: i32,
1053+
y: i32,
1054+
}
1055+
1056+
let expected = Point { x: 1, y: 2 };
1057+
let expected_json = json!({"x": 1, "y": 2});
1058+
1059+
Python::attach(|py| {
1060+
let locals = PyDict::new(py);
1061+
py.run(code, None, Some(&locals)).unwrap();
1062+
let obj = locals.get_item("point").unwrap().unwrap();
1063+
test_de_with_obj(&obj, &expected, &expected_json);
1064+
1065+
let map: HashMap<String, i32> = depythonize(&obj).unwrap();
1066+
assert_eq!(map.len(), 2);
1067+
assert_eq!(*map.get("x").unwrap(), 1);
1068+
assert_eq!(*map.get("y").unwrap(), 2);
1069+
});
1070+
}
1071+
1072+
#[test]
1073+
fn test_dataclass_missing_field() {
1074+
let code = c"\
1075+
from dataclasses import dataclass
1076+
1077+
@dataclass
1078+
class Point:
1079+
x: int
1080+
y: int
1081+
1082+
point = Point(1, 2)";
1083+
1084+
#[derive(Debug, Deserialize, PartialEq)]
1085+
struct Point {
1086+
x: i32,
1087+
y: i32,
1088+
z: i32,
1089+
}
1090+
1091+
Python::attach(|py| {
1092+
let locals = PyDict::new(py);
1093+
py.run(code, None, Some(&locals)).unwrap();
1094+
let obj = locals.get_item("point").unwrap().unwrap();
1095+
let err = depythonize::<Point>(&obj).unwrap_err();
1096+
assert!(matches!(
1097+
*err.inner,
1098+
ErrorImpl::Message(msg) if msg == "missing field `z`"
1099+
));
1100+
});
1101+
}
1102+
1103+
#[test]
1104+
fn test_dataclass_extra_field() {
1105+
let code = c"\
1106+
from dataclasses import dataclass
1107+
1108+
@dataclass
1109+
class Point:
1110+
x: int
1111+
y: int
1112+
z: int
1113+
1114+
point = Point(1, 2, 3)";
1115+
1116+
#[derive(Debug, Deserialize, PartialEq)]
1117+
#[serde(deny_unknown_fields)]
1118+
struct Point {
1119+
x: i32,
1120+
y: i32,
1121+
}
1122+
1123+
Python::attach(|py| {
1124+
let locals = PyDict::new(py);
1125+
py.run(code, None, Some(&locals)).unwrap();
1126+
let obj = locals.get_item("point").unwrap().unwrap();
1127+
let err = depythonize::<Point>(&obj).unwrap_err();
1128+
assert!(matches!(
1129+
*err.inner,
1130+
ErrorImpl::Message(msg) if msg == "unknown field `z`, expected `x` or `y`"
1131+
));
1132+
});
1133+
}
9331134
}

0 commit comments

Comments
 (0)