generate_test_keys.py: Support ML-DSA, ML-KEM and SLH-DSA keys

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
Gilles Peskine
2026-01-21 17:53:04 +01:00
parent dec6c51f7d
commit d60e412a2d
+29 -13
View File
@@ -15,6 +15,7 @@ from mbedtls_framework import build_tree
BYTES_PER_LINE = 16
def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[str]:
"""Return C code that defines array_name as a byte array with the given content."""
yield 'static const unsigned char '
yield array_name
yield '[] = {'
@@ -27,16 +28,23 @@ def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[s
def convert_der_to_c(array_name: str, key_data: bytes) -> str:
return ''.join(c_byte_array_literal_content(array_name, key_data))
def get_key_type(key: str) -> str:
if re.match('PSA_KEY_TYPE_RSA_.*', key):
return "rsa"
elif re.match('PSA_KEY_TYPE_ECC_.*', key):
def get_key_type(key_type: str) -> str:
"""Short name for a PSA key type."""
if key_type.startswith('PSA_KEY_TYPE_ECC_'):
return "ec"
elif key_type.startswith('PSA_KEY_TYPE_ML_DSA_'):
return "mldsa"
elif key_type.startswith('PSA_KEY_TYPE_ML_KEM_'):
return "mlkem"
elif key_type.startswith('PSA_KEY_TYPE_RSA_'):
return "rsa"
elif key_type.startswith('PSA_KEY_TYPE_SLH_DSA_'):
return "slhdsa"
else:
print("Unhandled key type {}".format(key))
return "unknown"
raise Exception(f"Unhandled key type {key_type}")
def get_ec_key_family(key: str) -> str:
"""Extract "PSA_ECC_xxx" from "PSA_KEY_TYPE_ECC_ttt(PSA_ECC_xxx)"."""
match = re.search(r'.*\((.*)\)', key)
if match is None:
raise Exception("Unable to get EC family from {}".format(key))
@@ -70,6 +78,7 @@ EC_NAME_CONVERSION = {
}
def get_ec_curve_name(priv_key: str, bits: int) -> str:
"""Short name for an elliptic curve key type."""
ec_family = get_ec_key_family(priv_key)
try:
prefix = EC_NAME_CONVERSION[ec_family][bits][0]
@@ -78,8 +87,15 @@ def get_ec_curve_name(priv_key: str, bits: int) -> str:
return ""
return prefix + str(bits) + suffix
def get_slh_dsa_family(key_type: str) -> str:
"""Short name from an SLH-DSA family."""
m = re.search(r'PSA_SLH_FAMILY_(\w+)', key_type)
assert m
return m.group(1).replace('_', '').lower()
def get_look_up_table_entry(key_type: str, group_id_or_keybits: str,
priv_array_name: str, pub_array_name: str) -> Iterator[str]:
"""Yield C code lines for the definition of a key pair and its matching public key."""
if key_type == "ec":
yield " {{ {}, 0,\n".format(group_id_or_keybits)
else:
@@ -147,10 +163,6 @@ def collect_keys() -> Tuple[str, str]:
for priv_key in priv_keys:
key_type = get_key_type(priv_key)
# Ignore keys which are not EC or RSA
if key_type == "unknown":
continue
pub_key = re.sub('_KEY_PAIR', '_PUBLIC_KEY', priv_key)
for bits in ASYMMETRIC_KEY_DATA[priv_key]:
@@ -160,10 +172,13 @@ def collect_keys() -> Tuple[str, str]:
if curve == "":
continue
# Create output array name
if key_type == "rsa":
array_name_base = "_".join(["test", key_type, str(bits)])
else:
if key_type == "ec":
array_name_base = "_".join(["test", key_type, curve])
elif key_type == "slhdsa":
family = get_slh_dsa_family(priv_key)
array_name_base = "_".join(["test", key_type, family, str(bits)])
else:
array_name_base = "_".join(["test", key_type, str(bits)])
array_name_priv = array_name_base + "_priv"
array_name_pub = array_name_base + "_pub"
# Convert bytearray to C array
@@ -182,6 +197,7 @@ def collect_keys() -> Tuple[str, str]:
return ''.join(arrays), '\n'.join(look_up_table)
def main() -> None:
"""Command line entry point."""
default_output_path = build_tree.guess_project_root() + "/tests/include/test/test_keys.h"
argparser = argparse.ArgumentParser()