diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 066e58c..9783e1b 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -158,31 +158,6 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics generics } -fn with_recursive_count_guard(recursive_count: &syn::Ident, expr: TokenStream) -> TokenStream { - quote! { - let guard_against_recursion = u.is_empty(); - if guard_against_recursion { - #recursive_count.with(|count| { - if count.get() > 0 { - return Err(arbitrary::Error::NotEnoughData); - } - count.set(count.get() + 1); - Ok(()) - })?; - } - - let result = (|| { #expr })(); - - if guard_against_recursion { - #recursive_count.with(|count| { - count.set(count.get() - 1); - }); - } - - result - } -} - fn gen_arbitrary_method( input: &DeriveInput, lifetime: LifetimeParam, @@ -195,11 +170,18 @@ fn gen_arbitrary_method( recursive_count: &syn::Ident, ) -> Result { let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?; - let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) }); + let body = quote! { + arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| { + Ok(#ident #arbitrary) + }) + }; let arbitrary_take_rest = construct_take_rest(fields)?; - let take_rest_body = - with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) }); + let take_rest_body = quote! { + arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| { + Ok(#ident #arbitrary_take_rest) + }) + }; Ok(quote! { fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { @@ -243,7 +225,11 @@ fn gen_arbitrary_method( }; if needs_recursive_count { - with_recursive_count_guard(recursive_count, do_variants) + quote! { + arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| { + #do_variants + }) + } } else { do_variants } diff --git a/src/lib.rs b/src/lib.rs index 0375032..8b28c69 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -551,3 +551,58 @@ mod test { /// ``` #[cfg(all(doctest, feature = "derive"))] pub struct CompileFailTests; + +// Support for `#[derive(Arbitrary)]`. +#[doc(hidden)] +#[cfg(feature = "derive")] +pub mod details { + use super::*; + + // Hidden trait that papers over the difference between `&mut Unstructured` and + // `Unstructured` arguments so that `with_recursive_count` can be used for both + // `arbitrary` and `arbitrary_take_rest`. + pub trait IsEmpty { + fn is_empty(&self) -> bool; + } + + impl IsEmpty for Unstructured<'_> { + fn is_empty(&self) -> bool { + Unstructured::is_empty(self) + } + } + + impl IsEmpty for &mut Unstructured<'_> { + fn is_empty(&self) -> bool { + Unstructured::is_empty(self) + } + } + + // Calls `f` with a recursive count guard. + #[inline] + pub fn with_recursive_count( + u: U, + recursive_count: &'static std::thread::LocalKey>, + f: impl FnOnce(U) -> Result, + ) -> Result { + let guard_against_recursion = u.is_empty(); + if guard_against_recursion { + recursive_count.with(|count| { + if count.get() > 0 { + return Err(Error::NotEnoughData); + } + count.set(count.get() + 1); + Ok(()) + })?; + } + + let result = f(u); + + if guard_against_recursion { + recursive_count.with(|count| { + count.set(count.get() - 1); + }); + } + + result + } +}