Merge pull request #300 from gilles-peskine-arm/mldsa-sign-multipart-dispatch-framework

Generate multipart ML-DSA tests
This commit is contained in:
Bence Szépkúti
2026-06-03 15:07:48 +02:00
committed by GitHub
+104 -3
View File
@@ -4,7 +4,9 @@
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
from typing import Iterator, List, Optional
import collections
import functools
from typing import Callable, Iterator, List, Optional, Sequence, Tuple
# pip install dilithium-py
import dilithium_py.ml_dsa #type: ignore
@@ -41,6 +43,7 @@ class Key:
self.seed = seed
self.public, self.secret = PURE[kl]._keygen_internal(seed)
@functools.lru_cache(maxsize=9999)
def sign_message(self, message: bytes, deterministic: bool) -> bytes:
PURE[self.kl].set_drbg_seed(bytes(48))
return PURE[self.kl].sign(self.secret, message,
@@ -82,6 +85,35 @@ class Generator:
def secret_is_seed(cls) -> bool:
return True
@staticmethod
def message_for_length(length: int) -> bytes:
# b'ABCDE...' wrapping around and repeating every 256 bytes
return bytes(b & 0xff for b in range(65, 65 + length))
def chunks_for_lengths(self,
lengths: Sequence[int],
arity: Optional[int] = None,
) -> Tuple[bytes, List[bytes]]:
"""Construct a message split in chunks of the given lengths.
The content of the message only depends on the total length.
If `arity` is specified and less than `len(lengths)`, pad the list
of chunks to that number. If `arity` is specified and larger than
`len(lengths)`, raise an exception.
"""
total_length = sum(lengths)
message = self.message_for_length(total_length)
chunks = []
offset = 0
for n in lengths:
chunks.append(message[offset:offset + n])
offset += n
if arity is not None:
assert len(lengths) <= arity
chunks += [b''] * (arity - len(lengths))
return (message, chunks)
def one_mldsa_public_key_from_seed(self, key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for driver export_public_key()."""
@@ -242,14 +274,83 @@ class DriverGenerator(Generator):
return arguments
@classmethod
def final_arguments(cls) -> List[str]:
return ['PSA_SUCCESS']
def final_arguments(cls, status: str = 'PSA_SUCCESS') -> List[str]:
return [status]
def gen_key_management(self, kl: int) -> Iterator[test_case.TestCase]:
"""Generate test cases for driver export_public_key()."""
for i, key in enumerate(KEYS[kl], 1):
yield self.one_mldsa_public_key_from_seed(key, f'key#{i}')
MULTIPART_ARITY = 3
def one_multipart(self, function: str, key: Key,
lengths: Sequence[int],
tweak_signature: Optional[Callable[[bytes], Tuple[bytes, str, str]]] = None,
) -> test_case.TestCase:
"""Construct one test case for a multipart operation.
The number of message chunks must be at most MULTIPART_ARITY.
"""
message, chunks = self.chunks_for_lengths(lengths, self.MULTIPART_ARITY)
descr = '+'.join(map(str, lengths)) if lengths else 'empty (no update)'
type_is_pair = not function.startswith('verif')
deterministic = True
actual_signature = key.sign_message(message, deterministic=deterministic)
signature = actual_signature
status = 'PSA_SUCCESS'
more_descr = ''
if tweak_signature is not None:
signature, status, more_descr = tweak_signature(actual_signature)
if more_descr:
more_descr = ', ' + more_descr
tc = test_case.TestCase()
tc.set_function(self.function(function + '_multipart', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(self.metadata_arguments(key.kl, type_is_pair, True) + [
test_case.hex_string(key.seed if type_is_pair else key.public),
str(len(lengths)),
] + [test_case.hex_string(chunk) for chunk in chunks] + [
test_case.hex_string(signature),
] + self.final_arguments(status=status))
tc.set_description(f'MLDSA-{key.kl} {function} multipart {descr}{more_descr}')
return tc
MANY_MULTIPART_LENGTHS: Sequence[Sequence[int]] = [
[], [0], [0, 0],
[1],
[3], [1, 2], [2, 1], [1, 1, 1],
[42], [0, 42], [42, 0], [41, 1],
[300], [100, 200], [200, 100], [100, 100, 100],
]
FEW_MULTIPART_LENGTHS: Sequence[Sequence[int]] = [[], [42], [41, 1]]
VERIFY_TWEAKS = collections.OrderedDict([
('sig=empty', lambda sig: b''),
('truncated sig', lambda sig: sig[:-1]),
('sig+garbage', lambda sig: sig + b'\x00'),
('sig[-1]^=1', lambda sig: sig[:-1] + bytes([sig[-1] ^ 1])),
('sig[0]^=1', lambda sig: bytes([sig[0] ^ 1]) + sig[1:]),
])
def gen_multipart(self, key: Key) -> Iterator[test_case.TestCase]:
"""Generate test cases for multipart sign and verify."""
for lengths in self.MANY_MULTIPART_LENGTHS:
yield self.one_multipart('sign_deterministic', key, lengths)
yield self.one_multipart('verify', key, lengths)
for descr, func in self.VERIFY_TWEAKS.items():
for lengths in self.FEW_MULTIPART_LENGTHS:
tweak = (lambda sig, func=func:
(func(sig), 'PSA_ERROR_INVALID_SIGNATURE', descr))
yield self.one_multipart('verify', key, lengths,
tweak_signature=tweak)
def gen_all(self, multipart: bool = False) -> Iterator[test_case.TestCase]:
"""Generate all the tests for this API."""
yield from super().gen_all()
if multipart:
for kl in sorted(KEYS.keys()):
yield from self.gen_multipart(KEYS[kl][0])
class DispatchGenerator(DriverGenerator):
"""Test the driver dispatch layer."""