From 40fe82d4df9ca73250eb795656ce0a25fcae40ff Mon Sep 17 00:00:00 2001 From: Nicholas Nethercote Date: Wed, 13 Aug 2025 13:22:39 +1000 Subject: [PATCH] Shrink derived code size with `with_recursive_count`. The amount of code generated by `derive(Arbitrary)` is large, mostly due to the recursion count guard. This commit factors out that code with a new hidden function `with_recursive_count`. Because `with_recursive_count` is mark with `inline`, the generated code should end up being much the same. But compile times are reduced by 30-40% because rustc's frontend has much less code to chew through. --- derive/src/lib.rs | 44 +++++++++++++------------------------ src/lib.rs | 55 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 29 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 066e58c..9783e1b 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -158,31 +158,6 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics generics } -fn with_recursive_count_guard(recursive_count: &syn::Ident, expr: TokenStream) -> TokenStream { - quote! { - let guard_against_recursion = u.is_empty(); - if guard_against_recursion { - #recursive_count.with(|count| { - if count.get() > 0 { - return Err(arbitrary::Error::NotEnoughData); - } - count.set(count.get() + 1); - Ok(()) - })?; - } - - let result = (|| { #expr })(); - - if guard_against_recursion { - #recursive_count.with(|count| { - count.set(count.get() - 1); - }); - } - - result - } -} - fn gen_arbitrary_method( input: &DeriveInput, lifetime: LifetimeParam, @@ -195,11 +170,18 @@ fn gen_arbitrary_method( recursive_count: &syn::Ident, ) -> Result { let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?; - let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) }); + let body = quote! { + arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| { + Ok(#ident #arbitrary) + }) + }; let arbitrary_take_rest = construct_take_rest(fields)?; - let take_rest_body = - with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) }); + let take_rest_body = quote! { + arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| { + Ok(#ident #arbitrary_take_rest) + }) + }; Ok(quote! { fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { @@ -243,7 +225,11 @@ fn gen_arbitrary_method( }; if needs_recursive_count { - with_recursive_count_guard(recursive_count, do_variants) + quote! { + arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| { + #do_variants + }) + } } else { do_variants } diff --git a/src/lib.rs b/src/lib.rs index 0375032..8b28c69 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -551,3 +551,58 @@ mod test { /// ``` #[cfg(all(doctest, feature = "derive"))] pub struct CompileFailTests; + +// Support for `#[derive(Arbitrary)]`. +#[doc(hidden)] +#[cfg(feature = "derive")] +pub mod details { + use super::*; + + // Hidden trait that papers over the difference between `&mut Unstructured` and + // `Unstructured` arguments so that `with_recursive_count` can be used for both + // `arbitrary` and `arbitrary_take_rest`. + pub trait IsEmpty { + fn is_empty(&self) -> bool; + } + + impl IsEmpty for Unstructured<'_> { + fn is_empty(&self) -> bool { + Unstructured::is_empty(self) + } + } + + impl IsEmpty for &mut Unstructured<'_> { + fn is_empty(&self) -> bool { + Unstructured::is_empty(self) + } + } + + // Calls `f` with a recursive count guard. + #[inline] + pub fn with_recursive_count( + u: U, + recursive_count: &'static std::thread::LocalKey>, + f: impl FnOnce(U) -> Result, + ) -> Result { + let guard_against_recursion = u.is_empty(); + if guard_against_recursion { + recursive_count.with(|count| { + if count.get() > 0 { + return Err(Error::NotEnoughData); + } + count.set(count.get() + 1); + Ok(()) + })?; + } + + let result = f(u); + + if guard_against_recursion { + recursive_count.with(|count| { + count.set(count.get() - 1); + }); + } + + result + } +}