diff --git a/util/mbedtls_maintainer/mldsa_test_generator.py b/util/mbedtls_maintainer/mldsa_test_generator.py index 72b20ce35..c7ec7d591 100644 --- a/util/mbedtls_maintainer/mldsa_test_generator.py +++ b/util/mbedtls_maintainer/mldsa_test_generator.py @@ -4,7 +4,9 @@ # Copyright The Mbed TLS Contributors # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later -from typing import Iterator, List, Optional +import collections +import functools +from typing import Callable, Iterator, List, Optional, Sequence, Tuple # pip install dilithium-py import dilithium_py.ml_dsa #type: ignore @@ -41,6 +43,7 @@ class Key: self.seed = seed self.public, self.secret = PURE[kl]._keygen_internal(seed) + @functools.lru_cache(maxsize=9999) def sign_message(self, message: bytes, deterministic: bool) -> bytes: PURE[self.kl].set_drbg_seed(bytes(48)) return PURE[self.kl].sign(self.secret, message, @@ -82,6 +85,35 @@ class Generator: def secret_is_seed(cls) -> bool: return True + @staticmethod + def message_for_length(length: int) -> bytes: + # b'ABCDE...' wrapping around and repeating every 256 bytes + return bytes(b & 0xff for b in range(65, 65 + length)) + + def chunks_for_lengths(self, + lengths: Sequence[int], + arity: Optional[int] = None, + ) -> Tuple[bytes, List[bytes]]: + """Construct a message split in chunks of the given lengths. + + The content of the message only depends on the total length. + + If `arity` is specified and less than `len(lengths)`, pad the list + of chunks to that number. If `arity` is specified and larger than + `len(lengths)`, raise an exception. + """ + total_length = sum(lengths) + message = self.message_for_length(total_length) + chunks = [] + offset = 0 + for n in lengths: + chunks.append(message[offset:offset + n]) + offset += n + if arity is not None: + assert len(lengths) <= arity + chunks += [b''] * (arity - len(lengths)) + return (message, chunks) + def one_mldsa_public_key_from_seed(self, key: Key, descr: str) -> test_case.TestCase: """Construct one test case for driver export_public_key().""" @@ -242,14 +274,83 @@ class DriverGenerator(Generator): return arguments @classmethod - def final_arguments(cls) -> List[str]: - return ['PSA_SUCCESS'] + def final_arguments(cls, status: str = 'PSA_SUCCESS') -> List[str]: + return [status] def gen_key_management(self, kl: int) -> Iterator[test_case.TestCase]: """Generate test cases for driver export_public_key().""" for i, key in enumerate(KEYS[kl], 1): yield self.one_mldsa_public_key_from_seed(key, f'key#{i}') + MULTIPART_ARITY = 3 + + def one_multipart(self, function: str, key: Key, + lengths: Sequence[int], + tweak_signature: Optional[Callable[[bytes], Tuple[bytes, str, str]]] = None, + ) -> test_case.TestCase: + """Construct one test case for a multipart operation. + + The number of message chunks must be at most MULTIPART_ARITY. + """ + message, chunks = self.chunks_for_lengths(lengths, self.MULTIPART_ARITY) + descr = '+'.join(map(str, lengths)) if lengths else 'empty (no update)' + type_is_pair = not function.startswith('verif') + deterministic = True + actual_signature = key.sign_message(message, deterministic=deterministic) + signature = actual_signature + status = 'PSA_SUCCESS' + more_descr = '' + if tweak_signature is not None: + signature, status, more_descr = tweak_signature(actual_signature) + if more_descr: + more_descr = ', ' + more_descr + tc = test_case.TestCase() + tc.set_function(self.function(function + '_multipart', key.kl)) + tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED']) + tc.set_arguments(self.metadata_arguments(key.kl, type_is_pair, True) + [ + test_case.hex_string(key.seed if type_is_pair else key.public), + str(len(lengths)), + ] + [test_case.hex_string(chunk) for chunk in chunks] + [ + test_case.hex_string(signature), + ] + self.final_arguments(status=status)) + tc.set_description(f'MLDSA-{key.kl} {function} multipart {descr}{more_descr}') + return tc + + MANY_MULTIPART_LENGTHS: Sequence[Sequence[int]] = [ + [], [0], [0, 0], + [1], + [3], [1, 2], [2, 1], [1, 1, 1], + [42], [0, 42], [42, 0], [41, 1], + [300], [100, 200], [200, 100], [100, 100, 100], + ] + FEW_MULTIPART_LENGTHS: Sequence[Sequence[int]] = [[], [42], [41, 1]] + + VERIFY_TWEAKS = collections.OrderedDict([ + ('sig=empty', lambda sig: b''), + ('truncated sig', lambda sig: sig[:-1]), + ('sig+garbage', lambda sig: sig + b'\x00'), + ('sig[-1]^=1', lambda sig: sig[:-1] + bytes([sig[-1] ^ 1])), + ('sig[0]^=1', lambda sig: bytes([sig[0] ^ 1]) + sig[1:]), + ]) + + def gen_multipart(self, key: Key) -> Iterator[test_case.TestCase]: + """Generate test cases for multipart sign and verify.""" + for lengths in self.MANY_MULTIPART_LENGTHS: + yield self.one_multipart('sign_deterministic', key, lengths) + yield self.one_multipart('verify', key, lengths) + for descr, func in self.VERIFY_TWEAKS.items(): + for lengths in self.FEW_MULTIPART_LENGTHS: + tweak = (lambda sig, func=func: + (func(sig), 'PSA_ERROR_INVALID_SIGNATURE', descr)) + yield self.one_multipart('verify', key, lengths, + tweak_signature=tweak) + + def gen_all(self, multipart: bool = False) -> Iterator[test_case.TestCase]: + """Generate all the tests for this API.""" + yield from super().gen_all() + if multipart: + for kl in sorted(KEYS.keys()): + yield from self.gen_multipart(KEYS[kl][0]) class DispatchGenerator(DriverGenerator): """Test the driver dispatch layer."""