mirror of
https://github.com/Mbed-TLS/mbedtls-framework.git
synced 2026-06-06 05:25:18 +00:00
generate_test_keys.py: Support ML-DSA, ML-KEM and SLH-DSA keys
Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from mbedtls_framework import build_tree
|
||||
BYTES_PER_LINE = 16
|
||||
|
||||
def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[str]:
|
||||
"""Return C code that defines array_name as a byte array with the given content."""
|
||||
yield 'static const unsigned char '
|
||||
yield array_name
|
||||
yield '[] = {'
|
||||
@@ -27,16 +28,23 @@ def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[s
|
||||
def convert_der_to_c(array_name: str, key_data: bytes) -> str:
|
||||
return ''.join(c_byte_array_literal_content(array_name, key_data))
|
||||
|
||||
def get_key_type(key: str) -> str:
|
||||
if re.match('PSA_KEY_TYPE_RSA_.*', key):
|
||||
return "rsa"
|
||||
elif re.match('PSA_KEY_TYPE_ECC_.*', key):
|
||||
def get_key_type(key_type: str) -> str:
|
||||
"""Short name for a PSA key type."""
|
||||
if key_type.startswith('PSA_KEY_TYPE_ECC_'):
|
||||
return "ec"
|
||||
elif key_type.startswith('PSA_KEY_TYPE_ML_DSA_'):
|
||||
return "mldsa"
|
||||
elif key_type.startswith('PSA_KEY_TYPE_ML_KEM_'):
|
||||
return "mlkem"
|
||||
elif key_type.startswith('PSA_KEY_TYPE_RSA_'):
|
||||
return "rsa"
|
||||
elif key_type.startswith('PSA_KEY_TYPE_SLH_DSA_'):
|
||||
return "slhdsa"
|
||||
else:
|
||||
print("Unhandled key type {}".format(key))
|
||||
return "unknown"
|
||||
raise Exception(f"Unhandled key type {key_type}")
|
||||
|
||||
def get_ec_key_family(key: str) -> str:
|
||||
"""Extract "PSA_ECC_xxx" from "PSA_KEY_TYPE_ECC_ttt(PSA_ECC_xxx)"."""
|
||||
match = re.search(r'.*\((.*)\)', key)
|
||||
if match is None:
|
||||
raise Exception("Unable to get EC family from {}".format(key))
|
||||
@@ -70,6 +78,7 @@ EC_NAME_CONVERSION = {
|
||||
}
|
||||
|
||||
def get_ec_curve_name(priv_key: str, bits: int) -> str:
|
||||
"""Short name for an elliptic curve key type."""
|
||||
ec_family = get_ec_key_family(priv_key)
|
||||
try:
|
||||
prefix = EC_NAME_CONVERSION[ec_family][bits][0]
|
||||
@@ -78,8 +87,15 @@ def get_ec_curve_name(priv_key: str, bits: int) -> str:
|
||||
return ""
|
||||
return prefix + str(bits) + suffix
|
||||
|
||||
def get_slh_dsa_family(key_type: str) -> str:
|
||||
"""Short name from an SLH-DSA family."""
|
||||
m = re.search(r'PSA_SLH_FAMILY_(\w+)', key_type)
|
||||
assert m
|
||||
return m.group(1).replace('_', '').lower()
|
||||
|
||||
def get_look_up_table_entry(key_type: str, group_id_or_keybits: str,
|
||||
priv_array_name: str, pub_array_name: str) -> Iterator[str]:
|
||||
"""Yield C code lines for the definition of a key pair and its matching public key."""
|
||||
if key_type == "ec":
|
||||
yield " {{ {}, 0,\n".format(group_id_or_keybits)
|
||||
else:
|
||||
@@ -147,10 +163,6 @@ def collect_keys() -> Tuple[str, str]:
|
||||
|
||||
for priv_key in priv_keys:
|
||||
key_type = get_key_type(priv_key)
|
||||
# Ignore keys which are not EC or RSA
|
||||
if key_type == "unknown":
|
||||
continue
|
||||
|
||||
pub_key = re.sub('_KEY_PAIR', '_PUBLIC_KEY', priv_key)
|
||||
|
||||
for bits in ASYMMETRIC_KEY_DATA[priv_key]:
|
||||
@@ -160,10 +172,13 @@ def collect_keys() -> Tuple[str, str]:
|
||||
if curve == "":
|
||||
continue
|
||||
# Create output array name
|
||||
if key_type == "rsa":
|
||||
array_name_base = "_".join(["test", key_type, str(bits)])
|
||||
else:
|
||||
if key_type == "ec":
|
||||
array_name_base = "_".join(["test", key_type, curve])
|
||||
elif key_type == "slhdsa":
|
||||
family = get_slh_dsa_family(priv_key)
|
||||
array_name_base = "_".join(["test", key_type, family, str(bits)])
|
||||
else:
|
||||
array_name_base = "_".join(["test", key_type, str(bits)])
|
||||
array_name_priv = array_name_base + "_priv"
|
||||
array_name_pub = array_name_base + "_pub"
|
||||
# Convert bytearray to C array
|
||||
@@ -182,6 +197,7 @@ def collect_keys() -> Tuple[str, str]:
|
||||
return ''.join(arrays), '\n'.join(look_up_table)
|
||||
|
||||
def main() -> None:
|
||||
"""Command line entry point."""
|
||||
default_output_path = build_tree.guess_project_root() + "/tests/include/test/test_keys.h"
|
||||
|
||||
argparser = argparse.ArgumentParser()
|
||||
|
||||
Reference in New Issue
Block a user