Merge pull request #282 from gilles-peskine-arm/mldsa-pqcp-driver-framework

Generate MLDSA test cases for the driver and dispatch layers
This commit is contained in:
Gilles Peskine
2026-04-08 15:50:19 +02:00
committed by GitHub
6 changed files with 337 additions and 195 deletions
+11 -2
View File
@@ -55,6 +55,10 @@ elif [ "$1" = "--can-mypy" ]; then
fi
echo 'Running pylint ...'
# Exclude `maintainer` subdirectories, because they can contain code
# that does not work with the versions of pylint and mypy we use on the CI.
# https://github.com/Mbed-TLS/mbedtls-framework/issues/293
#
# When we move Python code between repositories, there is a transition
# period during which code is duplicated between the old repository and
# the new repository.
@@ -64,7 +68,9 @@ echo 'Running pylint ...'
# runs of pylint: one for the A files, and one for the others.
# Remove exceptions below once the A file (or the moved code in the A file)
# has been removed from all consuming branches.
find framework/scripts scripts tests/scripts -name '*.py' \( \
find framework/scripts scripts tests/scripts \
-name maintainer -prune -o \
-name '*.py' \( \
! -path scripts/abi_check.py \
! -path scripts/code_size_compare.py \
! -path scripts/ecp_comb_table.py \
@@ -87,7 +93,10 @@ $PYTHON -m mypy framework/scripts || {
ret=1
}
$PYTHON -m mypy scripts tests/scripts || {
# Exclude `maintainer` subdirectories, because they can contain code
# that does not work with the versions of pylint and mypy we use on the CI.
# https://github.com/Mbed-TLS/mbedtls-framework/issues/293
$PYTHON -m mypy --exclude maintainer scripts tests/scripts || {
echo >&2 "mypy reported errors in the parent repository"
ret=1
}
+6 -7
View File
@@ -37,14 +37,13 @@ class read_file_lines:
except that if process(line) raises an exception, then the read_file_lines
snippet annotates the exception with the file name and line number.
"""
def __init__(self, filename: str, binary: bool = False) -> None:
def __init__(self, filename: str) -> None:
self.filename = filename
self.file = None #type: Optional[IO[str]]
self.line_number = 'entry' #type: Union[int, str]
self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
self.binary = binary
def __enter__(self) -> 'read_file_lines':
self.file = open(self.filename, 'rb' if self.binary else 'r')
self.file = open(self.filename)
self.generator = enumerate(self.file)
return self
def __iter__(self) -> Iterator[str]:
@@ -517,10 +516,10 @@ enumerate
_nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
def parse_header(self, filename: str) -> None:
"""Parse a C header file, looking for "#define PSA_xxx"."""
with read_file_lines(filename, binary=True) as lines:
for line in lines:
line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
self.parse_header_line(line)
with open(filename, 'rb') as input_:
for line in input_:
line = re.sub(self._nonascii_re, rb'', line)
self.parse_header_line(line.decode('ascii'))
_macro_identifier_re = re.compile(r'[A-Z]\w+')
def generate_undeclared_names(self, expr: str) -> Iterable[str]:
+52
View File
@@ -39,6 +39,30 @@
#include LIBTESTDRIVER1_PSA_DRIVER_INTERNAL_HEADER(psa_crypto_rsa.h)
#endif
/* This file is part of the framework and needs to be compatible with all
* maintained branches of Mbed TLS and TF-PSA-Crypto.
*
* - Until shortly before TF-PSA-Crypto 1.1.0, ML-DSA does not exist at all.
* - In TF-PSA-Crypto 1.1.0, TF_PSA_CRYPTO_PQCP_MLDSA_ENABLED exists, but
* there is no driver dispatch for it yet, so this driver doesn't need to
* worry about ML-DSA.
* - Shortly after TF-PSA-Crypto 1.1.0, in
* https://github.com/Mbed-TLS/TF-PSA-Crypto/pull/700, we introduced
* driver dispatch for ML-DSA, but the macro PSA_ALG_IS_ML_DSA is not
* in the API yet, only in a private header. Including this private header
* is a pain due to how our various build scripts set up include paths, so
* we don't do it. Instead, define PSA_ALG_IS_ML_DSA manually: it's the
* only thing we need.
* - Later we will add ML-DSA to the API, including the definition of
* PSA_ALG_IS_ML_DSA. After that we may also add driver dispatch testing
* for ML-DSA.
*/
#if !defined(PSA_ALG_IS_ML_DSA)
/* Pure ML-DSA (hedged or deterministic) */
#define PSA_ALG_IS_ML_DSA(alg) \
((alg) == 0x06004400u || (alg) == 0x06004500u)
#endif
#include <string.h>
mbedtls_test_driver_signature_hooks_t
@@ -213,6 +237,20 @@ psa_status_t mbedtls_test_transparent_signature_sign_message(
return PSA_SUCCESS;
}
#if defined(TF_PSA_CRYPTO_PQCP_MLDSA_ENABLED)
/* Pure ML-DSA is not a sign-the-hash algorithm. At the moment, this
* function only knows how to deal with sign-the-hash algorithms.
* So give up and let the next driver in the chain handle the algorithm.
* For pure ML-DSA, this will be the pqcp driver, which does not have
* a libtestdriver1 variant, meaning that we can't test "driver-only"
* builds for pure ML-DSA, but we can have ML-DSA enabled in builds that
* dispatch through the test driver.
*/
if (PSA_ALG_IS_ML_DSA(alg)) {
return PSA_ERROR_NOT_SUPPORTED;
}
#endif
#if defined(MBEDTLS_TEST_LIBTESTDRIVER1) && \
defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_HASH)
status = libtestdriver1_mbedtls_psa_hash_compute(
@@ -280,6 +318,20 @@ psa_status_t mbedtls_test_transparent_signature_verify_message(
return mbedtls_test_driver_signature_verify_hooks.forced_status;
}
#if defined(TF_PSA_CRYPTO_PQCP_MLDSA_ENABLED)
/* Pure ML-DSA is not a sign-the-hash algorithm. At the moment, this
* function only knows how to deal with sign-the-hash algorithms.
* So give up and let the next driver in the chain handle the algorithm.
* For pure ML-DSA, this will be the pqcp driver, which does not have
* a libtestdriver1 variant, meaning that we can't test "driver-only"
* builds for pure ML-DSA, but we can have ML-DSA enabled in builds that
* dispatch through the test driver.
*/
if (PSA_ALG_IS_ML_DSA(alg)) {
return PSA_ERROR_NOT_SUPPORTED;
}
#endif
#if defined(MBEDTLS_TEST_LIBTESTDRIVER1) && \
defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_HASH)
status = libtestdriver1_mbedtls_psa_hash_compute(
+6 -186
View File
@@ -1,203 +1,23 @@
#!/usr/bin/env python3
"""Generate ML-DSA test cases.
This is a transitional script that does not handle different feature sets
in different states of TF-PSA-Crypto. The live version of this script
is `scripts/maintainer/generate_mldsa_tests.py` in TF-PSA-Crypto.
"""
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
import sys
from typing import Iterable, List, Optional
# pip install dilithium-py
import dilithium_py.ml_dsa #type: ignore
import scripts_path # pylint: disable=unused-import
from mbedtls_framework import test_case
from mbedtls_framework import test_data_generation
# ML_DSA instances for pure ML-DSA
PURE = {
#44: dilithium_py.ml_dsa.ML_DSA_44,
#65: dilithium_py.ml_dsa.ML_DSA_65,
87: dilithium_py.ml_dsa.ML_DSA_87,
}
# ML_DSA instances for HashML-DSA
HASH = {
#44: dilithium_py.ml_dsa.HASH_ML_DSA_44_WITH_SHA512,
#65: dilithium_py.ml_dsa.HASH_ML_DSA_65_WITH_SHA512,
87: dilithium_py.ml_dsa.HASH_ML_DSA_87_WITH_SHA512,
}
# Seeds (i.e. private keys) to test with.
SEEDS = [
b'There was once upon a time a ...',
b'\x00' * 32,
]
class Key:
"""An MLDSA key pair."""
#pylint: disable=too-few-public-methods
def __init__(self, kl: int, seed: bytes) -> None:
self.kl = kl #pylint: disable=invalid-name
self.seed = seed
self.public, self.secret = PURE[kl]._keygen_internal(seed)
def sign_message(self, message: bytes, deterministic: bool) -> bytes:
PURE[self.kl].set_drbg_seed(bytes(48))
return PURE[self.kl].sign(self.secret, message,
deterministic=deterministic)
# Key pairs to test with.
KEYS = {kl: [Key(kl, seed) for seed in SEEDS]
for kl in sorted(PURE.keys())}
# Input messages to test with.
MESSAGES = [
(b'This is a test', ''),
(b'', 'empty message'),
(b'\x00', '"\\x00"'),
(b'\x01', '"\\x01"'),
(b'ACBDEFGHIJ' * 100, '1000B'),
]
class API:
"""Abstract base class for the interface of the test functions."""
@classmethod
def function(cls, func: str, kl: int) -> str:
raise NotImplementedError
@classmethod
def metadata_arguments(cls,
kl: int,
pair: bool,
deterministic: Optional[bool]) -> List[str]:
raise NotImplementedError
@classmethod
def final_arguments(cls) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return True
class PQCPAPI(API):
"""Test mldsa-native entry points."""
@classmethod
def function(cls, func: str, kl: int) -> str:
return f'{func}_{kl}'
@classmethod
def metadata_arguments(cls,
_kl: int,
_pair: bool,
_deterministic: Optional[bool]) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return False
def one_mldsa_key_pair_from_seed(key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for mldsa-native keypair_internal()."""
tc = test_case.TestCase()
tc.set_function(f'key_pair_from_seed_{key.kl}')
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments([
test_case.hex_string(key.seed),
test_case.hex_string(key.secret),
test_case.hex_string(key.public),
])
tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}')
return tc
def gen_pqcp_key_management(kl: int) -> Iterable[test_case.TestCase]:
"""Generate test cases for mldsa-native keypair_internal()."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_key_pair_from_seed(key, f'key#{i}')
def one_mldsa_sign_deterministic_pure(api: API,
key: Key,
message: bytes,
descr: str) -> test_case.TestCase:
"""Construct one test case for deterministic signature."""
signature = key.sign_message(message, deterministic=True)
tc = test_case.TestCase()
tc.set_function(api.function('sign_deterministic_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, True, True) + [
test_case.hex_string(key.seed if api.secret_is_seed() else key.secret),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
tc.set_description(f'MLDSA-{key.kl} sign deterministic {descr}')
return tc
def one_mldsa_verify_pure(api: API,
key: Key,
message: bytes,
deterministic: bool,
descr: str) -> test_case.TestCase:
"""Construct one test case for verification.
When deterministic is true, the test case is a deterministic signature.
When deterministic is false, the test case is some other valid signature.
"""
signature = key.sign_message(message, deterministic=deterministic)
tc = test_case.TestCase()
tc.set_function(api.function('verify_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, False, True) + [
test_case.hex_string(key.public),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
variant = "deterministic" if deterministic else "randomized"
tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}')
return tc
def gen_mldsa_pure(api: API, kl: int) -> Iterable[test_case.TestCase]:
"""Generate all test cases for pure ML-DSA signature and verification."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_sign_deterministic_pure(api, key, MESSAGES[0][0],
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_sign_deterministic_pure(api, KEYS[kl][0], message,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], True,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, True,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], False,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, False,
f'key#1 {descr}')
def gen_pqcp_mldsa_all() -> Iterable[test_case.TestCase]:
"""Generate all test cases for mldsa-native."""
api = PQCPAPI()
for kl in sorted(KEYS.keys()):
yield from gen_pqcp_key_management(kl)
yield from gen_mldsa_pure(api, kl)
from mbedtls_maintainer import mldsa_test_generator
class MLDSATestGenerator(test_data_generation.TestGenerator):
"""Generate test cases for ML-DSA."""
def __init__(self, settings) -> None:
self.targets = {
'test_suite_pqcp_mldsa.dilithium_py': gen_pqcp_mldsa_all,
'test_suite_pqcp_mldsa.dilithium_py': mldsa_test_generator.gen_pqcp_mldsa_all,
}
super().__init__(settings)
+3
View File
@@ -0,0 +1,3 @@
# This file needs to exist to make mbedtls_maintainer a package.
# Among other things, this allows modules in this directory to make
# relative imports.
@@ -0,0 +1,259 @@
"""Generate ML-DSA test cases.
"""
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
from typing import Iterator, List, Optional
# pip install dilithium-py
import dilithium_py.ml_dsa #type: ignore
import scripts_path # pylint: disable=unused-import
from mbedtls_framework import test_case
# ML_DSA instances for pure ML-DSA
PURE = {
#44: dilithium_py.ml_dsa.ML_DSA_44,
#65: dilithium_py.ml_dsa.ML_DSA_65,
87: dilithium_py.ml_dsa.ML_DSA_87,
}
# ML_DSA instances for HashML-DSA
HASH = {
#44: dilithium_py.ml_dsa.HASH_ML_DSA_44_WITH_SHA512,
#65: dilithium_py.ml_dsa.HASH_ML_DSA_65_WITH_SHA512,
87: dilithium_py.ml_dsa.HASH_ML_DSA_87_WITH_SHA512,
}
# Seeds (i.e. private keys) to test with.
SEEDS = [
b'There was once upon a time a ...',
b'\x00' * 32,
]
class Key:
"""An MLDSA key pair."""
#pylint: disable=too-few-public-methods
def __init__(self, kl: int, seed: bytes) -> None:
self.kl = kl #pylint: disable=invalid-name
self.seed = seed
self.public, self.secret = PURE[kl]._keygen_internal(seed)
def sign_message(self, message: bytes, deterministic: bool) -> bytes:
PURE[self.kl].set_drbg_seed(bytes(48))
return PURE[self.kl].sign(self.secret, message,
deterministic=deterministic)
# Key pairs to test with.
KEYS = {kl: [Key(kl, seed) for seed in SEEDS]
for kl in sorted(PURE.keys())}
# Input messages to test with.
MESSAGES = [
(b'This is a test', ''),
(b'', 'empty message'),
(b'\x00', '"\\x00"'),
(b'\x01', '"\\x01"'),
(b'ACBDEFGHIJ' * 100, '1000B'),
]
class Generator:
"""Abstract base class to generate tests for one API."""
@classmethod
def function(cls, func: str, kl: int) -> str:
raise NotImplementedError
@classmethod
def metadata_arguments(cls,
kl: int,
pair: bool,
deterministic: Optional[bool]) -> List[str]:
raise NotImplementedError
@classmethod
def final_arguments(cls) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return True
def one_mldsa_public_key_from_seed(self, key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for driver export_public_key()."""
tc = test_case.TestCase()
tc.set_function('export_public_key')
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(self.metadata_arguments(key.kl, True, None) + [
test_case.hex_string(key.seed),
test_case.hex_string(key.public),
] + self.final_arguments())
tc.set_description(f'MLDSA-{key.kl} export public key from seed {descr}')
return tc
def one_mldsa_sign_deterministic_pure(self,
key: Key,
message: bytes,
descr: str) -> test_case.TestCase:
"""Construct one test case for deterministic signature."""
signature = key.sign_message(message, deterministic=True)
tc = test_case.TestCase()
tc.set_function(self.function('sign_message_deterministic', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(self.metadata_arguments(key.kl, True, True) + [
test_case.hex_string(key.seed if self.secret_is_seed() else key.secret),
test_case.hex_string(message),
test_case.hex_string(signature),
] + self.final_arguments())
tc.set_description(f'MLDSA-{key.kl} sign deterministic {descr}')
return tc
def one_mldsa_verify_pure(self,
key: Key,
message: bytes,
deterministic: bool,
descr: str) -> test_case.TestCase:
"""Construct one test case for verification.
When deterministic is true, the test case is a deterministic signature.
When deterministic is false, the test case is some other valid signature.
"""
signature = key.sign_message(message, deterministic=deterministic)
tc = test_case.TestCase()
tc.set_function(self.function('verify_message', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(self.metadata_arguments(key.kl, False, True) + [
test_case.hex_string(key.public),
test_case.hex_string(message),
test_case.hex_string(signature),
] + self.final_arguments())
variant = "deterministic" if deterministic else "randomized"
tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}')
return tc
def gen_mldsa_pure(self, kl: int) -> Iterator[test_case.TestCase]:
"""Generate all test cases for pure ML-DSA signature and verification."""
for i, key in enumerate(KEYS[kl], 1):
yield self.one_mldsa_sign_deterministic_pure(key, MESSAGES[0][0],
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield self.one_mldsa_sign_deterministic_pure(KEYS[kl][0], message,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield self.one_mldsa_verify_pure(key, MESSAGES[0][0], True,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield self.one_mldsa_verify_pure(KEYS[kl][0], message, True,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield self.one_mldsa_verify_pure(key, MESSAGES[0][0], False,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield self.one_mldsa_verify_pure(KEYS[kl][0], message, False,
f'key#1 {descr}')
def gen_key_management(self, kl: int) -> Iterator[test_case.TestCase]:
"""Generate all key management test cases for the given parameter set."""
raise NotImplementedError
def gen_all(self) -> Iterator[test_case.TestCase]:
"""Generate all the tests for this API."""
for kl in sorted(KEYS.keys()):
yield from self.gen_key_management(kl)
yield from self.gen_mldsa_pure(kl)
class PQCPGenerator(Generator):
"""Test mldsa-native entry points."""
@classmethod
def function(cls, func: str, kl: int) -> str:
if func == 'verify_message':
func = 'verify_pure'
elif func == 'sign_message_deterministic':
func = 'sign_deterministic_pure'
return f'{func}_{kl}'
@classmethod
def metadata_arguments(cls,
_kl: int,
_pair: bool,
_deterministic: Optional[bool]) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return False
@staticmethod
def one_mldsa_key_pair_from_seed(key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for mldsa-native keypair_internal()."""
tc = test_case.TestCase()
tc.set_function(f'key_pair_from_seed_{key.kl}')
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments([
test_case.hex_string(key.seed),
test_case.hex_string(key.secret),
test_case.hex_string(key.public),
])
tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}')
return tc
def gen_key_management(self, kl: int) -> Iterator[test_case.TestCase]:
"""Generate test cases for mldsa-native keypair_internal()."""
for i, key in enumerate(KEYS[kl], 1):
yield self.one_mldsa_key_pair_from_seed(key, f'key#{i}')
def gen_pqcp_mldsa_all() -> Iterator[test_case.TestCase]:
"""Generate all test cases for mldsa-native."""
generator = PQCPGenerator()
yield from generator.gen_all()
class DriverGenerator(Generator):
"""Test driver entry points."""
@classmethod
def function(cls, func: str, _kl: int) -> str:
if func == 'verify_message':
func = 'verify_pure'
elif func == 'sign_message_deterministic':
func = 'sign_deterministic_pure'
return func
@classmethod
def metadata_arguments(cls,
kl: int,
pair: bool,
deterministic: Optional[bool]) -> List[str]:
arguments = []
arguments.append('PSA_KEY_TYPE_ML_DSA_KEY_PAIR' if pair else
'PSA_KEY_TYPE_ML_DSA_PUBLIC_KEY')
arguments.append(str(kl))
if deterministic is not None:
arguments.append('PSA_ALG_DETERMINISTIC_ML_DSA' if deterministic else
'PSA_ALG_ML_DSA')
return arguments
@classmethod
def final_arguments(cls) -> List[str]:
return ['PSA_SUCCESS']
def gen_key_management(self, kl: int) -> Iterator[test_case.TestCase]:
"""Generate test cases for driver export_public_key()."""
for i, key in enumerate(KEYS[kl], 1):
yield self.one_mldsa_public_key_from_seed(key, f'key#{i}')
class DispatchGenerator(DriverGenerator):
"""Test the driver dispatch layer."""
@classmethod
def function(cls, func: str, _kl: int) -> str:
return func