mirror of
https://github.com/Mbed-TLS/mbedtls-framework.git
synced 2026-06-05 21:15:09 +00:00
Merge pull request #300 from gilles-peskine-arm/mldsa-sign-multipart-dispatch-framework
Generate multipart ML-DSA tests
This commit is contained in:
@@ -4,7 +4,9 @@
|
|||||||
# Copyright The Mbed TLS Contributors
|
# Copyright The Mbed TLS Contributors
|
||||||
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
|
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
|
||||||
|
|
||||||
from typing import Iterator, List, Optional
|
import collections
|
||||||
|
import functools
|
||||||
|
from typing import Callable, Iterator, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
# pip install dilithium-py
|
# pip install dilithium-py
|
||||||
import dilithium_py.ml_dsa #type: ignore
|
import dilithium_py.ml_dsa #type: ignore
|
||||||
@@ -41,6 +43,7 @@ class Key:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.public, self.secret = PURE[kl]._keygen_internal(seed)
|
self.public, self.secret = PURE[kl]._keygen_internal(seed)
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=9999)
|
||||||
def sign_message(self, message: bytes, deterministic: bool) -> bytes:
|
def sign_message(self, message: bytes, deterministic: bool) -> bytes:
|
||||||
PURE[self.kl].set_drbg_seed(bytes(48))
|
PURE[self.kl].set_drbg_seed(bytes(48))
|
||||||
return PURE[self.kl].sign(self.secret, message,
|
return PURE[self.kl].sign(self.secret, message,
|
||||||
@@ -82,6 +85,35 @@ class Generator:
|
|||||||
def secret_is_seed(cls) -> bool:
|
def secret_is_seed(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def message_for_length(length: int) -> bytes:
|
||||||
|
# b'ABCDE...' wrapping around and repeating every 256 bytes
|
||||||
|
return bytes(b & 0xff for b in range(65, 65 + length))
|
||||||
|
|
||||||
|
def chunks_for_lengths(self,
|
||||||
|
lengths: Sequence[int],
|
||||||
|
arity: Optional[int] = None,
|
||||||
|
) -> Tuple[bytes, List[bytes]]:
|
||||||
|
"""Construct a message split in chunks of the given lengths.
|
||||||
|
|
||||||
|
The content of the message only depends on the total length.
|
||||||
|
|
||||||
|
If `arity` is specified and less than `len(lengths)`, pad the list
|
||||||
|
of chunks to that number. If `arity` is specified and larger than
|
||||||
|
`len(lengths)`, raise an exception.
|
||||||
|
"""
|
||||||
|
total_length = sum(lengths)
|
||||||
|
message = self.message_for_length(total_length)
|
||||||
|
chunks = []
|
||||||
|
offset = 0
|
||||||
|
for n in lengths:
|
||||||
|
chunks.append(message[offset:offset + n])
|
||||||
|
offset += n
|
||||||
|
if arity is not None:
|
||||||
|
assert len(lengths) <= arity
|
||||||
|
chunks += [b''] * (arity - len(lengths))
|
||||||
|
return (message, chunks)
|
||||||
|
|
||||||
def one_mldsa_public_key_from_seed(self, key: Key,
|
def one_mldsa_public_key_from_seed(self, key: Key,
|
||||||
descr: str) -> test_case.TestCase:
|
descr: str) -> test_case.TestCase:
|
||||||
"""Construct one test case for driver export_public_key()."""
|
"""Construct one test case for driver export_public_key()."""
|
||||||
@@ -242,14 +274,83 @@ class DriverGenerator(Generator):
|
|||||||
return arguments
|
return arguments
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def final_arguments(cls) -> List[str]:
|
def final_arguments(cls, status: str = 'PSA_SUCCESS') -> List[str]:
|
||||||
return ['PSA_SUCCESS']
|
return [status]
|
||||||
|
|
||||||
def gen_key_management(self, kl: int) -> Iterator[test_case.TestCase]:
|
def gen_key_management(self, kl: int) -> Iterator[test_case.TestCase]:
|
||||||
"""Generate test cases for driver export_public_key()."""
|
"""Generate test cases for driver export_public_key()."""
|
||||||
for i, key in enumerate(KEYS[kl], 1):
|
for i, key in enumerate(KEYS[kl], 1):
|
||||||
yield self.one_mldsa_public_key_from_seed(key, f'key#{i}')
|
yield self.one_mldsa_public_key_from_seed(key, f'key#{i}')
|
||||||
|
|
||||||
|
MULTIPART_ARITY = 3
|
||||||
|
|
||||||
|
def one_multipart(self, function: str, key: Key,
|
||||||
|
lengths: Sequence[int],
|
||||||
|
tweak_signature: Optional[Callable[[bytes], Tuple[bytes, str, str]]] = None,
|
||||||
|
) -> test_case.TestCase:
|
||||||
|
"""Construct one test case for a multipart operation.
|
||||||
|
|
||||||
|
The number of message chunks must be at most MULTIPART_ARITY.
|
||||||
|
"""
|
||||||
|
message, chunks = self.chunks_for_lengths(lengths, self.MULTIPART_ARITY)
|
||||||
|
descr = '+'.join(map(str, lengths)) if lengths else 'empty (no update)'
|
||||||
|
type_is_pair = not function.startswith('verif')
|
||||||
|
deterministic = True
|
||||||
|
actual_signature = key.sign_message(message, deterministic=deterministic)
|
||||||
|
signature = actual_signature
|
||||||
|
status = 'PSA_SUCCESS'
|
||||||
|
more_descr = ''
|
||||||
|
if tweak_signature is not None:
|
||||||
|
signature, status, more_descr = tweak_signature(actual_signature)
|
||||||
|
if more_descr:
|
||||||
|
more_descr = ', ' + more_descr
|
||||||
|
tc = test_case.TestCase()
|
||||||
|
tc.set_function(self.function(function + '_multipart', key.kl))
|
||||||
|
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
|
||||||
|
tc.set_arguments(self.metadata_arguments(key.kl, type_is_pair, True) + [
|
||||||
|
test_case.hex_string(key.seed if type_is_pair else key.public),
|
||||||
|
str(len(lengths)),
|
||||||
|
] + [test_case.hex_string(chunk) for chunk in chunks] + [
|
||||||
|
test_case.hex_string(signature),
|
||||||
|
] + self.final_arguments(status=status))
|
||||||
|
tc.set_description(f'MLDSA-{key.kl} {function} multipart {descr}{more_descr}')
|
||||||
|
return tc
|
||||||
|
|
||||||
|
MANY_MULTIPART_LENGTHS: Sequence[Sequence[int]] = [
|
||||||
|
[], [0], [0, 0],
|
||||||
|
[1],
|
||||||
|
[3], [1, 2], [2, 1], [1, 1, 1],
|
||||||
|
[42], [0, 42], [42, 0], [41, 1],
|
||||||
|
[300], [100, 200], [200, 100], [100, 100, 100],
|
||||||
|
]
|
||||||
|
FEW_MULTIPART_LENGTHS: Sequence[Sequence[int]] = [[], [42], [41, 1]]
|
||||||
|
|
||||||
|
VERIFY_TWEAKS = collections.OrderedDict([
|
||||||
|
('sig=empty', lambda sig: b''),
|
||||||
|
('truncated sig', lambda sig: sig[:-1]),
|
||||||
|
('sig+garbage', lambda sig: sig + b'\x00'),
|
||||||
|
('sig[-1]^=1', lambda sig: sig[:-1] + bytes([sig[-1] ^ 1])),
|
||||||
|
('sig[0]^=1', lambda sig: bytes([sig[0] ^ 1]) + sig[1:]),
|
||||||
|
])
|
||||||
|
|
||||||
|
def gen_multipart(self, key: Key) -> Iterator[test_case.TestCase]:
|
||||||
|
"""Generate test cases for multipart sign and verify."""
|
||||||
|
for lengths in self.MANY_MULTIPART_LENGTHS:
|
||||||
|
yield self.one_multipart('sign_deterministic', key, lengths)
|
||||||
|
yield self.one_multipart('verify', key, lengths)
|
||||||
|
for descr, func in self.VERIFY_TWEAKS.items():
|
||||||
|
for lengths in self.FEW_MULTIPART_LENGTHS:
|
||||||
|
tweak = (lambda sig, func=func:
|
||||||
|
(func(sig), 'PSA_ERROR_INVALID_SIGNATURE', descr))
|
||||||
|
yield self.one_multipart('verify', key, lengths,
|
||||||
|
tweak_signature=tweak)
|
||||||
|
|
||||||
|
def gen_all(self, multipart: bool = False) -> Iterator[test_case.TestCase]:
|
||||||
|
"""Generate all the tests for this API."""
|
||||||
|
yield from super().gen_all()
|
||||||
|
if multipart:
|
||||||
|
for kl in sorted(KEYS.keys()):
|
||||||
|
yield from self.gen_multipart(KEYS[kl][0])
|
||||||
|
|
||||||
class DispatchGenerator(DriverGenerator):
|
class DispatchGenerator(DriverGenerator):
|
||||||
"""Test the driver dispatch layer."""
|
"""Test the driver dispatch layer."""
|
||||||
|
|||||||
Reference in New Issue
Block a user