from typing import List def next_positive_power_of_2(x: int) -> int: if x < 1: return 1 return 1 << (x - 1).bit_length() def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: max_num_tokens = next_positive_power_of_2(max_num_tokens) num_token_buckets = [] m = 1 while m <= max_num_tokens: num_token_buckets.append(m) m *= 2 return num_token_buckets