Files
mbedtls-framework/util/generate_mldsa_tests.py
T
Gilles Peskine d6dce9b3f8 Make the output deterministic
Tests using randomized signatures were currently randomized. We need the
output of the script to be reproducible, so this won't do.

Force randomized signatures to use a specific RNG which is reset before
constructing each test case. This way, the script's output is reproducible.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
2026-02-25 19:02:22 +01:00

207 lines
7.1 KiB
Python
Executable File

#!/usr/bin/env python3
"""Generate ML-DSA test cases.
"""
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
import sys
from typing import Iterable, List, Optional
# pip install dilithium-py
import dilithium_py.ml_dsa #type: ignore
import scripts_path # pylint: disable=unused-import
from mbedtls_framework import test_case
from mbedtls_framework import test_data_generation
# ML_DSA instances for pure ML-DSA
PURE = {
#44: dilithium_py.ml_dsa.ML_DSA_44,
#65: dilithium_py.ml_dsa.ML_DSA_65,
87: dilithium_py.ml_dsa.ML_DSA_87,
}
# ML_DSA instances for HashML-DSA
HASH = {
#44: dilithium_py.ml_dsa.HASH_ML_DSA_44_WITH_SHA512,
#65: dilithium_py.ml_dsa.HASH_ML_DSA_65_WITH_SHA512,
87: dilithium_py.ml_dsa.HASH_ML_DSA_87_WITH_SHA512,
}
# Seeds (i.e. private keys) to test with.
SEEDS = [
b'There was once upon a time a ...',
b'\x00' * 32,
]
class Key:
"""An MLDSA key pair."""
#pylint: disable=too-few-public-methods
def __init__(self, kl: int, seed: bytes) -> None:
self.kl = kl #pylint: disable=invalid-name
self.seed = seed
self.public, self.secret = PURE[kl]._keygen_internal(seed)
def sign_message(self, message: bytes, deterministic: bool) -> bytes:
PURE[self.kl].set_drbg_seed(bytes(48))
return PURE[self.kl].sign(self.secret, message,
deterministic=deterministic)
# Key pairs to test with.
KEYS = {kl: [Key(kl, seed) for seed in SEEDS]
for kl in sorted(PURE.keys())}
# Input messages to test with.
MESSAGES = [
(b'This is a test', ''),
(b'', 'empty message'),
(b'\x00', '"\\x00"'),
(b'\x01', '"\\x01"'),
(b'ACBDEFGHIJ' * 100, '1000B'),
]
class API:
"""Abstract base class for the interface of the test functions."""
@classmethod
def function(cls, func: str, kl: int) -> str:
raise NotImplementedError
@classmethod
def metadata_arguments(cls,
kl: int,
pair: bool,
deterministic: Optional[bool]) -> List[str]:
raise NotImplementedError
@classmethod
def final_arguments(cls) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return True
class PQCPAPI(API):
"""Test mldsa-native entry points."""
@classmethod
def function(cls, func: str, kl: int) -> str:
return f'{func}_{kl}'
@classmethod
def metadata_arguments(cls,
_kl: int,
_pair: bool,
_deterministic: Optional[bool]) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return False
def one_mldsa_key_pair_from_seed(key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for mldsa-native keypair_internal()."""
tc = test_case.TestCase()
tc.set_function(f'key_pair_from_seed_{key.kl}')
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments([
test_case.hex_string(key.seed),
test_case.hex_string(key.secret),
test_case.hex_string(key.public),
])
tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}')
return tc
def gen_pqcp_key_management(kl: int) -> Iterable[test_case.TestCase]:
"""Generate test cases for mldsa-native keypair_internal()."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_key_pair_from_seed(key, f'key#{i}')
def one_mldsa_sign_deterministic_pure(api: API,
key: Key,
message: bytes,
descr: str) -> test_case.TestCase:
"""Construct one test case for deterministic signature."""
signature = key.sign_message(message, deterministic=True)
tc = test_case.TestCase()
tc.set_function(api.function('sign_deterministic_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, True, True) + [
test_case.hex_string(key.seed if api.secret_is_seed() else key.secret),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
tc.set_description(f'MLDSA-{key.kl} sign deterministic {descr}')
return tc
def one_mldsa_verify_pure(api: API,
key: Key,
message: bytes,
deterministic: bool,
descr: str) -> test_case.TestCase:
"""Construct one test case for verification.
When deterministic is true, the test case is a deterministic signature.
When deterministic is false, the test case is some other valid signature.
"""
signature = key.sign_message(message, deterministic=deterministic)
tc = test_case.TestCase()
tc.set_function(api.function('verify_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, False, True) + [
test_case.hex_string(key.public),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
variant = "deterministic" if deterministic else "randomized"
tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}')
return tc
def gen_mldsa_pure(api: API, kl: int) -> Iterable[test_case.TestCase]:
"""Generate all test cases for pure ML-DSA signature and verification."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_sign_deterministic_pure(api, key, MESSAGES[0][0],
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_sign_deterministic_pure(api, KEYS[kl][0], message,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], True,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, True,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], False,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, False,
f'key#1 {descr}')
def gen_pqcp_mldsa_all() -> Iterable[test_case.TestCase]:
"""Generate all test cases for mldsa-native."""
api = PQCPAPI()
for kl in sorted(KEYS.keys()):
yield from gen_pqcp_key_management(kl)
yield from gen_mldsa_pure(api, kl)
class MLDSATestGenerator(test_data_generation.TestGenerator):
"""Generate test cases for ML-DSA."""
def __init__(self, settings) -> None:
self.targets = {
'test_suite_pqcp_mldsa.dilithium_py': gen_pqcp_mldsa_all,
}
super().__init__(settings)
if __name__ == '__main__':
test_data_generation.main(sys.argv[1:], __doc__, MLDSATestGenerator)