Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 178 additions & 2 deletions confidence-resolver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,35 @@ const TARGETING_KEY: &str = "targeting_key";
const NULL: Value = Value { kind: None };

const MAX_NO_OF_FLAGS_TO_BATCH_RESOLVE: usize = 1000;
const FLAG_NAME_PREFIX: &str = "flags/";
const MIN_FLAG_ID_LEN: usize = 2;
const MAX_FLAG_ID_LEN: usize = 63;

fn validate_flag_name(name: &str) -> Result<(), String> {
let id = name
.strip_prefix(FLAG_NAME_PREFIX)
.ok_or_else(|| format!("Invalid flag name '{}': must start with 'flags/'", name))?;

let len = id.len();
if !(MIN_FLAG_ID_LEN..=MAX_FLAG_ID_LEN).contains(&len) {
return Err(format!(
"Invalid flag name '{}': id must be {}-{} characters, got {}",
name, MIN_FLAG_ID_LEN, MAX_FLAG_ID_LEN, len
));
}

if !id
.bytes()
.all(|b| b.is_ascii_lowercase() || b.is_ascii_digit() || b == b'-')
{
return Err(format!(
"Invalid flag name '{}': id must contain only lowercase letters, digits, and hyphens",
name
));
}

Ok(())
}

/// Seeds the thread-local random number generator.
///
Expand Down Expand Up @@ -800,14 +829,20 @@ impl<'a, H: Host> AccountResolver<'a, H> {
};

let resolve_request = state.resolve_request.as_ref().or_fail()?;
let flag_names = resolve_request.flags.clone();
let requested_specific_flags = !resolve_request.flags.is_empty();
let flag_names: Vec<String> = resolve_request
.flags
.iter()
.filter(|name| validate_flag_name(name).is_ok())
.cloned()
.collect();
let flags_to_resolve = self
.state
.flags
.values()
.filter(|flag| flag.state() == flags_admin::flag::State::Active)
.filter(|flag| flag.clients.contains(&self.client.client_name))
.filter(|flag| flag_names.is_empty() || flag_names.contains(&flag.name))
.filter(|flag| !requested_specific_flags || flag_names.contains(&flag.name))
// Skip flags that were already resolved in a prior attempt
.filter(|flag| !state.resolved_flags.iter().any(|rf| rf.flag == flag.name))
.collect::<Vec<&Flag>>();
Expand Down Expand Up @@ -953,6 +988,9 @@ impl<'a, H: Host> AccountResolver<'a, H> {
// ensure that all flags are present before we start sending events
let mut assigned_flags: Vec<FlagToApply> = Vec::with_capacity(request.flags.len());
for applied_flag in &request.flags {
if validate_flag_name(&applied_flag.flag).is_err() {
continue;
}
let Some(assigned_flag) = assignments.get(&applied_flag.flag) else {
return Err("Flag in resolve token does not match flag in request".to_string());
};
Expand Down Expand Up @@ -2479,6 +2517,144 @@ mod tests {
);
}

#[test]
fn test_validate_flag_name() {
assert!(validate_flag_name("flags/my-flag").is_ok());
assert!(validate_flag_name("flags/ab").is_ok());
assert!(validate_flag_name("flags/a-b-c-123").is_ok());
let max_id = "a".repeat(63);
assert!(validate_flag_name(&format!("flags/{}", max_id)).is_ok());
}

#[test]
fn test_validate_flag_name_rejects_missing_prefix() {
let result = validate_flag_name("my-flag");
assert!(result.is_err());
assert!(result.unwrap_err().contains("must start with 'flags/'"));
}

#[test]
fn test_validate_flag_name_rejects_uppercase() {
let result = validate_flag_name("flags/My-Flag");
assert!(result.is_err());
assert!(result.unwrap_err().contains("lowercase"));
}

#[test]
fn test_validate_flag_name_rejects_spaces() {
let result = validate_flag_name("flags/my flag");
assert!(result.is_err());
assert!(result.unwrap_err().contains("lowercase"));
}

#[test]
fn test_validate_flag_name_rejects_too_short() {
let result = validate_flag_name("flags/a");
assert!(result.is_err());
assert!(result.unwrap_err().contains("2-63 characters"));
}

#[test]
fn test_validate_flag_name_rejects_too_long() {
let long_id = "a".repeat(64);
let result = validate_flag_name(&format!("flags/{}", long_id));
assert!(result.is_err());
assert!(result.unwrap_err().contains("2-63 characters"));
}

#[test]
fn test_validate_flag_name_rejects_double_prefix() {
let result = validate_flag_name("flags/flags/my-flag");
assert!(result.is_err());
}

#[test]
fn test_validate_flag_name_rejects_special_chars() {
assert!(validate_flag_name("flags/my_flag").is_err());
assert!(validate_flag_name("flags/my.flag").is_err());
assert!(validate_flag_name("flags/my/flag").is_err());
}

#[test]
fn test_resolve_skips_invalid_flag_name() {
let state = ResolverState::from_proto(
EXAMPLE_STATE.to_owned().try_into().unwrap(),
"confidence-demo-june",
None,
)
.unwrap();

let context_json = r#"{"visitor_id": "tutorial_visitor"}"#;
let resolver: AccountResolver<'_, L> = state
.get_resolver_with_json_context(SECRET, context_json, &ENCRYPTION_KEY)
.unwrap();

let resolve_flag_req = flags_resolver::ResolveFlagsRequest {
evaluation_context: Some(Struct::default()),
client_secret: SECRET.to_string(),
flags: vec!["flags/My Cool Flag".to_string()],
apply: false,
sdk: None,
};

let result = resolver
.resolve_flags_no_materialization(&resolve_flag_req)
.unwrap();
assert!(
result.resolved_flags.is_empty(),
"Invalid flag names should be silently skipped"
);
}

#[test]
fn test_apply_skips_invalid_flag_name() {
let state = ResolverState::from_proto(
EXAMPLE_STATE.to_owned().try_into().unwrap(),
"confidence-demo-june",
None,
)
.unwrap();

let context_json = r#"{"visitor_id": "tutorial_visitor"}"#;
let resolver: AccountResolver<'_, L> = state
.get_resolver_with_json_context(SECRET, context_json, &ENCRYPTION_KEY)
.unwrap();

let resolve_flag_req = flags_resolver::ResolveFlagsRequest {
evaluation_context: Some(Struct::default()),
client_secret: SECRET.to_string(),
flags: vec!["flags/tutorial-feature".to_string()],
apply: false,
sdk: None,
};

let response: ResolveFlagsResponse = resolver
.resolve_flags_no_materialization(&resolve_flag_req)
.unwrap();

let now = Timestamp {
seconds: 1704067200,
nanos: 0,
};

let apply_request = flags_resolver::ApplyFlagsRequest {
flags: vec![flags_resolver::AppliedFlag {
flag: "flags/INVALID-NAME".to_string(),
apply_time: Some(now.clone()),
}],
client_secret: SECRET.to_string(),
resolve_token: response.resolve_token,
send_time: Some(now),
sdk: None,
};

let result = resolver.apply_flags(&apply_request);
assert!(
result.is_ok(),
"Invalid flag names should be silently skipped"
);
}

#[test]
fn test_targeting_key_integer_supported() {
let state = ResolverState::from_proto(
Expand Down
Binary file not shown.