Merge pull request #278 from gilles-peskine-arm/generate_mldsa_tests-create

Support committed generated test data and generate PQCP test data
This commit is contained in:
Valerio Setti
2026-02-27 11:12:44 +01:00
committed by GitHub
7 changed files with 476 additions and 14 deletions
+22 -2
View File
@@ -6,9 +6,10 @@
import glob
import os
import re
from typing import FrozenSet, Iterable, Iterator
from typing import FrozenSet, Iterable, Iterator, List
from . import build_tree
from . import generate_files_helper
class ConfigMacros:
@@ -34,7 +35,7 @@ class ConfigMacros:
for line in input_)
class Current(ConfigMacros):
class Current(ConfigMacros, generate_files_helper.Generator):
"""Information about config-like macros parsed from the source code."""
_SHADOW_FILE = 'scripts/data_files/config-options-current.txt'
@@ -136,6 +137,25 @@ class Current(ConfigMacros):
for name in sorted(self.live_config_options()):
out.write(name + '\n')
# Implement the generate_files_helper.Generator interface
def generator_name(self) -> str:
"""Name as a generate_files_helper.Generator."""
return 'options'
def target_files(self) -> List[str]:
"""List the (single) generated file name."""
return [os.path.join(self._submodule, self._SHADOW_FILE)]
def outdated_files(self) -> List[str]:
"""List the (single) generated file name if it is out of date."""
if self.is_shadow_file_up_to_date():
return []
else:
return self.target_files()
def update(self, always: bool) -> None:
"""Update the shadow file from the live config file."""
self.update_shadow_file(always)
class History(ConfigMacros):
@@ -0,0 +1,182 @@
"""Utilities for intermediate files that are generated, but platform-independent
and configuration-independent.
"""
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
import argparse
import os
import subprocess
import sys
from typing import Dict, Iterable, List, Sequence, Set
class Generator:
"""An abstract base class for generators of intermediate files."""
def generator_name(self) -> str:
"""A name for this generator.
Generator names must be unique and should not be identical to
the name of any target.
"""
raise NotImplementedError
def target_files(self) -> List[str]:
"""The list of files targeted by this generator.
File names are relative to the project root.
"""
raise NotImplementedError
def outdated_files(self) -> Iterable[str]:
"""Return the list of targets that are out of date.
This is empty after running update().
Missing targets are considered out of date.
"""
raise NotImplementedError
def update(self, always: bool) -> None:
"""Update the target(s) of this generator.
If always is false, avoid changing the output file if it already has
the desired content. If always is true, make sure to update the
time stamp on the output file even if it already has the desired content.
"""
raise NotImplementedError
class TestDataGenerator(Generator):
"""A test data generator script.
Even though the test data generator scripts are written in Python, we
run them as a separate process, because their output depends on the
program name (they write sys.argv[0] in a comment in the .data file).
"""
def __init__(self, script: str) -> None:
"""Run the specified test generator to generate files.
Assume that the script is written in Python and has the command line
interface of test_data_generation.py.
"""
self.script = script
def generator_name(self) -> str:
return os.path.basename(self.script)
def target_files(self) -> List[str]:
output = subprocess.check_output([sys.executable, self.script, '--list'],
encoding='utf-8')
return output.splitlines()
def outdated_files(self) -> List[str]:
output = subprocess.check_output([sys.executable, self.script, '--list-outdated'],
encoding='utf-8')
return output.splitlines()
def update(self, _always) -> None:
subprocess.check_call([sys.executable, self.script])
def assemble(available: Iterable[Generator]) -> Dict[str, Generator]:
"""Assemble the generators into a dictionary with both names and targets as keys."""
by_ident = {} #type: Dict[str, Generator]
for generator in available:
ident = generator.generator_name()
if ident in by_ident:
raise Exception(f'Generator conflict: name "{ident}" of {generator} '
f'already recorded for {by_ident[ident]}')
by_ident[ident] = generator
for ident in generator.target_files():
if ident in by_ident:
raise Exception(f'Generator conflict: target "{ident}" of {generator} '
f'already recorded for {by_ident[ident]}')
by_ident[ident] = generator
return by_ident
def list_names(available: Iterable[Generator]) -> List[str]:
"""Return the list of generator names."""
return sorted(generator.generator_name() for generator in available)
def list_targets(available: Iterable[Generator]) -> List[str]:
"""Return the list of generator targets."""
return sorted(target
for generator in available
for target in generator.target_files())
def select(available: Dict[str, Generator],
wanted: Iterable[str]) -> List[Generator]:
"""Select generators by name or target."""
wanted_names = set() #type: Set[str]
for ident in wanted:
if ident not in available:
raise Exception(f'No generator found for {ident}')
wanted_names.add(ident)
return [available[name] for name in sorted(wanted_names)]
def main(generators: Sequence[Generator],
description: str) -> None:
#pylint: disable=too-many-branches
"""Command line entry point.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--always-update', '-U',
action='store_true',
help=('Update target files unconditionally '
'(overrides --update)'))
parser.add_argument('--list',
action='store_true',
help='List generator names and targets and exit')
parser.add_argument('--list-names',
action='store_true',
help='List generator names and exit')
parser.add_argument('--list-targets',
action='store_true',
help='List generator targets and exit')
parser.add_argument('--update', '-u',
action='store_true',
help='Update target files if needed')
parser.add_argument('--verbose', '-v',
action='store_true',
help='Be more verbose')
parser.add_argument('idents', nargs='*', metavar='NAME|TARGET',
help='List of generator names or targets (all targets if empty)')
args = parser.parse_args()
if args.list:
args.list_names = True
args.list_targets = True
if args.list_names:
for name in list_names(generators):
print(name)
if args.list_targets:
for target in list_targets(generators):
print(target)
if args.list_names or args.list_targets:
return
if args.idents:
available = assemble(generators)
wanted = select(available, args.idents) #type: Sequence[Generator]
else:
wanted = generators
if args.update or args.always_update:
for generator in wanted:
if args.verbose:
sys.stderr.write(f'Running generator {generator.generator_name()}...\n')
generator.update(args.always_update)
else:
outdated = [] #type: List[str]
for generator in wanted:
if args.verbose:
sys.stderr.write(f'Checking targets of generator {generator.generator_name()}...\n')
outdated += generator.outdated_files()
if outdated:
sys.stderr.write(f'Some targets are missing or out of date.\n')
for target in outdated:
print(target)
sys.stderr.write(f'Run {sys.argv[0]} -u and commit the result.')
sys.exit(1)
+13 -7
View File
@@ -127,6 +127,18 @@ class TestCase:
out.write(prefix + self.function + ':' +
':'.join(self.arguments) + '\n')
def write_data_stream(out,
test_cases: Iterable[TestCase],
caller: Optional[str] = None) -> None:
"""Write the test cases to the specified output stream."""
if caller is None:
caller = os.path.basename(sys.argv[0])
out.write('# Automatically generated by {}. Do not edit!\n'
.format(caller))
for tc in test_cases:
tc.write(out)
out.write('\n# End of automatically generated file.\n')
def write_data_file(filename: str,
test_cases: Iterable[TestCase],
caller: Optional[str] = None) -> None:
@@ -134,15 +146,9 @@ def write_data_file(filename: str,
If the file already exists, it is overwritten.
"""
if caller is None:
caller = os.path.basename(sys.argv[0])
tempfile = filename + '.new'
with open(tempfile, 'w') as out:
out.write('# Automatically generated by {}. Do not edit!\n'
.format(caller))
for tc in test_cases:
tc.write(out)
out.write('\n# End of automatically generated file.\n')
write_data_stream(out, test_cases, caller)
os.replace(tempfile, filename)
def psa_or_3_6_feature_macro(psa_name: str,
@@ -11,6 +11,7 @@ These are used both by generate_psa_tests.py and generate_bignum_tests.py.
#
import argparse
import io
import os
import posixpath
import re
@@ -139,6 +140,11 @@ class BaseTarget:
class TestGenerator:
"""Generate test cases and write to data files."""
# Note that targets whose names contain 'test_format' have their content
# validated by `abi_check.py`.
targets = {} # type: Dict[str, Callable[..., Iterable[test_case.TestCase]]]
def __init__(self, options) -> None:
self.test_suite_directory = options.directory
# Update `targets` with an entry for each child class of BaseTarget.
@@ -163,10 +169,6 @@ class TestGenerator:
filename = self.filename_for(basename)
test_case.write_data_file(filename, test_cases)
# Note that targets whose names contain 'test_format' have their content
# validated by `abi_check.py`.
targets = {} # type: Dict[str, Callable[..., Iterable[test_case.TestCase]]]
def generate_target(self, name: str, *target_args) -> None:
"""Generate cases and write to data file for a target.
@@ -176,6 +178,22 @@ class TestGenerator:
test_cases = self.targets[name](*target_args)
self.write_test_data_file(name, test_cases)
def is_up_to_date(self, target) -> bool:
"""Check if the given target already has the expected content."""
filename = self.filename_for(target)
if not os.path.exists(filename):
return False
test_cases = self.targets[target]()
out = io.StringIO()
test_case.write_data_stream(out, test_cases)
out.seek(0)
new_content = out.read()
out.close()
with open(filename) as current_file:
old_content = current_file.read()
return new_content == old_content
def main(args, description: str, generator_class: Type[TestGenerator] = TestGenerator):
"""Command line entry point."""
parser = argparse.ArgumentParser(description=description)
@@ -183,6 +201,9 @@ def main(args, description: str, generator_class: Type[TestGenerator] = TestGene
help='List available targets and exit')
parser.add_argument('--list-for-cmake', action='store_true',
help='Print \';\'-separated list of available targets and exit')
parser.add_argument('--list-outdated', action='store_true',
help=('List outdated targets and exit '
'(succeeds even if there are outdated or missing targets)'))
# If specified explicitly, this option may be a path relative to the
# current directory when the script is invoked. The default value
# is relative to the mbedtls root, which we don't know yet. So we
@@ -221,4 +242,8 @@ def main(args, description: str, generator_class: Type[TestGenerator] = TestGene
else:
options.targets = sorted(generator.targets)
for target in options.targets:
generator.generate_target(target)
if options.list_outdated:
if not generator.is_up_to_date(target):
print(generator.filename_for(target))
else:
generator.generate_target(target)
+206
View File
@@ -0,0 +1,206 @@
#!/usr/bin/env python3
"""Generate ML-DSA test cases.
"""
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
import sys
from typing import Iterable, List, Optional
# pip install dilithium-py
import dilithium_py.ml_dsa #type: ignore
import scripts_path # pylint: disable=unused-import
from mbedtls_framework import test_case
from mbedtls_framework import test_data_generation
# ML_DSA instances for pure ML-DSA
PURE = {
#44: dilithium_py.ml_dsa.ML_DSA_44,
#65: dilithium_py.ml_dsa.ML_DSA_65,
87: dilithium_py.ml_dsa.ML_DSA_87,
}
# ML_DSA instances for HashML-DSA
HASH = {
#44: dilithium_py.ml_dsa.HASH_ML_DSA_44_WITH_SHA512,
#65: dilithium_py.ml_dsa.HASH_ML_DSA_65_WITH_SHA512,
87: dilithium_py.ml_dsa.HASH_ML_DSA_87_WITH_SHA512,
}
# Seeds (i.e. private keys) to test with.
SEEDS = [
b'There was once upon a time a ...',
b'\x00' * 32,
]
class Key:
"""An MLDSA key pair."""
#pylint: disable=too-few-public-methods
def __init__(self, kl: int, seed: bytes) -> None:
self.kl = kl #pylint: disable=invalid-name
self.seed = seed
self.public, self.secret = PURE[kl]._keygen_internal(seed)
def sign_message(self, message: bytes, deterministic: bool) -> bytes:
PURE[self.kl].set_drbg_seed(bytes(48))
return PURE[self.kl].sign(self.secret, message,
deterministic=deterministic)
# Key pairs to test with.
KEYS = {kl: [Key(kl, seed) for seed in SEEDS]
for kl in sorted(PURE.keys())}
# Input messages to test with.
MESSAGES = [
(b'This is a test', ''),
(b'', 'empty message'),
(b'\x00', '"\\x00"'),
(b'\x01', '"\\x01"'),
(b'ACBDEFGHIJ' * 100, '1000B'),
]
class API:
"""Abstract base class for the interface of the test functions."""
@classmethod
def function(cls, func: str, kl: int) -> str:
raise NotImplementedError
@classmethod
def metadata_arguments(cls,
kl: int,
pair: bool,
deterministic: Optional[bool]) -> List[str]:
raise NotImplementedError
@classmethod
def final_arguments(cls) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return True
class PQCPAPI(API):
"""Test mldsa-native entry points."""
@classmethod
def function(cls, func: str, kl: int) -> str:
return f'{func}_{kl}'
@classmethod
def metadata_arguments(cls,
_kl: int,
_pair: bool,
_deterministic: Optional[bool]) -> List[str]:
return []
@classmethod
def secret_is_seed(cls) -> bool:
return False
def one_mldsa_key_pair_from_seed(key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for mldsa-native keypair_internal()."""
tc = test_case.TestCase()
tc.set_function(f'key_pair_from_seed_{key.kl}')
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments([
test_case.hex_string(key.seed),
test_case.hex_string(key.secret),
test_case.hex_string(key.public),
])
tc.set_description(f'MLDSA-{key.kl} key pair from seed {descr}')
return tc
def gen_pqcp_key_management(kl: int) -> Iterable[test_case.TestCase]:
"""Generate test cases for mldsa-native keypair_internal()."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_key_pair_from_seed(key, f'key#{i}')
def one_mldsa_sign_deterministic_pure(api: API,
key: Key,
message: bytes,
descr: str) -> test_case.TestCase:
"""Construct one test case for deterministic signature."""
signature = key.sign_message(message, deterministic=True)
tc = test_case.TestCase()
tc.set_function(api.function('sign_deterministic_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, True, True) + [
test_case.hex_string(key.seed if api.secret_is_seed() else key.secret),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
tc.set_description(f'MLDSA-{key.kl} sign deterministic {descr}')
return tc
def one_mldsa_verify_pure(api: API,
key: Key,
message: bytes,
deterministic: bool,
descr: str) -> test_case.TestCase:
"""Construct one test case for verification.
When deterministic is true, the test case is a deterministic signature.
When deterministic is false, the test case is some other valid signature.
"""
signature = key.sign_message(message, deterministic=deterministic)
tc = test_case.TestCase()
tc.set_function(api.function('verify_pure', key.kl))
tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED'])
tc.set_arguments(api.metadata_arguments(key.kl, False, True) + [
test_case.hex_string(key.public),
test_case.hex_string(message),
test_case.hex_string(signature),
] + api.final_arguments())
variant = "deterministic" if deterministic else "randomized"
tc.set_description(f'MLDSA-{key.kl} verify {variant} {descr}')
return tc
def gen_mldsa_pure(api: API, kl: int) -> Iterable[test_case.TestCase]:
"""Generate all test cases for pure ML-DSA signature and verification."""
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_sign_deterministic_pure(api, key, MESSAGES[0][0],
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_sign_deterministic_pure(api, KEYS[kl][0], message,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], True,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, True,
f'key#1 {descr}')
for i, key in enumerate(KEYS[kl], 1):
yield one_mldsa_verify_pure(api, key, MESSAGES[0][0], False,
f'key#{i}')
for message, descr in MESSAGES[1:]:
yield one_mldsa_verify_pure(api, KEYS[kl][0], message, False,
f'key#1 {descr}')
def gen_pqcp_mldsa_all() -> Iterable[test_case.TestCase]:
"""Generate all test cases for mldsa-native."""
api = PQCPAPI()
for kl in sorted(KEYS.keys()):
yield from gen_pqcp_key_management(kl)
yield from gen_mldsa_pure(api, kl)
class MLDSATestGenerator(test_data_generation.TestGenerator):
"""Generate test cases for ML-DSA."""
def __init__(self, settings) -> None:
self.targets = {
'test_suite_pqcp_mldsa.dilithium_py': gen_pqcp_mldsa_all,
}
super().__init__(settings)
if __name__ == '__main__':
test_data_generation.main(sys.argv[1:], __doc__, MLDSATestGenerator)
+5
View File
@@ -0,0 +1,5 @@
# Python module requirements for maintainer utilities
# For generate_mldsa_tests.py
pycryptodome
dilithium-py >= 1.3.0; python_version >= "3.9"
+18
View File
@@ -0,0 +1,18 @@
"""Add our Python library directory to the module search path.
Usage:
import scripts_path # pylint: disable=unused-import
from mbedtls_framework import ...
"""
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
#
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__),
os.path.pardir,
'scripts'))