mldsa_test_generator refactor: extend API class to Generator

Turn functions that generate test cases into methods of the generator class.
Functions that used to take an API argument are now implemented in the
generic class, and API-specific functions are now methods of the
corresponding API-specific concrete class.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
Gilles Peskine
2026-03-31 13:41:42 +02:00
parent 31bae441f6
commit 38d9d83ae7
+94 -80
View File
@@ -60,8 +60,17 @@ MESSAGES = [
]
class API:
"""Abstract base class for the interface of the test functions."""
def one_mldsa_sign_deterministic_pure(api, *args):
return api.one_mldsa_sign_deterministic_pure(*args)
def one_mldsa_verify_pure(api, *args):
return api.one_mldsa_verify_pure(*args)
def gen_mldsa_pure(api, *args):
return api.gen_mldsa_pure(*args)
class Generator:
"""Abstract base class to generate tests for one API."""
@classmethod
def function(cls, func: str, kl: int) -> str:
@@ -82,68 +91,72 @@ class API:
def secret_is_seed(cls) -> bool:
return True
def one_mldsa_sign_deterministic_pure(api: API,
key: Key,
message: bytes,
descr: str) -> test_case.TestCase:
"""Construct one test case for deterministic signature."""
signature = key.sign_message(message, deterministic=True)
tc = test_case.TestCase()
tc.set_function(api.function('sign_deterministic_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, True, True) + [
test_case.hex_string(key.seed if api.secret_is_seed() else key.secret),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
tc.set_description(f'MLDSA-{key.kl} sign deterministic {descr}')
return tc
def one_mldsa_sign_deterministic_pure(self,
key: Key,
message: bytes,
descr: str) -> test_case.TestCase:
"""Construct one test case for deterministic signature."""
api = self
signature = key.sign_message(message, deterministic=True)
tc = test_case.TestCase()
tc.set_function(api.function('sign_deterministic_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, True, True) + [
test_case.hex_string(key.seed if api.secret_is_seed() else key.secret),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
tc.set_description(f'MLDSA-{key.kl} sign deterministic {descr}')
return tc
def one_mldsa_verify_pure(api: API,
key: Key,
message: bytes,
deterministic: bool,
descr: str) -> test_case.TestCase:
"""Construct one test case for verification.
def one_mldsa_verify_pure(self,
key: Key,
message: bytes,
deterministic: bool,
descr: str) -> test_case.TestCase:
"""Construct one test case for verification.
When deterministic is true, the test case is a deterministic signature.
When deterministic is false, the test case is some other valid signature.
"""
signature = key.sign_message(message, deterministic=deterministic)
tc = test_case.TestCase()
tc.set_function(api.function('verify_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, False, True) + [
test_case.hex_string(key.public),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
variant = "deterministic" if deterministic else "randomized"
tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}')
return tc
When deterministic is true, the test case is a deterministic signature.
When deterministic is false, the test case is some other valid signature.
"""
api = self
signature = key.sign_message(message, deterministic=deterministic)
tc = test_case.TestCase()
tc.set_function(api.function('verify_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, False, True) + [
test_case.hex_string(key.public),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
variant = "deterministic" if deterministic else "randomized"
tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}')
return tc
def gen_mldsa_pure(api: API, kl: int) -> Iterable[test_case.TestCase]:
"""Generate all test cases for pure ML-DSA signature and verification."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_sign_deterministic_pure(api, key, MESSAGES[0][0],
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_sign_deterministic_pure(api, KEYS[kl][0], message,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], True,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, True,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], False,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, False,
f'key#1 {descr}')
def gen_mldsa_pure(self, kl: int) -> Iterable[test_case.TestCase]:
"""Generate all test cases for pure ML-DSA signature and verification."""
api = self
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_sign_deterministic_pure(api, key, MESSAGES[0][0],
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_sign_deterministic_pure(api, KEYS[kl][0], message,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], True,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, True,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], False,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, False,
f'key#1 {descr}')
class PQCPAPI(API):
class PQCPGenerator(Generator):
"""Test mldsa-native entry points."""
@classmethod
@@ -161,29 +174,30 @@ class PQCPAPI(API):
def secret_is_seed(cls) -> bool:
return False
@staticmethod
def one_mldsa_key_pair_from_seed(key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for mldsa-native keypair_internal()."""
tc = test_case.TestCase()
tc.set_function(f'key_pair_from_seed_{key.kl}')
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments([
test_case.hex_string(key.seed),
test_case.hex_string(key.secret),
test_case.hex_string(key.public),
])
tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}')
return tc
def one_mldsa_key_pair_from_seed(key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for mldsa-native keypair_internal()."""
tc = test_case.TestCase()
tc.set_function(f'key_pair_from_seed_{key.kl}')
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments([
test_case.hex_string(key.seed),
test_case.hex_string(key.secret),
test_case.hex_string(key.public),
])
tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}')
return tc
def gen_pqcp_key_management(self, kl: int) -> Iterable[test_case.TestCase]:
"""Generate test cases for mldsa-native keypair_internal()."""
for i, key in enumerate(KEYS[kl], 1):
yield self.one_mldsa_key_pair_from_seed(key, f'key#{i}')
def gen_pqcp_key_management(kl: int) -> Iterable[test_case.TestCase]:
"""Generate test cases for mldsa-native keypair_internal()."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_key_pair_from_seed(key, f'key#{i}')
def gen_pqcp_mldsa_all() -> Iterable[test_case.TestCase]:
"""Generate all test cases for mldsa-native."""
api = PQCPAPI()
generator = PQCPGenerator()
for kl in sorted(KEYS.keys()):
yield from gen_pqcp_key_management(kl)
yield from gen_mldsa_pure(api, kl)
yield from generator.gen_pqcp_key_management(kl)
yield from gen_mldsa_pure(generator, kl)