Improve the construction of inputs to multipart APIs

Fix a bug whereby the chunks did not actually have the desired lengths.

Make the message content depend only on its length, and not how it is split
into chunks. This way, it'll be easier to notice and analyze bugs that cause
different outputs for different ways to split the input.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
Gilles Peskine
2026-04-23 15:43:05 +02:00
parent ffc707d4f6
commit 88cfd0cf47
@@ -85,6 +85,35 @@ class Generator:
def secret_is_seed(cls) -> bool:
return True
@staticmethod
def message_for_length(length: int) -> bytes:
# b'ABCDE...' wrapping around and repeating every 256 bytes
return bytes(b & 0xff for b in range(65, 65 + length))
def chunks_for_lengths(self,
lengths: Sequence[int],
arity: Optional[int] = None,
) -> Tuple[bytes, List[bytes]]:
"""Construct a message split in chunks of the given lengths.
The content of the message only depends on the total length.
If `arity` is specified and less than `len(lengths)`, pad the list
of chunks to that number. If `arity` is specified and larger than
`len(lengths)`, raise an exception.
"""
total_length = sum(lengths)
message = self.message_for_length(total_length)
chunks = []
offset = 0
for n in lengths:
chunks.append(message[offset:offset + n])
offset += n
if arity is not None:
assert len(lengths) <= arity
chunks += [b''] * (arity - len(lengths))
return (message, chunks)
def one_mldsa_public_key_from_seed(self, key: Key,
descr: str) -> test_case.TestCase:
"""Construct one test case for driver export_public_key()."""
@@ -263,10 +292,7 @@ class DriverGenerator(Generator):
The number of message chunks must be at most MULTIPART_ARITY.
"""
assert len(lengths) <= self.MULTIPART_ARITY
chunks = ([bytes(i) * n for i, n in enumerate(lengths, 65)] +
[b''] * (self.MULTIPART_ARITY - len(lengths)))
message = b''.join(chunks)
message, chunks = self.chunks_for_lengths(lengths, self.MULTIPART_ARITY)
descr = '+'.join(map(str, lengths)) if lengths else 'empty (no update)'
type_is_pair = not function.startswith('verif')
deterministic = True