diff --git a/src/errors.rs b/src/errors.rs index 02334a2..22dc7e7 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,11 +1,7 @@ -#[cfg(any(feature = "alloc", feature = "std"))] -use alloc::string::String; use core::fmt; #[cfg(feature = "std")] use std::error; -#[cfg(feature = "std")] -use std::io; /// Library generic result type. pub type BcryptResult = Result; @@ -14,22 +10,19 @@ pub type BcryptResult = Result; /// All the errors we can encounter while hashing/verifying /// passwords pub enum BcryptError { - #[cfg(feature = "std")] - Io(io::Error), + /// Raised when the cost value is outside of the allowed 4-31 range. + /// + /// Cost is provided as an argument to hashing functions, and extracted from the hash in + /// verification functions. CostNotAllowed(u32), + /// Raised when verifying against an incorrectly formatted hash. #[cfg(any(feature = "alloc", feature = "std"))] - InvalidCost(String), - #[cfg(any(feature = "alloc", feature = "std"))] - InvalidPrefix(String), - #[cfg(any(feature = "alloc", feature = "std"))] - InvalidHash(String), - InvalidSaltLen(usize), - InvalidBase64(base64::DecodeError), + InvalidHash(&'static str), + /// Raised when an error occurs when generating a salt value. #[cfg(any(feature = "alloc", feature = "std"))] Rand(getrandom::Error), - /// Return this error if the input contains more than 72 bytes. This variant contains the - /// length of the input in bytes. - /// Only returned when calling `non_truncating_*` functions + /// Raised when the input to a `non_truncating_*` function contains more than 72 bytes. + /// This variant contains the length of the input in bytes. Truncation(usize), } @@ -43,19 +36,12 @@ macro_rules! impl_from_error { }; } -impl_from_error!(base64::DecodeError, BcryptError::InvalidBase64); -#[cfg(feature = "std")] -impl_from_error!(io::Error, BcryptError::Io); #[cfg(any(feature = "alloc", feature = "std"))] impl_from_error!(getrandom::Error, BcryptError::Rand); impl fmt::Display for BcryptError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - #[cfg(feature = "std")] - BcryptError::Io(ref err) => write!(f, "IO error: {}", err), - #[cfg(any(feature = "alloc", feature = "std"))] - BcryptError::InvalidCost(ref cost) => write!(f, "Invalid Cost: {}", cost), BcryptError::CostNotAllowed(ref cost) => write!( f, "Cost needs to be between {} and {}, got {}", @@ -64,13 +50,7 @@ impl fmt::Display for BcryptError { cost ), #[cfg(any(feature = "alloc", feature = "std"))] - BcryptError::InvalidPrefix(ref prefix) => write!(f, "Invalid Prefix: {}", prefix), - #[cfg(any(feature = "alloc", feature = "std"))] - BcryptError::InvalidHash(ref hash) => write!(f, "Invalid hash: {}", hash), - BcryptError::InvalidBase64(ref err) => write!(f, "Base64 error: {}", err), - BcryptError::InvalidSaltLen(len) => { - write!(f, "Invalid salt len: expected 16, received {}", len) - } + BcryptError::InvalidHash(ref reason) => write!(f, "Invalid hash: {}", reason), #[cfg(any(feature = "alloc", feature = "std"))] BcryptError::Rand(ref err) => write!(f, "Rand error: {}", err), BcryptError::Truncation(len) => { @@ -84,14 +64,9 @@ impl fmt::Display for BcryptError { impl error::Error for BcryptError { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match *self { - BcryptError::Io(ref err) => Some(err), - BcryptError::InvalidCost(_) - | BcryptError::CostNotAllowed(_) - | BcryptError::InvalidPrefix(_) + BcryptError::CostNotAllowed(_) | BcryptError::InvalidHash(_) - | BcryptError::InvalidSaltLen(_) | BcryptError::Truncation(_) => None, - BcryptError::InvalidBase64(ref err) => Some(err), BcryptError::Rand(ref err) => Some(err), } } diff --git a/src/lib.rs b/src/lib.rs index c147580..1833145 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -157,25 +157,27 @@ fn split_hash(hash: &str) -> BcryptResult { let raw_parts: Vec<_> = hash.split('$').filter(|s| !s.is_empty()).collect(); if raw_parts.len() != 3 { - return Err(BcryptError::InvalidHash(hash.to_string())); + return Err(BcryptError::InvalidHash("the hash format is malformed")); } if raw_parts[0] != "2y" && raw_parts[0] != "2b" && raw_parts[0] != "2a" && raw_parts[0] != "2x" { - return Err(BcryptError::InvalidPrefix(raw_parts[0].to_string())); + return Err(BcryptError::InvalidHash( + "the hash prefix is not a bcrypt prefix", + )); } if let Ok(c) = raw_parts[1].parse::() { parts.cost = c; } else { - return Err(BcryptError::InvalidCost(raw_parts[1].to_string())); + return Err(BcryptError::InvalidHash("the cost value is not a number")); } if raw_parts[2].len() == 53 && raw_parts[2].is_char_boundary(22) { parts.salt = raw_parts[2][..22].chars().collect(); parts.hash = raw_parts[2][22..].chars().collect(); } else { - return Err(BcryptError::InvalidHash(hash.to_string())); + return Err(BcryptError::InvalidHash("the hash format is malformed")); } Ok(parts) @@ -257,17 +259,22 @@ fn _verify>(password: P, hash: &str, err_on_truncation: bool) -> use subtle::ConstantTimeEq; let parts = split_hash(hash)?; - let salt = BASE_64.decode(&parts.salt)?; - let salt_len = salt.len(); + let salt = BASE_64 + .decode(&parts.salt) + .map_err(|_| BcryptError::InvalidHash("the salt part is not valid base64"))?; let generated = _hash_password( password.as_ref(), parts.cost, salt.try_into() - .map_err(|_| BcryptError::InvalidSaltLen(salt_len))?, + .map_err(|_| BcryptError::InvalidHash("the salt length is not 16 bytes"))?, err_on_truncation, )?; - let source_decoded = BASE_64.decode(parts.hash)?; - let generated_decoded = BASE_64.decode(generated.hash)?; + let source_decoded = BASE_64 + .decode(parts.hash) + .map_err(|_| BcryptError::InvalidHash("the hash to verify against is not valid base64"))?; + let generated_decoded = BASE_64.decode(generated.hash).map_err(|_| { + BcryptError::InvalidHash("the generated hash for the password is not valid base64") + })?; Ok(source_decoded.ct_eq(&generated_decoded).into()) }