test_driver.py: Various small code improvements

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
This commit is contained in:
Ronald Cron
2025-12-05 18:10:46 +01:00
parent e8dcf71c2b
commit ed9c4a31e4
+20 -21
View File
@@ -37,7 +37,7 @@ def get_parsearg_base() -> argparse.ArgumentParser:
class TestDriverGenerator:
"""A TF-PSA-Crypto test driver generator"""
def __init__(self, src_dir: Path, dst_dir: Path, driver: str, \
exclude_files: Optional[Set[str]] = None) -> None:
exclude_files: Optional[Iterable[str]] = None) -> None:
"""
Initialize a test driver generator.
@@ -68,7 +68,7 @@ class TestDriverGenerator:
self.src_dir = src_dir
self.dst_dir = dst_dir
self.driver = driver
self.exclude_files = set()
self.exclude_files: Iterable[str] = ()
if exclude_files is not None:
self.exclude_files = exclude_files
@@ -79,16 +79,16 @@ class TestDriverGenerator:
raise RuntimeError(f'"src" directory in {src_dir} not found')
def write_list_vars_for_cmake(self, fname: str) -> None:
src_relpaths = self.__iter_src_code_files()
src_relpaths = self.__get_src_code_files()
with open(self.dst_dir / fname, "w") as f:
f.write(f"set({self.driver}_input_files " + \
" ".join(str(path) for path in src_relpaths) + ")\n\n")
"\n".join(str(path) for path in src_relpaths) + ")\n\n")
f.write(f"set({self.driver}_files " + \
" ".join(str(self.__get_dst_relpath(path.relative_to(self.src_dir))) \
"\n".join(str(self.__get_dst_relpath(path.relative_to(self.src_dir))) \
for path in src_relpaths) + ")\n\n")
f.write(f"set({self.driver}_src_files " + \
" ".join(str(path.relative_to(self.src_dir)) \
for path in src_relpaths if path.suffix == ".c") + ")")
"\n".join(str(path.relative_to(self.src_dir)) \
for path in src_relpaths if path.suffix == ".c") + ")\n")
def get_identifiers_to_prefix(self, prefixes: Set[str]) -> Set[str]:
@@ -117,7 +117,7 @@ class TestDriverGenerator:
Set[str]: The default set of identifiers to rename.
"""
identifiers = set()
for file in self.__iter_code_files(self.dst_dir):
for file in self.__get_code_files(self.dst_dir):
identifiers.update(self.get_c_identifiers(file))
identifiers_with_prefixes = set()
@@ -126,7 +126,7 @@ class TestDriverGenerator:
identifiers_with_prefixes.add(identifier)
return identifiers_with_prefixes
def build_tree(self, prefixes: Set[str]) -> None:
def create_test_driver_tree(self, prefixes: Set[str]) -> None:
"""
Build a test driver tree from `self.src_dir` into `self.dst_dir`.
@@ -166,7 +166,7 @@ class TestDriverGenerator:
shutil.rmtree(self.dst_dir / "src")
# Clone the source tree into `dst_dir`
for file in self.__iter_src_code_files():
for file in self.__get_src_code_files():
dst = self.dst_dir / \
self.__get_dst_relpath(file.relative_to(self.src_dir))
dst.parent.mkdir(parents=True, exist_ok=True)
@@ -180,31 +180,30 @@ class TestDriverGenerator:
}
identifiers_to_prefix = self.get_identifiers_to_prefix(prefixes)
for f in self.__iter_code_files(self.dst_dir):
for f in self.__get_code_files(self.dst_dir):
self.__rewrite_test_driver_file(f, headers,\
src_include_dir_name,
identifiers_to_prefix, self.driver)
@staticmethod
def __iter_code_files(root: Path) -> Iterable[Path]:
def __get_code_files(root: Path) -> List[Path]:
"""
Iterate over all "*.c" and "*.h" files found recursively under the
Return all "*.c" and "*.h" files found recursively under the
`include` and `src` subdirectories of `root`.
"""
for directory in ("include", "src"):
directory_path = root / directory
for ext in (".c", ".h"):
yield from directory_path.rglob(f"*{ext}")
return sorted(path
for directory in ('include', 'src')
for path in (root / directory).rglob('*.[hc]'))
def __iter_src_code_files(self) -> List[Path]:
def __get_src_code_files(self) -> List[Path]:
"""
Iterate over all "*.c" and "*.h" files found recursively under the
Return all "*.c" and "*.h" files found recursively under the
`include` and `src` subdirectories of the source directory `self.src_dir`
excluding the files whose basename match any of the patterns in
`self.exclude_files`.
"""
out = []
for file in self.__iter_code_files(self.src_dir):
for file in self.__get_code_files(self.src_dir):
if not any(fnmatch(file.name, pattern) for pattern in self.exclude_files):
out.append(file)
return out
@@ -284,7 +283,7 @@ class TestDriverGenerator:
text = file.read_text(encoding="utf-8")
include_line_re = re.compile(
fr'^\s*#\s*include\s*([<"])\s*{src_include_dir}/([^>"]+)\s*([>"])',
fr'^\s*#\s*include\s*([<"]){src_include_dir}/([^>"]+)([>"])',
re.MULTILINE
)
def repl_header_inclusion(m: Match) -> str: