Move most of generate_mldsa_tests.py into a module

We are moving the command line entry point to the consuming branch.

The move will be completed when the consuming branch no longer needs the
script to exist in the framework.
https://github.com/Mbed-TLS/mbedtls-framework/issues/294

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
Gilles Peskine
2026-03-31 13:21:45 +02:00
parent 7f537471bd
commit 72f178bdbd
2 changed files with 196 additions and 186 deletions
+6 -186
View File
@@ -1,203 +1,23 @@
#!/usr/bin/env python3
"""Generate ML-DSA test cases.
This is a transitional script that does not handle different feature sets
in different states of TF-PSA-Crypto. The live version of this script
is `scripts/maintainer/generate_mldsa_tests.py` in TF-PSA-Crypto.
"""
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
import sys
from typing import Iterable, List, Optional
# pip install dilithium-py
import dilithium_py.ml_dsa #type: ignore
import scripts_path # pylint: disable=unused-import
from mbedtls_framework import test_case
from mbedtls_framework import test_data_generation
# ML_DSA instances for pure ML-DSA
PURE = {
#44: dilithium_py.ml_dsa.ML_DSA_44,
#65: dilithium_py.ml_dsa.ML_DSA_65,
87: dilithium_py.ml_dsa.ML_DSA_87,
}
# ML_DSA instances for HashML-DSA
HASH = {
#44: dilithium_py.ml_dsa.HASH_ML_DSA_44_WITH_SHA512,
#65: dilithium_py.ml_dsa.HASH_ML_DSA_65_WITH_SHA512,
87: dilithium_py.ml_dsa.HASH_ML_DSA_87_WITH_SHA512,
}
# Seeds (i.e. private keys) to test with.
SEEDS = [
b'There was once upon a time a ...',
b'\x00' * 32,
]
class Key:
"""An MLDSA key pair."""
#pylint: disable=too-few-public-methods
def __init__(self, kl: int, seed: bytes) -> None:
self.kl = kl #pylint: disable=invalid-name
self.seed = seed
self.public, self.secret = PURE[kl]._keygen_internal(seed)
def sign_message(self, message: bytes, deterministic: bool) -> bytes:
PURE[self.kl].set_drbg_seed(bytes(48))
return PURE[self.kl].sign(self.secret, message,
deterministic=deterministic)
# Key pairs to test with.
KEYS = {kl: [Key(kl, seed) for seed in SEEDS]
for kl in sorted(PURE.keys())}
# Input messages to test with.
MESSAGES = [
(b'This is a test', ''),
(b'', 'empty message'),
(b'\x00', '"\\x00"'),
(b'\x01', '"\\x01"'),
(b'ACBDEFGHIJ' * 100, '1000B'),
]
class API:
"""Abstract base class for the interface of the test functions."""
@classmethod
def function(cls, func: str, kl: int) -> str:
raise NotImplementedError
@classmethod
def metadata_arguments(cls,
kl: int,
pair: bool,
deterministic: Optional[bool]) -> List[str]:
raise NotImplementedError
@classmethod
def final_arguments(cls) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return True
class PQCPAPI(API):
"""Test mldsa-native entry points."""
@classmethod
def function(cls, func: str, kl: int) -> str:
return f'{func}_{kl}'
@classmethod
def metadata_arguments(cls,
_kl: int,
_pair: bool,
_deterministic: Optional[bool]) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return False
def one_mldsa_key_pair_from_seed(key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for mldsa-native keypair_internal()."""
tc = test_case.TestCase()
tc.set_function(f'key_pair_from_seed_{key.kl}')
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments([
test_case.hex_string(key.seed),
test_case.hex_string(key.secret),
test_case.hex_string(key.public),
])
tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}')
return tc
def gen_pqcp_key_management(kl: int) -> Iterable[test_case.TestCase]:
"""Generate test cases for mldsa-native keypair_internal()."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_key_pair_from_seed(key, f'key#{i}')
def one_mldsa_sign_deterministic_pure(api: API,
key: Key,
message: bytes,
descr: str) -> test_case.TestCase:
"""Construct one test case for deterministic signature."""
signature = key.sign_message(message, deterministic=True)
tc = test_case.TestCase()
tc.set_function(api.function('sign_deterministic_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, True, True) + [
test_case.hex_string(key.seed if api.secret_is_seed() else key.secret),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
tc.set_description(f'MLDSA-{key.kl} sign deterministic {descr}')
return tc
def one_mldsa_verify_pure(api: API,
key: Key,
message: bytes,
deterministic: bool,
descr: str) -> test_case.TestCase:
"""Construct one test case for verification.
When deterministic is true, the test case is a deterministic signature.
When deterministic is false, the test case is some other valid signature.
"""
signature = key.sign_message(message, deterministic=deterministic)
tc = test_case.TestCase()
tc.set_function(api.function('verify_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, False, True) + [
test_case.hex_string(key.public),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
variant = "deterministic" if deterministic else "randomized"
tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}')
return tc
def gen_mldsa_pure(api: API, kl: int) -> Iterable[test_case.TestCase]:
"""Generate all test cases for pure ML-DSA signature and verification."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_sign_deterministic_pure(api, key, MESSAGES[0][0],
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_sign_deterministic_pure(api, KEYS[kl][0], message,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], True,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, True,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], False,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, False,
f'key#1 {descr}')
def gen_pqcp_mldsa_all() -> Iterable[test_case.TestCase]:
"""Generate all test cases for mldsa-native."""
api = PQCPAPI()
for kl in sorted(KEYS.keys()):
yield from gen_pqcp_key_management(kl)
yield from gen_mldsa_pure(api, kl)
from mbedtls_maintainer import mldsa_test_generator
class MLDSATestGenerator(test_data_generation.TestGenerator):
"""Generate test cases for ML-DSA."""
def __init__(self, settings) -> None:
self.targets = {
'test_suite_pqcp_mldsa.dilithium_py': gen_pqcp_mldsa_all,
'test_suite_pqcp_mldsa.dilithium_py': mldsa_test_generator.gen_pqcp_mldsa_all,
}
super().__init__(settings)
@@ -0,0 +1,190 @@
"""Generate ML-DSA test cases.
"""
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
from typing import Iterable, List, Optional
# pip install dilithium-py
import dilithium_py.ml_dsa #type: ignore
import scripts_path # pylint: disable=unused-import
from mbedtls_framework import test_case
# ML_DSA instances for pure ML-DSA
PURE = {
#44: dilithium_py.ml_dsa.ML_DSA_44,
#65: dilithium_py.ml_dsa.ML_DSA_65,
87: dilithium_py.ml_dsa.ML_DSA_87,
}
# ML_DSA instances for HashML-DSA
HASH = {
#44: dilithium_py.ml_dsa.HASH_ML_DSA_44_WITH_SHA512,
#65: dilithium_py.ml_dsa.HASH_ML_DSA_65_WITH_SHA512,
87: dilithium_py.ml_dsa.HASH_ML_DSA_87_WITH_SHA512,
}
# Seeds (i.e. private keys) to test with.
SEEDS = [
b'There was once upon a time a ...',
b'\x00' * 32,
]
class Key:
"""An MLDSA key pair."""
#pylint: disable=too-few-public-methods
def __init__(self, kl: int, seed: bytes) -> None:
self.kl = kl #pylint: disable=invalid-name
self.seed = seed
self.public, self.secret = PURE[kl]._keygen_internal(seed)
def sign_message(self, message: bytes, deterministic: bool) -> bytes:
PURE[self.kl].set_drbg_seed(bytes(48))
return PURE[self.kl].sign(self.secret, message,
deterministic=deterministic)
# Key pairs to test with.
KEYS = {kl: [Key(kl, seed) for seed in SEEDS]
for kl in sorted(PURE.keys())}
# Input messages to test with.
MESSAGES = [
(b'This is a test', ''),
(b'', 'empty message'),
(b'\x00', '"\\x00"'),
(b'\x01', '"\\x01"'),
(b'ACBDEFGHIJ' * 100, '1000B'),
]
class API:
"""Abstract base class for the interface of the test functions."""
@classmethod
def function(cls, func: str, kl: int) -> str:
raise NotImplementedError
@classmethod
def metadata_arguments(cls,
kl: int,
pair: bool,
deterministic: Optional[bool]) -> List[str]:
raise NotImplementedError
@classmethod
def final_arguments(cls) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return True
class PQCPAPI(API):
"""Test mldsa-native entry points."""
@classmethod
def function(cls, func: str, kl: int) -> str:
return f'{func}_{kl}'
@classmethod
def metadata_arguments(cls,
_kl: int,
_pair: bool,
_deterministic: Optional[bool]) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return False
def one_mldsa_key_pair_from_seed(key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for mldsa-native keypair_internal()."""
tc = test_case.TestCase()
tc.set_function(f'key_pair_from_seed_{key.kl}')
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments([
test_case.hex_string(key.seed),
test_case.hex_string(key.secret),
test_case.hex_string(key.public),
])
tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}')
return tc
def gen_pqcp_key_management(kl: int) -> Iterable[test_case.TestCase]:
"""Generate test cases for mldsa-native keypair_internal()."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_key_pair_from_seed(key, f'key#{i}')
def one_mldsa_sign_deterministic_pure(api: API,
key: Key,
message: bytes,
descr: str) -> test_case.TestCase:
"""Construct one test case for deterministic signature."""
signature = key.sign_message(message, deterministic=True)
tc = test_case.TestCase()
tc.set_function(api.function('sign_deterministic_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, True, True) + [
test_case.hex_string(key.seed if api.secret_is_seed() else key.secret),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
tc.set_description(f'MLDSA-{key.kl} sign deterministic {descr}')
return tc
def one_mldsa_verify_pure(api: API,
key: Key,
message: bytes,
deterministic: bool,
descr: str) -> test_case.TestCase:
"""Construct one test case for verification.
When deterministic is true, the test case is a deterministic signature.
When deterministic is false, the test case is some other valid signature.
"""
signature = key.sign_message(message, deterministic=deterministic)
tc = test_case.TestCase()
tc.set_function(api.function('verify_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, False, True) + [
test_case.hex_string(key.public),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
variant = "deterministic" if deterministic else "randomized"
tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}')
return tc
def gen_mldsa_pure(api: API, kl: int) -> Iterable[test_case.TestCase]:
"""Generate all test cases for pure ML-DSA signature and verification."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_sign_deterministic_pure(api, key, MESSAGES[0][0],
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_sign_deterministic_pure(api, KEYS[kl][0], message,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], True,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, True,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], False,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, False,
f'key#1 {descr}')
def gen_pqcp_mldsa_all() -> Iterable[test_case.TestCase]:
"""Generate all test cases for mldsa-native."""
api = PQCPAPI()
for kl in sorted(KEYS.keys()):
yield from gen_pqcp_key_management(kl)
yield from gen_mldsa_pure(api, kl)