mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 16:55:08 +08:00
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Signed-off-by: Gal Hubara-Agam <96368689+galagam@users.noreply.github.com> Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Co-authored-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
parent
fedd7178d1
commit
1c065fbb3e
@ -243,12 +243,6 @@ class _FlashInferMLAPlanner:
|
||||
plan_params: Parameters for planning.
|
||||
"""
|
||||
# Decode qo_indptr: [0, 1, 2, ..., batch_size] (1 token per sequence)
|
||||
batch_size = kv_page_indptr.shape[0] - 1
|
||||
qo_indptr = torch.arange(batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32)
|
||||
|
||||
# Compute kv_len_arr for CUDA graph wrapper initialization
|
||||
num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1]
|
||||
kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len
|
||||
|
||||
# we want to plan during warm-up of cuda graph capture to ensure we have the plan cached
|
||||
if (
|
||||
@ -257,6 +251,14 @@ class _FlashInferMLAPlanner:
|
||||
):
|
||||
# During CUDA graph capture, the metadata tensors provided by auto-deploy are stable.
|
||||
# Pass the buffer tensors to the wrapper for use_cuda_graph=True
|
||||
batch_size = kv_page_indptr.shape[0] - 1
|
||||
qo_indptr = torch.arange(
|
||||
batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
# Compute kv_len_arr for CUDA graph wrapper initialization
|
||||
num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1]
|
||||
kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len
|
||||
wrapper = self._init_decode_wrapper(
|
||||
use_cuda_graph=True,
|
||||
qo_indptr=qo_indptr,
|
||||
@ -276,6 +278,11 @@ class _FlashInferMLAPlanner:
|
||||
|
||||
# Re-plan if plan_params changed
|
||||
if plan_params != self.plan_params_decode:
|
||||
batch_size = kv_page_indptr.shape[0] - 1
|
||||
qo_indptr = torch.arange(
|
||||
batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
self._plan_mla_wrapper(
|
||||
self.decode_wrapper,
|
||||
qo_indptr,
|
||||
|
||||
@ -50,6 +50,7 @@ from ...utils.node_utils import (
|
||||
is_any_moe_op,
|
||||
is_any_ssm_op,
|
||||
is_op,
|
||||
is_weight_node,
|
||||
num_users_of_weight_node,
|
||||
shape,
|
||||
subgraph,
|
||||
@ -1855,6 +1856,7 @@ def _process_mla_sharding(
|
||||
- q_a_proj: # gather (simple shard, output is replicated)
|
||||
- q_b_proj: # column-sharding (output is head-distributed)
|
||||
- kv_a_proj # gather (simple shard, output is replicated)
|
||||
# This one is actually absorbed by the MLA kernel.
|
||||
- kv_b_proj # column-sharding (output is head-distributed)
|
||||
- o_proj # row-sharding + all-reduce
|
||||
|
||||
@ -1869,8 +1871,41 @@ def _process_mla_sharding(
|
||||
q_a_proj, kv_a_proj = layer_subgraph.opening_nodes
|
||||
# extract q_b_proj and kv_b_proj nodes
|
||||
lin_nodes = list(filtered_nodes(layer_subgraph.subgraph_nodes, is_any_lin_op))
|
||||
assert len(lin_nodes) == 2, "Expecting exactly two linear nodes in the interior of the subgraph"
|
||||
q_b_proj, kv_b_proj = lin_nodes
|
||||
assert len(lin_nodes) <= 2, (
|
||||
"Expecting at most two linear nodes in the interior of the MLA layer"
|
||||
)
|
||||
|
||||
if len(lin_nodes) == 1:
|
||||
# we don't have explicit kv_b projection. Instead, it is
|
||||
# absorbed by the MLA kernel.
|
||||
mla_node = list(
|
||||
filtered_nodes(layer_subgraph.subgraph_nodes, ops=torch.ops.auto_deploy.torch_mla)
|
||||
)
|
||||
assert len(mla_node) == 1, "Expecting exactly one MLA node"
|
||||
mla_node = mla_node[0]
|
||||
# torch_mla args:
|
||||
# q_nope: [B, S, N, qk_nope_head_dim]
|
||||
# q_pe: [B, S, N, qk_rope_head_dim]
|
||||
# compressed_kv: [B, S, kv_lora_rank]
|
||||
# kpe: [B, S, 1, qk_rope_head_dim]
|
||||
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
# is_causal: bool
|
||||
# scale: float
|
||||
# layout: str
|
||||
|
||||
# we need to shard 4th argument: kv_b_proj_weight
|
||||
assert is_weight_node(mla_node.args[4]), "Expecting weight node for kv_b_proj_weight"
|
||||
# column-shard it
|
||||
transform_container.add(
|
||||
WeightShardingInfo.from_node(
|
||||
mla_node.args[4],
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
config=transform_container.config,
|
||||
dist_op=None,
|
||||
min_local_shape=1,
|
||||
layer_type=LayerType.MLA,
|
||||
)
|
||||
)
|
||||
|
||||
# extract o_proj node
|
||||
o_proj = layer_subgraph.terminating_node
|
||||
@ -1885,21 +1920,18 @@ def _process_mla_sharding(
|
||||
|
||||
# extract the sub-subgraph from q_b_proj and kv_b_proj to o_proj
|
||||
sub_subgraph = subgraph(
|
||||
sources=[q_b_proj, kv_b_proj],
|
||||
sources=lin_nodes,
|
||||
boundary_condition=is_any_lin_op,
|
||||
)
|
||||
attention_subgraph = LayerSubgraph(
|
||||
opening_nodes=[q_b_proj, kv_b_proj],
|
||||
opening_nodes=lin_nodes,
|
||||
subgraph_nodes=sub_subgraph,
|
||||
terminating_node=o_proj,
|
||||
layer_type=LayerType.MLA,
|
||||
min_local_shape=layer_subgraph.min_local_shape,
|
||||
)
|
||||
# shard q_b_proj and kv_b_proj nodes
|
||||
num_column_row_shards = _process_column_sharding(attention_subgraph, transform_container)
|
||||
if num_column_row_shards < 2:
|
||||
# it means that "someone else" already sharded these nodes. Skipping.
|
||||
return 0
|
||||
_process_column_sharding(attention_subgraph, transform_container)
|
||||
|
||||
# update "empty" and "expand" nodes' args. Reference in modeling_deepseek.py:
|
||||
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
@ -1926,6 +1958,27 @@ def _process_mla_sharding(
|
||||
)
|
||||
)
|
||||
|
||||
# update reshape nodes' args. Reference in modeling_deepseek.py:
|
||||
# Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim]
|
||||
# attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
# attn_output = self.o_proj(attn_output)
|
||||
candidate_reshape = layer_subgraph.terminating_node.args[0]
|
||||
if is_op(candidate_reshape, [torch.ops.aten.reshape]):
|
||||
# reshape args are (attn_output, [bsz, q_len, num_heads * v_head_dim])
|
||||
# set 3rd arg (num_heads * v_head_dim) to -1
|
||||
reshape_args = list(candidate_reshape.args)
|
||||
reshape_sizes = list(reshape_args[1])
|
||||
reshape_sizes[2] = -1
|
||||
reshape_args[1] = tuple(reshape_sizes)
|
||||
transform_container.add(
|
||||
ParameterUpdateInfo(
|
||||
target_node=candidate_reshape.name,
|
||||
config=transform_container.config,
|
||||
args=tuple(reshape_args),
|
||||
)
|
||||
)
|
||||
ad_logger.debug(f"\nUpdated reshape node {candidate_reshape} arguments to {reshape_args}")
|
||||
|
||||
# shard o_proj node
|
||||
transform_container.add(
|
||||
WeightShardingInfo.from_node(
|
||||
|
||||
@ -321,6 +321,32 @@ def extract_weight_nodes(node: Node) -> WeightNodes:
|
||||
],
|
||||
biases=[],
|
||||
)
|
||||
elif is_weight_node(node):
|
||||
weights = []
|
||||
biases = []
|
||||
|
||||
if node.target.endswith("bias"):
|
||||
biases = [
|
||||
WeightNode(
|
||||
node=node,
|
||||
node_key=node.target,
|
||||
tensor=get_param_or_buffer(node.target, gm),
|
||||
submod=gm.get_submodule(node.target.rpartition(".")[0]),
|
||||
)
|
||||
]
|
||||
else:
|
||||
weights = [
|
||||
WeightNode(
|
||||
node=node,
|
||||
node_key=node.target,
|
||||
tensor=get_param_or_buffer(node.target, gm),
|
||||
submod=gm.get_submodule(node.target.rpartition(".")[0]),
|
||||
)
|
||||
]
|
||||
return WeightNodes(
|
||||
weights=weights,
|
||||
biases=biases,
|
||||
)
|
||||
# for other parametrized nodes, we need to find the weight node
|
||||
else:
|
||||
all_weight_nodes = [
|
||||
@ -524,6 +550,16 @@ def is_any_attention_op(node: Node) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_any_mla_op(node: Node) -> bool:
|
||||
"""Check if the node is a mla op."""
|
||||
return is_op(
|
||||
node,
|
||||
ops=[
|
||||
torch.ops.auto_deploy.torch_mla,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def is_linear_op(node: Node) -> bool:
|
||||
"""Check if the node is a linear op.
|
||||
|
||||
@ -1100,6 +1136,7 @@ def get_layer_after_linear_node(
|
||||
]
|
||||
ssm_nodes = list(filtered_nodes(interior_nodes, is_any_ssm_op))
|
||||
attention_nodes = list(filtered_nodes(interior_nodes, is_any_attention_op))
|
||||
mla_nodes = list(filtered_nodes(interior_nodes, is_any_mla_op))
|
||||
intermediate_lin_nodes = list(filtered_nodes(interior_nodes, is_any_lin_op))
|
||||
intermediate_weight_nodes = list(
|
||||
filtered_nodes(
|
||||
@ -1112,23 +1149,15 @@ def get_layer_after_linear_node(
|
||||
####################################################
|
||||
|
||||
def classify_layer_type() -> [LayerType, int]:
|
||||
if len(ssm_nodes) + len(attention_nodes) > 1:
|
||||
if len(ssm_nodes) + len(attention_nodes) + len(mla_nodes) > 1:
|
||||
# ambiguous layer type
|
||||
return LayerType.UNKNOWN, 1
|
||||
|
||||
if len(attention_nodes) == 1:
|
||||
head_size = shape(attention_nodes[0])[-1]
|
||||
# check if this is MLA:
|
||||
# these two intermediate linear nodes are the latent q and kv projections.
|
||||
if len(intermediate_lin_nodes) == 2:
|
||||
# MLA has a RMS norm inside, so it should have one (or two, couning biaas)
|
||||
# intermediate weight nodes
|
||||
if len(intermediate_weight_nodes) not in [1, 2]:
|
||||
return LayerType.UNKNOWN, 1
|
||||
return LayerType.MLA, head_size
|
||||
else:
|
||||
if len(intermediate_lin_nodes) != 0:
|
||||
return LayerType.UNKNOWN, 1
|
||||
return LayerType.ATTENTION, head_size
|
||||
if len(intermediate_lin_nodes) > 0:
|
||||
return LayerType.UNKNOWN, 1
|
||||
return LayerType.ATTENTION, head_size
|
||||
|
||||
if len(ssm_nodes) == 1:
|
||||
head_size = shape(ssm_nodes[0])[-1]
|
||||
@ -1146,6 +1175,16 @@ def get_layer_after_linear_node(
|
||||
return LayerType.UNKNOWN, 1
|
||||
return LayerType.SSM, head_size
|
||||
|
||||
if len(mla_nodes) == 1:
|
||||
head_size = shape(mla_nodes[0])[-1]
|
||||
# MLA should have two intermediate linear nodes:
|
||||
# kv_b_proj and q_b_proj, but:
|
||||
# - kv_b_proj may be absorbed by the MLA op
|
||||
# - q_b_proj is skipped if q_lora_rank is None
|
||||
if len(intermediate_lin_nodes) > 2:
|
||||
return LayerType.UNKNOWN, 1
|
||||
return LayerType.MLA, head_size
|
||||
|
||||
# if we reach here, it means the layer is a MLP.
|
||||
# MLP should not have any intermediate linear or weight nodes.
|
||||
if len(intermediate_lin_nodes) > 0 or len(intermediate_weight_nodes) > 0:
|
||||
|
||||
@ -324,6 +324,11 @@ zai-org/GLM-4.5-Air:
|
||||
spec_dec_algo: MTP
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 88.2
|
||||
GLM-4.7-Flash:
|
||||
- accuracy: 83.434
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 81.046
|
||||
bigcode/starcoder2-3b:
|
||||
- accuracy: 20.2
|
||||
bigcode/starcoder2-7b:
|
||||
|
||||
@ -397,3 +397,8 @@ nvidia/NVIDIA-Nemotron-3-Super-120B-012726:
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 86.12
|
||||
GLM-4.7-Flash:
|
||||
- accuracy: 73.562
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 70.955
|
||||
|
||||
@ -418,20 +418,12 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
||||
|
||||
|
||||
class TestGLM4Flash(LlmapiAccuracyTestHarness):
|
||||
"""Accuracy regression tests for GLM-4.7-Flash.
|
||||
"""Accuracy regression tests for GLM-4.7-Flash variants"""
|
||||
|
||||
TODO: enable in CI, see https://github.com/NVIDIA/TensorRT-LLM/issues/11117
|
||||
MODEL_NAME = "GLM-4.7-Flash"
|
||||
MODEL_PATH_BF16 = hf_id_to_local_model_dir("zai-org/GLM-4.7-Flash")
|
||||
MODEL_PATH_NVFP4 = hf_id_to_local_model_dir("DeepInfra/GLM-4.7-Flash-NVFP4")
|
||||
|
||||
In the meantime, you should run this test locally:
|
||||
|
||||
```
|
||||
cd tests/integration/defs
|
||||
TRTLLM_ACCURACY_NO_REFERENCE=1 pytest -svv "accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[True]"
|
||||
```
|
||||
"""
|
||||
|
||||
MODEL_NAME = "zai-org/GLM-4.7-Flash"
|
||||
MODEL_PATH = MODEL_NAME # Model is in HF_CACHE
|
||||
# Set minimum possible seq len + small buffer, for test speed & memory usage
|
||||
MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN,
|
||||
GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN)
|
||||
@ -488,10 +480,27 @@ class TestGLM4Flash(LlmapiAccuracyTestHarness):
|
||||
def test_auto_dtype(self, enable_chunked_prefill):
|
||||
kwargs = self.get_default_kwargs(enable_chunked_prefill)
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH,
|
||||
tokenizer=self.MODEL_PATH,
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_BF16,
|
||||
tokenizer=self.MODEL_PATH_BF16,
|
||||
**kwargs) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
|
||||
def test_nvfp4(self, enable_chunked_prefill):
|
||||
kwargs = self.get_default_kwargs(enable_chunked_prefill)
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_NVFP4,
|
||||
tokenizer=self.MODEL_PATH_NVFP4,
|
||||
**kwargs) as llm:
|
||||
# Manually set quant_config for NVFP4 model to get the accuracy threshold
|
||||
llm.args.quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -81,13 +81,15 @@ def make_eagle3_config(spec_model_path: str):
|
||||
)
|
||||
|
||||
|
||||
def run_with_autodeploy(model, speculative_config, batch_size):
|
||||
def run_with_autodeploy(model, speculative_config, batch_size, transforms_override=None):
|
||||
"""Run AutoDeploy with or without speculative decoding.
|
||||
|
||||
Args:
|
||||
model: Path to the base model
|
||||
speculative_config: Speculative decoding config (None for baseline mode)
|
||||
batch_size: Number of prompts to process
|
||||
transforms_override: Optional dict of transform config overrides to merge
|
||||
into the default transforms config (e.g. {"fuse_add_rms_norm": {"enabled": False}})
|
||||
|
||||
Returns:
|
||||
List of (prompt, output) tuples from prompts_and_outputs
|
||||
@ -112,6 +114,10 @@ def run_with_autodeploy(model, speculative_config, batch_size):
|
||||
"max_num_tokens": 64,
|
||||
}
|
||||
|
||||
# Apply any transform config overrides
|
||||
if transforms_override:
|
||||
llm_args["transforms"] = transforms_override
|
||||
|
||||
# Configure experiment with prompts
|
||||
experiment_config = {
|
||||
"args": llm_args,
|
||||
@ -169,18 +175,34 @@ def test_autodeploy_spec_dec_output(spec_dec_mode):
|
||||
print(f"\nBase Model: {base_model}")
|
||||
print(f"Speculative Model ({spec_dec_mode}): {spec_model}")
|
||||
|
||||
# For eagle3, disable fuse_add_rms_norm for both runs to ensure numerical
|
||||
# parity. The eagle3 hidden state capture replaces aten.add with
|
||||
# residual_add_for_capture at certain layers, which prevents
|
||||
# fuse_add_rms_norm from fusing those layers. This causes the target
|
||||
# model to produce slightly different logits vs baseline (which fuses
|
||||
# all layers), leading to greedy-divergent outputs on some hardware.
|
||||
transforms_override = None
|
||||
if spec_dec_mode == "eagle3":
|
||||
transforms_override = {"fuse_add_rms_norm": {"enabled": False}}
|
||||
|
||||
# Run with speculative decoding
|
||||
print("\n[1/2] Running with speculative decoding enabled...")
|
||||
spec_outputs = run_with_autodeploy(
|
||||
model=base_model,
|
||||
speculative_config=spec_config,
|
||||
batch_size=1,
|
||||
transforms_override=transforms_override,
|
||||
)
|
||||
print(f"Generated {len(spec_outputs)} outputs with speculative decoding")
|
||||
|
||||
# Run without speculative decoding (baseline)
|
||||
print("\n[2/2] Running without speculative decoding (baseline)...")
|
||||
baseline_outputs = run_with_autodeploy(model=base_model, speculative_config=None, batch_size=1)
|
||||
baseline_outputs = run_with_autodeploy(
|
||||
model=base_model,
|
||||
speculative_config=None,
|
||||
batch_size=1,
|
||||
transforms_override=transforms_override,
|
||||
)
|
||||
print(f"Generated {len(baseline_outputs)} outputs in baseline mode")
|
||||
|
||||
# Verify outputs are identical
|
||||
|
||||
@ -191,4 +191,6 @@ l0_b200:
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[flashinfer-False-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[torch-True-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[False]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_nvfp4[False]
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
|
||||
@ -445,6 +445,8 @@ l0_h100:
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-True]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[False]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[True]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate
|
||||
|
||||
@ -43,6 +43,8 @@ HF_ID_TO_LLM_MODELS_SUBDIR = {
|
||||
"nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3": "NVIDIA-Nemotron-Nano-31B-A3-v3",
|
||||
"nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024": "Nemotron-Nano-3-30B-A3.5B-dev-1024",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": "EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
"zai-org/GLM-4.7-Flash": "GLM-4.7-Flash",
|
||||
"DeepInfra/GLM-4.7-Flash-NVFP4": "GLM-4.7-Flash-NVFP4",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import (
|
||||
WeightShardingInfo,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op, is_weight_node
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
|
||||
base_model_tp_plan = {
|
||||
@ -139,6 +139,7 @@ class MLA_Block(nn.Module):
|
||||
|
||||
Based on DeepSeek MLA architecture with KV compression.
|
||||
This is a minimal, self-contained implementation for testing sharding patterns.
|
||||
Based on models/custom/modeling_deepseek.py:DeepSeekV3Attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -163,41 +164,62 @@ class MLA_Block(nn.Module):
|
||||
|
||||
# KV compression path (not sharded - gather)
|
||||
self.kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + qk_rope_head_dim, bias=bias)
|
||||
self.kv_a_layernorm = nn.LayerNorm(kv_lora_rank)
|
||||
|
||||
# KV decompression (sharded column-wise)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim), bias=False
|
||||
# KV decompression weight (absorbed into torch_mla, sharded column-wise)
|
||||
# NOTE: This is nn.Parameter, not nn.Linear - the weight is passed directly to torch_mla
|
||||
self.kv_b_proj_weight = nn.Parameter(
|
||||
torch.randn(num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank)
|
||||
)
|
||||
|
||||
# Query path (sharded column-wise)
|
||||
self.q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=bias)
|
||||
self.q_b_proj = nn.Linear(q_lora_rank, num_heads * self.qk_head_dim, bias=bias)
|
||||
self.q_a_layernorm = nn.LayerNorm(q_lora_rank)
|
||||
|
||||
# Output projection (sharded row-wise)
|
||||
self.o_proj = nn.Linear(num_heads * v_head_dim, hidden_size, bias=bias)
|
||||
|
||||
# Softmax scale
|
||||
self.softmax_scale = self.qk_head_dim ** (-0.5)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, s, _ = x.shape
|
||||
|
||||
# Compress KV to latent
|
||||
compressed_kv_rope = self.kv_a_proj_with_mqa(x) # (b, s, kv_lora_rank + rope_dim)
|
||||
compressed_kv = compressed_kv_rope[:, :, : self.kv_lora_rank] # (b, s, kv_lora_rank)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv_rope, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Decompress to full K and V
|
||||
kv = self.kv_b_proj(compressed_kv) # (b, s, num_heads * (qk_nope + v))
|
||||
k_nope_v = kv.view(b, s, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope = k_nope_v[:, :, :, : self.qk_nope_head_dim]
|
||||
v = k_nope_v[:, :, :, self.qk_nope_head_dim :]
|
||||
# Apply layernorm to compressed KV
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv) # (b, s, kv_lora_rank)
|
||||
|
||||
# Query projection
|
||||
# q = q_b_proj @ (layernorm(q_a_proj @ x))
|
||||
# k_pe: shared across heads for simplified version
|
||||
k_pe = k_pe.view(b, s, 1, self.qk_rope_head_dim) # (b, s, 1, qk_rope_head_dim)
|
||||
|
||||
# Query projection: q = q_b_proj @ (layernorm(q_a_proj @ x))
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) # (b, s, num_heads * qk_head_dim)
|
||||
q = q.view(b, s, self.num_heads, self.qk_head_dim)
|
||||
q_nope = q[:, :, :, : self.qk_nope_head_dim]
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# Call MLA kernel with compressed KV and kv_b_proj weight absorbed
|
||||
# NOTE: kv_b_proj weight is passed directly to the kernel instead of calling Linear
|
||||
attn_out = torch.ops.auto_deploy.torch_mla(
|
||||
q_nope, # [B, S, N, qk_nope_head_dim]
|
||||
q_pe, # [B, S, N, qk_rope_head_dim]
|
||||
compressed_kv, # [B, S, kv_lora_rank]
|
||||
k_pe, # [B, S, 1, qk_rope_head_dim]
|
||||
self.kv_b_proj_weight, # [N*(qk_nope+v), kv_lora_rank] - absorbed weight
|
||||
True, # is_causal
|
||||
self.softmax_scale, # softmax_scale
|
||||
"bsnd", # layout
|
||||
)
|
||||
|
||||
# Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim]
|
||||
attn_out = attn_out.reshape(b, s, self.num_heads * self.v_head_dim)
|
||||
|
||||
attn_out = torch.ops.auto_deploy.torch_attention(q_nope, k_nope, v, is_causal=True)
|
||||
attn_out = attn_out.contiguous().view(b, s, -1)
|
||||
# Output projection
|
||||
output = self.o_proj(attn_out)
|
||||
return output
|
||||
@ -616,6 +638,16 @@ def _run_pattern_detection_job(
|
||||
layer_type=LayerType.MLA,
|
||||
)
|
||||
)
|
||||
if is_weight_node(node):
|
||||
if "kv_b_proj_weight" in node.name:
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
config=config,
|
||||
layer_type=LayerType.MLA,
|
||||
)
|
||||
)
|
||||
|
||||
# get detected transformations
|
||||
optimizer = InferenceOptimizer(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user