Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 11 additions & 36 deletions src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
#[cfg(any(feature = "alloc", feature = "std"))]
use alloc::string::String;
use core::fmt;

#[cfg(feature = "std")]
use std::error;
#[cfg(feature = "std")]
use std::io;

/// Library generic result type.
pub type BcryptResult<T> = Result<T, BcryptError>;
Expand All @@ -14,22 +10,19 @@ pub type BcryptResult<T> = Result<T, BcryptError>;
/// All the errors we can encounter while hashing/verifying
/// passwords
pub enum BcryptError {
#[cfg(feature = "std")]
Io(io::Error),
/// Raised when the cost value is outside of the allowed 4-31 range.
///
/// Cost is provided as an argument to hashing functions, and extracted from the hash in
/// verification functions.
Comment on lines +15 to +16
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pointed out where the cost value comes from because if this error is raised from bcrypt::hash(), it was actually a programmer error, and if you have a hardcoded cost you can just handle that branch with unreachable!(). But during verification it means the hash was malformed which is something you have to deal with always.

CostNotAllowed(u32),
/// Raised when verifying against an incorrectly formatted hash.
#[cfg(any(feature = "alloc", feature = "std"))]
InvalidCost(String),
#[cfg(any(feature = "alloc", feature = "std"))]
InvalidPrefix(String),
#[cfg(any(feature = "alloc", feature = "std"))]
InvalidHash(String),
InvalidSaltLen(usize),
InvalidBase64(base64::DecodeError),
InvalidHash(&'static str),
/// Raised when an error occurs when generating a salt value.
#[cfg(any(feature = "alloc", feature = "std"))]
Rand(getrandom::Error),
/// Return this error if the input contains more than 72 bytes. This variant contains the
/// length of the input in bytes.
/// Only returned when calling `non_truncating_*` functions
/// Raised when the input to a `non_truncating_*` function contains more than 72 bytes.
/// This variant contains the length of the input in bytes.
Truncation(usize),
}

Expand All @@ -43,19 +36,12 @@ macro_rules! impl_from_error {
};
}

impl_from_error!(base64::DecodeError, BcryptError::InvalidBase64);
#[cfg(feature = "std")]
impl_from_error!(io::Error, BcryptError::Io);
#[cfg(any(feature = "alloc", feature = "std"))]
impl_from_error!(getrandom::Error, BcryptError::Rand);

impl fmt::Display for BcryptError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
#[cfg(feature = "std")]
BcryptError::Io(ref err) => write!(f, "IO error: {}", err),
#[cfg(any(feature = "alloc", feature = "std"))]
BcryptError::InvalidCost(ref cost) => write!(f, "Invalid Cost: {}", cost),
BcryptError::CostNotAllowed(ref cost) => write!(
f,
"Cost needs to be between {} and {}, got {}",
Expand All @@ -64,13 +50,7 @@ impl fmt::Display for BcryptError {
cost
),
#[cfg(any(feature = "alloc", feature = "std"))]
BcryptError::InvalidPrefix(ref prefix) => write!(f, "Invalid Prefix: {}", prefix),
#[cfg(any(feature = "alloc", feature = "std"))]
BcryptError::InvalidHash(ref hash) => write!(f, "Invalid hash: {}", hash),
BcryptError::InvalidBase64(ref err) => write!(f, "Base64 error: {}", err),
BcryptError::InvalidSaltLen(len) => {
write!(f, "Invalid salt len: expected 16, received {}", len)
}
BcryptError::InvalidHash(ref reason) => write!(f, "Invalid hash: {}", reason),
#[cfg(any(feature = "alloc", feature = "std"))]
BcryptError::Rand(ref err) => write!(f, "Rand error: {}", err),
BcryptError::Truncation(len) => {
Expand All @@ -84,14 +64,9 @@ impl fmt::Display for BcryptError {
impl error::Error for BcryptError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match *self {
BcryptError::Io(ref err) => Some(err),
BcryptError::InvalidCost(_)
| BcryptError::CostNotAllowed(_)
| BcryptError::InvalidPrefix(_)
BcryptError::CostNotAllowed(_)
| BcryptError::InvalidHash(_)
| BcryptError::InvalidSaltLen(_)
| BcryptError::Truncation(_) => None,
BcryptError::InvalidBase64(ref err) => Some(err),
BcryptError::Rand(ref err) => Some(err),
}
}
Expand Down
25 changes: 16 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,25 +157,27 @@ fn split_hash(hash: &str) -> BcryptResult<HashParts> {
let raw_parts: Vec<_> = hash.split('$').filter(|s| !s.is_empty()).collect();

if raw_parts.len() != 3 {
return Err(BcryptError::InvalidHash(hash.to_string()));
return Err(BcryptError::InvalidHash("the hash format is malformed"));
}

if raw_parts[0] != "2y" && raw_parts[0] != "2b" && raw_parts[0] != "2a" && raw_parts[0] != "2x"
{
return Err(BcryptError::InvalidPrefix(raw_parts[0].to_string()));
return Err(BcryptError::InvalidHash(
"the hash prefix is not a bcrypt prefix",
));
}

if let Ok(c) = raw_parts[1].parse::<u32>() {
parts.cost = c;
} else {
return Err(BcryptError::InvalidCost(raw_parts[1].to_string()));
return Err(BcryptError::InvalidHash("the cost value is not a number"));
}

if raw_parts[2].len() == 53 && raw_parts[2].is_char_boundary(22) {
parts.salt = raw_parts[2][..22].chars().collect();
parts.hash = raw_parts[2][22..].chars().collect();
} else {
return Err(BcryptError::InvalidHash(hash.to_string()));
return Err(BcryptError::InvalidHash("the hash format is malformed"));
}

Ok(parts)
Expand Down Expand Up @@ -257,17 +259,22 @@ fn _verify<P: AsRef<[u8]>>(password: P, hash: &str, err_on_truncation: bool) ->
use subtle::ConstantTimeEq;

let parts = split_hash(hash)?;
let salt = BASE_64.decode(&parts.salt)?;
let salt_len = salt.len();
let salt = BASE_64
.decode(&parts.salt)
.map_err(|_| BcryptError::InvalidHash("the salt part is not valid base64"))?;
let generated = _hash_password(
password.as_ref(),
parts.cost,
salt.try_into()
.map_err(|_| BcryptError::InvalidSaltLen(salt_len))?,
.map_err(|_| BcryptError::InvalidHash("the salt length is not 16 bytes"))?,
err_on_truncation,
)?;
let source_decoded = BASE_64.decode(parts.hash)?;
let generated_decoded = BASE_64.decode(generated.hash)?;
let source_decoded = BASE_64
.decode(parts.hash)
.map_err(|_| BcryptError::InvalidHash("the hash to verify against is not valid base64"))?;
let generated_decoded = BASE_64.decode(generated.hash).map_err(|_| {
BcryptError::InvalidHash("the generated hash for the password is not valid base64")
})?;

Ok(source_decoded.ct_eq(&generated_decoded).into())
}
Expand Down