diff --git a/util/mbedtls_maintainer/mldsa_test_generator.py b/util/mbedtls_maintainer/mldsa_test_generator.py index 24226d77c..5b7bc9077 100644 --- a/util/mbedtls_maintainer/mldsa_test_generator.py +++ b/util/mbedtls_maintainer/mldsa_test_generator.py @@ -4,7 +4,7 @@ # Copyright The Mbed TLS Contributors # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later -from typing import Iterable, List, Optional +from typing import Iterator, List, Optional # pip install dilithium-py import dilithium_py.ml_dsa #type: ignore @@ -122,7 +122,7 @@ class Generator: tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}') return tc - def gen_mldsa_pure(self, kl: int) -> Iterable[test_case.TestCase]: + def gen_mldsa_pure(self, kl: int) -> Iterator[test_case.TestCase]: """Generate all test cases for pure ML-DSA signature and verification.""" for i, key in enumerate(KEYS[kl], 1): yield self.one_mldsa_sign_deterministic_pure(key, MESSAGES[0][0], @@ -143,6 +143,16 @@ class Generator: yield self.one_mldsa_verify_pure(KEYS[kl][0], message, False, f'key#1 {descr}') + def gen_key_management(self, kl: int) -> Iterator[test_case.TestCase]: + """Generate all key management test cases for the given parameter set.""" + raise NotImplementedError + + def gen_all(self) -> Iterator[test_case.TestCase]: + """Generate all the tests for this API.""" + for kl in sorted(KEYS.keys()): + yield from self.gen_key_management(kl) + yield from self.gen_mldsa_pure(kl) + class PQCPGenerator(Generator): """Test mldsa-native entry points.""" @@ -177,15 +187,13 @@ class PQCPGenerator(Generator): tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}') return tc - def gen_pqcp_key_management(self, kl: int) -> Iterable[test_case.TestCase]: + def gen_key_management(self, kl: int) -> Iterator[test_case.TestCase]: """Generate test cases for mldsa-native keypair_internal().""" for i, key in enumerate(KEYS[kl], 1): yield self.one_mldsa_key_pair_from_seed(key, f'key#{i}') -def gen_pqcp_mldsa_all() -> Iterable[test_case.TestCase]: +def gen_pqcp_mldsa_all() -> Iterator[test_case.TestCase]: """Generate all test cases for mldsa-native.""" generator = PQCPGenerator() - for kl in sorted(KEYS.keys()): - yield from generator.gen_pqcp_key_management(kl) - yield from generator.gen_mldsa_pure(kl) + yield from generator.gen_all()