mirror of
https://github.com/Mbed-TLS/mbedtls-framework.git
synced 2026-06-05 21:15:09 +00:00
generate_test_keys.py: Support ML-DSA, ML-KEM and SLH-DSA keys
Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from mbedtls_framework import build_tree
|
|||||||
BYTES_PER_LINE = 16
|
BYTES_PER_LINE = 16
|
||||||
|
|
||||||
def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[str]:
|
def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[str]:
|
||||||
|
"""Return C code that defines array_name as a byte array with the given content."""
|
||||||
yield 'static const unsigned char '
|
yield 'static const unsigned char '
|
||||||
yield array_name
|
yield array_name
|
||||||
yield '[] = {'
|
yield '[] = {'
|
||||||
@@ -27,16 +28,23 @@ def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[s
|
|||||||
def convert_der_to_c(array_name: str, key_data: bytes) -> str:
|
def convert_der_to_c(array_name: str, key_data: bytes) -> str:
|
||||||
return ''.join(c_byte_array_literal_content(array_name, key_data))
|
return ''.join(c_byte_array_literal_content(array_name, key_data))
|
||||||
|
|
||||||
def get_key_type(key: str) -> str:
|
def get_key_type(key_type: str) -> str:
|
||||||
if re.match('PSA_KEY_TYPE_RSA_.*', key):
|
"""Short name for a PSA key type."""
|
||||||
return "rsa"
|
if key_type.startswith('PSA_KEY_TYPE_ECC_'):
|
||||||
elif re.match('PSA_KEY_TYPE_ECC_.*', key):
|
|
||||||
return "ec"
|
return "ec"
|
||||||
|
elif key_type.startswith('PSA_KEY_TYPE_ML_DSA_'):
|
||||||
|
return "mldsa"
|
||||||
|
elif key_type.startswith('PSA_KEY_TYPE_ML_KEM_'):
|
||||||
|
return "mlkem"
|
||||||
|
elif key_type.startswith('PSA_KEY_TYPE_RSA_'):
|
||||||
|
return "rsa"
|
||||||
|
elif key_type.startswith('PSA_KEY_TYPE_SLH_DSA_'):
|
||||||
|
return "slhdsa"
|
||||||
else:
|
else:
|
||||||
print("Unhandled key type {}".format(key))
|
raise Exception(f"Unhandled key type {key_type}")
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
def get_ec_key_family(key: str) -> str:
|
def get_ec_key_family(key: str) -> str:
|
||||||
|
"""Extract "PSA_ECC_xxx" from "PSA_KEY_TYPE_ECC_ttt(PSA_ECC_xxx)"."""
|
||||||
match = re.search(r'.*\((.*)\)', key)
|
match = re.search(r'.*\((.*)\)', key)
|
||||||
if match is None:
|
if match is None:
|
||||||
raise Exception("Unable to get EC family from {}".format(key))
|
raise Exception("Unable to get EC family from {}".format(key))
|
||||||
@@ -70,6 +78,7 @@ EC_NAME_CONVERSION = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def get_ec_curve_name(priv_key: str, bits: int) -> str:
|
def get_ec_curve_name(priv_key: str, bits: int) -> str:
|
||||||
|
"""Short name for an elliptic curve key type."""
|
||||||
ec_family = get_ec_key_family(priv_key)
|
ec_family = get_ec_key_family(priv_key)
|
||||||
try:
|
try:
|
||||||
prefix = EC_NAME_CONVERSION[ec_family][bits][0]
|
prefix = EC_NAME_CONVERSION[ec_family][bits][0]
|
||||||
@@ -78,8 +87,15 @@ def get_ec_curve_name(priv_key: str, bits: int) -> str:
|
|||||||
return ""
|
return ""
|
||||||
return prefix + str(bits) + suffix
|
return prefix + str(bits) + suffix
|
||||||
|
|
||||||
|
def get_slh_dsa_family(key_type: str) -> str:
|
||||||
|
"""Short name from an SLH-DSA family."""
|
||||||
|
m = re.search(r'PSA_SLH_FAMILY_(\w+)', key_type)
|
||||||
|
assert m
|
||||||
|
return m.group(1).replace('_', '').lower()
|
||||||
|
|
||||||
def get_look_up_table_entry(key_type: str, group_id_or_keybits: str,
|
def get_look_up_table_entry(key_type: str, group_id_or_keybits: str,
|
||||||
priv_array_name: str, pub_array_name: str) -> Iterator[str]:
|
priv_array_name: str, pub_array_name: str) -> Iterator[str]:
|
||||||
|
"""Yield C code lines for the definition of a key pair and its matching public key."""
|
||||||
if key_type == "ec":
|
if key_type == "ec":
|
||||||
yield " {{ {}, 0,\n".format(group_id_or_keybits)
|
yield " {{ {}, 0,\n".format(group_id_or_keybits)
|
||||||
else:
|
else:
|
||||||
@@ -147,10 +163,6 @@ def collect_keys() -> Tuple[str, str]:
|
|||||||
|
|
||||||
for priv_key in priv_keys:
|
for priv_key in priv_keys:
|
||||||
key_type = get_key_type(priv_key)
|
key_type = get_key_type(priv_key)
|
||||||
# Ignore keys which are not EC or RSA
|
|
||||||
if key_type == "unknown":
|
|
||||||
continue
|
|
||||||
|
|
||||||
pub_key = re.sub('_KEY_PAIR', '_PUBLIC_KEY', priv_key)
|
pub_key = re.sub('_KEY_PAIR', '_PUBLIC_KEY', priv_key)
|
||||||
|
|
||||||
for bits in ASYMMETRIC_KEY_DATA[priv_key]:
|
for bits in ASYMMETRIC_KEY_DATA[priv_key]:
|
||||||
@@ -160,10 +172,13 @@ def collect_keys() -> Tuple[str, str]:
|
|||||||
if curve == "":
|
if curve == "":
|
||||||
continue
|
continue
|
||||||
# Create output array name
|
# Create output array name
|
||||||
if key_type == "rsa":
|
if key_type == "ec":
|
||||||
array_name_base = "_".join(["test", key_type, str(bits)])
|
|
||||||
else:
|
|
||||||
array_name_base = "_".join(["test", key_type, curve])
|
array_name_base = "_".join(["test", key_type, curve])
|
||||||
|
elif key_type == "slhdsa":
|
||||||
|
family = get_slh_dsa_family(priv_key)
|
||||||
|
array_name_base = "_".join(["test", key_type, family, str(bits)])
|
||||||
|
else:
|
||||||
|
array_name_base = "_".join(["test", key_type, str(bits)])
|
||||||
array_name_priv = array_name_base + "_priv"
|
array_name_priv = array_name_base + "_priv"
|
||||||
array_name_pub = array_name_base + "_pub"
|
array_name_pub = array_name_base + "_pub"
|
||||||
# Convert bytearray to C array
|
# Convert bytearray to C array
|
||||||
@@ -182,6 +197,7 @@ def collect_keys() -> Tuple[str, str]:
|
|||||||
return ''.join(arrays), '\n'.join(look_up_table)
|
return ''.join(arrays), '\n'.join(look_up_table)
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
"""Command line entry point."""
|
||||||
default_output_path = build_tree.guess_project_root() + "/tests/include/test/test_keys.h"
|
default_output_path = build_tree.guess_project_root() + "/tests/include/test/test_keys.h"
|
||||||
|
|
||||||
argparser = argparse.ArgumentParser()
|
argparser = argparse.ArgumentParser()
|
||||||
|
|||||||
Reference in New Issue
Block a user