diff --git a/scripts/generate_test_keys.py b/scripts/generate_test_keys.py index 7cec7b391..d53b53f1e 100755 --- a/scripts/generate_test_keys.py +++ b/scripts/generate_test_keys.py @@ -15,6 +15,7 @@ from mbedtls_framework import build_tree BYTES_PER_LINE = 16 def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[str]: + """Return C code that defines array_name as a byte array with the given content.""" yield 'static const unsigned char ' yield array_name yield '[] = {' @@ -27,16 +28,23 @@ def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[s def convert_der_to_c(array_name: str, key_data: bytes) -> str: return ''.join(c_byte_array_literal_content(array_name, key_data)) -def get_key_type(key: str) -> str: - if re.match('PSA_KEY_TYPE_RSA_.*', key): - return "rsa" - elif re.match('PSA_KEY_TYPE_ECC_.*', key): +def get_key_type(key_type: str) -> str: + """Short name for a PSA key type.""" + if key_type.startswith('PSA_KEY_TYPE_ECC_'): return "ec" + elif key_type.startswith('PSA_KEY_TYPE_ML_DSA_'): + return "mldsa" + elif key_type.startswith('PSA_KEY_TYPE_ML_KEM_'): + return "mlkem" + elif key_type.startswith('PSA_KEY_TYPE_RSA_'): + return "rsa" + elif key_type.startswith('PSA_KEY_TYPE_SLH_DSA_'): + return "slhdsa" else: - print("Unhandled key type {}".format(key)) - return "unknown" + raise Exception(f"Unhandled key type {key_type}") def get_ec_key_family(key: str) -> str: + """Extract "PSA_ECC_xxx" from "PSA_KEY_TYPE_ECC_ttt(PSA_ECC_xxx)".""" match = re.search(r'.*\((.*)\)', key) if match is None: raise Exception("Unable to get EC family from {}".format(key)) @@ -70,6 +78,7 @@ EC_NAME_CONVERSION = { } def get_ec_curve_name(priv_key: str, bits: int) -> str: + """Short name for an elliptic curve key type.""" ec_family = get_ec_key_family(priv_key) try: prefix = EC_NAME_CONVERSION[ec_family][bits][0] @@ -78,8 +87,15 @@ def get_ec_curve_name(priv_key: str, bits: int) -> str: return "" return prefix + str(bits) + suffix +def get_slh_dsa_family(key_type: str) -> str: + """Short name from an SLH-DSA family.""" + m = re.search(r'PSA_SLH_FAMILY_(\w+)', key_type) + assert m + return m.group(1).replace('_', '').lower() + def get_look_up_table_entry(key_type: str, group_id_or_keybits: str, priv_array_name: str, pub_array_name: str) -> Iterator[str]: + """Yield C code lines for the definition of a key pair and its matching public key.""" if key_type == "ec": yield " {{ {}, 0,\n".format(group_id_or_keybits) else: @@ -147,10 +163,6 @@ def collect_keys() -> Tuple[str, str]: for priv_key in priv_keys: key_type = get_key_type(priv_key) - # Ignore keys which are not EC or RSA - if key_type == "unknown": - continue - pub_key = re.sub('_KEY_PAIR', '_PUBLIC_KEY', priv_key) for bits in ASYMMETRIC_KEY_DATA[priv_key]: @@ -160,10 +172,13 @@ def collect_keys() -> Tuple[str, str]: if curve == "": continue # Create output array name - if key_type == "rsa": - array_name_base = "_".join(["test", key_type, str(bits)]) - else: + if key_type == "ec": array_name_base = "_".join(["test", key_type, curve]) + elif key_type == "slhdsa": + family = get_slh_dsa_family(priv_key) + array_name_base = "_".join(["test", key_type, family, str(bits)]) + else: + array_name_base = "_".join(["test", key_type, str(bits)]) array_name_priv = array_name_base + "_priv" array_name_pub = array_name_base + "_pub" # Convert bytearray to C array @@ -182,6 +197,7 @@ def collect_keys() -> Tuple[str, str]: return ''.join(arrays), '\n'.join(look_up_table) def main() -> None: + """Command line entry point.""" default_output_path = build_tree.guess_project_root() + "/tests/include/test/test_keys.h" argparser = argparse.ArgumentParser()