From 72f178bdbd1457a5470acd3885d8c9fa81e39bc2 Mon Sep 17 00:00:00 2001 From: Gilles Peskine Date: Tue, 31 Mar 2026 13:21:45 +0200 Subject: [PATCH] Move most of generate_mldsa_tests.py into a module We are moving the command line entry point to the consuming branch. The move will be completed when the consuming branch no longer needs the script to exist in the framework. https://github.com/Mbed-TLS/mbedtls-framework/issues/294 Signed-off-by: Gilles Peskine --- util/generate_mldsa_tests.py | 192 +----------------- .../mldsa_test_generator.py | 190 +++++++++++++++++ 2 files changed, 196 insertions(+), 186 deletions(-) create mode 100644 util/mbedtls_maintainer/mldsa_test_generator.py diff --git a/util/generate_mldsa_tests.py b/util/generate_mldsa_tests.py index 4cfe3e763..15854d31d 100755 --- a/util/generate_mldsa_tests.py +++ b/util/generate_mldsa_tests.py @@ -1,203 +1,23 @@ #!/usr/bin/env python3 """Generate ML-DSA test cases. + +This is a transitional script that does not handle different feature sets +in different states of TF-PSA-Crypto. The live version of this script +is `scripts/maintainer/generate_mldsa_tests.py` in TF-PSA-Crypto. """ -# Copyright The Mbed TLS Contributors -# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later - import sys -from typing import Iterable, List, Optional - -# pip install dilithium-py -import dilithium_py.ml_dsa #type: ignore import scripts_path # pylint: disable=unused-import -from mbedtls_framework import test_case from mbedtls_framework import test_data_generation - -# ML_DSA instances for pure ML-DSA -PURE = { - #44: dilithium_py.ml_dsa.ML_DSA_44, - #65: dilithium_py.ml_dsa.ML_DSA_65, - 87: dilithium_py.ml_dsa.ML_DSA_87, -} - -# ML_DSA instances for HashML-DSA -HASH = { - #44: dilithium_py.ml_dsa.HASH_ML_DSA_44_WITH_SHA512, - #65: dilithium_py.ml_dsa.HASH_ML_DSA_65_WITH_SHA512, - 87: dilithium_py.ml_dsa.HASH_ML_DSA_87_WITH_SHA512, -} - -# Seeds (i.e. private keys) to test with. -SEEDS = [ - b'There was once upon a time a ...', - b'\x00' * 32, -] - -class Key: - """An MLDSA key pair.""" - #pylint: disable=too-few-public-methods - - def __init__(self, kl: int, seed: bytes) -> None: - self.kl = kl #pylint: disable=invalid-name - self.seed = seed - self.public, self.secret = PURE[kl]._keygen_internal(seed) - - def sign_message(self, message: bytes, deterministic: bool) -> bytes: - PURE[self.kl].set_drbg_seed(bytes(48)) - return PURE[self.kl].sign(self.secret, message, - deterministic=deterministic) - -# Key pairs to test with. -KEYS = {kl: [Key(kl, seed) for seed in SEEDS] - for kl in sorted(PURE.keys())} - -# Input messages to test with. -MESSAGES = [ - (b'This is a test', ''), - (b'', 'empty message'), - (b'\x00', '"\\x00"'), - (b'\x01', '"\\x01"'), - (b'ACBDEFGHIJ' * 100, '1000B'), -] - - -class API: - """Abstract base class for the interface of the test functions.""" - - @classmethod - def function(cls, func: str, kl: int) -> str: - raise NotImplementedError - - @classmethod - def metadata_arguments(cls, - kl: int, - pair: bool, - deterministic: Optional[bool]) -> List[str]: - raise NotImplementedError - - @classmethod - def final_arguments(cls) -> List[str]: - return [] - - @classmethod - def secret_is_seed(cls) -> bool: - return True - - -class PQCPAPI(API): - """Test mldsa-native entry points.""" - - @classmethod - def function(cls, func: str, kl: int) -> str: - return f'{func}_{kl}' - - @classmethod - def metadata_arguments(cls, - _kl: int, - _pair: bool, - _deterministic: Optional[bool]) -> List[str]: - return [] - - @classmethod - def secret_is_seed(cls) -> bool: - return False - - -def one_mldsa_key_pair_from_seed(key: Key, - descr: str) -> test_case.TestCase: - """Construct one test case for mldsa-native keypair_internal().""" - tc = test_case.TestCase() - tc.set_function(f'key_pair_from_seed_{key.kl}') - tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED']) - tc.set_arguments([ - test_case.hex_string(key.seed), - test_case.hex_string(key.secret), - test_case.hex_string(key.public), - ]) - tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}') - return tc - -def gen_pqcp_key_management(kl: int) -> Iterable[test_case.TestCase]: - """Generate test cases for mldsa-native keypair_internal().""" - for i, key in enumerate(KEYS[kl], 1): - yield one_mldsa_key_pair_from_seed(key, f'key#{i}') - -def one_mldsa_sign_deterministic_pure(api: API, - key: Key, - message: bytes, - descr: str) -> test_case.TestCase: - """Construct one test case for deterministic signature.""" - signature = key.sign_message(message, deterministic=True) - tc = test_case.TestCase() - tc.set_function(api.function('sign_deterministic_pure', key.kl)) - tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED']) - tc.set_arguments(api.metadata_arguments(key.kl, True, True) + [ - test_case.hex_string(key.seed if api.secret_is_seed() else key.secret), - test_case.hex_string(message), - test_case.hex_string(signature), - ] + api.final_arguments()) - tc.set_description(f'MLDSA-{key.kl} sign deterministic {descr}') - return tc - -def one_mldsa_verify_pure(api: API, - key: Key, - message: bytes, - deterministic: bool, - descr: str) -> test_case.TestCase: - """Construct one test case for verification. - - When deterministic is true, the test case is a deterministic signature. - When deterministic is false, the test case is some other valid signature. - """ - signature = key.sign_message(message, deterministic=deterministic) - tc = test_case.TestCase() - tc.set_function(api.function('verify_pure', key.kl)) - tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED']) - tc.set_arguments(api.metadata_arguments(key.kl, False, True) + [ - test_case.hex_string(key.public), - test_case.hex_string(message), - test_case.hex_string(signature), - ] + api.final_arguments()) - variant = "deterministic" if deterministic else "randomized" - tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}') - return tc - -def gen_mldsa_pure(api: API, kl: int) -> Iterable[test_case.TestCase]: - """Generate all test cases for pure ML-DSA signature and verification.""" - for i, key in enumerate(KEYS[kl], 1): - yield one_mldsa_sign_deterministic_pure(api, key, MESSAGES[0][0], - f'key#{i}') - for message, descr in MESSAGES[1:]: - yield one_mldsa_sign_deterministic_pure(api, KEYS[kl][0], message, - f'key#1 {descr}') - for i, key in enumerate(KEYS[kl], 1): - yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], True, - f'key#{i}') - for message, descr in MESSAGES[1:]: - yield one_mldsa_verify_pure(api, KEYS[kl][0], message, True, - f'key#1 {descr}') - for i, key in enumerate(KEYS[kl], 1): - yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], False, - f'key#{i}') - for message, descr in MESSAGES[1:]: - yield one_mldsa_verify_pure(api, KEYS[kl][0], message, False, - f'key#1 {descr}') - -def gen_pqcp_mldsa_all() -> Iterable[test_case.TestCase]: - """Generate all test cases for mldsa-native.""" - api = PQCPAPI() - for kl in sorted(KEYS.keys()): - yield from gen_pqcp_key_management(kl) - yield from gen_mldsa_pure(api, kl) +from mbedtls_maintainer import mldsa_test_generator class MLDSATestGenerator(test_data_generation.TestGenerator): """Generate test cases for ML-DSA.""" def __init__(self, settings) -> None: self.targets = { - 'test_suite_pqcp_mldsa.dilithium_py': gen_pqcp_mldsa_all, + 'test_suite_pqcp_mldsa.dilithium_py': mldsa_test_generator.gen_pqcp_mldsa_all, } super().__init__(settings) diff --git a/util/mbedtls_maintainer/mldsa_test_generator.py b/util/mbedtls_maintainer/mldsa_test_generator.py new file mode 100644 index 000000000..384fd8de6 --- /dev/null +++ b/util/mbedtls_maintainer/mldsa_test_generator.py @@ -0,0 +1,190 @@ +"""Generate ML-DSA test cases. +""" + +# Copyright The Mbed TLS Contributors +# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later + +from typing import Iterable, List, Optional + +# pip install dilithium-py +import dilithium_py.ml_dsa #type: ignore + +import scripts_path # pylint: disable=unused-import +from mbedtls_framework import test_case + +# ML_DSA instances for pure ML-DSA +PURE = { + #44: dilithium_py.ml_dsa.ML_DSA_44, + #65: dilithium_py.ml_dsa.ML_DSA_65, + 87: dilithium_py.ml_dsa.ML_DSA_87, +} + +# ML_DSA instances for HashML-DSA +HASH = { + #44: dilithium_py.ml_dsa.HASH_ML_DSA_44_WITH_SHA512, + #65: dilithium_py.ml_dsa.HASH_ML_DSA_65_WITH_SHA512, + 87: dilithium_py.ml_dsa.HASH_ML_DSA_87_WITH_SHA512, +} + +# Seeds (i.e. private keys) to test with. +SEEDS = [ + b'There was once upon a time a ...', + b'\x00' * 32, +] + +class Key: + """An MLDSA key pair.""" + #pylint: disable=too-few-public-methods + + def __init__(self, kl: int, seed: bytes) -> None: + self.kl = kl #pylint: disable=invalid-name + self.seed = seed + self.public, self.secret = PURE[kl]._keygen_internal(seed) + + def sign_message(self, message: bytes, deterministic: bool) -> bytes: + PURE[self.kl].set_drbg_seed(bytes(48)) + return PURE[self.kl].sign(self.secret, message, + deterministic=deterministic) + +# Key pairs to test with. +KEYS = {kl: [Key(kl, seed) for seed in SEEDS] + for kl in sorted(PURE.keys())} + +# Input messages to test with. +MESSAGES = [ + (b'This is a test', ''), + (b'', 'empty message'), + (b'\x00', '"\\x00"'), + (b'\x01', '"\\x01"'), + (b'ACBDEFGHIJ' * 100, '1000B'), +] + + +class API: + """Abstract base class for the interface of the test functions.""" + + @classmethod + def function(cls, func: str, kl: int) -> str: + raise NotImplementedError + + @classmethod + def metadata_arguments(cls, + kl: int, + pair: bool, + deterministic: Optional[bool]) -> List[str]: + raise NotImplementedError + + @classmethod + def final_arguments(cls) -> List[str]: + return [] + + @classmethod + def secret_is_seed(cls) -> bool: + return True + + +class PQCPAPI(API): + """Test mldsa-native entry points.""" + + @classmethod + def function(cls, func: str, kl: int) -> str: + return f'{func}_{kl}' + + @classmethod + def metadata_arguments(cls, + _kl: int, + _pair: bool, + _deterministic: Optional[bool]) -> List[str]: + return [] + + @classmethod + def secret_is_seed(cls) -> bool: + return False + + +def one_mldsa_key_pair_from_seed(key: Key, + descr: str) -> test_case.TestCase: + """Construct one test case for mldsa-native keypair_internal().""" + tc = test_case.TestCase() + tc.set_function(f'key_pair_from_seed_{key.kl}') + tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED']) + tc.set_arguments([ + test_case.hex_string(key.seed), + test_case.hex_string(key.secret), + test_case.hex_string(key.public), + ]) + tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}') + return tc + +def gen_pqcp_key_management(kl: int) -> Iterable[test_case.TestCase]: + """Generate test cases for mldsa-native keypair_internal().""" + for i, key in enumerate(KEYS[kl], 1): + yield one_mldsa_key_pair_from_seed(key, f'key#{i}') + +def one_mldsa_sign_deterministic_pure(api: API, + key: Key, + message: bytes, + descr: str) -> test_case.TestCase: + """Construct one test case for deterministic signature.""" + signature = key.sign_message(message, deterministic=True) + tc = test_case.TestCase() + tc.set_function(api.function('sign_deterministic_pure', key.kl)) + tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED']) + tc.set_arguments(api.metadata_arguments(key.kl, True, True) + [ + test_case.hex_string(key.seed if api.secret_is_seed() else key.secret), + test_case.hex_string(message), + test_case.hex_string(signature), + ] + api.final_arguments()) + tc.set_description(f'MLDSA-{key.kl} sign deterministic {descr}') + return tc + +def one_mldsa_verify_pure(api: API, + key: Key, + message: bytes, + deterministic: bool, + descr: str) -> test_case.TestCase: + """Construct one test case for verification. + + When deterministic is true, the test case is a deterministic signature. + When deterministic is false, the test case is some other valid signature. + """ + signature = key.sign_message(message, deterministic=deterministic) + tc = test_case.TestCase() + tc.set_function(api.function('verify_pure', key.kl)) + tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED']) + tc.set_arguments(api.metadata_arguments(key.kl, False, True) + [ + test_case.hex_string(key.public), + test_case.hex_string(message), + test_case.hex_string(signature), + ] + api.final_arguments()) + variant = "deterministic" if deterministic else "randomized" + tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}') + return tc + +def gen_mldsa_pure(api: API, kl: int) -> Iterable[test_case.TestCase]: + """Generate all test cases for pure ML-DSA signature and verification.""" + for i, key in enumerate(KEYS[kl], 1): + yield one_mldsa_sign_deterministic_pure(api, key, MESSAGES[0][0], + f'key#{i}') + for message, descr in MESSAGES[1:]: + yield one_mldsa_sign_deterministic_pure(api, KEYS[kl][0], message, + f'key#1 {descr}') + for i, key in enumerate(KEYS[kl], 1): + yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], True, + f'key#{i}') + for message, descr in MESSAGES[1:]: + yield one_mldsa_verify_pure(api, KEYS[kl][0], message, True, + f'key#1 {descr}') + for i, key in enumerate(KEYS[kl], 1): + yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], False, + f'key#{i}') + for message, descr in MESSAGES[1:]: + yield one_mldsa_verify_pure(api, KEYS[kl][0], message, False, + f'key#1 {descr}') + +def gen_pqcp_mldsa_all() -> Iterable[test_case.TestCase]: + """Generate all test cases for mldsa-native.""" + api = PQCPAPI() + for kl in sorted(KEYS.keys()): + yield from gen_pqcp_key_management(kl) + yield from gen_mldsa_pure(api, kl)