diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py index 0171713aae..961135ff9c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py @@ -243,12 +243,6 @@ class _FlashInferMLAPlanner: plan_params: Parameters for planning. """ # Decode qo_indptr: [0, 1, 2, ..., batch_size] (1 token per sequence) - batch_size = kv_page_indptr.shape[0] - 1 - qo_indptr = torch.arange(batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32) - - # Compute kv_len_arr for CUDA graph wrapper initialization - num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1] - kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len # we want to plan during warm-up of cuda graph capture to ensure we have the plan cached if ( @@ -257,6 +251,14 @@ class _FlashInferMLAPlanner: ): # During CUDA graph capture, the metadata tensors provided by auto-deploy are stable. # Pass the buffer tensors to the wrapper for use_cuda_graph=True + batch_size = kv_page_indptr.shape[0] - 1 + qo_indptr = torch.arange( + batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32 + ) + + # Compute kv_len_arr for CUDA graph wrapper initialization + num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1] + kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len wrapper = self._init_decode_wrapper( use_cuda_graph=True, qo_indptr=qo_indptr, @@ -276,6 +278,11 @@ class _FlashInferMLAPlanner: # Re-plan if plan_params changed if plan_params != self.plan_params_decode: + batch_size = kv_page_indptr.shape[0] - 1 + qo_indptr = torch.arange( + batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32 + ) + self._plan_mla_wrapper( self.decode_wrapper, qo_indptr, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 48c8053863..49a478bd19 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -50,6 +50,7 @@ from ...utils.node_utils import ( is_any_moe_op, is_any_ssm_op, is_op, + is_weight_node, num_users_of_weight_node, shape, subgraph, @@ -1855,6 +1856,7 @@ def _process_mla_sharding( - q_a_proj: # gather (simple shard, output is replicated) - q_b_proj: # column-sharding (output is head-distributed) - kv_a_proj # gather (simple shard, output is replicated) + # This one is actually absorbed by the MLA kernel. - kv_b_proj # column-sharding (output is head-distributed) - o_proj # row-sharding + all-reduce @@ -1869,8 +1871,41 @@ def _process_mla_sharding( q_a_proj, kv_a_proj = layer_subgraph.opening_nodes # extract q_b_proj and kv_b_proj nodes lin_nodes = list(filtered_nodes(layer_subgraph.subgraph_nodes, is_any_lin_op)) - assert len(lin_nodes) == 2, "Expecting exactly two linear nodes in the interior of the subgraph" - q_b_proj, kv_b_proj = lin_nodes + assert len(lin_nodes) <= 2, ( + "Expecting at most two linear nodes in the interior of the MLA layer" + ) + + if len(lin_nodes) == 1: + # we don't have explicit kv_b projection. Instead, it is + # absorbed by the MLA kernel. + mla_node = list( + filtered_nodes(layer_subgraph.subgraph_nodes, ops=torch.ops.auto_deploy.torch_mla) + ) + assert len(mla_node) == 1, "Expecting exactly one MLA node" + mla_node = mla_node[0] + # torch_mla args: + # q_nope: [B, S, N, qk_nope_head_dim] + # q_pe: [B, S, N, qk_rope_head_dim] + # compressed_kv: [B, S, kv_lora_rank] + # kpe: [B, S, 1, qk_rope_head_dim] + # kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + # is_causal: bool + # scale: float + # layout: str + + # we need to shard 4th argument: kv_b_proj_weight + assert is_weight_node(mla_node.args[4]), "Expecting weight node for kv_b_proj_weight" + # column-shard it + transform_container.add( + WeightShardingInfo.from_node( + mla_node.args[4], + split_dim=SplitDimension.COLUMN, + config=transform_container.config, + dist_op=None, + min_local_shape=1, + layer_type=LayerType.MLA, + ) + ) # extract o_proj node o_proj = layer_subgraph.terminating_node @@ -1885,21 +1920,18 @@ def _process_mla_sharding( # extract the sub-subgraph from q_b_proj and kv_b_proj to o_proj sub_subgraph = subgraph( - sources=[q_b_proj, kv_b_proj], + sources=lin_nodes, boundary_condition=is_any_lin_op, ) attention_subgraph = LayerSubgraph( - opening_nodes=[q_b_proj, kv_b_proj], + opening_nodes=lin_nodes, subgraph_nodes=sub_subgraph, terminating_node=o_proj, layer_type=LayerType.MLA, min_local_shape=layer_subgraph.min_local_shape, ) # shard q_b_proj and kv_b_proj nodes - num_column_row_shards = _process_column_sharding(attention_subgraph, transform_container) - if num_column_row_shards < 2: - # it means that "someone else" already sharded these nodes. Skipping. - return 0 + _process_column_sharding(attention_subgraph, transform_container) # update "empty" and "expand" nodes' args. Reference in modeling_deepseek.py: # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) @@ -1926,6 +1958,27 @@ def _process_mla_sharding( ) ) + # update reshape nodes' args. Reference in modeling_deepseek.py: + # Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim] + # attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + # attn_output = self.o_proj(attn_output) + candidate_reshape = layer_subgraph.terminating_node.args[0] + if is_op(candidate_reshape, [torch.ops.aten.reshape]): + # reshape args are (attn_output, [bsz, q_len, num_heads * v_head_dim]) + # set 3rd arg (num_heads * v_head_dim) to -1 + reshape_args = list(candidate_reshape.args) + reshape_sizes = list(reshape_args[1]) + reshape_sizes[2] = -1 + reshape_args[1] = tuple(reshape_sizes) + transform_container.add( + ParameterUpdateInfo( + target_node=candidate_reshape.name, + config=transform_container.config, + args=tuple(reshape_args), + ) + ) + ad_logger.debug(f"\nUpdated reshape node {candidate_reshape} arguments to {reshape_args}") + # shard o_proj node transform_container.add( WeightShardingInfo.from_node( diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index ddf3474959..7c9fd4fb11 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -321,6 +321,32 @@ def extract_weight_nodes(node: Node) -> WeightNodes: ], biases=[], ) + elif is_weight_node(node): + weights = [] + biases = [] + + if node.target.endswith("bias"): + biases = [ + WeightNode( + node=node, + node_key=node.target, + tensor=get_param_or_buffer(node.target, gm), + submod=gm.get_submodule(node.target.rpartition(".")[0]), + ) + ] + else: + weights = [ + WeightNode( + node=node, + node_key=node.target, + tensor=get_param_or_buffer(node.target, gm), + submod=gm.get_submodule(node.target.rpartition(".")[0]), + ) + ] + return WeightNodes( + weights=weights, + biases=biases, + ) # for other parametrized nodes, we need to find the weight node else: all_weight_nodes = [ @@ -524,6 +550,16 @@ def is_any_attention_op(node: Node) -> bool: ) +def is_any_mla_op(node: Node) -> bool: + """Check if the node is a mla op.""" + return is_op( + node, + ops=[ + torch.ops.auto_deploy.torch_mla, + ], + ) + + def is_linear_op(node: Node) -> bool: """Check if the node is a linear op. @@ -1100,6 +1136,7 @@ def get_layer_after_linear_node( ] ssm_nodes = list(filtered_nodes(interior_nodes, is_any_ssm_op)) attention_nodes = list(filtered_nodes(interior_nodes, is_any_attention_op)) + mla_nodes = list(filtered_nodes(interior_nodes, is_any_mla_op)) intermediate_lin_nodes = list(filtered_nodes(interior_nodes, is_any_lin_op)) intermediate_weight_nodes = list( filtered_nodes( @@ -1112,23 +1149,15 @@ def get_layer_after_linear_node( #################################################### def classify_layer_type() -> [LayerType, int]: - if len(ssm_nodes) + len(attention_nodes) > 1: + if len(ssm_nodes) + len(attention_nodes) + len(mla_nodes) > 1: + # ambiguous layer type return LayerType.UNKNOWN, 1 if len(attention_nodes) == 1: head_size = shape(attention_nodes[0])[-1] - # check if this is MLA: - # these two intermediate linear nodes are the latent q and kv projections. - if len(intermediate_lin_nodes) == 2: - # MLA has a RMS norm inside, so it should have one (or two, couning biaas) - # intermediate weight nodes - if len(intermediate_weight_nodes) not in [1, 2]: - return LayerType.UNKNOWN, 1 - return LayerType.MLA, head_size - else: - if len(intermediate_lin_nodes) != 0: - return LayerType.UNKNOWN, 1 - return LayerType.ATTENTION, head_size + if len(intermediate_lin_nodes) > 0: + return LayerType.UNKNOWN, 1 + return LayerType.ATTENTION, head_size if len(ssm_nodes) == 1: head_size = shape(ssm_nodes[0])[-1] @@ -1146,6 +1175,16 @@ def get_layer_after_linear_node( return LayerType.UNKNOWN, 1 return LayerType.SSM, head_size + if len(mla_nodes) == 1: + head_size = shape(mla_nodes[0])[-1] + # MLA should have two intermediate linear nodes: + # kv_b_proj and q_b_proj, but: + # - kv_b_proj may be absorbed by the MLA op + # - q_b_proj is skipped if q_lora_rank is None + if len(intermediate_lin_nodes) > 2: + return LayerType.UNKNOWN, 1 + return LayerType.MLA, head_size + # if we reach here, it means the layer is a MLP. # MLP should not have any intermediate linear or weight nodes. if len(intermediate_lin_nodes) > 0 or len(intermediate_weight_nodes) > 0: diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index d5836220d0..afbbfca18b 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -324,6 +324,11 @@ zai-org/GLM-4.5-Air: spec_dec_algo: MTP kv_cache_quant_algo: FP8 accuracy: 88.2 +GLM-4.7-Flash: + - accuracy: 83.434 + - quant_algo: NVFP4 + kv_cache_quant_algo: FP8 + accuracy: 81.046 bigcode/starcoder2-3b: - accuracy: 20.2 bigcode/starcoder2-7b: diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 6698920605..73f1318c75 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -397,3 +397,8 @@ nvidia/NVIDIA-Nemotron-3-Super-120B-012726: - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 86.12 +GLM-4.7-Flash: + - accuracy: 73.562 + - quant_algo: NVFP4 + kv_cache_quant_algo: FP8 + accuracy: 70.955 diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index a26884b473..cbbe2078de 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -418,20 +418,12 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness): class TestGLM4Flash(LlmapiAccuracyTestHarness): - """Accuracy regression tests for GLM-4.7-Flash. + """Accuracy regression tests for GLM-4.7-Flash variants""" - TODO: enable in CI, see https://github.com/NVIDIA/TensorRT-LLM/issues/11117 + MODEL_NAME = "GLM-4.7-Flash" + MODEL_PATH_BF16 = hf_id_to_local_model_dir("zai-org/GLM-4.7-Flash") + MODEL_PATH_NVFP4 = hf_id_to_local_model_dir("DeepInfra/GLM-4.7-Flash-NVFP4") - In the meantime, you should run this test locally: - - ``` - cd tests/integration/defs - TRTLLM_ACCURACY_NO_REFERENCE=1 pytest -svv "accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[True]" - ``` - """ - - MODEL_NAME = "zai-org/GLM-4.7-Flash" - MODEL_PATH = MODEL_NAME # Model is in HF_CACHE # Set minimum possible seq len + small buffer, for test speed & memory usage MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN, GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN) @@ -488,10 +480,27 @@ class TestGLM4Flash(LlmapiAccuracyTestHarness): def test_auto_dtype(self, enable_chunked_prefill): kwargs = self.get_default_kwargs(enable_chunked_prefill) sampling_params = self.get_default_sampling_params() - with AutoDeployLLM(model=self.MODEL_PATH, - tokenizer=self.MODEL_PATH, + with AutoDeployLLM(model=self.MODEL_PATH_BF16, + tokenizer=self.MODEL_PATH_BF16, **kwargs) as llm: task = MMLU(self.MODEL_NAME) task.evaluate(llm, sampling_params=sampling_params) task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + + @skip_pre_blackwell + @pytest.mark.skip_less_device_memory(32000) + @pytest.mark.parametrize("enable_chunked_prefill", [True, False]) + def test_nvfp4(self, enable_chunked_prefill): + kwargs = self.get_default_kwargs(enable_chunked_prefill) + sampling_params = self.get_default_sampling_params() + with AutoDeployLLM(model=self.MODEL_PATH_NVFP4, + tokenizer=self.MODEL_PATH_NVFP4, + **kwargs) as llm: + # Manually set quant_config for NVFP4 model to get the accuracy threshold + llm.args.quant_config.quant_algo = QuantAlgo.NVFP4 + llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8 + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) diff --git a/tests/integration/defs/examples/test_ad_speculative_decoding.py b/tests/integration/defs/examples/test_ad_speculative_decoding.py index 0c316bfbb1..97ba96a9fc 100644 --- a/tests/integration/defs/examples/test_ad_speculative_decoding.py +++ b/tests/integration/defs/examples/test_ad_speculative_decoding.py @@ -81,13 +81,15 @@ def make_eagle3_config(spec_model_path: str): ) -def run_with_autodeploy(model, speculative_config, batch_size): +def run_with_autodeploy(model, speculative_config, batch_size, transforms_override=None): """Run AutoDeploy with or without speculative decoding. Args: model: Path to the base model speculative_config: Speculative decoding config (None for baseline mode) batch_size: Number of prompts to process + transforms_override: Optional dict of transform config overrides to merge + into the default transforms config (e.g. {"fuse_add_rms_norm": {"enabled": False}}) Returns: List of (prompt, output) tuples from prompts_and_outputs @@ -112,6 +114,10 @@ def run_with_autodeploy(model, speculative_config, batch_size): "max_num_tokens": 64, } + # Apply any transform config overrides + if transforms_override: + llm_args["transforms"] = transforms_override + # Configure experiment with prompts experiment_config = { "args": llm_args, @@ -169,18 +175,34 @@ def test_autodeploy_spec_dec_output(spec_dec_mode): print(f"\nBase Model: {base_model}") print(f"Speculative Model ({spec_dec_mode}): {spec_model}") + # For eagle3, disable fuse_add_rms_norm for both runs to ensure numerical + # parity. The eagle3 hidden state capture replaces aten.add with + # residual_add_for_capture at certain layers, which prevents + # fuse_add_rms_norm from fusing those layers. This causes the target + # model to produce slightly different logits vs baseline (which fuses + # all layers), leading to greedy-divergent outputs on some hardware. + transforms_override = None + if spec_dec_mode == "eagle3": + transforms_override = {"fuse_add_rms_norm": {"enabled": False}} + # Run with speculative decoding print("\n[1/2] Running with speculative decoding enabled...") spec_outputs = run_with_autodeploy( model=base_model, speculative_config=spec_config, batch_size=1, + transforms_override=transforms_override, ) print(f"Generated {len(spec_outputs)} outputs with speculative decoding") # Run without speculative decoding (baseline) print("\n[2/2] Running without speculative decoding (baseline)...") - baseline_outputs = run_with_autodeploy(model=base_model, speculative_config=None, batch_size=1) + baseline_outputs = run_with_autodeploy( + model=base_model, + speculative_config=None, + batch_size=1, + transforms_override=transforms_override, + ) print(f"Generated {len(baseline_outputs)} outputs in baseline mode") # Verify outputs are identical diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 30a48da2cb..1144fd9a62 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -191,4 +191,6 @@ l0_b200: - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[flashinfer-False-1] - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[torch-True-1] - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1] + - accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[False] + - accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_nvfp4[False] - unittest/_torch/auto_deploy/unit/singlegpu diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 80b2635162..693c618d32 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -445,6 +445,8 @@ l0_h100: - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-True] - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1] - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[1] + - accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[False] + - accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[True] - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target] - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3] - examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate diff --git a/tests/test_common/llm_data.py b/tests/test_common/llm_data.py index cc8d939e04..ded67cad2e 100644 --- a/tests/test_common/llm_data.py +++ b/tests/test_common/llm_data.py @@ -43,6 +43,8 @@ HF_ID_TO_LLM_MODELS_SUBDIR = { "nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3": "NVIDIA-Nemotron-Nano-31B-A3-v3", "nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024": "Nemotron-Nano-3-30B-A3.5B-dev-1024", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": "EAGLE3-LLaMA3.1-Instruct-8B", + "zai-org/GLM-4.7-Flash": "GLM-4.7-Flash", + "DeepInfra/GLM-4.7-Flash-NVFP4": "GLM-4.7-Flash-NVFP4", } diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index b66a6a54eb..e8352c88a8 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -23,7 +23,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ( WeightShardingInfo, ) from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer -from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op, is_weight_node from tensorrt_llm.functional import AllReduceStrategy base_model_tp_plan = { @@ -139,6 +139,7 @@ class MLA_Block(nn.Module): Based on DeepSeek MLA architecture with KV compression. This is a minimal, self-contained implementation for testing sharding patterns. + Based on models/custom/modeling_deepseek.py:DeepSeekV3Attention """ def __init__( @@ -163,41 +164,62 @@ class MLA_Block(nn.Module): # KV compression path (not sharded - gather) self.kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + qk_rope_head_dim, bias=bias) + self.kv_a_layernorm = nn.LayerNorm(kv_lora_rank) - # KV decompression (sharded column-wise) - self.kv_b_proj = nn.Linear( - kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim), bias=False + # KV decompression weight (absorbed into torch_mla, sharded column-wise) + # NOTE: This is nn.Parameter, not nn.Linear - the weight is passed directly to torch_mla + self.kv_b_proj_weight = nn.Parameter( + torch.randn(num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank) ) # Query path (sharded column-wise) self.q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=bias) self.q_b_proj = nn.Linear(q_lora_rank, num_heads * self.qk_head_dim, bias=bias) self.q_a_layernorm = nn.LayerNorm(q_lora_rank) + # Output projection (sharded row-wise) self.o_proj = nn.Linear(num_heads * v_head_dim, hidden_size, bias=bias) + # Softmax scale + self.softmax_scale = self.qk_head_dim ** (-0.5) + @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: b, s, _ = x.shape # Compress KV to latent compressed_kv_rope = self.kv_a_proj_with_mqa(x) # (b, s, kv_lora_rank + rope_dim) - compressed_kv = compressed_kv_rope[:, :, : self.kv_lora_rank] # (b, s, kv_lora_rank) + compressed_kv, k_pe = torch.split( + compressed_kv_rope, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) - # Decompress to full K and V - kv = self.kv_b_proj(compressed_kv) # (b, s, num_heads * (qk_nope + v)) - k_nope_v = kv.view(b, s, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope = k_nope_v[:, :, :, : self.qk_nope_head_dim] - v = k_nope_v[:, :, :, self.qk_nope_head_dim :] + # Apply layernorm to compressed KV + compressed_kv = self.kv_a_layernorm(compressed_kv) # (b, s, kv_lora_rank) - # Query projection - # q = q_b_proj @ (layernorm(q_a_proj @ x)) + # k_pe: shared across heads for simplified version + k_pe = k_pe.view(b, s, 1, self.qk_rope_head_dim) # (b, s, 1, qk_rope_head_dim) + + # Query projection: q = q_b_proj @ (layernorm(q_a_proj @ x)) q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) # (b, s, num_heads * qk_head_dim) q = q.view(b, s, self.num_heads, self.qk_head_dim) - q_nope = q[:, :, :, : self.qk_nope_head_dim] + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Call MLA kernel with compressed KV and kv_b_proj weight absorbed + # NOTE: kv_b_proj weight is passed directly to the kernel instead of calling Linear + attn_out = torch.ops.auto_deploy.torch_mla( + q_nope, # [B, S, N, qk_nope_head_dim] + q_pe, # [B, S, N, qk_rope_head_dim] + compressed_kv, # [B, S, kv_lora_rank] + k_pe, # [B, S, 1, qk_rope_head_dim] + self.kv_b_proj_weight, # [N*(qk_nope+v), kv_lora_rank] - absorbed weight + True, # is_causal + self.softmax_scale, # softmax_scale + "bsnd", # layout + ) + + # Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim] + attn_out = attn_out.reshape(b, s, self.num_heads * self.v_head_dim) - attn_out = torch.ops.auto_deploy.torch_attention(q_nope, k_nope, v, is_causal=True) - attn_out = attn_out.contiguous().view(b, s, -1) # Output projection output = self.o_proj(attn_out) return output @@ -616,6 +638,16 @@ def _run_pattern_detection_job( layer_type=LayerType.MLA, ) ) + if is_weight_node(node): + if "kv_b_proj_weight" in node.name: + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=SplitDimension.COLUMN, + config=config, + layer_type=LayerType.MLA, + ) + ) # get detected transformations optimizer = InferenceOptimizer(