[#11109][feat] AutoDeploy: GLM 4.7 Flash Improvements (#11414)

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: Gal Hubara-Agam <96368689+galagam@users.noreply.github.com>
Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Co-authored-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Co-authored-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
Bala Marimuthu 2026-02-17 08:43:59 -05:00 committed by GitHub
parent fedd7178d1
commit 1c065fbb3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 236 additions and 58 deletions

View File

@ -243,12 +243,6 @@ class _FlashInferMLAPlanner:
plan_params: Parameters for planning.
"""
# Decode qo_indptr: [0, 1, 2, ..., batch_size] (1 token per sequence)
batch_size = kv_page_indptr.shape[0] - 1
qo_indptr = torch.arange(batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32)
# Compute kv_len_arr for CUDA graph wrapper initialization
num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1]
kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len
# we want to plan during warm-up of cuda graph capture to ensure we have the plan cached
if (
@ -257,6 +251,14 @@ class _FlashInferMLAPlanner:
):
# During CUDA graph capture, the metadata tensors provided by auto-deploy are stable.
# Pass the buffer tensors to the wrapper for use_cuda_graph=True
batch_size = kv_page_indptr.shape[0] - 1
qo_indptr = torch.arange(
batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32
)
# Compute kv_len_arr for CUDA graph wrapper initialization
num_pages_per_seq = kv_page_indptr[1:] - kv_page_indptr[:-1]
kv_len_arr = (num_pages_per_seq - 1) * plan_params.page_size + kv_last_page_len
wrapper = self._init_decode_wrapper(
use_cuda_graph=True,
qo_indptr=qo_indptr,
@ -276,6 +278,11 @@ class _FlashInferMLAPlanner:
# Re-plan if plan_params changed
if plan_params != self.plan_params_decode:
batch_size = kv_page_indptr.shape[0] - 1
qo_indptr = torch.arange(
batch_size + 1, device=kv_page_indptr.device, dtype=torch.int32
)
self._plan_mla_wrapper(
self.decode_wrapper,
qo_indptr,

View File

@ -50,6 +50,7 @@ from ...utils.node_utils import (
is_any_moe_op,
is_any_ssm_op,
is_op,
is_weight_node,
num_users_of_weight_node,
shape,
subgraph,
@ -1855,6 +1856,7 @@ def _process_mla_sharding(
- q_a_proj: # gather (simple shard, output is replicated)
- q_b_proj: # column-sharding (output is head-distributed)
- kv_a_proj # gather (simple shard, output is replicated)
# This one is actually absorbed by the MLA kernel.
- kv_b_proj # column-sharding (output is head-distributed)
- o_proj # row-sharding + all-reduce
@ -1869,8 +1871,41 @@ def _process_mla_sharding(
q_a_proj, kv_a_proj = layer_subgraph.opening_nodes
# extract q_b_proj and kv_b_proj nodes
lin_nodes = list(filtered_nodes(layer_subgraph.subgraph_nodes, is_any_lin_op))
assert len(lin_nodes) == 2, "Expecting exactly two linear nodes in the interior of the subgraph"
q_b_proj, kv_b_proj = lin_nodes
assert len(lin_nodes) <= 2, (
"Expecting at most two linear nodes in the interior of the MLA layer"
)
if len(lin_nodes) == 1:
# we don't have explicit kv_b projection. Instead, it is
# absorbed by the MLA kernel.
mla_node = list(
filtered_nodes(layer_subgraph.subgraph_nodes, ops=torch.ops.auto_deploy.torch_mla)
)
assert len(mla_node) == 1, "Expecting exactly one MLA node"
mla_node = mla_node[0]
# torch_mla args:
# q_nope: [B, S, N, qk_nope_head_dim]
# q_pe: [B, S, N, qk_rope_head_dim]
# compressed_kv: [B, S, kv_lora_rank]
# kpe: [B, S, 1, qk_rope_head_dim]
# kv_b_proj_weight: [N * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
# is_causal: bool
# scale: float
# layout: str
# we need to shard 4th argument: kv_b_proj_weight
assert is_weight_node(mla_node.args[4]), "Expecting weight node for kv_b_proj_weight"
# column-shard it
transform_container.add(
WeightShardingInfo.from_node(
mla_node.args[4],
split_dim=SplitDimension.COLUMN,
config=transform_container.config,
dist_op=None,
min_local_shape=1,
layer_type=LayerType.MLA,
)
)
# extract o_proj node
o_proj = layer_subgraph.terminating_node
@ -1885,21 +1920,18 @@ def _process_mla_sharding(
# extract the sub-subgraph from q_b_proj and kv_b_proj to o_proj
sub_subgraph = subgraph(
sources=[q_b_proj, kv_b_proj],
sources=lin_nodes,
boundary_condition=is_any_lin_op,
)
attention_subgraph = LayerSubgraph(
opening_nodes=[q_b_proj, kv_b_proj],
opening_nodes=lin_nodes,
subgraph_nodes=sub_subgraph,
terminating_node=o_proj,
layer_type=LayerType.MLA,
min_local_shape=layer_subgraph.min_local_shape,
)
# shard q_b_proj and kv_b_proj nodes
num_column_row_shards = _process_column_sharding(attention_subgraph, transform_container)
if num_column_row_shards < 2:
# it means that "someone else" already sharded these nodes. Skipping.
return 0
_process_column_sharding(attention_subgraph, transform_container)
# update "empty" and "expand" nodes' args. Reference in modeling_deepseek.py:
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
@ -1926,6 +1958,27 @@ def _process_mla_sharding(
)
)
# update reshape nodes' args. Reference in modeling_deepseek.py:
# Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim]
# attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
# attn_output = self.o_proj(attn_output)
candidate_reshape = layer_subgraph.terminating_node.args[0]
if is_op(candidate_reshape, [torch.ops.aten.reshape]):
# reshape args are (attn_output, [bsz, q_len, num_heads * v_head_dim])
# set 3rd arg (num_heads * v_head_dim) to -1
reshape_args = list(candidate_reshape.args)
reshape_sizes = list(reshape_args[1])
reshape_sizes[2] = -1
reshape_args[1] = tuple(reshape_sizes)
transform_container.add(
ParameterUpdateInfo(
target_node=candidate_reshape.name,
config=transform_container.config,
args=tuple(reshape_args),
)
)
ad_logger.debug(f"\nUpdated reshape node {candidate_reshape} arguments to {reshape_args}")
# shard o_proj node
transform_container.add(
WeightShardingInfo.from_node(

View File

@ -321,6 +321,32 @@ def extract_weight_nodes(node: Node) -> WeightNodes:
],
biases=[],
)
elif is_weight_node(node):
weights = []
biases = []
if node.target.endswith("bias"):
biases = [
WeightNode(
node=node,
node_key=node.target,
tensor=get_param_or_buffer(node.target, gm),
submod=gm.get_submodule(node.target.rpartition(".")[0]),
)
]
else:
weights = [
WeightNode(
node=node,
node_key=node.target,
tensor=get_param_or_buffer(node.target, gm),
submod=gm.get_submodule(node.target.rpartition(".")[0]),
)
]
return WeightNodes(
weights=weights,
biases=biases,
)
# for other parametrized nodes, we need to find the weight node
else:
all_weight_nodes = [
@ -524,6 +550,16 @@ def is_any_attention_op(node: Node) -> bool:
)
def is_any_mla_op(node: Node) -> bool:
"""Check if the node is a mla op."""
return is_op(
node,
ops=[
torch.ops.auto_deploy.torch_mla,
],
)
def is_linear_op(node: Node) -> bool:
"""Check if the node is a linear op.
@ -1100,6 +1136,7 @@ def get_layer_after_linear_node(
]
ssm_nodes = list(filtered_nodes(interior_nodes, is_any_ssm_op))
attention_nodes = list(filtered_nodes(interior_nodes, is_any_attention_op))
mla_nodes = list(filtered_nodes(interior_nodes, is_any_mla_op))
intermediate_lin_nodes = list(filtered_nodes(interior_nodes, is_any_lin_op))
intermediate_weight_nodes = list(
filtered_nodes(
@ -1112,23 +1149,15 @@ def get_layer_after_linear_node(
####################################################
def classify_layer_type() -> [LayerType, int]:
if len(ssm_nodes) + len(attention_nodes) > 1:
if len(ssm_nodes) + len(attention_nodes) + len(mla_nodes) > 1:
# ambiguous layer type
return LayerType.UNKNOWN, 1
if len(attention_nodes) == 1:
head_size = shape(attention_nodes[0])[-1]
# check if this is MLA:
# these two intermediate linear nodes are the latent q and kv projections.
if len(intermediate_lin_nodes) == 2:
# MLA has a RMS norm inside, so it should have one (or two, couning biaas)
# intermediate weight nodes
if len(intermediate_weight_nodes) not in [1, 2]:
return LayerType.UNKNOWN, 1
return LayerType.MLA, head_size
else:
if len(intermediate_lin_nodes) != 0:
return LayerType.UNKNOWN, 1
return LayerType.ATTENTION, head_size
if len(intermediate_lin_nodes) > 0:
return LayerType.UNKNOWN, 1
return LayerType.ATTENTION, head_size
if len(ssm_nodes) == 1:
head_size = shape(ssm_nodes[0])[-1]
@ -1146,6 +1175,16 @@ def get_layer_after_linear_node(
return LayerType.UNKNOWN, 1
return LayerType.SSM, head_size
if len(mla_nodes) == 1:
head_size = shape(mla_nodes[0])[-1]
# MLA should have two intermediate linear nodes:
# kv_b_proj and q_b_proj, but:
# - kv_b_proj may be absorbed by the MLA op
# - q_b_proj is skipped if q_lora_rank is None
if len(intermediate_lin_nodes) > 2:
return LayerType.UNKNOWN, 1
return LayerType.MLA, head_size
# if we reach here, it means the layer is a MLP.
# MLP should not have any intermediate linear or weight nodes.
if len(intermediate_lin_nodes) > 0 or len(intermediate_weight_nodes) > 0:

View File

@ -324,6 +324,11 @@ zai-org/GLM-4.5-Air:
spec_dec_algo: MTP
kv_cache_quant_algo: FP8
accuracy: 88.2
GLM-4.7-Flash:
- accuracy: 83.434
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 81.046
bigcode/starcoder2-3b:
- accuracy: 20.2
bigcode/starcoder2-7b:

View File

@ -397,3 +397,8 @@ nvidia/NVIDIA-Nemotron-3-Super-120B-012726:
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 86.12
GLM-4.7-Flash:
- accuracy: 73.562
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 70.955

View File

@ -418,20 +418,12 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
class TestGLM4Flash(LlmapiAccuracyTestHarness):
"""Accuracy regression tests for GLM-4.7-Flash.
"""Accuracy regression tests for GLM-4.7-Flash variants"""
TODO: enable in CI, see https://github.com/NVIDIA/TensorRT-LLM/issues/11117
MODEL_NAME = "GLM-4.7-Flash"
MODEL_PATH_BF16 = hf_id_to_local_model_dir("zai-org/GLM-4.7-Flash")
MODEL_PATH_NVFP4 = hf_id_to_local_model_dir("DeepInfra/GLM-4.7-Flash-NVFP4")
In the meantime, you should run this test locally:
```
cd tests/integration/defs
TRTLLM_ACCURACY_NO_REFERENCE=1 pytest -svv "accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[True]"
```
"""
MODEL_NAME = "zai-org/GLM-4.7-Flash"
MODEL_PATH = MODEL_NAME # Model is in HF_CACHE
# Set minimum possible seq len + small buffer, for test speed & memory usage
MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN,
GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN)
@ -488,10 +480,27 @@ class TestGLM4Flash(LlmapiAccuracyTestHarness):
def test_auto_dtype(self, enable_chunked_prefill):
kwargs = self.get_default_kwargs(enable_chunked_prefill)
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH,
tokenizer=self.MODEL_PATH,
with AutoDeployLLM(model=self.MODEL_PATH_BF16,
tokenizer=self.MODEL_PATH_BF16,
**kwargs) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@skip_pre_blackwell
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
def test_nvfp4(self, enable_chunked_prefill):
kwargs = self.get_default_kwargs(enable_chunked_prefill)
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH_NVFP4,
tokenizer=self.MODEL_PATH_NVFP4,
**kwargs) as llm:
# Manually set quant_config for NVFP4 model to get the accuracy threshold
llm.args.quant_config.quant_algo = QuantAlgo.NVFP4
llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

View File

@ -81,13 +81,15 @@ def make_eagle3_config(spec_model_path: str):
)
def run_with_autodeploy(model, speculative_config, batch_size):
def run_with_autodeploy(model, speculative_config, batch_size, transforms_override=None):
"""Run AutoDeploy with or without speculative decoding.
Args:
model: Path to the base model
speculative_config: Speculative decoding config (None for baseline mode)
batch_size: Number of prompts to process
transforms_override: Optional dict of transform config overrides to merge
into the default transforms config (e.g. {"fuse_add_rms_norm": {"enabled": False}})
Returns:
List of (prompt, output) tuples from prompts_and_outputs
@ -112,6 +114,10 @@ def run_with_autodeploy(model, speculative_config, batch_size):
"max_num_tokens": 64,
}
# Apply any transform config overrides
if transforms_override:
llm_args["transforms"] = transforms_override
# Configure experiment with prompts
experiment_config = {
"args": llm_args,
@ -169,18 +175,34 @@ def test_autodeploy_spec_dec_output(spec_dec_mode):
print(f"\nBase Model: {base_model}")
print(f"Speculative Model ({spec_dec_mode}): {spec_model}")
# For eagle3, disable fuse_add_rms_norm for both runs to ensure numerical
# parity. The eagle3 hidden state capture replaces aten.add with
# residual_add_for_capture at certain layers, which prevents
# fuse_add_rms_norm from fusing those layers. This causes the target
# model to produce slightly different logits vs baseline (which fuses
# all layers), leading to greedy-divergent outputs on some hardware.
transforms_override = None
if spec_dec_mode == "eagle3":
transforms_override = {"fuse_add_rms_norm": {"enabled": False}}
# Run with speculative decoding
print("\n[1/2] Running with speculative decoding enabled...")
spec_outputs = run_with_autodeploy(
model=base_model,
speculative_config=spec_config,
batch_size=1,
transforms_override=transforms_override,
)
print(f"Generated {len(spec_outputs)} outputs with speculative decoding")
# Run without speculative decoding (baseline)
print("\n[2/2] Running without speculative decoding (baseline)...")
baseline_outputs = run_with_autodeploy(model=base_model, speculative_config=None, batch_size=1)
baseline_outputs = run_with_autodeploy(
model=base_model,
speculative_config=None,
batch_size=1,
transforms_override=transforms_override,
)
print(f"Generated {len(baseline_outputs)} outputs in baseline mode")
# Verify outputs are identical

View File

@ -191,4 +191,6 @@ l0_b200:
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[flashinfer-False-1]
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[torch-True-1]
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1]
- accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[False]
- accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_nvfp4[False]
- unittest/_torch/auto_deploy/unit/singlegpu

View File

@ -445,6 +445,8 @@ l0_h100:
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-True]
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1]
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[1]
- accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[False]
- accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[True]
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target]
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3]
- examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate

View File

@ -43,6 +43,8 @@ HF_ID_TO_LLM_MODELS_SUBDIR = {
"nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3": "NVIDIA-Nemotron-Nano-31B-A3-v3",
"nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024": "Nemotron-Nano-3-30B-A3.5B-dev-1024",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": "EAGLE3-LLaMA3.1-Instruct-8B",
"zai-org/GLM-4.7-Flash": "GLM-4.7-Flash",
"DeepInfra/GLM-4.7-Flash-NVFP4": "GLM-4.7-Flash-NVFP4",
}

View File

@ -23,7 +23,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import (
WeightShardingInfo,
)
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op, is_weight_node
from tensorrt_llm.functional import AllReduceStrategy
base_model_tp_plan = {
@ -139,6 +139,7 @@ class MLA_Block(nn.Module):
Based on DeepSeek MLA architecture with KV compression.
This is a minimal, self-contained implementation for testing sharding patterns.
Based on models/custom/modeling_deepseek.py:DeepSeekV3Attention
"""
def __init__(
@ -163,41 +164,62 @@ class MLA_Block(nn.Module):
# KV compression path (not sharded - gather)
self.kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + qk_rope_head_dim, bias=bias)
self.kv_a_layernorm = nn.LayerNorm(kv_lora_rank)
# KV decompression (sharded column-wise)
self.kv_b_proj = nn.Linear(
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim), bias=False
# KV decompression weight (absorbed into torch_mla, sharded column-wise)
# NOTE: This is nn.Parameter, not nn.Linear - the weight is passed directly to torch_mla
self.kv_b_proj_weight = nn.Parameter(
torch.randn(num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank)
)
# Query path (sharded column-wise)
self.q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=bias)
self.q_b_proj = nn.Linear(q_lora_rank, num_heads * self.qk_head_dim, bias=bias)
self.q_a_layernorm = nn.LayerNorm(q_lora_rank)
# Output projection (sharded row-wise)
self.o_proj = nn.Linear(num_heads * v_head_dim, hidden_size, bias=bias)
# Softmax scale
self.softmax_scale = self.qk_head_dim ** (-0.5)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, s, _ = x.shape
# Compress KV to latent
compressed_kv_rope = self.kv_a_proj_with_mqa(x) # (b, s, kv_lora_rank + rope_dim)
compressed_kv = compressed_kv_rope[:, :, : self.kv_lora_rank] # (b, s, kv_lora_rank)
compressed_kv, k_pe = torch.split(
compressed_kv_rope, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
# Decompress to full K and V
kv = self.kv_b_proj(compressed_kv) # (b, s, num_heads * (qk_nope + v))
k_nope_v = kv.view(b, s, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope = k_nope_v[:, :, :, : self.qk_nope_head_dim]
v = k_nope_v[:, :, :, self.qk_nope_head_dim :]
# Apply layernorm to compressed KV
compressed_kv = self.kv_a_layernorm(compressed_kv) # (b, s, kv_lora_rank)
# Query projection
# q = q_b_proj @ (layernorm(q_a_proj @ x))
# k_pe: shared across heads for simplified version
k_pe = k_pe.view(b, s, 1, self.qk_rope_head_dim) # (b, s, 1, qk_rope_head_dim)
# Query projection: q = q_b_proj @ (layernorm(q_a_proj @ x))
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) # (b, s, num_heads * qk_head_dim)
q = q.view(b, s, self.num_heads, self.qk_head_dim)
q_nope = q[:, :, :, : self.qk_nope_head_dim]
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Call MLA kernel with compressed KV and kv_b_proj weight absorbed
# NOTE: kv_b_proj weight is passed directly to the kernel instead of calling Linear
attn_out = torch.ops.auto_deploy.torch_mla(
q_nope, # [B, S, N, qk_nope_head_dim]
q_pe, # [B, S, N, qk_rope_head_dim]
compressed_kv, # [B, S, kv_lora_rank]
k_pe, # [B, S, 1, qk_rope_head_dim]
self.kv_b_proj_weight, # [N*(qk_nope+v), kv_lora_rank] - absorbed weight
True, # is_causal
self.softmax_scale, # softmax_scale
"bsnd", # layout
)
# Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim]
attn_out = attn_out.reshape(b, s, self.num_heads * self.v_head_dim)
attn_out = torch.ops.auto_deploy.torch_attention(q_nope, k_nope, v, is_causal=True)
attn_out = attn_out.contiguous().view(b, s, -1)
# Output projection
output = self.o_proj(attn_out)
return output
@ -616,6 +638,16 @@ def _run_pattern_detection_job(
layer_type=LayerType.MLA,
)
)
if is_weight_node(node):
if "kv_b_proj_weight" in node.name:
expected_transformations.append(
WeightShardingInfo(
target_node=node.name,
split_dim=SplitDimension.COLUMN,
config=config,
layer_type=LayerType.MLA,
)
)
# get detected transformations
optimizer = InferenceOptimizer(