diff --git a/include/sframe/result.h b/include/sframe/result.h new file mode 100644 index 0000000..a14d179 --- /dev/null +++ b/include/sframe/result.h @@ -0,0 +1,237 @@ +#pragma once + +#include +#include +#include + +#include + +namespace SFRAME_NAMESPACE { + +// Error types to replace exceptions +enum class SFrameErrorType +{ + none = 0, + internal_error, + invalid_parameter_error, + buffer_too_small_error, + crypto_error, + unsupported_ciphersuite_error, + authentication_error, + invalid_key_usage_error, +}; + +class SFrameError +{ +public: + SFrameError() + : type_(SFrameErrorType::none) + , message_() + { + } + + explicit SFrameError(SFrameErrorType type) + : type_(type) + , message_() + { + } + + SFrameError(SFrameErrorType type, std::string message) + : type_(type) + , message_(std::move(message)) + { + } + + // Copy constructor + SFrameError(const SFrameError& other) + : type_(SFrameErrorType::none) + , message_(other.message_) + { + type_ = other.type_; + } + + // Copy assignment + SFrameError& operator=(const SFrameError& other) + { + if (this != &other) { + type_ = other.type_; + message_ = other.message_; + } + return *this; + } + + // Move constructor + SFrameError(SFrameError&& other) noexcept + : type_(other.type_) + , message_(std::move(other.message_)) + { + } + + // Move assignment + SFrameError& operator=(SFrameError&& other) noexcept + { + if (this != &other) { + type_ = other.type_; + message_ = std::move(other.message_); + } + return *this; + } + + SFrameErrorType type() const { return type_; } + + const char* message() const { return message_.c_str(); } + + bool ok() const { return type_ == SFrameErrorType::none; } + +private: + SFrameErrorType type_ = SFrameErrorType::none; + std::string message_; +}; + +// Helper to convert SFrameError to appropriate exception type +void +throw_on_error(const SFrameError& error); + +template +class Result +{ +public: + typedef T element_type; + + static Result ok(const T& value) { return Result(value); } + + static Result ok(T&& value) { return Result(std::move(value)); } + + static Result err(SFrameErrorType error, const std::string& message = "") + { + return Result(SFrameError(error, message)); + } + + static Result err(SFrameError&& error) + { + return Result(std::move(error)); + } + + Result(SFrameError error) + : data_(std::move(error)) + { + } + + Result(const T& value) + : data_(value) + { + } + + Result(T&& value) + : data_(std::move(value)) + { + } + + Result(const Result& other) = delete; + Result& operator=(const Result& other) = delete; + + Result(Result&& other) noexcept + : data_(std::move(other.data_)) + { + } + + Result& operator=(Result&& other) noexcept + { + data_ = std::move(other.data_); + return *this; + } + + template + Result(Result&& other) + : data_(std::move(other.data_)) + { + } + + template + Result& operator=(Result&& other) + { + data_ = std::move(other.data_); + return *this; + } + + T value() { return std::move(std::get(data_)); } + + SFrameError error() + { + if (std::holds_alternative(data_)) { + auto error = std::get(data_); + return error; + } + return SFrameError(); // Default OK error + } + + bool is_ok() const { return std::holds_alternative(data_); } + + bool is_err() const { return std::holds_alternative(data_); } + +private: + std::variant data_; +}; + +// Specialization for Result +template<> +class Result +{ +public: + typedef void element_type; + + static Result ok() { return Result(); } + + static Result err(SFrameErrorType error, + const std::string& message = "") + { + return Result(SFrameError(error, message)); + } + + static Result err(SFrameError&& error) + { + return Result(std::move(error)); + } + + Result() + : is_ok_(true) + , error_() + { + } + + Result(SFrameError error) + : is_ok_(false) + , error_(std::move(error)) + { + } + + Result(const Result& other) = delete; + Result& operator=(const Result& other) = delete; + + Result(Result&& other) noexcept + : is_ok_(other.is_ok_) + , error_(std::move(other.error_)) + { + } + + Result& operator=(Result&& other) noexcept + { + is_ok_ = other.is_ok_; + error_ = std::move(other.error_); + return *this; + } + + void value() { /* void has no value to move */ } + + SFrameError error() { return error_; } + + bool is_ok() const { return is_ok_; } + + bool is_err() const { return !is_ok_; } + +private: + bool is_ok_; + SFrameError error_; +}; + +} // namespace SFRAME_NAMESPACE \ No newline at end of file diff --git a/src/header.cpp b/src/header.cpp index dcf022c..69c89d6 100644 --- a/src/header.cpp +++ b/src/header.cpp @@ -28,18 +28,19 @@ encode_uint(uint64_t val, output_bytes buffer) } } -static uint64_t +static Result decode_uint(input_bytes data) { if (!data.empty() && data[0] == 0) { - throw invalid_parameter_error("Integer is not minimally encoded"); + return Result::err(SFrameErrorType::invalid_parameter_error, + "Integer is not minimally encoded"); } uint64_t val = 0; for (size_t i = 0; i < data.size(); i++) { val = (val << 8) + static_cast(data[i]); } - return val; + return Result::ok(val); } struct ValueOrLength @@ -77,17 +78,23 @@ struct ValueOrLength return value_or_length + 1; } - std::tuple read(input_bytes data) const + Result> read(input_bytes data) const { if (!is_length) { // Nothing to read; value is already in config byte - return { value_or_length, data }; + return Result>::ok( + std::make_tuple(value_or_length, data)); } const auto size = value_size(); - const auto value = decode_uint(data.subspan(0, size)); + auto value_result = decode_uint(data.subspan(0, size)); + if (!value_result.is_ok()) { + return value_result.error(); + } + const auto value = value_result.value(); const auto remaining = data.subspan(size); - return { value, remaining }; + return Result>::ok( + std::make_tuple(value, remaining)); } private: @@ -140,20 +147,32 @@ Header::Header(KeyID key_id_in, Counter counter_in) encode_uint(counter, after_kid.subspan(0, cfg.ctr.value_size())); } -Header +Result
Header::parse(input_bytes buffer) { if (buffer.size() < Header::min_size) { - throw buffer_too_small_error("Ciphertext too small to decode header"); + return Result
::err(SFrameErrorType::buffer_too_small_error, + "Ciphertext too small to decode header"); } const auto cfg = ConfigByte{ buffer[0] }; const auto after_cfg = buffer.subspan(1); - const auto [key_id, after_kid] = cfg.kid.read(after_cfg); - const auto [counter, _] = cfg.ctr.read(after_kid); + + auto read_result = cfg.kid.read(after_cfg); + if (!read_result.is_ok()) { + return read_result.error(); + } + auto [key_id, after_kid] = read_result.value(); + + read_result = cfg.ctr.read(after_kid); + if (!read_result.is_ok()) { + return read_result.error(); + } + auto [counter, _] = read_result.value(); + const auto encoded = buffer.subspan(0, cfg.encoded_size()); - return Header(key_id, counter, encoded); + return Result
::ok(Header(key_id, counter, encoded)); } Header::Header(KeyID key_id_in, Counter counter_in, input_bytes encoded_in) diff --git a/src/header.h b/src/header.h index 2c881eb..2adab6b 100644 --- a/src/header.h +++ b/src/header.h @@ -1,5 +1,6 @@ #pragma once +#include #include namespace SFRAME_NAMESPACE { @@ -14,7 +15,7 @@ class Header const Counter counter; Header(KeyID key_id_in, Counter counter_in); - static Header parse(input_bytes buffer); + static Result
parse(input_bytes buffer); input_bytes encoded() const { return _encoded; } size_t size() const { return _encoded.size(); } diff --git a/src/result.cpp b/src/result.cpp new file mode 100644 index 0000000..e1f7a17 --- /dev/null +++ b/src/result.cpp @@ -0,0 +1,29 @@ +#include +#include + +namespace SFRAME_NAMESPACE { + +void +throw_on_error(const SFrameError& error) +{ + switch (error.type()) { + case SFrameErrorType::none: + return; + case SFrameErrorType::buffer_too_small_error: + throw buffer_too_small_error(error.message()); + case SFrameErrorType::invalid_parameter_error: + throw invalid_parameter_error(error.message()); + case SFrameErrorType::crypto_error: + throw crypto_error(); + case SFrameErrorType::unsupported_ciphersuite_error: + throw unsupported_ciphersuite_error(); + case SFrameErrorType::authentication_error: + throw authentication_error(); + case SFrameErrorType::invalid_key_usage_error: + throw invalid_key_usage_error(error.message()); + default: + throw std::runtime_error(error.message()); + } +} + +} // namespace SFRAME_NAMESPACE diff --git a/src/sframe.cpp b/src/sframe.cpp index 7616bc1..08012ea 100644 --- a/src/sframe.cpp +++ b/src/sframe.cpp @@ -155,7 +155,11 @@ Context::unprotect(output_bytes plaintext, input_bytes ciphertext, input_bytes metadata) { - const auto header = Header::parse(ciphertext); + auto header_result = Header::parse(ciphertext); + if (!header_result.is_ok()) { + throw_on_error(header_result.error()); + } + const auto header = header_result.value(); const auto inner_ciphertext = ciphertext.subspan(header.size()); return Context::unprotect_inner( header, plaintext, inner_ciphertext, metadata); @@ -271,7 +275,11 @@ MLSContext::unprotect(output_bytes plaintext, input_bytes ciphertext, input_bytes metadata) { - const auto header = Header::parse(ciphertext); + auto header_result = Header::parse(ciphertext); + if (!header_result.is_ok()) { + throw_on_error(header_result.error()); + } + const auto header = header_result.value(); const auto inner_ciphertext = ciphertext.subspan(header.size()); ensure_key(header.key_id, KeyUsage::unprotect); diff --git a/test/header.cpp b/test/header.cpp index 7c8cab1..1231daa 100644 --- a/test/header.cpp +++ b/test/header.cpp @@ -37,7 +37,9 @@ TEST_CASE("Header Known-Answer") for (const auto& tc : cases) { // Decode - const auto decoded = Header::parse(tc.encoding); + auto decode_result = Header::parse(tc.encoding); + REQUIRE(decode_result.is_ok()); + const auto decoded = decode_result.value(); REQUIRE(decoded.key_id == tc.key_id); REQUIRE(decoded.counter == tc.counter); REQUIRE(decoded.size() == tc.encoding.size()); diff --git a/test/vectors.cpp b/test/vectors.cpp index a3e1c45..2bdac72 100644 --- a/test/vectors.cpp +++ b/test/vectors.cpp @@ -65,7 +65,9 @@ struct HeaderTestVector void verify() const { // Decode - const auto decoded = Header::parse(encoded); + auto decode_result = Header::parse(encoded); + REQUIRE(decode_result.is_ok()); + const auto decoded = decode_result.value(); REQUIRE(decoded.key_id == kid); REQUIRE(decoded.counter == ctr); REQUIRE(decoded.size() == encoded.data.size());