[None][fix] Cherry-pick conflict changes for PR 7999 PR 8515 (#9446)

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
Jin Li 2025-12-25 23:23:04 +08:00 committed by GitHub
parent d8b5aeb061
commit 7e4cef9def
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 211 additions and 46 deletions

View File

@ -411,7 +411,7 @@ TRTLLM_NAMESPACE_END
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("fp8_block_scaling_gemm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale) -> Tensor");
m.def("fp8_block_scaling_gemm_impl(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale) -> Tensor");
m.def(
"fp8_block_scaling_bmm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale, ScalarType? "
"out_dtype=None) -> Tensor");
@ -425,7 +425,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("fp8_block_scaling_gemm", &tensorrt_llm::torch_ext::fp8_block_scaling_gemm);
m.impl("fp8_block_scaling_gemm_impl", &tensorrt_llm::torch_ext::fp8_block_scaling_gemm);
m.impl("fp8_block_scaling_bmm", &tensorrt_llm::torch_ext::fp8_block_scaling_bmm);
m.impl("fp8_block_scaling_bmm_out", &tensorrt_llm::torch_ext::fp8_block_scaling_bmm_out);
m.impl("fp8_block_scaling_moe_gemm", &tensorrt_llm::torch_ext::fp8_block_scaling_moe_gemm);

View File

@ -201,7 +201,7 @@ def _register_fake():
def _(input, force_applying_finalize):
return torch.empty_like(input)
@torch.library.register_fake("trtllm::fp8_block_scaling_gemm")
@torch.library.register_fake("trtllm::fp8_block_scaling_gemm_impl")
def _(a, b, a_scale, b_scale):
m = a.shape[0]
n = b.shape[0]

View File

@ -1441,7 +1441,7 @@ def _(
return input.new_empty((M, N), dtype=output_dtype)
def fp8_swap_ab_gen_tuning_buckets(x: int):
def deep_gemm_gen_tuning_buckets(x: int):
buckets = tuple(range(8, 128, 8))
if x >= 128:
buckets += tuple(range(128, x, 128))
@ -1451,7 +1451,7 @@ def fp8_swap_ab_gen_tuning_buckets(x: int):
class fp8SwapABGemmRunner(TunableRunner):
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, fp8_swap_ab_gen_tuning_buckets), ),
0, 0, deep_gemm_gen_tuning_buckets), ),
tune_max_num_tokens=4096,
)
@ -1536,6 +1536,78 @@ def _(
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)
# The runner is used to trigger deepgemm jit during autotune.
class Fp8BlockScalingGemmRunner(TunableRunner):
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, deep_gemm_gen_tuning_buckets), ),
tune_max_num_tokens=4096,
)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
return [0]
def forward(
self,
inputs: List[torch.Tensor],
tactic: int = -1,
) -> torch.Tensor:
a, b, a_scale, b_scale = inputs
return torch.ops.trtllm.fp8_block_scaling_gemm_impl(
a, b, a_scale, b_scale)
def get_fp8_block_scaling_gemm_constraint_spec():
# The implementation aligns with the fp8_quantize_1x128 custom op.
def fp8_quantize_1x128_sm90_constrant(inputs: List[List[int]]):
pad_m = fp4_utils.pad_up(inputs[0][0], 4)
blocked_n = (inputs[0][1] + 127) // 128
return fp4_utils.pad_up(pad_m * blocked_n * 4, 128) // 4
if get_sm_version() >= 100:
return (ConstraintSpec(2, 1, lambda inputs: inputs[0][0]), )
else:
return (ConstraintSpec(2, 0, fp8_quantize_1x128_sm90_constrant), )
@torch.library.custom_op("trtllm::fp8_block_scaling_gemm", mutates_args=())
def fp8_block_scaling_gemm(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
tune_max_num_tokens: int = 4096,
) -> torch.Tensor:
tuner = AutoTuner.get()
fp8_block_scaling_gemm_runner = Fp8BlockScalingGemmRunner()
Fp8BlockScalingGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
Fp8BlockScalingGemmRunner.tuning_config.constraint_specs = get_fp8_block_scaling_gemm_constraint_spec(
)
_, best_tactic = tuner.choose_one(
"trtllm::fp8_block_scaling_gemm",
[fp8_block_scaling_gemm_runner],
Fp8BlockScalingGemmRunner.tuning_config,
[a, b, a_scale, b_scale],
)
return fp8_block_scaling_gemm_runner(
inputs=[a, b, a_scale, b_scale],
tactic=best_tactic,
)
@fp8_block_scaling_gemm.register_fake
def _(a, b, a_scale, b_scale, tune_max_num_tokens=4096):
m = a.shape[0]
n = b.shape[0]
return a.new_empty((m, n), dtype=torch.bfloat16)
@torch.library.custom_op("trtllm::silu_and_mul", mutates_args=())
def silu_and_mul(x: torch.Tensor,
scale: Optional[torch.Tensor] = None,

View File

@ -155,9 +155,17 @@ class KvCacheCreator:
dummy_mm_prompt = input_processor.get_dummy_prompt(input_seq_len)
if dummy_mm_prompt is not None:
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor(
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor_with_hash(
dummy_mm_prompt, sampling_params=None)
multimodal_input = extra_processed_inputs.get(
'multimodal_input')
multimodal_data = extra_processed_inputs.get('multimodal_data')
req_mm_input = trtllm.MultimodalInput(
multimodal_hashes=multimodal_input.multimodal_hashes,
multimodal_positions=multimodal_input.multimodal_positions,
multimodal_lengths=multimodal_input.multimodal_lengths
) if multimodal_input else None
request = trtllm.Request(prompt_token_ids,
max_tokens=1,
@ -165,7 +173,8 @@ class KvCacheCreator:
sampling_config=trtllm.SamplingConfig(
beam_width=max_beam_width, ),
output_config=trtllm.OutputConfig(),
end_id=-1)
end_id=-1,
multimodal_input=req_mm_input)
request.py_multimodal_data = multimodal_data
else:
# Fall back to text-only prompt when we could not find the small image size.
@ -266,9 +275,29 @@ class KvCacheCreator:
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.
num_cache_blocks += (num_req_tokens + self._tokens_per_block -
1) // self._tokens_per_block
# Max cuda graph warmup required tokens
max_cuda_graph_bs = min(self._model_engine.batch_size,
self._model_engine._max_cuda_graph_batch_size)
cuda_graph_warmup_block = (
self._model_engine.max_seq_len +
1) // self._tokens_per_block + max_cuda_graph_bs - 1
num_cache_blocks = max(cuda_graph_warmup_block, num_cache_blocks)
# This is the minimal blocks required to run with max bs
# If not able to allocate self._model_engine.batch_size blocks, the max batch size should be adjusted.
num_cache_blocks = max(num_cache_blocks, self._model_engine.batch_size)
free_mem, total_mem = torch.cuda.mem_get_info()
max_memory = self._kv_cache_config.free_gpu_memory_fraction * free_mem
max_num_tokens_in_memory = max_memory // self._get_kv_size_per_token(
) // self._tokens_per_block * self._tokens_per_block
# Multiply by beam width, to prevent rescaling of the max_seq_len caused by the influence of beam width during the preparation for kv_cache_estimation
return num_cache_blocks * self._tokens_per_block * self._dummy_reqs[
0].sampling_config.beam_width
return min(
num_cache_blocks * self._tokens_per_block *
self._dummy_reqs[0].sampling_config.beam_width,
max_num_tokens_in_memory)
def try_prepare_estimation(self) -> bool:
"""Prepare for possible KV cache capacity estimation.
@ -279,8 +308,10 @@ class KvCacheCreator:
estimating_kv_cache = False
if 'cp_type' not in self._mapping.cp_config:
estimating_kv_cache = True
self._kv_cache_config.max_tokens = self._get_token_num_for_estimation(
)
estimate_max_tokens = self._get_token_num_for_estimation()
self._kv_cache_config.max_tokens = min(
estimate_max_tokens, self._kv_cache_config.max_tokens
) if self._kv_cache_config.max_tokens is not None else estimate_max_tokens
model_config = self._model_engine.model.model_config
if model_config.attn_backend == "VANILLA":
logger.info(

View File

@ -602,43 +602,54 @@ class PyTorchModelEngine(ModelEngine):
# Set the value back to the original value after all warmups are complete
self.enable_spec_decode = self.is_spec_decode
def _run_torch_compile_warmup(self, resource_manager: ResourceManager):
"""Runs warmup iterations to specialize torch.compile kernels."""
if not self._torch_compile_enabled:
return
logger.info("Running torch.compile warmup...")
def _general_warmup(self,
resource_manager: ResourceManager,
reverse: bool = False):
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
curr_max_num_tokens = min(
kv_cache_manager.get_num_available_tokens(
self.original_max_draft_len), self.max_num_tokens,
self.batch_size * (self.max_seq_len - 1))
max_batch_size = min(
self.batch_size,
curr_max_num_tokens // (1 + self.runtime_draft_len))
warmup_requests_configs = {
(1, 1), # Specialize for 1 token.
(self.batch_size,
self.batch_size), # max_batch_size, pure generation
(max_batch_size, max_batch_size), # max_batch_size, pure generation
(2, 0), # Non-one, pure context
(curr_max_num_tokens, 0), # max_num_tokens, pure context
}
warmup_requests_configs = sorted(list(warmup_requests_configs),
reverse=reverse)
for num_tokens, num_gen_tokens in warmup_requests_configs:
with self._release_batch_context(
self._create_warmup_request(resource_manager, num_tokens,
num_gen_tokens),
resource_manager) as batch:
if batch is None:
continue # Not enough KV cache space
logger.info(
f"Run warmup with {num_tokens} tokens, include {num_gen_tokens} generation tokens"
)
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()
def _run_torch_compile_warmup(self, resource_manager: ResourceManager):
"""Runs warmup iterations to specialize torch.compile kernels."""
if not self._torch_compile_enabled:
return
logger.info("Running torch.compile warmup...")
# Disable cuda graph capture here so that we can properly capture it later
with self.no_cuda_graph():
for num_tokens, num_gen_tokens in warmup_requests_configs:
with self._release_batch_context(
self._create_warmup_request(resource_manager,
num_tokens, num_gen_tokens),
resource_manager) as batch:
if batch is None:
continue # Not enough KV cache space
logger.info(
f"Run warmup with {num_tokens} tokens, include {num_gen_tokens} generation tokens"
)
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()
self._general_warmup(resource_manager)
def _run_autotuner_warmup(self, resource_manager: ResourceManager):
"""Runs a forward pass to populate the autotuner cache."""
@ -762,7 +773,7 @@ class PyTorchModelEngine(ModelEngine):
resource_manager) as batch:
if batch is None:
# No KV cache space, cannot continue capturing graphs
return
continue
logger.info(
f"Run generation-only CUDA graph warmup for batch size={bs}, draft_len={draft_len}, max_seq_len={max_seq_len}"
@ -812,6 +823,27 @@ class PyTorchModelEngine(ModelEngine):
gc.collect()
torch.cuda.empty_cache()
# When using piecewise cuda graph, the logits may suffer severe memory faction problem.
# When the num of requests is growing, the block allocated by torch cannot be reused.
# So after piecewise cuda graph capture, a request with most requests is triggered to make
# sure that large enough blocks are allocated and can be correctly reused.
for num_tokens in piecewise_cuda_graph_num_tokens:
warmup_request = self._create_warmup_request(resource_manager,
num_tokens,
0,
least_requests=False)
with self._release_batch_context(warmup_request,
resource_manager) as batch:
if batch is None:
continue
logger.info(
f"Run piecewise CUDA graph warmup for num tokens={num_tokens} with most requests"
)
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()
### Helper methods promoted from the original warmup method ###
@contextlib.contextmanager
@ -842,8 +874,11 @@ class PyTorchModelEngine(ModelEngine):
return 0
def _create_warmup_request(
self, resource_manager: ResourceManager, num_tokens: int,
num_gen_tokens: int) -> Optional[ScheduledRequests]:
self,
resource_manager: ResourceManager,
num_tokens: int,
num_gen_requests: int,
least_requests: bool = True) -> Optional[ScheduledRequests]:
"""Creates a generic dummy ScheduledRequests object for warmup."""
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
@ -853,6 +888,9 @@ class PyTorchModelEngine(ModelEngine):
available_tokens = kv_cache_manager.get_num_available_tokens(
self.runtime_draft_len)
available_blocks = kv_cache_manager.get_num_free_blocks()
print(
f"available_tokens: {available_tokens}, num_tokens: {num_tokens}, num_gen_requests: {num_gen_requests}"
)
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
return None
@ -860,6 +898,12 @@ class PyTorchModelEngine(ModelEngine):
if num_extra_decoding_steps > 0:
return None # Disable autotuning for fused drafting loops for now.
if num_gen_requests > self.batch_size:
return None
num_gen_tokens = num_gen_requests * (1 + self.runtime_draft_len)
if num_gen_tokens > self.max_num_tokens:
return None
num_ctx_tokens = num_tokens - num_gen_tokens
num_ctx_requests = 0
ctx_requests = []
@ -869,19 +913,33 @@ class PyTorchModelEngine(ModelEngine):
num_full_seqs = 0
num_left_over_tokens = 0
max_context_requests = self.batch_size - num_gen_requests
if max_context_requests * max_seq_len < num_ctx_tokens:
return None
if num_ctx_tokens > 0:
num_full_seqs = num_ctx_tokens // max_seq_len
num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
if least_requests:
num_full_seqs = num_ctx_tokens // max_seq_len
num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
else:
max_bs = min(num_ctx_tokens, max_context_requests)
if num_ctx_tokens % max_bs == 0:
num_full_seqs = max_bs
else:
num_full_seqs = max_bs - 1
max_seq_len = num_ctx_tokens // num_full_seqs
num_left_over_tokens = num_ctx_tokens - max_seq_len * num_full_seqs
num_ctx_requests = num_full_seqs + (1 if num_left_over_tokens > 0
else 0)
if num_ctx_requests + num_gen_tokens > self.batch_size:
if num_ctx_requests + num_gen_requests > self.batch_size:
return None # Not enough batch size to fill the request
blocks_to_use = num_full_seqs * math.ceil(
max_seq_len / kv_cache_manager.tokens_per_block) + math.ceil(
num_left_over_tokens /
kv_cache_manager.tokens_per_block) + num_gen_tokens
kv_cache_manager.tokens_per_block) + num_gen_requests
if blocks_to_use > available_blocks:
return None
@ -902,17 +960,21 @@ class PyTorchModelEngine(ModelEngine):
spec_resource_manager.add_dummy_requests(
request_ids=list(range(num_ctx_requests)))
if num_gen_tokens > 0:
if num_gen_requests > 0:
gen_requests = kv_cache_manager.add_dummy_requests(
list(range(num_ctx_requests,
num_ctx_requests + num_gen_tokens)),
token_nums=[1] * num_gen_tokens,
list(
range(num_ctx_requests,
num_ctx_requests + num_gen_requests)),
token_nums=[1] * num_gen_requests,
is_gen=True,
max_num_draft_tokens=self.max_total_draft_tokens,
use_mrope=self.use_mrope)
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps)
if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(request_ids=list(
range(num_ctx_requests, num_ctx_requests + num_gen_tokens)))
range(num_ctx_requests, num_ctx_requests +
num_gen_requests)))
result = ScheduledRequests()
result.context_requests = ctx_requests