Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions cryptography_suite/pqc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@

from __future__ import annotations

from typing import Tuple
from ..errors import EncryptionError, DecryptionError
from ..symmetric.kdf import derive_hkdf
from ..utils import KeyVault
import os
import hmac
import base64
import hmac
import os

from cryptography.hazmat.primitives.ciphers.aead import AESGCM

from ..errors import DecryptionError, EncryptionError
from ..symmetric.kdf import derive_hkdf
from ..utils import KeyVault

try: # pragma: no cover - optional dependency
from pqcrypto.kem import ml_kem_512, ml_kem_768, ml_kem_1024
from pqcrypto.sign import ml_dsa_44, ml_dsa_65, ml_dsa_87
Expand Down Expand Up @@ -54,7 +55,7 @@

def generate_kyber_keypair(
level: int = 512, *, sensitive: bool = True
) -> Tuple[bytes, KeyVault | bytes]:
) -> tuple[bytes, KeyVault | bytes]:
"""Generate a Kyber key pair for the given ``level``.

Parameters
Expand All @@ -81,7 +82,7 @@ def kyber_encrypt(
*,
level: int = 512,
raw_output: bool = False,
) -> Tuple[str | bytes, str | bytes]:
) -> tuple[str | bytes, str | bytes]:
"""Encrypt ``plaintext`` using Kyber and AES-GCM.

``level`` selects the ML-KEM security level (512, 768 or 1024).
Expand Down Expand Up @@ -143,28 +144,37 @@ def kyber_decrypt(
except Exception as exc: # pragma: no cover - defensive
raise DecryptionError(f"Invalid shared secret: {exc}") from exc

if len(ciphertext) < ct_size + 12 + 16:
min_ct_len = ct_size + 16 + 12 + 16
if len(ciphertext) < min_ct_len:
raise DecryptionError("Invalid ciphertext")

kem_ct = ciphertext[:ct_size]
salt = ciphertext[ct_size : ct_size + 16]
enc = ciphertext[ct_size + 16 :]
priv = bytes(private_key) if isinstance(private_key, KeyVault) else private_key
ss_check = alg.decrypt(priv, kem_ct)
try:
ss_check = alg.decrypt(priv, kem_ct)
except Exception as exc: # pragma: no cover - defensive
raise DecryptionError("Invalid ciphertext") from exc
if shared_secret is None:
shared_secret = ss_check
elif not hmac.compare_digest(ss_check, shared_secret):
raise DecryptionError("Shared secret mismatch")

key = derive_hkdf(shared_secret, salt, b"kyber-aes-key", 32)
with KeyVault(key) as key_buf:
aesgcm = AESGCM(bytes(key_buf))
nonce = enc[:12]
ct = enc[12:]
return aesgcm.decrypt(nonce, ct, None)


def generate_dilithium_keypair(*, sensitive: bool = True) -> Tuple[bytes, KeyVault | bytes]:
try:
with KeyVault(key) as key_buf:
aesgcm = AESGCM(bytes(key_buf))
nonce = enc[:12]
ct = enc[12:]
return aesgcm.decrypt(nonce, ct, None)
except Exception as exc:
raise DecryptionError("Invalid ciphertext") from exc


def generate_dilithium_keypair(
*, sensitive: bool = True
) -> tuple[bytes, KeyVault | bytes]:
"""Generate a Dilithium key pair using level 2 parameters.

When ``sensitive`` is ``True`` (default) the private key is wrapped in
Expand Down Expand Up @@ -218,7 +228,9 @@ def dilithium_verify(
return False


def generate_sphincs_keypair(*, sensitive: bool = True) -> Tuple[bytes, KeyVault | bytes]:
def generate_sphincs_keypair(
*, sensitive: bool = True
) -> tuple[bytes, KeyVault | bytes]:
"""Generate a SPHINCS+ key pair using a 128-bit security level.

When ``sensitive`` is ``True`` (default) the private key is wrapped in
Expand Down
35 changes: 31 additions & 4 deletions tests/test_pqc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import unittest
from typing import cast

from cryptography_suite.errors import DecryptionError
from cryptography_suite.pqc import (
PQCRYPTO_AVAILABLE,
SPHINCS_AVAILABLE,
generate_kyber_keypair,
kyber_encrypt,
kyber_decrypt,
generate_dilithium_keypair,
dilithium_sign,
dilithium_verify,
generate_dilithium_keypair,
generate_kyber_keypair,
generate_sphincs_keypair,
kyber_decrypt,
kyber_encrypt,
sphincs_sign,
sphincs_verify,
)
Expand All @@ -35,6 +37,31 @@ def test_dilithium_signature(self):
self.assertIsInstance(sig, str)
self.assertTrue(dilithium_verify(pk, msg, sig))

def test_kyber_decrypt_short_ciphertext_raises_decryption_error(self):
_, sk = generate_kyber_keypair(level=512)
with self.assertRaises(DecryptionError):
kyber_decrypt(sk, b"\x00" * 10, level=512)

def test_kyber_decrypt_corrupt_nonce_or_tag_raises_decryption_error(self):
msg = b"nonce/tag corruption test"
pk, sk = generate_kyber_keypair(level=512)
ct, _ = kyber_encrypt(pk, msg, level=512, raw_output=True)
raw_ct = cast(bytes, ct)

# Corrupt nonce byte (first byte after KEM ciphertext + salt).
nonce_corrupt = bytearray(raw_ct)
kem_ct_size = len(raw_ct) - (16 + 12 + len(msg) + 16)
nonce_index = kem_ct_size + 16
nonce_corrupt[nonce_index] ^= 0x01
with self.assertRaises(DecryptionError):
kyber_decrypt(sk, bytes(nonce_corrupt), level=512)

# Corrupt tag byte (last byte of payload).
tag_corrupt = bytearray(raw_ct)
tag_corrupt[-1] ^= 0x01
with self.assertRaises(DecryptionError):
kyber_decrypt(sk, bytes(tag_corrupt), level=512)

@unittest.skipUnless(SPHINCS_AVAILABLE, "SPHINCS+ not available")
def test_sphincs_signature(self):
pk, sk = generate_sphincs_keypair()
Expand Down