From d7c9b67d6f7f897aa48cb1184934337a8f84eb7f Mon Sep 17 00:00:00 2001 From: Pierre Fenoll Date: Tue, 17 Jun 2025 16:47:37 +0200 Subject: [PATCH] refacto: try and dedupe some code in derive crate Signed-off-by: Pierre Fenoll --- prost-validate-derive-core/src/any.rs | 3 +- prost-validate-derive-core/src/bool.rs | 2 +- prost-validate-derive-core/src/bytes.rs | 30 +++------ prost-validate-derive-core/src/duration.rs | 16 +---- prost-validate-derive-core/src/enum.rs | 5 +- prost-validate-derive-core/src/list.rs | 68 +++++++++------------ prost-validate-derive-core/src/map.rs | 42 +++++-------- prost-validate-derive-core/src/message.rs | 16 ++--- prost-validate-derive-core/src/number.rs | 16 +---- prost-validate-derive-core/src/string.rs | 40 +++--------- prost-validate-derive-core/src/timestamp.rs | 19 +----- 11 files changed, 76 insertions(+), 181 deletions(-) diff --git a/prost-validate-derive-core/src/any.rs b/prost-validate-derive-core/src/any.rs index d1af7da..9b76d45 100644 --- a/prost-validate-derive-core/src/any.rs +++ b/prost-validate-derive-core/src/any.rs @@ -15,10 +15,10 @@ pub struct AnyRules { impl ToValidationTokens for AnyRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; let rules = prost_validate_types::AnyRules::from(self.to_owned()); let r#in = rules.r#in.is_empty().not().then(|| { let v = rules.r#in; - let field = &ctx.name; quote! { let values = vec![#(#v),*]; if !values.contains(&#name.type_url.as_str()) { @@ -28,7 +28,6 @@ impl ToValidationTokens for AnyRules { }); let not_in = rules.not_in.is_empty().not().then(|| { let v = rules.not_in; - let field = &ctx.name; quote! { let values = vec![#(#v),*]; if values.contains(&#name.type_url.as_str()) { diff --git a/prost-validate-derive-core/src/bool.rs b/prost-validate-derive-core/src/bool.rs index 1138fd4..a1fa27f 100644 --- a/prost-validate-derive-core/src/bool.rs +++ b/prost-validate-derive-core/src/bool.rs @@ -10,8 +10,8 @@ pub struct BoolRules { impl ToValidationTokens for BoolRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; let r#const = self.r#const.map(|v| { - let field = &ctx.name; quote! { if *#name != #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::r#bool::Error::Const(#v))); diff --git a/prost-validate-derive-core/src/bytes.rs b/prost-validate-derive-core/src/bytes.rs index dca79e5..2186922 100644 --- a/prost-validate-derive-core/src/bytes.rs +++ b/prost-validate-derive-core/src/bytes.rs @@ -25,10 +25,10 @@ pub struct BytesRules { impl ToValidationTokens for BytesRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; let rules = prost_validate_types::BytesRules::from(self.to_owned()); let r#const = rules.r#const.map(|v| { let v = LitByteStr::new(v.as_slice(), Span::call_site()); - let field = &ctx.name; quote! { if !#name.iter().eq(#v.iter()) { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Const(#v.to_vec()))); @@ -37,7 +37,6 @@ impl ToValidationTokens for BytesRules { }); let len = rules.len.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.len() != #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Len(#v))); @@ -46,7 +45,6 @@ impl ToValidationTokens for BytesRules { }); let min_len = rules.min_len.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.len() < #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::MinLen(#v))); @@ -55,7 +53,6 @@ impl ToValidationTokens for BytesRules { }); let max_len = rules.max_len.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.len() > #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::MaxLen(#v))); @@ -63,22 +60,22 @@ impl ToValidationTokens for BytesRules { } }); let pattern = rules.pattern.map(|v| { - let field = &ctx.name; - if let Err(err ) = regex::bytes::Regex::new(&v) { - panic!("{field}: Invalid regex pattern: {}", err); + if let Err(err) = regex::bytes::Regex::new(&v) { + panic!("{field}: Invalid regex pattern: {err}"); } quote! { - let regex = ::regex::bytes::Regex::new(#v).map_err(|err| { - ::prost_validate::Error::new(#field, format!("Invalid regex pattern: {}", err)) - })?; - if !regex.is_match(#name.iter().as_slice()) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Pattern(#v.to_string()))); + match ::regex::bytes::Regex::new(#v) { + Err(e) => return Err(::prost_validate::Error::new(#field, format!("Invalid regex pattern: {e}"))), + Ok(regex) => { + if !regex.is_match(#name.iter().as_slice()) { + return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Pattern(#v.to_string()))); + } + } } } }); let prefix = rules.prefix.map(|v| { let v = LitByteStr::new(v.as_slice(), Span::call_site()); - let field = &ctx.name; quote! { if !#name.starts_with(#v) { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Prefix(#v.to_vec()))); @@ -87,7 +84,6 @@ impl ToValidationTokens for BytesRules { }); let suffix = rules.suffix.map(|v| { let v = LitByteStr::new(v.as_slice(), Span::call_site()); - let field = &ctx.name; quote! { if !#name.ends_with(#v) { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Suffix(#v.to_vec()))); @@ -96,7 +92,6 @@ impl ToValidationTokens for BytesRules { }); let contains = rules.contains.map(|v| { let v = LitByteStr::new(v.as_slice(), Span::call_site()); - let field = &ctx.name; quote! { if !::prost_validate::ValidateBytesExt::contains(&#name, #v.as_slice()) { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Contains(#v.to_vec()))); @@ -109,7 +104,6 @@ impl ToValidationTokens for BytesRules { .iter() .map(|v| LitByteStr::new(v.as_slice(), Span::call_site())) .collect::>(); - let field = &ctx.name; quote! { let values = [#(#v.to_vec()),*]; if !values.contains(&#name) { @@ -123,7 +117,6 @@ impl ToValidationTokens for BytesRules { .iter() .map(|v| LitByteStr::new(v.as_slice(), Span::call_site())) .collect::>(); - let field = &ctx.name; quote! { let values = [#(#v.to_vec()),*]; if values.contains(&#name) { @@ -133,7 +126,6 @@ impl ToValidationTokens for BytesRules { }); let well_known = rules.well_known.map(|v| match v { bytes_rules::WellKnown::Ip(true) => { - let field = &ctx.name; quote! { if #name.len() != 4 && #name.len() != 16 { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ip)); @@ -141,7 +133,6 @@ impl ToValidationTokens for BytesRules { } } bytes_rules::WellKnown::Ipv4(true) => { - let field = &ctx.name; quote! { if #name.len() != 4 { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ipv4)); @@ -149,7 +140,6 @@ impl ToValidationTokens for BytesRules { } } bytes_rules::WellKnown::Ipv6(true) => { - let field = &ctx.name; quote! { if #name.len() != 16 { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ipv6)); diff --git a/prost-validate-derive-core/src/duration.rs b/prost-validate-derive-core/src/duration.rs index 9f098f4..1b4a8ea 100644 --- a/prost-validate-derive-core/src/duration.rs +++ b/prost-validate-derive-core/src/duration.rs @@ -32,10 +32,10 @@ pub fn duration_to_tokens(name: &Ident, want: &Duration) -> (TokenStream, TokenS impl ToValidationTokens for DurationRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; let rules = prost_validate_types::DurationRules::from(self.clone()); let r#const = rules.r#const.map(|v| v.as_duration()).map(|v| { let (got, want) = duration_to_tokens(name, &v); - let field = &ctx.name; quote! { if #got != #want { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Const(#want))); @@ -46,7 +46,6 @@ impl ToValidationTokens for DurationRules { let gte_lte = if let Some(lt) = rules.lt.map(|v| v.as_duration()) { if let Some(gt) = rules.gt.map(|v| v.as_duration()) { if lt > gt { - let field = &ctx.name; let (val, lt) = duration_to_tokens(name, <); let (_, gt) = duration_to_tokens(name, >); quote! { @@ -55,7 +54,6 @@ impl ToValidationTokens for DurationRules { } } } else { - let field = &ctx.name; let (val, lt) = duration_to_tokens(name, <); let (_, gt) = duration_to_tokens(name, >); quote! { @@ -66,7 +64,6 @@ impl ToValidationTokens for DurationRules { } } else if let Some(gte) = rules.gte.map(|v| v.as_duration()) { if lt > gte { - let field = &ctx.name; let (val, lt) = duration_to_tokens(name, <); let (_, gte) = duration_to_tokens(name, >e); quote! { @@ -75,7 +72,6 @@ impl ToValidationTokens for DurationRules { } } } else { - let field = &ctx.name; let (val, lt) = duration_to_tokens(name, <); let (_, gte) = duration_to_tokens(name, >e); quote! { @@ -85,7 +81,6 @@ impl ToValidationTokens for DurationRules { } } } else { - let field = &ctx.name; let (val, lt) = duration_to_tokens(name, <); quote! { if #val >= #lt { @@ -96,7 +91,6 @@ impl ToValidationTokens for DurationRules { } else if let Some(lte) = rules.lte.map(|v| v.as_duration()) { if let Some(gt) = rules.gt.map(|v| v.as_duration()) { if lte > gt { - let field = &ctx.name; let (val, lte) = duration_to_tokens(name, <e); let (_, gt) = duration_to_tokens(name, >); quote! { @@ -105,7 +99,6 @@ impl ToValidationTokens for DurationRules { } } } else { - let field = &ctx.name; let (val, lte) = duration_to_tokens(name, <e); let (_, gt) = duration_to_tokens(name, >); quote! { @@ -116,7 +109,6 @@ impl ToValidationTokens for DurationRules { } } else if let Some(gte) = rules.gte.map(|v| v.as_duration()) { if lte > gte { - let field = &ctx.name; let (val, lte) = duration_to_tokens(name, <e); let (_, gte) = duration_to_tokens(name, >e); quote! { @@ -125,7 +117,6 @@ impl ToValidationTokens for DurationRules { } } } else { - let field = &ctx.name; let (val, lte) = duration_to_tokens(name, <e); let (_, gte) = duration_to_tokens(name, >e); quote! { @@ -135,7 +126,6 @@ impl ToValidationTokens for DurationRules { } } } else { - let field = &ctx.name; let (val, lte) = duration_to_tokens(name, <e); quote! { if #val > #lte { @@ -144,7 +134,6 @@ impl ToValidationTokens for DurationRules { } } } else if let Some(gt) = rules.gt.map(|v| v.as_duration()) { - let field = &ctx.name; let (val, gt) = duration_to_tokens(name, >); quote! { if #val <= #gt { @@ -152,7 +141,6 @@ impl ToValidationTokens for DurationRules { } } } else if let Some(gte) = rules.gte.map(|v| v.as_duration()) { - let field = &ctx.name; let (val, gte) = duration_to_tokens(name, >e); quote! { if #val < #gte { @@ -168,7 +156,6 @@ impl ToValidationTokens for DurationRules { .iter() .map(|v| v.as_duration()) .collect::>(); - let field = &ctx.name; let (val, _) = duration_to_tokens(name, &vals[0]); let vals = rules .r#in @@ -188,7 +175,6 @@ impl ToValidationTokens for DurationRules { .iter() .map(|v| v.as_duration()) .collect::>(); - let field = &ctx.name; let (val, _) = duration_to_tokens(name, &vals[0]); let vals = rules .not_in diff --git a/prost-validate-derive-core/src/enum.rs b/prost-validate-derive-core/src/enum.rs index 09033ae..3316675 100644 --- a/prost-validate-derive-core/src/enum.rs +++ b/prost-validate-derive-core/src/enum.rs @@ -17,9 +17,9 @@ pub struct EnumRules { impl ToValidationTokens for EnumRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; let rules = prost_validate_types::EnumRules::from(self.to_owned()); let r#const = rules.r#const.map(|v| { - let field = &ctx.name; quote! { if (*#name as i32) != #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::r#enum::Error::Const(#v))); @@ -45,7 +45,6 @@ impl ToValidationTokens for EnumRules { }; let enum_type: syn::Path = syn::parse_str(enumeration.as_str()) .expect("Invalid enum path"); - let field = &ctx.name; quote! { if !#enum_type::is_valid(*#name) { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::r#enum::Error::DefinedOnly)); @@ -54,7 +53,6 @@ impl ToValidationTokens for EnumRules { }); let r#in = rules.r#in.is_empty().not().then(|| { let v = rules.r#in.to_owned(); - let field = &ctx.name; quote! { let values = [#(#v),*]; if !values.contains(&#name) { @@ -64,7 +62,6 @@ impl ToValidationTokens for EnumRules { }); let not_in = rules.not_in.is_empty().not().then(|| { let v = rules.not_in.to_owned(); - let field = &ctx.name; quote! { let values = [#(#v),*]; if values.contains(#name) { diff --git a/prost-validate-derive-core/src/list.rs b/prost-validate-derive-core/src/list.rs index 4e0fee0..a2d4a97 100644 --- a/prost-validate-derive-core/src/list.rs +++ b/prost-validate-derive-core/src/list.rs @@ -16,9 +16,9 @@ pub struct RepeatedRules { impl ToValidationTokens for RepeatedRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; let min_items = self.min_items.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.len() < #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::list::Error::MinItems(#v))); @@ -27,7 +27,6 @@ impl ToValidationTokens for RepeatedRules { }); let max_items = self.max_items.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.len() > #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::list::Error::MaxItems(#v))); @@ -35,53 +34,44 @@ impl ToValidationTokens for RepeatedRules { } }); let unique = self.unique.is_true_and(|| { - let field = &ctx.name; quote! { if ::prost_validate::VecExt::unique(#name).len() != #name.len() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::list::Error::Unique)); } } }); - let item = self - .items - .as_ref() - .map(|v| { - let field = &ctx.name; - let item = format_ident!("item"); - let validation = v.to_validation_tokens(ctx, &item); - quote! { - for (i, item) in #name.iter().enumerate() { - || -> ::prost_validate::Result<_> { - #validation - Ok(()) - }().map_err(|e| ::prost_validate::Error::new(format!("{}[{}]", #field, i), ::prost_validate::errors::list::Error::Item(Box::new(e))))?; - } + let map = quote! { |e| ::prost_validate::Error::new(format!("{}[{i}]", #field), ::prost_validate::errors::list::Error::Item(Box::new(e))) }; + let items = self.items.as_ref().map(|v| { + let validation = v.to_validation_tokens(ctx, &format_ident!("item")); + quote! { + for (i, item) in #name.iter().enumerate() { + || -> ::prost_validate::Result<_> { + #validation + Ok(()) + }().map_err(#map)?; } - }); + } + }); let msg = (ctx.message && !ctx.wkt && !self - .items - .as_ref() - .map(|v| v.message.map(|v| v.skip).unwrap_or_default()) - .unwrap_or_default()) - .then(|| { - let field = &ctx.name; - if ctx.boxed { - quote! { - for (i, item) in #name.iter.enumerate() { - let item = item.as_ref(); - ::prost_validate::validate!(item).map_err(|e| ::prost_validate::Error::new(format!("{}[{}]", #field, i), ::prost_validate::errors::list::Error::Item(Box::new(e))))?; - } - } - } else { - quote! { - for (i, item) in #name.iter().enumerate() { - ::prost_validate::validate!(item).map_err(|e| ::prost_validate::Error::new(format!("{}[{}]", #field, i), ::prost_validate::errors::list::Error::Item(Box::new(e))))?; - } - } + .items + .as_ref() + .map(|v| v.message.map(|v| v.skip).unwrap_or_default()) + .unwrap_or_default()) + .then(|| { + let (name_iter, item_ref) = if ctx.boxed { + (quote! { #name.iter }, quote! { let item = item.as_ref(); }) + } else { + (quote! { #name.iter() }, quote! {}) + }; + quote! { + for (i, item) in #name_iter.enumerate() { + #item_ref + ::prost_validate::validate!(item).map_err(#map)?; } - }); + } + }); with_ignore_empty( name, self.ignore_empty, @@ -89,7 +79,7 @@ impl ToValidationTokens for RepeatedRules { #min_items #max_items #unique - #item + #items #msg }, ) diff --git a/prost-validate-derive-core/src/map.rs b/prost-validate-derive-core/src/map.rs index 926f227..7589dcb 100644 --- a/prost-validate-derive-core/src/map.rs +++ b/prost-validate-derive-core/src/map.rs @@ -18,10 +18,10 @@ pub struct MapRules { impl ToValidationTokens for MapRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; let rules = prost_validate_types::MapRules::from(self.to_owned()); let min_pairs = rules.min_pairs.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.len() < #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::map::Error::MinPairs(#v))); @@ -30,7 +30,6 @@ impl ToValidationTokens for MapRules { }); let max_pairs = rules.max_pairs.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.len() > #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::map::Error::MaxPairs(#v))); @@ -41,7 +40,6 @@ impl ToValidationTokens for MapRules { let keys = self.keys.as_ref().map(|rules| { let validate = rules.to_validation_tokens(ctx, &key); validate.is_empty().not().then(|| { - let field = &ctx.name; quote! { for #key in #name.keys() { || -> ::prost_validate::Result<_> { @@ -53,37 +51,31 @@ impl ToValidationTokens for MapRules { }) }); let value = format_ident!("value"); - let msg = (ctx.message - && !ctx.wkt - && !self - .values - .as_ref() - .map(|v| v.message.map(|v| v.skip).unwrap_or_default()) - .unwrap_or_default()).then(|| { - let validation = MessageRules::default().to_validation_tokens(ctx, &value); - let field = &ctx.name; + let map = quote! { |e| ::prost_validate::Error::new(format!("{}[{k}]", #field), ::prost_validate::errors::map::Error::Values(Box::new(e))) }; + let quote_values = |validation: TokenStream| { quote! { for (k, #value) in #name.iter() { || -> ::prost_validate::Result<_> { #validation Ok(()) - }().map_err(|e| ::prost_validate::Error::new(format!("{}[{}]", #field, k), ::prost_validate::errors::map::Error::Values(Box::new(e))))?; + }().map_err(#map)?; } } + }; + let msg = (ctx.message + && !ctx.wkt + && !self + .values + .as_ref() + .map(|v| v.message.map(|v| v.skip).unwrap_or_default()) + .unwrap_or_default()) + .then(|| { + let validation = MessageRules::default().to_validation_tokens(ctx, &value); + quote_values(validation) }); let values = self.values.as_ref().map(|rules| { let validate = rules.to_validation_tokens(ctx, &value); - validate.is_empty().not().then(|| { - let field = &ctx.name; - quote! { - for (k, #value) in #name.iter() { - || -> ::prost_validate::Result<_> { - #validate - Ok(()) - }().map_err(|e| ::prost_validate::Error::new(format!("{}[{}]", #field, k), ::prost_validate::errors::map::Error::Values(Box::new(e))))?; - } - } - }) + validate.is_empty().not().then(|| quote_values(validate)) }); with_ignore_empty( name, @@ -108,8 +100,6 @@ impl From for prost_validate_types::MapRules { keys: value.keys.map(|v| (*v).into()).map(Box::new), values: value.values.map(|v| (*v).into()).map(Box::new), ignore_empty: Some(value.ignore_empty), - // keys: None, - // values: None, } } } diff --git a/prost-validate-derive-core/src/message.rs b/prost-validate-derive-core/src/message.rs index 891ef5b..8655f9f 100644 --- a/prost-validate-derive-core/src/message.rs +++ b/prost-validate-derive-core/src/message.rs @@ -14,20 +14,16 @@ pub struct MessageRules { impl ToValidationTokens for MessageRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; if self.skip { return quote! {}; } let validate = self.skip.not().then(|| { - let field = &ctx.name; - if ctx.boxed { - quote! { - let #name = #name.as_ref(); - ::prost_validate::validate!(#name).map_err(|e| ::prost_validate::Error::new(#field, ::prost_validate::errors::message::Error::Message(Box::new(e))))?; - } - } else { - quote! { - ::prost_validate::validate!(#name).map_err(|e| ::prost_validate::Error::new(#field, ::prost_validate::errors::message::Error::Message(Box::new(e))))?; - } + let map = quote! { |e| ::prost_validate::Error::new(#field, ::prost_validate::errors::message::Error::Message(Box::new(e))) }; + let name_ref = ctx.boxed.then(|| quote! { let #name = #name.as_ref(); }); + quote! { + #name_ref + ::prost_validate::validate!(#name).map_err(#map)?; } }); validate.unwrap_or_default() diff --git a/prost-validate-derive-core/src/number.rs b/prost-validate-derive-core/src/number.rs index 7e35339..31eacf0 100644 --- a/prost-validate-derive-core/src/number.rs +++ b/prost-validate-derive-core/src/number.rs @@ -22,9 +22,9 @@ macro_rules! make_number_rules { impl ToValidationTokens for $name { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; let rules = prost_validate_types::$name::from(self.to_owned()); let r#const = rules.r#const.map(|v| { - let field = &ctx.name; quote! { if *#name != #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Const(#v))); @@ -35,14 +35,12 @@ macro_rules! make_number_rules { let lte_gte = if let Some(lt) = rules.lt { if let Some(gt) = rules.gt { if lt > gt { - let field = &ctx.name; quote! { if *#name <= #gt || *#name >= #lt { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(false, #gt, #lt, false))); } } } else { - let field = &ctx.name; quote! { if *#name >= #lt && *#name <= #gt { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(true, #lt, #gt, true))); @@ -51,14 +49,12 @@ macro_rules! make_number_rules { } } else if let Some(gte) = rules.gte { if lt > gte { - let field = &ctx.name; quote! { if *#name < #gte || *#name >= #lt { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(true, #gte, #lt, false))); } } } else { - let field = &ctx.name; quote! { if *#name >= #lt && *#name < #gte { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(true, #lt, #gte, false))); @@ -66,7 +62,6 @@ macro_rules! make_number_rules { } } } else { - let field = &ctx.name; quote! { if *#name >= #lt { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Lt(#lt))); @@ -76,14 +71,12 @@ macro_rules! make_number_rules { } else if let Some(lte) = rules.lte { if let Some(gt) = rules.gt { if lte > gt { - let field = &ctx.name; quote! { if *#name <= #gt || *#name > #lte { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(false, #gt, #lte, true))); } } } else { - let field = &ctx.name; quote! { if *#name > #lte && *#name <= #gt { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(false, #lte, #gt, true))); @@ -92,14 +85,12 @@ macro_rules! make_number_rules { } } else if let Some(gte) = rules.gte { if lte > gte { - let field = &ctx.name; quote! { if *#name < #gte || *#name > #lte { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(true, #gte, #lte, true))); } } } else { - let field = &ctx.name; quote! { if *#name > #lte && *#name < #gte { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(false, #lte, #gte, false))); @@ -107,7 +98,6 @@ macro_rules! make_number_rules { } } } else { - let field = &ctx.name; quote! { if *#name > #lte { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Lte(#lte))); @@ -115,14 +105,12 @@ macro_rules! make_number_rules { } } } else if let Some(gt) = rules.gt { - let field = &ctx.name; quote! { if *#name <= #gt { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Gt(#gt))); } } } else if let Some(gte) = rules.gte { - let field = &ctx.name; quote! { if *#name < #gte { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Gte(#gte))); @@ -133,7 +121,6 @@ macro_rules! make_number_rules { }; let r#in = rules.r#in.is_empty().not().then(|| { let v = rules.r#in.to_owned(); - let field = &ctx.name; quote! { let values = vec![#(#v),*]; if !values.contains(#name) { @@ -143,7 +130,6 @@ macro_rules! make_number_rules { }); let not_in = rules.not_in.is_empty().not().then(|| { let v = rules.not_in.to_owned(); - let field = &ctx.name; quote! { let values = vec![#(#v),*]; if values.contains(#name) { diff --git a/prost-validate-derive-core/src/string.rs b/prost-validate-derive-core/src/string.rs index 50b93ec..16c58ea 100644 --- a/prost-validate-derive-core/src/string.rs +++ b/prost-validate-derive-core/src/string.rs @@ -29,9 +29,9 @@ pub struct StringRules { impl ToValidationTokens for StringRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; let rules = prost_validate_types::StringRules::from(self.to_owned()); let r#const = rules.r#const.map(|v| { - let field = &ctx.name; quote! { if #name != #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Const(#v.to_string()))); @@ -40,7 +40,6 @@ impl ToValidationTokens for StringRules { }); let len = rules.len.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.chars().count() != #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Len(#v))); @@ -49,7 +48,6 @@ impl ToValidationTokens for StringRules { }); let min_len = rules.min_len.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.chars().count() < #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MinLen(#v))); @@ -58,7 +56,6 @@ impl ToValidationTokens for StringRules { }); let max_len = rules.max_len.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.chars().count() > #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MaxLen(#v))); @@ -67,7 +64,6 @@ impl ToValidationTokens for StringRules { }); let len_bytes = rules.len_bytes.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.len() != #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::LenBytes(#v))); @@ -76,7 +72,6 @@ impl ToValidationTokens for StringRules { }); let min_bytes = rules.min_bytes.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.len() < #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MinLenBytes(#v))); @@ -85,7 +80,6 @@ impl ToValidationTokens for StringRules { }); let max_bytes = rules.max_bytes.map(|v| { let v = v as usize; - let field = &ctx.name; quote! { if #name.len() > #v { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MaxLenBytes(#v))); @@ -93,21 +87,21 @@ impl ToValidationTokens for StringRules { } }); let pattern = rules.pattern.map(|v| { - let field = &ctx.name; if let Err(err) = regex::Regex::new(&v) { - panic!("{field}: Invalid regex pattern: {}", err); + panic!("{field}: Invalid regex pattern: {err}"); } quote! { - let regex = ::regex::Regex::new(#v).map_err(|e| { - ::prost_validate::Error::new(#field, format!("invalid regex pattern: {}", e)) - })?; - if !regex.is_match(#name.as_str()) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Pattern(#v.to_string()))); + match ::regex::Regex::new(#v) { + Err(e) => return Err(::prost_validate::Error::new(#field, format!("Invalid regex pattern: {e}"))), + Ok(regex) => { + if !regex.is_match(#name.as_str()) { + return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Pattern(#v.to_string()))); + } + } } } }); let prefix = rules.prefix.map(|v| { - let field = &ctx.name; quote! { if !#name.starts_with(#v) { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Prefix(#v.to_string()))); @@ -115,7 +109,6 @@ impl ToValidationTokens for StringRules { } }); let suffix = rules.suffix.map(|v| { - let field = &ctx.name; quote! { if !#name.ends_with(#v) { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Suffix(#v.to_string()))); @@ -123,7 +116,6 @@ impl ToValidationTokens for StringRules { } }); let contains = rules.contains.map(|v| { - let field = &ctx.name; quote! { if !#name.contains(#v) { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Contains(#v.to_string()))); @@ -131,7 +123,6 @@ impl ToValidationTokens for StringRules { } }); let not_contains = rules.not_contains.map(|v| { - let field = &ctx.name; quote! { if #name.contains(#v) { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::NotContains(#v.to_string()))); @@ -140,7 +131,6 @@ impl ToValidationTokens for StringRules { }); let r#in = rules.r#in.is_empty().not().then(|| { let v = rules.r#in; - let field = &ctx.name; quote! { let values = [#(#v),*]; if !values.contains(&#name.as_str()) { @@ -150,7 +140,6 @@ impl ToValidationTokens for StringRules { }); let not_in = rules.not_in.is_empty().not().then(|| { let v = rules.not_in; - let field = &ctx.name; quote! { let values = [#(#v),*]; if values.contains(&#name.as_str()) { @@ -161,7 +150,6 @@ impl ToValidationTokens for StringRules { let well_known = rules.well_known.map(|v| { match v { string_rules::WellKnown::Email(true) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_email(&#name).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Email)); @@ -169,7 +157,6 @@ impl ToValidationTokens for StringRules { } } string_rules::WellKnown::Hostname(true) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_hostname(&#name).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Hostname)); @@ -177,7 +164,6 @@ impl ToValidationTokens for StringRules { } } string_rules::WellKnown::Ip(true) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_ip(&#name).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Ip)); @@ -185,7 +171,6 @@ impl ToValidationTokens for StringRules { } } string_rules::WellKnown::Ipv4(true) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_ipv4(&#name).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Ipv4)); @@ -193,7 +178,6 @@ impl ToValidationTokens for StringRules { } } string_rules::WellKnown::Ipv6(true) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_ipv6(&#name).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Ipv6)); @@ -201,7 +185,6 @@ impl ToValidationTokens for StringRules { } } string_rules::WellKnown::Uri(true) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_uri(&#name).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Uri)); @@ -209,7 +192,6 @@ impl ToValidationTokens for StringRules { } } string_rules::WellKnown::UriRef(true) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_uri_ref(&#name).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::UriRef)); @@ -217,7 +199,6 @@ impl ToValidationTokens for StringRules { } } string_rules::WellKnown::Address(true) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_address(&#name).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Address)); @@ -225,7 +206,6 @@ impl ToValidationTokens for StringRules { } } string_rules::WellKnown::Uuid(true) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_uuid(&#name).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Uuid)); @@ -236,7 +216,6 @@ impl ToValidationTokens for StringRules { let strict = rules.strict.unwrap_or(true); match prost_validate_types::KnownRegex::try_from(wk) { Ok(prost_validate_types::KnownRegex::HttpHeaderName) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_header_name(&#name, #strict).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::HttpHeaderName)); @@ -244,7 +223,6 @@ impl ToValidationTokens for StringRules { } } Ok(prost_validate_types::KnownRegex::HttpHeaderValue) => { - let field = &ctx.name; quote! { if ::prost_validate::ValidateStringExt::validate_header_value(&#name, #strict).is_err() { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::HttpHeaderValue)); diff --git a/prost-validate-derive-core/src/timestamp.rs b/prost-validate-derive-core/src/timestamp.rs index cb92db7..f3d9e54 100644 --- a/prost-validate-derive-core/src/timestamp.rs +++ b/prost-validate-derive-core/src/timestamp.rs @@ -48,10 +48,10 @@ impl From for prost_validate_types::TimestampRules { impl ToValidationTokens for TimestampRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + let field = &ctx.name; let rules = prost_validate_types::TimestampRules::from(self.clone()); let r#const = rules.r#const.map(|v| v.as_datetime()).map(|v| { let (got, want) = datetime_to_tokens(name, &v); - let field = &ctx.name; quote! { if #got != #want { return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Const(#want))); @@ -62,7 +62,6 @@ impl ToValidationTokens for TimestampRules { let gte_lte = if let Some(lt) = rules.lt.map(|v| v.as_datetime()) { if let Some(gt) = rules.gt.map(|v| v.as_datetime()) { if lt > gt { - let field = &ctx.name; let (val, lt) = datetime_to_tokens(name, <); let (_, gt) = datetime_to_tokens(name, >); quote! { @@ -71,7 +70,6 @@ impl ToValidationTokens for TimestampRules { } } } else { - let field = &ctx.name; let (val, lt) = datetime_to_tokens(name, <); let (_, gt) = datetime_to_tokens(name, >); quote! { @@ -82,7 +80,6 @@ impl ToValidationTokens for TimestampRules { } } else if let Some(gte) = rules.gte.map(|v| v.as_datetime()) { if lt > gte { - let field = &ctx.name; let (val, lt) = datetime_to_tokens(name, <); let (_, gte) = datetime_to_tokens(name, >e); quote! { @@ -91,7 +88,6 @@ impl ToValidationTokens for TimestampRules { } } } else { - let field = &ctx.name; let (val, lt) = datetime_to_tokens(name, <); let (_, gte) = datetime_to_tokens(name, >e); quote! { @@ -101,7 +97,6 @@ impl ToValidationTokens for TimestampRules { } } } else { - let field = &ctx.name; let (val, lt) = datetime_to_tokens(name, <); quote! { if #val >= #lt { @@ -112,7 +107,6 @@ impl ToValidationTokens for TimestampRules { } else if let Some(lte) = rules.lte.map(|v| v.as_datetime()) { if let Some(gt) = rules.gt.map(|v| v.as_datetime()) { if lte > gt { - let field = &ctx.name; let (val, lte) = datetime_to_tokens(name, <e); let (_, gt) = datetime_to_tokens(name, >); quote! { @@ -121,7 +115,6 @@ impl ToValidationTokens for TimestampRules { } } } else { - let field = &ctx.name; let (val, lte) = datetime_to_tokens(name, <e); let (_, gt) = datetime_to_tokens(name, >); quote! { @@ -132,7 +125,6 @@ impl ToValidationTokens for TimestampRules { } } else if let Some(gte) = rules.gte.map(|v| v.as_datetime()) { if lte > gte { - let field = &ctx.name; let (val, lte) = datetime_to_tokens(name, <e); let (_, gte) = datetime_to_tokens(name, >e); quote! { @@ -141,7 +133,6 @@ impl ToValidationTokens for TimestampRules { } } } else { - let field = &ctx.name; let (val, lte) = datetime_to_tokens(name, <e); let (_, gte) = datetime_to_tokens(name, >e); quote! { @@ -151,7 +142,6 @@ impl ToValidationTokens for TimestampRules { } } } else { - let field = &ctx.name; let (val, lte) = datetime_to_tokens(name, <e); quote! { if #val > #lte { @@ -160,7 +150,6 @@ impl ToValidationTokens for TimestampRules { } } } else if let Some(gt) = rules.gt.map(|v| v.as_datetime()) { - let field = &ctx.name; let (val, gt) = datetime_to_tokens(name, >); quote! { if #val <= #gt { @@ -168,7 +157,6 @@ impl ToValidationTokens for TimestampRules { } } } else if let Some(gte) = rules.gte.map(|v| v.as_datetime()) { - let field = &ctx.name; let (val, gte) = datetime_to_tokens(name, >e); quote! { if #val < #gte { @@ -179,7 +167,6 @@ impl ToValidationTokens for TimestampRules { if let Some(ref within) = rules.within.map(|v| v.as_duration()) { let (val, _) = datetime_to_tokens(name, &OffsetDateTime::now_utc()); let (_, d) = duration_to_tokens(name, within); - let field = &ctx.name; quote! { let now = ::time::OffsetDateTime::now_utc(); let d = #d; @@ -189,7 +176,6 @@ impl ToValidationTokens for TimestampRules { } } else { let (val, _) = datetime_to_tokens(name, &OffsetDateTime::now_utc()); - let field = &ctx.name; quote! { let now = ::time::OffsetDateTime::now_utc(); if #val >= now { @@ -201,7 +187,6 @@ impl ToValidationTokens for TimestampRules { if let Some(ref within) = rules.within.map(|v| v.as_duration()) { let (val, _) = datetime_to_tokens(name, &OffsetDateTime::now_utc()); let (_, d) = duration_to_tokens(name, within); - let field = &ctx.name; quote! { let now = ::time::OffsetDateTime::now_utc(); let d = #d; @@ -211,7 +196,6 @@ impl ToValidationTokens for TimestampRules { } } else { let (val, _) = datetime_to_tokens(name, &OffsetDateTime::now_utc()); - let field = &ctx.name; quote! { let now = ::time::OffsetDateTime::now_utc(); if #val <= now { @@ -222,7 +206,6 @@ impl ToValidationTokens for TimestampRules { } else if let Some(ref within) = rules.within.map(|v| v.as_duration()) { let (val, _) = datetime_to_tokens(name, &OffsetDateTime::now_utc()); let (_, d) = duration_to_tokens(name, within); - let field = &ctx.name; quote! { let now = ::time::OffsetDateTime::now_utc(); let d = #d;