check_names: add type annotations

I needed that to understand how the data is represented (str vs Match, list
vs set vs tuple, ...).

No semantic change.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
Gilles Peskine
2026-01-15 17:57:47 +01:00
parent 8caa0e42ab
commit 2f01eca203
+105 -64
View File
@@ -47,7 +47,7 @@ import subprocess
import logging
import tempfile
import typing
from typing import List, Optional
from typing import Dict, List, Pattern, Optional, Set, Tuple, Union
import project_scripts # pylint: disable=unused-import
from mbedtls_framework import build_tree
@@ -55,10 +55,11 @@ from mbedtls_framework import build_tree
# Naming patterns to check against. These are defined outside the NameCheck
# class for ease of modification.
PUBLIC_MACRO_PATTERN = r"^(MBEDTLS|PSA|TF_PSA_CRYPTO)_[0-9A-Z_]*[0-9A-Z]$"
INTERNAL_MACRO_PATTERN = r"^[0-9A-Za-z_]*[0-9A-Z]$"
PUBLIC_MACRO_PATTERN = re.compile(r"^(MBEDTLS|PSA|TF_PSA_CRYPTO)_[0-9A-Z_]*[0-9A-Z]$")
INTERNAL_MACRO_PATTERN = re.compile(r"^[0-9A-Za-z_]*[0-9A-Z]$")
CONSTANTS_PATTERN = PUBLIC_MACRO_PATTERN
IDENTIFIER_PATTERN = r"^(mbedtls|psa|tf_psa_crypto)_[0-9a-z_]*[0-9a-z]$"
IDENTIFIER_PATTERN = re.compile(r"^(mbedtls|psa|tf_psa_crypto)_[0-9a-z_]*[0-9a-z]$")
class Match(): # pylint: disable=too-few-public-methods
"""
@@ -71,7 +72,9 @@ class Match(): # pylint: disable=too-few-public-methods
* pos: a tuple of (start, end) positions on the line where the match is.
* name: the match itself.
"""
def __init__(self, filename, line, line_no, pos, name):
def __init__(self,
filename: str, line: str, line_no: int, pos: Tuple[int, int],
name: str) -> None:
# pylint: disable=too-many-arguments
self.filename = filename
self.line = line
@@ -79,7 +82,7 @@ class Match(): # pylint: disable=too-few-public-methods
self.pos = pos
self.name = name
def __str__(self):
def __str__(self) -> str:
"""
Return a formatted code listing representation of the erroneous line.
"""
@@ -127,14 +130,15 @@ class Problem(abc.ABC): # pylint: disable=too-few-public-methods
"""
# Class variable to control the quietness of all problems
quiet = False
def __init__(self):
def __init__(self) -> None:
self.textwrapper = textwrap.TextWrapper(break_on_hyphens=False,
break_long_words=False)
self.textwrapper.width = 80
self.textwrapper.initial_indent = " > "
self.textwrapper.subsequent_indent = " "
def __str__(self):
def __str__(self) -> str:
"""
Unified string representation method for all Problems.
"""
@@ -143,19 +147,20 @@ class Problem(abc.ABC): # pylint: disable=too-few-public-methods
return self.verbose_output()
@abc.abstractmethod
def quiet_output(self):
def quiet_output(self) -> str:
"""
The output when --quiet is enabled.
"""
pass
@abc.abstractmethod
def verbose_output(self):
def verbose_output(self) -> str:
"""
The default output with explanation and code snippet if appropriate.
"""
pass
class SymbolNotInHeader(Problem): # pylint: disable=too-few-public-methods
"""
A problem that occurs when an exported/available symbol in the object file
@@ -165,19 +170,20 @@ class SymbolNotInHeader(Problem): # pylint: disable=too-few-public-methods
Fields:
* symbol_name: the name of the symbol.
"""
def __init__(self, symbol_name):
def __init__(self, symbol_name: str) -> None:
self.symbol_name = symbol_name
Problem.__init__(self)
def quiet_output(self):
def quiet_output(self) -> str:
return "{0}".format(self.symbol_name)
def verbose_output(self):
def verbose_output(self) -> str:
return self.textwrapper.fill(
"'{0}' was found as an available symbol in the output of nm, "
"however it was not declared in any header files."
.format(self.symbol_name))
class PatternMismatch(Problem): # pylint: disable=too-few-public-methods
"""
A problem that occurs when something doesn't match the expected pattern.
@@ -187,19 +193,19 @@ class PatternMismatch(Problem): # pylint: disable=too-few-public-methods
* pattern: the expected regex pattern
* match: the Match object in question
"""
def __init__(self, pattern, match):
def __init__(self, pattern: Union[Pattern, str], match: Match) -> None:
self.pattern = pattern
self.match = match
Problem.__init__(self)
def quiet_output(self):
def quiet_output(self) -> str:
return (
"{0}:{1}:{2}"
.format(self.match.filename, self.match.line_no, self.match.name)
)
def verbose_output(self):
def verbose_output(self) -> str:
return self.textwrapper.fill(
"{0}:{1}: '{2}' does not match the required pattern '{3}'."
.format(
@@ -210,6 +216,7 @@ class PatternMismatch(Problem): # pylint: disable=too-few-public-methods
)
) + "\n" + str(self.match)
class Typo(Problem): # pylint: disable=too-few-public-methods
"""
A problem that occurs when a word using MBED or PSA doesn't
@@ -219,17 +226,17 @@ class Typo(Problem): # pylint: disable=too-few-public-methods
Fields:
* match: the Match object of the MBED|PSA name in question.
"""
def __init__(self, match):
def __init__(self, match: Match) -> None:
self.match = match
Problem.__init__(self)
def quiet_output(self):
def quiet_output(self) -> str:
return (
"{0}:{1}:{2}"
.format(self.match.filename, self.match.line_no, self.match.name)
)
def verbose_output(self):
def verbose_output(self) -> str:
return self.textwrapper.fill(
"{0}:{1}: '{2}' looks like a typo. It was not found in any "
"macros or any enums. If this is not a typo, put "
@@ -237,26 +244,29 @@ class Typo(Problem): # pylint: disable=too-few-public-methods
.format(self.match.filename, self.match.line_no, self.match.name)
) + "\n" + str(self.match)
class CodeParser():
"""
Class for retrieving files and parsing the code. This can be used
independently of the checks that NameChecker performs, for example for
list_internal_identifiers.py.
"""
def __init__(self, log):
def __init__(self, log: logging.Logger) -> None:
self.log = log
if not build_tree.looks_like_root(os.getcwd()):
raise Exception("This script must be run from Mbed TLS or TF-PSA-Crypto root")
# Memo for storing "glob expression": set(filepaths)
self.files = {}
# Globally excluded filenames.
# Note that "*" can match directory separators in exclude lists.
self.excluded_files = ["*/bn_mul", "*/compat-2.x.h"]
def _parse(self, all_macros, enum_consts, identifiers,
excluded_identifiers, mbed_psa_words, symbols):
def _parse(self,
all_macros: Dict[str, List[Match]],
enum_consts: List[Match],
identifiers: List[Match],
excluded_identifiers: List[Match],
mbed_psa_words: List[Match],
symbols: List[str]) -> ParseResult:
# pylint: disable=too-many-arguments
"""
Parse macros, enums, identifiers, excluded identifiers, Mbed PSA word and Symbols.
@@ -272,7 +282,7 @@ class CodeParser():
# Remove identifier macros like mbedtls_printf or mbedtls_calloc
identifiers_justname = [x.name for x in identifiers]
actual_macros = {"public": [], "internal": []}
actual_macros = {"public": [], "internal": []} #type: Dict[str, List[Match]]
for scope in actual_macros:
for macro in all_macros[scope]:
if macro.name not in identifiers_justname:
@@ -299,16 +309,20 @@ class CodeParser():
mbed_psa_words=mbed_psa_words,
)
def is_file_excluded(self, path, exclude_wildcards):
def is_file_excluded(self,
path: str,
exclude_wildcards: Optional[List[str]]) -> bool:
"""Whether the given file path is excluded."""
# exclude_wildcards may be None. Also, consider the global exclusions.
exclude_wildcards = (exclude_wildcards or []) + self.excluded_files
for pattern in exclude_wildcards:
for pattern in (exclude_wildcards or []) + self.excluded_files:
if fnmatch.fnmatch(path, pattern):
return True
return False
def get_all_files(self, include_wildcards, exclude_wildcards):
def get_all_files(self,
include_wildcards: List[str],
exclude_wildcards: Optional[List[str]],
) -> Tuple[List[str], List[str]]:
"""
Get all files that match any of the included UNIX-style wildcards
and filter them into included and excluded lists.
@@ -333,8 +347,10 @@ class CodeParser():
* inc_files: A List of relative filepaths for included files.
* exc_files: A List of relative filepaths for excluded files.
"""
accumulator = set()
all_wildcards = include_wildcards + (exclude_wildcards or [])
if exclude_wildcards is None:
exclude_wildcards = []
accumulator = set() #type: Set[str]
all_wildcards = include_wildcards + exclude_wildcards
for wildcard in all_wildcards:
accumulator = accumulator.union(glob.iglob(wildcard, recursive=True))
@@ -347,7 +363,10 @@ class CodeParser():
inc_files.append(path)
return (sorted(inc_files), sorted(exc_files))
def get_included_files(self, include_wildcards, exclude_wildcards):
def get_included_files(self,
include_wildcards: List[str],
exclude_wildcards: Optional[List[str]],
) -> List[str]:
"""
Get all files that match any of the included UNIX-style wildcards.
While the check_names script is designed only for use on UNIX/macOS
@@ -360,7 +379,7 @@ class CodeParser():
Returns a List of relative filepaths.
"""
accumulator = set()
accumulator = set() #type: Set[str]
for include_wildcard in include_wildcards:
accumulator = accumulator.union(glob.iglob(include_wildcard,
@@ -369,7 +388,10 @@ class CodeParser():
return sorted(path for path in accumulator
if not self.is_file_excluded(path, exclude_wildcards))
def parse_macros(self, include, exclude=None):
def parse_macros(self,
include: List[str],
exclude: Optional[List[str]] = None,
) -> List[Match]:
"""
Parse all macros defined by #define preprocessor directives.
@@ -405,7 +427,10 @@ class CodeParser():
return macros
def parse_mbed_psa_words(self, include, exclude=None):
def parse_mbed_psa_words(self,
include: List[str],
exclude: Optional[List[str]] = None,
) -> List[Match]:
"""
Parse all words in the file that begin with MBED|PSA, in and out of
macros, comments, anything.
@@ -444,7 +469,10 @@ class CodeParser():
return mbed_psa_words
def parse_enum_consts(self, include, exclude=None):
def parse_enum_consts(self,
include: List[str],
exclude: Optional[List[str]] = None,
) -> List[Match]:
"""
Parse all enum value constants that are declared.
@@ -507,7 +535,8 @@ class CodeParser():
r'(?P<string>")(?:[^\\\"]|\\.)*"', # string literal
]))
def strip_comments_and_literals(self, line, in_block_comment):
def strip_comments_and_literals(self, line: str,
in_block_comment: bool) -> Tuple[str, bool]:
"""Strip comments and string literals from line.
Continuation lines are not supported.
@@ -573,7 +602,9 @@ class CodeParser():
r"#",
]))
def parse_identifiers_in_file(self, header_file, identifiers):
def parse_identifiers_in_file(self,
header_file: str,
identifiers: List[Match]) -> None:
"""
Parse all lines of a header where a function/enum/struct/union/typedef
identifier is declared, based on some regex and heuristics. Highly
@@ -635,7 +666,10 @@ class CodeParser():
identifier.span(),
group))
def parse_identifiers(self, include, exclude=None):
def parse_identifiers(self,
include: List[str],
exclude: Optional[List[str]] = None,
) -> Tuple[List[Match], List[Match]]:
"""
Parse all lines of a header where a function/enum/struct/union/typedef
identifier is declared, based on some regex and heuristics. Highly
@@ -658,19 +692,19 @@ class CodeParser():
self.log.debug("Looking for included identifiers in {} files".format \
(len(included_files)))
included_identifiers = []
included_identifiers = [] #type: List[Match]
for header_file in included_files:
self.parse_identifiers_in_file(header_file, included_identifiers)
self.log.debug("Looking for excluded identifiers in {} files".format \
(len(excluded_files)))
excluded_identifiers = []
excluded_identifiers = [] #type: List[Match]
for header_file in excluded_files:
self.parse_identifiers_in_file(header_file, excluded_identifiers)
return (included_identifiers, excluded_identifiers)
def parse_symbols(self):
def parse_symbols(self) -> List[str]:
"""
Compile a library, and parse the object files using nm to retrieve the
list of referenced symbols. Exceptions thrown here are rethrown because
@@ -681,7 +715,7 @@ class CodeParser():
"""
raise NotImplementedError("parse_symbols must be implemented by a code parser")
def comprehensive_parse(self):
def comprehensive_parse(self) -> ParseResult:
"""
(Must be defined as a class method)
Comprehensive ("default") function to call each parsing function and
@@ -691,7 +725,7 @@ class CodeParser():
"""
raise NotImplementedError("comprehension_parse must be implemented by a code parser")
def parse_symbols_from_nm(self, object_files):
def parse_symbols_from_nm(self, object_files: List[str]) -> List[str]:
"""
Run nm to retrieve the list of referenced symbols in each object file.
Does not return the position data since it is of no use.
@@ -705,7 +739,7 @@ class CodeParser():
nm_undefined_regex = re.compile(r"^\S+: +U |^$|^\S+:$")
nm_valid_regex = re.compile(r"^\S+( [0-9A-Fa-f]+)* . _*(?P<symbol>\w+)")
exclusions = ("FStar", "Hacl")
symbols = []
symbols = [] #type: List[str]
# Gather all outputs of nm
nm_output = ""
for lib in object_files:
@@ -725,13 +759,14 @@ class CodeParser():
self.log.error(line)
return symbols
class TFPSACryptoCodeParser(CodeParser):
"""
Class for retrieving files and parsing TF-PSA-Crypto code. This can be used
independently of the checks that NameChecker performs.
"""
def __init__(self, log):
def __init__(self, log: logging.Logger) -> None:
super().__init__(log)
if not build_tree.looks_like_tf_psa_crypto_root(os.getcwd()):
raise Exception("This script must be run from TF-PSA-Crypto root.")
@@ -761,14 +796,14 @@ class TFPSACryptoCodeParser(CodeParser):
"drivers/*/src/*.c",
]
def comprehensive_parse(self):
def comprehensive_parse(self) -> ParseResult:
"""
Comprehensive ("default") function to call each parsing function and
retrieve various elements of the code, together with the source location.
Returns a dict of parsed item key to the corresponding List of Matches.
"""
all_macros = {"public": [], "internal": [], "private":[]}
all_macros = {"public": [], "internal": [], "private":[]} #type: Dict[str, List[Match]]
all_macros["public"] = self.parse_macros(self.H_PUBLIC,
self.H_PUBLIC_EXCLUDE)
all_macros["internal"] = self.parse_macros(self.H_INTERNAL +
@@ -788,7 +823,7 @@ class TFPSACryptoCodeParser(CodeParser):
return self._parse(all_macros, enum_consts, identifiers,
excluded_identifiers, mbed_psa_words, symbols)
def parse_symbols(self):
def parse_symbols(self) -> List[str]:
"""
Compile the TF-PSA-Crypto libraries, and parse the
object files using nm to retrieve the list of referenced symbols.
@@ -854,25 +889,26 @@ class TFPSACryptoCodeParser(CodeParser):
return symbols
class MBEDTLSCodeParser(CodeParser):
"""
Class for retrieving files and parsing Mbed TLS code. This can be used
independently of the checks that NameChecker performs.
"""
def __init__(self, log):
def __init__(self, log: logging.Logger) -> None:
super().__init__(log)
if not build_tree.looks_like_mbedtls_root(os.getcwd()):
raise Exception("This script must be run from Mbed TLS root.")
def comprehensive_parse(self):
def comprehensive_parse(self) -> ParseResult:
"""
Comprehensive ("default") function to call each parsing function and
retrieve various elements of the code, together with the source location.
Returns a dict of parsed item key to the corresponding List of Matches.
"""
all_macros = {"public": [], "internal": [], "private":[]}
all_macros = {"public": [], "internal": [], "private":[]} #type: Dict[str, List[Match]]
# TF-PSA-Crypto is in the same repo in 3.6 so initalise variable here.
tf_psa_crypto_parse_result = None
@@ -956,7 +992,7 @@ class MBEDTLSCodeParser(CodeParser):
mbed_psa_words, symbols)
return mbedtls_parse_result.add(tf_psa_crypto_parse_result)
def parse_symbols(self):
def parse_symbols(self) -> List[str]:
"""
Compile the Mbed TLS libraries, and parse the TLS, Crypto, and x509
object files using nm to retrieve the list of referenced symbols.
@@ -1031,15 +1067,18 @@ class MBEDTLSCodeParser(CodeParser):
return symbols
class NameChecker():
"""
Representation of the core name checking operation performed by this script.
"""
def __init__(self, parse_result, log):
def __init__(self,
parse_result: ParseResult,
log: logging.Logger) -> None:
self.parse_result = parse_result
self.log = log
def perform_checks(self, quiet=False):
def perform_checks(self, quiet=False) -> int:
"""
A comprehensive checker that performs each check in order, and outputs
a final verdict.
@@ -1075,7 +1114,7 @@ class NameChecker():
self.log.info("PASS")
return 0
def check_symbols_declared_in_header(self):
def check_symbols_declared_in_header(self) -> int:
"""
Perform a check that all detected symbols in the library object files
are properly declared in headers.
@@ -1083,7 +1122,7 @@ class NameChecker():
Returns the number of problems that need fixing.
"""
problems = []
problems = [] #type: List[Problem]
all_identifiers = self.parse_result.identifiers + \
self.parse_result.excluded_identifiers
@@ -1100,7 +1139,8 @@ class NameChecker():
self.output_check_result("All symbols in header", problems)
return len(problems)
def check_match_pattern(self, group_to_check, check_pattern):
def check_match_pattern(self, group_to_check: str,
check_pattern: Pattern) -> int:
"""
Perform a check that all items of a group conform to a regex pattern.
Assumes parse_names_in_source() was called before this.
@@ -1111,7 +1151,7 @@ class NameChecker():
Returns the number of problems that need fixing.
"""
problems = []
problems = [] #type: List[Problem]
for item_match in getattr(self.parse_result, group_to_check):
if not re.search(check_pattern, item_match.name):
@@ -1126,7 +1166,7 @@ class NameChecker():
problems)
return len(problems)
def check_for_typos(self):
def check_for_typos(self) -> int:
"""
Perform a check that all words in the source code beginning with MBED are
either defined as macros, or as enum constants.
@@ -1134,7 +1174,7 @@ class NameChecker():
Returns the number of problems that need fixing.
"""
problems = []
problems = [] #type: List[Problem]
# Set comprehension, equivalent to a list comprehension wrapped by set()
all_caps_names = {
@@ -1167,7 +1207,7 @@ class NameChecker():
self.output_check_result("Likely typos", problems)
return len(problems)
def output_check_result(self, name, problems):
def output_check_result(self, name: str, problems: List[Problem]) -> None:
"""
Write out the PASS/FAIL status of a performed check depending on whether
there were problems.
@@ -1183,7 +1223,8 @@ class NameChecker():
else:
self.log.info("{}: PASS".format(name))
def main():
def main() -> None:
"""
Perform argument parsing, and create an instance of CodeParser and
NameChecker to begin the core operation.