diff --git a/scripts/generate_psa_tests.py b/scripts/generate_psa_tests.py index c6c068e2a..caf239477 100755 --- a/scripts/generate_psa_tests.py +++ b/scripts/generate_psa_tests.py @@ -341,23 +341,27 @@ class OpFail: def test_cases_for_algorithm( self, alg: crypto_knowledge.Algorithm, + categories: Iterable[crypto_knowledge.AlgorithmCategory] ) -> Iterator[test_case.TestCase]: """Generate operation failure test cases for the specified algorithm.""" - for category in crypto_knowledge.AlgorithmCategory: - if category == crypto_knowledge.AlgorithmCategory.PAKE: - # PAKE operations are not implemented yet - pass - elif category.requires_key(): + for category in categories: + if category.requires_key(): yield from self.one_key_test_cases(alg, category) else: yield from self.no_key_test_cases(alg, category) def all_test_cases(self) -> Iterator[test_case.TestCase]: """Generate all test cases for operations that must fail.""" - algorithms = sorted(self.constructors.algorithms) - for expr in self.constructors.generate_expressions(algorithms): - alg = crypto_knowledge.Algorithm(expr) - yield from self.test_cases_for_algorithm(alg) + algorithm_constructors = sorted(self.constructors.algorithms) + algorithms = [crypto_knowledge.Algorithm(alg) + for alg in self.constructors.generate_expressions( + algorithm_constructors)] + categories = [ + cat for cat in crypto_knowledge.AlgorithmCategory + if cat != crypto_knowledge.AlgorithmCategory.PAKE # not implemented yet + ] + for alg in algorithms: + yield from self.test_cases_for_algorithm(alg, categories) class StorageKey(psa_storage.Key):