mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Cherry-pick conflict changes for PR 7999 PR 8515 (#9446)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
parent
d8b5aeb061
commit
7e4cef9def
@ -411,7 +411,7 @@ TRTLLM_NAMESPACE_END
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("fp8_block_scaling_gemm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale) -> Tensor");
|
||||
m.def("fp8_block_scaling_gemm_impl(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale) -> Tensor");
|
||||
m.def(
|
||||
"fp8_block_scaling_bmm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale, ScalarType? "
|
||||
"out_dtype=None) -> Tensor");
|
||||
@ -425,7 +425,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("fp8_block_scaling_gemm", &tensorrt_llm::torch_ext::fp8_block_scaling_gemm);
|
||||
m.impl("fp8_block_scaling_gemm_impl", &tensorrt_llm::torch_ext::fp8_block_scaling_gemm);
|
||||
m.impl("fp8_block_scaling_bmm", &tensorrt_llm::torch_ext::fp8_block_scaling_bmm);
|
||||
m.impl("fp8_block_scaling_bmm_out", &tensorrt_llm::torch_ext::fp8_block_scaling_bmm_out);
|
||||
m.impl("fp8_block_scaling_moe_gemm", &tensorrt_llm::torch_ext::fp8_block_scaling_moe_gemm);
|
||||
|
||||
@ -201,7 +201,7 @@ def _register_fake():
|
||||
def _(input, force_applying_finalize):
|
||||
return torch.empty_like(input)
|
||||
|
||||
@torch.library.register_fake("trtllm::fp8_block_scaling_gemm")
|
||||
@torch.library.register_fake("trtllm::fp8_block_scaling_gemm_impl")
|
||||
def _(a, b, a_scale, b_scale):
|
||||
m = a.shape[0]
|
||||
n = b.shape[0]
|
||||
|
||||
@ -1441,7 +1441,7 @@ def _(
|
||||
return input.new_empty((M, N), dtype=output_dtype)
|
||||
|
||||
|
||||
def fp8_swap_ab_gen_tuning_buckets(x: int):
|
||||
def deep_gemm_gen_tuning_buckets(x: int):
|
||||
buckets = tuple(range(8, 128, 8))
|
||||
if x >= 128:
|
||||
buckets += tuple(range(128, x, 128))
|
||||
@ -1451,7 +1451,7 @@ def fp8_swap_ab_gen_tuning_buckets(x: int):
|
||||
class fp8SwapABGemmRunner(TunableRunner):
|
||||
tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, fp8_swap_ab_gen_tuning_buckets), ),
|
||||
0, 0, deep_gemm_gen_tuning_buckets), ),
|
||||
tune_max_num_tokens=4096,
|
||||
)
|
||||
|
||||
@ -1536,6 +1536,78 @@ def _(
|
||||
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)
|
||||
|
||||
|
||||
# The runner is used to trigger deepgemm jit during autotune.
|
||||
class Fp8BlockScalingGemmRunner(TunableRunner):
|
||||
tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, deep_gemm_gen_tuning_buckets), ),
|
||||
tune_max_num_tokens=4096,
|
||||
)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
return [0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic: int = -1,
|
||||
) -> torch.Tensor:
|
||||
a, b, a_scale, b_scale = inputs
|
||||
return torch.ops.trtllm.fp8_block_scaling_gemm_impl(
|
||||
a, b, a_scale, b_scale)
|
||||
|
||||
|
||||
def get_fp8_block_scaling_gemm_constraint_spec():
|
||||
# The implementation aligns with the fp8_quantize_1x128 custom op.
|
||||
def fp8_quantize_1x128_sm90_constrant(inputs: List[List[int]]):
|
||||
pad_m = fp4_utils.pad_up(inputs[0][0], 4)
|
||||
blocked_n = (inputs[0][1] + 127) // 128
|
||||
return fp4_utils.pad_up(pad_m * blocked_n * 4, 128) // 4
|
||||
|
||||
if get_sm_version() >= 100:
|
||||
return (ConstraintSpec(2, 1, lambda inputs: inputs[0][0]), )
|
||||
else:
|
||||
return (ConstraintSpec(2, 0, fp8_quantize_1x128_sm90_constrant), )
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::fp8_block_scaling_gemm", mutates_args=())
|
||||
def fp8_block_scaling_gemm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
b_scale: torch.Tensor,
|
||||
tune_max_num_tokens: int = 4096,
|
||||
) -> torch.Tensor:
|
||||
tuner = AutoTuner.get()
|
||||
fp8_block_scaling_gemm_runner = Fp8BlockScalingGemmRunner()
|
||||
Fp8BlockScalingGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
|
||||
|
||||
Fp8BlockScalingGemmRunner.tuning_config.constraint_specs = get_fp8_block_scaling_gemm_constraint_spec(
|
||||
)
|
||||
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::fp8_block_scaling_gemm",
|
||||
[fp8_block_scaling_gemm_runner],
|
||||
Fp8BlockScalingGemmRunner.tuning_config,
|
||||
[a, b, a_scale, b_scale],
|
||||
)
|
||||
return fp8_block_scaling_gemm_runner(
|
||||
inputs=[a, b, a_scale, b_scale],
|
||||
tactic=best_tactic,
|
||||
)
|
||||
|
||||
|
||||
@fp8_block_scaling_gemm.register_fake
|
||||
def _(a, b, a_scale, b_scale, tune_max_num_tokens=4096):
|
||||
m = a.shape[0]
|
||||
n = b.shape[0]
|
||||
return a.new_empty((m, n), dtype=torch.bfloat16)
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::silu_and_mul", mutates_args=())
|
||||
def silu_and_mul(x: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
|
||||
@ -155,9 +155,17 @@ class KvCacheCreator:
|
||||
dummy_mm_prompt = input_processor.get_dummy_prompt(input_seq_len)
|
||||
|
||||
if dummy_mm_prompt is not None:
|
||||
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor(
|
||||
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor_with_hash(
|
||||
dummy_mm_prompt, sampling_params=None)
|
||||
|
||||
multimodal_input = extra_processed_inputs.get(
|
||||
'multimodal_input')
|
||||
multimodal_data = extra_processed_inputs.get('multimodal_data')
|
||||
req_mm_input = trtllm.MultimodalInput(
|
||||
multimodal_hashes=multimodal_input.multimodal_hashes,
|
||||
multimodal_positions=multimodal_input.multimodal_positions,
|
||||
multimodal_lengths=multimodal_input.multimodal_lengths
|
||||
) if multimodal_input else None
|
||||
|
||||
request = trtllm.Request(prompt_token_ids,
|
||||
max_tokens=1,
|
||||
@ -165,7 +173,8 @@ class KvCacheCreator:
|
||||
sampling_config=trtllm.SamplingConfig(
|
||||
beam_width=max_beam_width, ),
|
||||
output_config=trtllm.OutputConfig(),
|
||||
end_id=-1)
|
||||
end_id=-1,
|
||||
multimodal_input=req_mm_input)
|
||||
request.py_multimodal_data = multimodal_data
|
||||
else:
|
||||
# Fall back to text-only prompt when we could not find the small image size.
|
||||
@ -266,9 +275,29 @@ class KvCacheCreator:
|
||||
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.
|
||||
num_cache_blocks += (num_req_tokens + self._tokens_per_block -
|
||||
1) // self._tokens_per_block
|
||||
|
||||
# Max cuda graph warmup required tokens
|
||||
max_cuda_graph_bs = min(self._model_engine.batch_size,
|
||||
self._model_engine._max_cuda_graph_batch_size)
|
||||
cuda_graph_warmup_block = (
|
||||
self._model_engine.max_seq_len +
|
||||
1) // self._tokens_per_block + max_cuda_graph_bs - 1
|
||||
num_cache_blocks = max(cuda_graph_warmup_block, num_cache_blocks)
|
||||
|
||||
# This is the minimal blocks required to run with max bs
|
||||
# If not able to allocate self._model_engine.batch_size blocks, the max batch size should be adjusted.
|
||||
num_cache_blocks = max(num_cache_blocks, self._model_engine.batch_size)
|
||||
|
||||
free_mem, total_mem = torch.cuda.mem_get_info()
|
||||
max_memory = self._kv_cache_config.free_gpu_memory_fraction * free_mem
|
||||
max_num_tokens_in_memory = max_memory // self._get_kv_size_per_token(
|
||||
) // self._tokens_per_block * self._tokens_per_block
|
||||
|
||||
# Multiply by beam width, to prevent rescaling of the max_seq_len caused by the influence of beam width during the preparation for kv_cache_estimation
|
||||
return num_cache_blocks * self._tokens_per_block * self._dummy_reqs[
|
||||
0].sampling_config.beam_width
|
||||
return min(
|
||||
num_cache_blocks * self._tokens_per_block *
|
||||
self._dummy_reqs[0].sampling_config.beam_width,
|
||||
max_num_tokens_in_memory)
|
||||
|
||||
def try_prepare_estimation(self) -> bool:
|
||||
"""Prepare for possible KV cache capacity estimation.
|
||||
@ -279,8 +308,10 @@ class KvCacheCreator:
|
||||
estimating_kv_cache = False
|
||||
if 'cp_type' not in self._mapping.cp_config:
|
||||
estimating_kv_cache = True
|
||||
self._kv_cache_config.max_tokens = self._get_token_num_for_estimation(
|
||||
)
|
||||
estimate_max_tokens = self._get_token_num_for_estimation()
|
||||
self._kv_cache_config.max_tokens = min(
|
||||
estimate_max_tokens, self._kv_cache_config.max_tokens
|
||||
) if self._kv_cache_config.max_tokens is not None else estimate_max_tokens
|
||||
model_config = self._model_engine.model.model_config
|
||||
if model_config.attn_backend == "VANILLA":
|
||||
logger.info(
|
||||
|
||||
@ -602,43 +602,54 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# Set the value back to the original value after all warmups are complete
|
||||
self.enable_spec_decode = self.is_spec_decode
|
||||
|
||||
def _run_torch_compile_warmup(self, resource_manager: ResourceManager):
|
||||
"""Runs warmup iterations to specialize torch.compile kernels."""
|
||||
if not self._torch_compile_enabled:
|
||||
return
|
||||
|
||||
logger.info("Running torch.compile warmup...")
|
||||
def _general_warmup(self,
|
||||
resource_manager: ResourceManager,
|
||||
reverse: bool = False):
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
curr_max_num_tokens = min(
|
||||
kv_cache_manager.get_num_available_tokens(
|
||||
self.original_max_draft_len), self.max_num_tokens,
|
||||
self.batch_size * (self.max_seq_len - 1))
|
||||
max_batch_size = min(
|
||||
self.batch_size,
|
||||
curr_max_num_tokens // (1 + self.runtime_draft_len))
|
||||
|
||||
warmup_requests_configs = {
|
||||
(1, 1), # Specialize for 1 token.
|
||||
(self.batch_size,
|
||||
self.batch_size), # max_batch_size, pure generation
|
||||
(max_batch_size, max_batch_size), # max_batch_size, pure generation
|
||||
(2, 0), # Non-one, pure context
|
||||
(curr_max_num_tokens, 0), # max_num_tokens, pure context
|
||||
}
|
||||
|
||||
warmup_requests_configs = sorted(list(warmup_requests_configs),
|
||||
reverse=reverse)
|
||||
|
||||
for num_tokens, num_gen_tokens in warmup_requests_configs:
|
||||
with self._release_batch_context(
|
||||
self._create_warmup_request(resource_manager, num_tokens,
|
||||
num_gen_tokens),
|
||||
resource_manager) as batch:
|
||||
if batch is None:
|
||||
continue # Not enough KV cache space
|
||||
logger.info(
|
||||
f"Run warmup with {num_tokens} tokens, include {num_gen_tokens} generation tokens"
|
||||
)
|
||||
self.forward(batch,
|
||||
new_tensors_device=None,
|
||||
resource_manager=resource_manager)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def _run_torch_compile_warmup(self, resource_manager: ResourceManager):
|
||||
"""Runs warmup iterations to specialize torch.compile kernels."""
|
||||
if not self._torch_compile_enabled:
|
||||
return
|
||||
|
||||
logger.info("Running torch.compile warmup...")
|
||||
|
||||
# Disable cuda graph capture here so that we can properly capture it later
|
||||
with self.no_cuda_graph():
|
||||
for num_tokens, num_gen_tokens in warmup_requests_configs:
|
||||
with self._release_batch_context(
|
||||
self._create_warmup_request(resource_manager,
|
||||
num_tokens, num_gen_tokens),
|
||||
resource_manager) as batch:
|
||||
if batch is None:
|
||||
continue # Not enough KV cache space
|
||||
logger.info(
|
||||
f"Run warmup with {num_tokens} tokens, include {num_gen_tokens} generation tokens"
|
||||
)
|
||||
self.forward(batch,
|
||||
new_tensors_device=None,
|
||||
resource_manager=resource_manager)
|
||||
torch.cuda.synchronize()
|
||||
self._general_warmup(resource_manager)
|
||||
|
||||
def _run_autotuner_warmup(self, resource_manager: ResourceManager):
|
||||
"""Runs a forward pass to populate the autotuner cache."""
|
||||
@ -762,7 +773,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
resource_manager) as batch:
|
||||
if batch is None:
|
||||
# No KV cache space, cannot continue capturing graphs
|
||||
return
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Run generation-only CUDA graph warmup for batch size={bs}, draft_len={draft_len}, max_seq_len={max_seq_len}"
|
||||
@ -812,6 +823,27 @@ class PyTorchModelEngine(ModelEngine):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# When using piecewise cuda graph, the logits may suffer severe memory faction problem.
|
||||
# When the num of requests is growing, the block allocated by torch cannot be reused.
|
||||
# So after piecewise cuda graph capture, a request with most requests is triggered to make
|
||||
# sure that large enough blocks are allocated and can be correctly reused.
|
||||
for num_tokens in piecewise_cuda_graph_num_tokens:
|
||||
warmup_request = self._create_warmup_request(resource_manager,
|
||||
num_tokens,
|
||||
0,
|
||||
least_requests=False)
|
||||
with self._release_batch_context(warmup_request,
|
||||
resource_manager) as batch:
|
||||
if batch is None:
|
||||
continue
|
||||
logger.info(
|
||||
f"Run piecewise CUDA graph warmup for num tokens={num_tokens} with most requests"
|
||||
)
|
||||
self.forward(batch,
|
||||
new_tensors_device=None,
|
||||
resource_manager=resource_manager)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
### Helper methods promoted from the original warmup method ###
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -842,8 +874,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
return 0
|
||||
|
||||
def _create_warmup_request(
|
||||
self, resource_manager: ResourceManager, num_tokens: int,
|
||||
num_gen_tokens: int) -> Optional[ScheduledRequests]:
|
||||
self,
|
||||
resource_manager: ResourceManager,
|
||||
num_tokens: int,
|
||||
num_gen_requests: int,
|
||||
least_requests: bool = True) -> Optional[ScheduledRequests]:
|
||||
"""Creates a generic dummy ScheduledRequests object for warmup."""
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
@ -853,6 +888,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
self.runtime_draft_len)
|
||||
available_blocks = kv_cache_manager.get_num_free_blocks()
|
||||
print(
|
||||
f"available_tokens: {available_tokens}, num_tokens: {num_tokens}, num_gen_requests: {num_gen_requests}"
|
||||
)
|
||||
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
|
||||
return None
|
||||
|
||||
@ -860,6 +898,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if num_extra_decoding_steps > 0:
|
||||
return None # Disable autotuning for fused drafting loops for now.
|
||||
|
||||
if num_gen_requests > self.batch_size:
|
||||
return None
|
||||
num_gen_tokens = num_gen_requests * (1 + self.runtime_draft_len)
|
||||
if num_gen_tokens > self.max_num_tokens:
|
||||
return None
|
||||
|
||||
num_ctx_tokens = num_tokens - num_gen_tokens
|
||||
num_ctx_requests = 0
|
||||
ctx_requests = []
|
||||
@ -869,19 +913,33 @@ class PyTorchModelEngine(ModelEngine):
|
||||
num_full_seqs = 0
|
||||
num_left_over_tokens = 0
|
||||
|
||||
max_context_requests = self.batch_size - num_gen_requests
|
||||
if max_context_requests * max_seq_len < num_ctx_tokens:
|
||||
return None
|
||||
|
||||
if num_ctx_tokens > 0:
|
||||
num_full_seqs = num_ctx_tokens // max_seq_len
|
||||
num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
|
||||
if least_requests:
|
||||
num_full_seqs = num_ctx_tokens // max_seq_len
|
||||
num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
|
||||
|
||||
else:
|
||||
max_bs = min(num_ctx_tokens, max_context_requests)
|
||||
if num_ctx_tokens % max_bs == 0:
|
||||
num_full_seqs = max_bs
|
||||
else:
|
||||
num_full_seqs = max_bs - 1
|
||||
max_seq_len = num_ctx_tokens // num_full_seqs
|
||||
num_left_over_tokens = num_ctx_tokens - max_seq_len * num_full_seqs
|
||||
num_ctx_requests = num_full_seqs + (1 if num_left_over_tokens > 0
|
||||
else 0)
|
||||
|
||||
if num_ctx_requests + num_gen_tokens > self.batch_size:
|
||||
if num_ctx_requests + num_gen_requests > self.batch_size:
|
||||
return None # Not enough batch size to fill the request
|
||||
|
||||
blocks_to_use = num_full_seqs * math.ceil(
|
||||
max_seq_len / kv_cache_manager.tokens_per_block) + math.ceil(
|
||||
num_left_over_tokens /
|
||||
kv_cache_manager.tokens_per_block) + num_gen_tokens
|
||||
kv_cache_manager.tokens_per_block) + num_gen_requests
|
||||
|
||||
if blocks_to_use > available_blocks:
|
||||
return None
|
||||
@ -902,17 +960,21 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_resource_manager.add_dummy_requests(
|
||||
request_ids=list(range(num_ctx_requests)))
|
||||
|
||||
if num_gen_tokens > 0:
|
||||
if num_gen_requests > 0:
|
||||
gen_requests = kv_cache_manager.add_dummy_requests(
|
||||
list(range(num_ctx_requests,
|
||||
num_ctx_requests + num_gen_tokens)),
|
||||
token_nums=[1] * num_gen_tokens,
|
||||
list(
|
||||
range(num_ctx_requests,
|
||||
num_ctx_requests + num_gen_requests)),
|
||||
token_nums=[1] * num_gen_requests,
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=self.max_total_draft_tokens,
|
||||
use_mrope=self.use_mrope)
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.add_dummy_requests(request_ids=list(
|
||||
range(num_ctx_requests, num_ctx_requests + num_gen_tokens)))
|
||||
range(num_ctx_requests, num_ctx_requests +
|
||||
num_gen_requests)))
|
||||
|
||||
result = ScheduledRequests()
|
||||
result.context_requests = ctx_requests
|
||||
|
||||
Loading…
Reference in New Issue
Block a user