1- use pyo3:: { types:: * , Bound } ;
1+ use pyo3:: exceptions:: PyKeyError ;
2+ use pyo3:: { intern, types:: * , Bound } ;
23use serde:: de:: { self , IntoDeserializer } ;
34use serde:: Deserialize ;
45
@@ -7,7 +8,16 @@ use crate::error::{ErrorImpl, PythonizeError, Result};
78#[ cfg( feature = "arbitrary_precision" ) ]
89const TOKEN : & str = "$serde_json::private::Number" ;
910
10- /// Attempt to convert a Python object to an instance of `T`
11+ /// Attempt to convert a Python object to an instance of `T`.
12+ ///
13+ /// Generally this only supports Python types that match `serde`'s object model well:
14+ /// - integers (including arbitrary precision integers if the `arbitrary_precision` feature is enabled)
15+ /// - floats
16+ /// - strings
17+ /// - bytes
18+ /// - `collections.abc.Sequence` instances (as serde sequences)
19+ /// - `collections.abc.Mapping` instances (as serde maps)
20+ /// - dataclasses (as serde maps)
1121pub fn depythonize < ' a , ' py , T > ( obj : & ' a Bound < ' py , PyAny > ) -> Result < T >
1222where
1323 T : Deserialize < ' a > ,
@@ -55,6 +65,14 @@ impl<'a, 'py> Depythonizer<'a, 'py> {
5565 PyMappingAccess :: new ( self . input . cast ( ) ?)
5666 }
5767
68+ fn dataclass_access ( & self ) -> Result < Option < PyDataclassAccess < ' py > > > {
69+ if let Some ( dc) = DataclassCandidate :: try_new ( self . input ) {
70+ Some ( PyDataclassAccess :: new ( dc) ) . transpose ( )
71+ } else {
72+ Ok ( None )
73+ }
74+ }
75+
5876 fn deserialize_any_int < ' de , V > ( & self , int : & Bound < ' _ , PyInt > , visitor : V ) -> Result < V :: Value >
5977 where
6078 V : de:: Visitor < ' de > ,
@@ -147,6 +165,8 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
147165 self . deserialize_tuple ( obj. len ( ) ?, visitor)
148166 } else if obj. cast :: < PyMapping > ( ) . is_ok ( ) {
149167 self . deserialize_map ( visitor)
168+ } else if let Some ( dc) = DataclassCandidate :: try_new ( obj) {
169+ visitor. visit_map ( PyDataclassAccess :: new ( dc) ?)
150170 } else {
151171 Err ( obj. get_type ( ) . qualname ( ) . map_or_else (
152172 |_| PythonizeError :: unsupported_type ( "unknown" ) ,
@@ -293,7 +313,11 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
293313 where
294314 V : de:: Visitor < ' de > ,
295315 {
296- visitor. visit_map ( self . dict_access ( ) ?)
316+ if let Some ( dc_access) = self . dataclass_access ( ) ? {
317+ visitor. visit_map ( dc_access)
318+ } else {
319+ visitor. visit_map ( self . dict_access ( ) ?)
320+ }
297321 }
298322
299323 fn deserialize_struct < V > (
@@ -470,6 +494,79 @@ impl<'de> de::MapAccess<'de> for PyMappingAccess<'_> {
470494 }
471495}
472496
497+ /// Intermediate structure used to denote that `obj` is a dataclass with `fields`.
498+ struct DataclassCandidate < ' a , ' py > {
499+ obj : & ' a Bound < ' py , PyAny > ,
500+ fields : Bound < ' py , PyAny > ,
501+ }
502+
503+ impl < ' a , ' py > DataclassCandidate < ' a , ' py > {
504+ fn try_new ( obj : & ' a Bound < ' py , PyAny > ) -> Option < Self > {
505+ let fields = obj
506+ . getattr_opt ( intern ! ( obj. py( ) , "__dataclass_fields__" ) )
507+ . ok ( )
508+ . flatten ( ) ?;
509+ Some ( Self { obj, fields } )
510+ }
511+ }
512+
513+ struct PyDataclassAccess < ' py > {
514+ fields : Bound < ' py , PyList > ,
515+ dict : Bound < ' py , PyDict > ,
516+ field_idx : usize ,
517+ val_idx : usize ,
518+ len : usize ,
519+ }
520+
521+ impl < ' py > PyDataclassAccess < ' py > {
522+ fn new ( dc : DataclassCandidate < ' _ , ' py > ) -> Result < Self > {
523+ let fields = dc. fields . cast :: < PyDict > ( ) ?. keys ( ) ;
524+ let dict = dc
525+ . obj
526+ . getattr ( intern ! ( dc. obj. py( ) , "__dict__" ) ) ?
527+ . cast_into ( ) ?;
528+ let len = fields. len ( ) ;
529+ Ok ( Self {
530+ fields,
531+ dict,
532+ field_idx : 0 ,
533+ val_idx : 0 ,
534+ len,
535+ } )
536+ }
537+ }
538+
539+ impl < ' de > de:: MapAccess < ' de > for PyDataclassAccess < ' _ > {
540+ type Error = PythonizeError ;
541+
542+ fn next_key_seed < K > ( & mut self , seed : K ) -> Result < Option < K :: Value > >
543+ where
544+ K : de:: DeserializeSeed < ' de > ,
545+ {
546+ if self . field_idx < self . len {
547+ let item = self . fields . get_item ( self . field_idx ) ?;
548+ self . field_idx += 1 ;
549+ seed. deserialize ( & mut Depythonizer :: from_object ( & item) )
550+ . map ( Some )
551+ } else {
552+ Ok ( None )
553+ }
554+ }
555+
556+ fn next_value_seed < V > ( & mut self , seed : V ) -> Result < V :: Value >
557+ where
558+ V : de:: DeserializeSeed < ' de > ,
559+ {
560+ let key = self . fields . get_item ( self . val_idx ) ?;
561+ let value = self
562+ . dict
563+ . get_item ( & key) ?
564+ . ok_or_else ( || PyKeyError :: new_err ( key. unbind ( ) ) ) ?;
565+ self . val_idx += 1 ;
566+ seed. deserialize ( & mut Depythonizer :: from_object ( & value) )
567+ }
568+ }
569+
473570struct PyEnumAccess < ' a , ' py > {
474571 de : Depythonizer < ' a , ' py > ,
475572 variant : Bound < ' py , PyString > ,
@@ -558,7 +655,7 @@ impl<'de> de::MapAccess<'de> for NumberDeserializer {
558655
559656#[ cfg( test) ]
560657mod test {
561- use std:: ffi:: CStr ;
658+ use std:: { collections :: HashMap , ffi:: CStr } ;
562659
563660 use super :: * ;
564661 use crate :: error:: ErrorImpl ;
@@ -572,14 +669,21 @@ mod test {
572669 {
573670 Python :: attach ( |py| {
574671 let obj = py. eval ( code, None , None ) . unwrap ( ) ;
575- let actual: T = depythonize ( & obj) . unwrap ( ) ;
576- assert_eq ! ( & actual, expected) ;
577-
578- let actual_json: JsonValue = depythonize ( & obj) . unwrap ( ) ;
579- assert_eq ! ( & actual_json, expected_json) ;
672+ test_de_with_obj ( & obj, expected, expected_json) ;
580673 } ) ;
581674 }
582675
676+ fn test_de_with_obj < T > ( obj : & Bound < ' _ , PyAny > , expected : & T , expected_json : & JsonValue )
677+ where
678+ T : de:: DeserializeOwned + PartialEq + std:: fmt:: Debug ,
679+ {
680+ let actual: T = depythonize ( obj) . unwrap ( ) ;
681+ assert_eq ! ( & actual, expected) ;
682+
683+ let actual_json: JsonValue = depythonize ( obj) . unwrap ( ) ;
684+ assert_eq ! ( & actual_json, expected_json) ;
685+ }
686+
583687 #[ test]
584688 fn test_empty_struct ( ) {
585689 #[ derive( Debug , Deserialize , PartialEq ) ]
@@ -930,4 +1034,101 @@ mod test {
9301034 ) ) ;
9311035 } ) ;
9321036 }
1037+
1038+ #[ test]
1039+ fn test_dataclass ( ) {
1040+ let code = c"\
1041+ from dataclasses import dataclass
1042+
1043+ @dataclass
1044+ class Point:
1045+ x: int
1046+ y: int
1047+
1048+ point = Point(1, 2)" ;
1049+
1050+ #[ derive( Debug , Deserialize , PartialEq ) ]
1051+ struct Point {
1052+ x : i32 ,
1053+ y : i32 ,
1054+ }
1055+
1056+ let expected = Point { x : 1 , y : 2 } ;
1057+ let expected_json = json ! ( { "x" : 1 , "y" : 2 } ) ;
1058+
1059+ Python :: attach ( |py| {
1060+ let locals = PyDict :: new ( py) ;
1061+ py. run ( code, None , Some ( & locals) ) . unwrap ( ) ;
1062+ let obj = locals. get_item ( "point" ) . unwrap ( ) . unwrap ( ) ;
1063+ test_de_with_obj ( & obj, & expected, & expected_json) ;
1064+
1065+ let map: HashMap < String , i32 > = depythonize ( & obj) . unwrap ( ) ;
1066+ assert_eq ! ( map. len( ) , 2 ) ;
1067+ assert_eq ! ( * map. get( "x" ) . unwrap( ) , 1 ) ;
1068+ assert_eq ! ( * map. get( "y" ) . unwrap( ) , 2 ) ;
1069+ } ) ;
1070+ }
1071+
1072+ #[ test]
1073+ fn test_dataclass_missing_field ( ) {
1074+ let code = c"\
1075+ from dataclasses import dataclass
1076+
1077+ @dataclass
1078+ class Point:
1079+ x: int
1080+ y: int
1081+
1082+ point = Point(1, 2)" ;
1083+
1084+ #[ derive( Debug , Deserialize , PartialEq ) ]
1085+ struct Point {
1086+ x : i32 ,
1087+ y : i32 ,
1088+ z : i32 ,
1089+ }
1090+
1091+ Python :: attach ( |py| {
1092+ let locals = PyDict :: new ( py) ;
1093+ py. run ( code, None , Some ( & locals) ) . unwrap ( ) ;
1094+ let obj = locals. get_item ( "point" ) . unwrap ( ) . unwrap ( ) ;
1095+ let err = depythonize :: < Point > ( & obj) . unwrap_err ( ) ;
1096+ assert ! ( matches!(
1097+ * err. inner,
1098+ ErrorImpl :: Message ( msg) if msg == "missing field `z`"
1099+ ) ) ;
1100+ } ) ;
1101+ }
1102+
1103+ #[ test]
1104+ fn test_dataclass_extra_field ( ) {
1105+ let code = c"\
1106+ from dataclasses import dataclass
1107+
1108+ @dataclass
1109+ class Point:
1110+ x: int
1111+ y: int
1112+ z: int
1113+
1114+ point = Point(1, 2, 3)" ;
1115+
1116+ #[ derive( Debug , Deserialize , PartialEq ) ]
1117+ #[ serde( deny_unknown_fields) ]
1118+ struct Point {
1119+ x : i32 ,
1120+ y : i32 ,
1121+ }
1122+
1123+ Python :: attach ( |py| {
1124+ let locals = PyDict :: new ( py) ;
1125+ py. run ( code, None , Some ( & locals) ) . unwrap ( ) ;
1126+ let obj = locals. get_item ( "point" ) . unwrap ( ) . unwrap ( ) ;
1127+ let err = depythonize :: < Point > ( & obj) . unwrap_err ( ) ;
1128+ assert ! ( matches!(
1129+ * err. inner,
1130+ ErrorImpl :: Message ( msg) if msg == "unknown field `z`, expected `x` or `y`"
1131+ ) ) ;
1132+ } ) ;
1133+ }
9331134}
0 commit comments