mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 1184b5f1c6 into 6df2c8a074
This commit is contained in:
commit
86eccf87b6
@ -719,6 +719,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars)
|
||||
"tensorrt_llm/_torch/pyexecutor/_util.py",
|
||||
"tensorrt_llm/_torch/pyexecutor/model_engine.py",
|
||||
"tensorrt_llm/_torch/pyexecutor/py_executor.py",
|
||||
"tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py",
|
||||
"tensorrt_llm/evaluate/json_mode_eval.py",
|
||||
"tensorrt_llm/evaluate/mmlu.py",
|
||||
"tensorrt_llm/executor/",
|
||||
@ -740,6 +741,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars)
|
||||
"tests/integration/defs/accuracy/test_disaggregated_serving.py",
|
||||
"tests/unittest/_torch/ray_orchestrator/multi_gpu/",
|
||||
"tests/integration/defs/examples/test_ray.py",
|
||||
"tests/integration/defs/accuracy/test_llm_api_autodeploy.py",
|
||||
"tests/unittest/llmapi/test_async_llm.py",
|
||||
]
|
||||
|
||||
|
||||
@ -268,7 +268,7 @@ class WeightShardingInfo(ShardingTransformInfo):
|
||||
min_local_shape: int = 1
|
||||
layer_type: LayerType = LayerType.MLP
|
||||
# used for TP sharding of fused weights
|
||||
fused_weight_dims: Optional[list] = None
|
||||
fused_weight_dims: Optional[tuple] = None
|
||||
|
||||
def quantization_cb(
|
||||
self,
|
||||
@ -1229,7 +1229,7 @@ def _shard_parameter_node(
|
||||
config: ShardingTransformConfig,
|
||||
add_dist: bool = False,
|
||||
min_local_shape: int = 1,
|
||||
fused_weight_dims: Optional[list] = None,
|
||||
fused_weight_dims: Optional[tuple] = None,
|
||||
quantization_cb: Optional[
|
||||
Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None]
|
||||
] = None,
|
||||
@ -1365,6 +1365,7 @@ def _insert_sharded_moe(
|
||||
num_experts = len(args[3])
|
||||
|
||||
experts_per_rank = num_experts // ep_size
|
||||
# ad_logger.info(f"MoE sharding: Experts per rank: {experts_per_rank}, EP rank: {ep_rank}, EP size: {ep_size}")
|
||||
|
||||
with gm.graph.inserting_before(node):
|
||||
lower = experts_per_rank * ep_rank
|
||||
@ -1633,7 +1634,7 @@ def _process_ssm_sharding(
|
||||
config=config,
|
||||
dist_op=None,
|
||||
min_local_shape=1,
|
||||
fused_weight_dims=fused_weight_dims["in_proj"],
|
||||
fused_weight_dims=tuple(fused_weight_dims["in_proj"]),
|
||||
layer_type=LayerType.SSM,
|
||||
)
|
||||
):
|
||||
@ -1702,7 +1703,7 @@ def _process_ssm_sharding(
|
||||
fused_dims = None
|
||||
for k, v in fused_weight_dims.items():
|
||||
if k in weight_key:
|
||||
fused_dims = v
|
||||
fused_dims = tuple(v)
|
||||
break
|
||||
|
||||
# Shard the weight tensor (also updates the parameter in the module)
|
||||
@ -1887,7 +1888,7 @@ def _determine_fused_weight_dims(
|
||||
ad_logger.warning(
|
||||
f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping."
|
||||
)
|
||||
return
|
||||
return None
|
||||
chunk_nodes = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk))
|
||||
if len(chunk_nodes) > 0:
|
||||
assert len(linear_nodes) == 1
|
||||
@ -1896,6 +1897,8 @@ def _determine_fused_weight_dims(
|
||||
num_chunks = chunk_nodes[0].args[1]
|
||||
weight_dim = shape(linear_node)[2]
|
||||
fused_weight_dims = [weight_dim // num_chunks] * num_chunks
|
||||
if fused_weight_dims is not None:
|
||||
fused_weight_dims = tuple(fused_weight_dims)
|
||||
return fused_weight_dims
|
||||
|
||||
|
||||
|
||||
@ -203,7 +203,8 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
use_beam_search=beam_width > 1)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
def test_bf16(self):
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_bf16(self, world_size):
|
||||
kwargs = self.get_default_kwargs()
|
||||
# TODO: multi-stream MOE seems to increase the memory usage
|
||||
kwargs["max_batch_size"] = 32
|
||||
@ -211,6 +212,7 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_BF16,
|
||||
tokenizer=self.MODEL_PATH_BF16,
|
||||
world_size=world_size,
|
||||
**kwargs) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
@ -218,10 +220,12 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
def test_fp8(self):
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_fp8(self, world_size):
|
||||
kwargs = self.get_default_kwargs()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_FP8,
|
||||
tokenizer=self.MODEL_PATH_FP8,
|
||||
world_size=world_size,
|
||||
**kwargs) as llm:
|
||||
# Manually set quant_config for FP8 model to get the accuracy threshold
|
||||
llm.args.quant_config.quant_algo = QuantAlgo.FP8
|
||||
|
||||
@ -97,7 +97,7 @@ l0_b200:
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_fp8_blockwise_deepgemm[enable_configurable_moe-dtype1-72-256-2560-DefaultMoeRoutingMethod]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1]
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
- condition:
|
||||
ranges:
|
||||
|
||||
@ -32,7 +32,7 @@ l0_dgx_b200:
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[1]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -126,7 +126,10 @@ l0_dgx_h100:
|
||||
- disaggregated/test_auto_scaling.py::test_minimal_instances[http-round_robin]
|
||||
- disaggregated/test_auto_scaling.py::test_disagg_server_restart[http-round_robin]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[4]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -133,7 +133,7 @@ l0_dgx_h200:
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16
|
||||
- condition:
|
||||
ranges:
|
||||
|
||||
@ -120,8 +120,8 @@ l0_h100:
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[True-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[False]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[True]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[1]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Tests for basic graph sharding."""
|
||||
|
||||
from functools import partial
|
||||
from types import SimpleNamespace
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
@ -13,6 +14,7 @@ from _model_test_utils import FakeFP8Linear
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_nemotron_h import NemotronHMamba2Mixer
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.sharding import (
|
||||
FP8WeightShardingInfo,
|
||||
LayerType,
|
||||
@ -35,6 +37,14 @@ base_model_tp_plan = {
|
||||
"linear1": "colwise",
|
||||
"linear2": "rowwise",
|
||||
"linear": "gather",
|
||||
# Mamba2 specific projections
|
||||
"in_proj": "mamba",
|
||||
"out_proj": "rowwise",
|
||||
# MLA specific projections
|
||||
"q_a_proj": "gather",
|
||||
"q_b_proj": "colwise",
|
||||
"kv_a_proj_with_mqa": "gather",
|
||||
"kv_b_proj": "colwise",
|
||||
# "input_layernorm.weight": "sequence_parallel",
|
||||
# "post_attention_layernorm.weight": "sequence_parallel",
|
||||
# "norm.weight": "sequence_parallel",
|
||||
@ -50,7 +60,6 @@ base_model_tp_plan = {
|
||||
}
|
||||
|
||||
predefined_config = {
|
||||
"head_dim": 8,
|
||||
"tp_plan": base_model_tp_plan,
|
||||
}
|
||||
|
||||
@ -125,6 +134,75 @@ class FP8MLP(nn.Module):
|
||||
return self.linear2(y)
|
||||
|
||||
|
||||
class MLA_Block(nn.Module):
|
||||
"""Multi-Latent Attention block - simplified standalone version.
|
||||
|
||||
Based on DeepSeek MLA architecture with KV compression.
|
||||
This is a minimal, self-contained implementation for testing sharding patterns.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
q_lora_rank: int,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
# KV compression path (not sharded - gather)
|
||||
self.kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + qk_rope_head_dim, bias=bias)
|
||||
|
||||
# KV decompression (sharded column-wise)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim), bias=False
|
||||
)
|
||||
|
||||
# Query path (sharded column-wise)
|
||||
self.q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=bias)
|
||||
self.q_b_proj = nn.Linear(q_lora_rank, num_heads * self.qk_head_dim, bias=bias)
|
||||
self.q_a_layernorm = nn.LayerNorm(q_lora_rank)
|
||||
# Output projection (sharded row-wise)
|
||||
self.o_proj = nn.Linear(num_heads * v_head_dim, hidden_size, bias=bias)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, s, _ = x.shape
|
||||
|
||||
# Compress KV to latent
|
||||
compressed_kv_rope = self.kv_a_proj_with_mqa(x) # (b, s, kv_lora_rank + rope_dim)
|
||||
compressed_kv = compressed_kv_rope[:, :, : self.kv_lora_rank] # (b, s, kv_lora_rank)
|
||||
|
||||
# Decompress to full K and V
|
||||
kv = self.kv_b_proj(compressed_kv) # (b, s, num_heads * (qk_nope + v))
|
||||
k_nope_v = kv.view(b, s, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope = k_nope_v[:, :, :, : self.qk_nope_head_dim]
|
||||
v = k_nope_v[:, :, :, self.qk_nope_head_dim :]
|
||||
|
||||
# Query projection
|
||||
# q = q_b_proj @ (layernorm(q_a_proj @ x))
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) # (b, s, num_heads * qk_head_dim)
|
||||
q = q.view(b, s, self.num_heads, self.qk_head_dim)
|
||||
q_nope = q[:, :, :, : self.qk_nope_head_dim]
|
||||
|
||||
attn_out = torch.ops.auto_deploy.torch_attention(q_nope, k_nope, v, is_causal=True)
|
||||
attn_out = attn_out.contiguous().view(b, s, -1)
|
||||
# Output projection
|
||||
output = self.o_proj(attn_out)
|
||||
return output
|
||||
|
||||
|
||||
def _run_sharding_execution_job(
|
||||
model_cls: nn.Module,
|
||||
dist_op_expected: str,
|
||||
@ -137,6 +215,7 @@ def _run_sharding_execution_job(
|
||||
batch_size = 4
|
||||
sequence_len = 8
|
||||
num_features = 32
|
||||
skip_output_assert = False
|
||||
|
||||
# GQA specific parameters
|
||||
num_heads = 4
|
||||
@ -150,6 +229,54 @@ def _run_sharding_execution_job(
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
elif model_cls == FP8MLP:
|
||||
model = model_cls(num_features, num_features, bias=bias).to("cuda")
|
||||
elif model_cls == NemotronHMamba2Mixer:
|
||||
# Create config for Mamba2 based on Nemotron models
|
||||
# Scaled down from typical values: hidden_size=5120, ssm_state_size=128
|
||||
mamba_config = SimpleNamespace(
|
||||
hidden_size=num_features,
|
||||
ssm_state_size=16, # Scaled from 128
|
||||
mamba_num_heads=num_heads,
|
||||
mamba_head_dim=num_features // num_heads, # 8
|
||||
n_groups=1, # Typical value
|
||||
chunk_size=256,
|
||||
conv_kernel=4,
|
||||
use_conv_bias=bias,
|
||||
use_bias=bias,
|
||||
mamba_hidden_act="silu",
|
||||
layer_norm_epsilon=1e-5,
|
||||
time_step_limit=(0.0, float("inf")),
|
||||
time_step_min=0.001,
|
||||
time_step_max=0.1,
|
||||
time_step_floor=1e-4,
|
||||
initializer_range=0.02,
|
||||
rescale_prenorm_residual=False,
|
||||
residual_in_fp32=False,
|
||||
num_hidden_layers=1,
|
||||
)
|
||||
model = model_cls(mamba_config, layer_idx=0).to(device="cuda", dtype=torch.float16)
|
||||
elif model_cls == MLA_Block:
|
||||
# Use actual DeepSeek-V3/R1 production values
|
||||
# From HuggingFace config (HunYuanPretrainedConfig defaults):
|
||||
# hidden_size=4096, num_attention_heads=32
|
||||
# kv_lora_rank=512, q_lora_rank=1536
|
||||
# qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128
|
||||
num_heads_mla = 16
|
||||
qk_nope_head_dim = 64
|
||||
qk_rope_head_dim = 32
|
||||
v_head_dim = 64
|
||||
kv_lora_rank = 256
|
||||
skip_output_assert = True
|
||||
|
||||
model = model_cls(
|
||||
hidden_size=num_features,
|
||||
num_heads=num_heads_mla,
|
||||
q_lora_rank=kv_lora_rank,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
bias=bias,
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
else:
|
||||
model = model_cls(num_features, num_features, bias=bias).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
@ -178,6 +305,11 @@ def _run_sharding_execution_job(
|
||||
num_params = W_q_local_size + W_k_local_size + W_v_local_size + W_o_local_size
|
||||
else:
|
||||
num_params = num_p_og // world_size + num_update
|
||||
if model_cls == MLA_Block:
|
||||
# since q_a_proj is simple-sharded and followed by q_a_layernorm, the layernorm params
|
||||
# are NOT sharded - they have to be replicated. To account for this, we need to add the
|
||||
# number of parameters of the layernorm (weight and bias)to the number of parameters of the model.
|
||||
num_params += 2 * kv_lora_rank * (world_size - 1) // world_size
|
||||
return num_params
|
||||
|
||||
def verify_local_weight_sizes(gm) -> bool:
|
||||
@ -223,6 +355,7 @@ def _run_sharding_execution_job(
|
||||
gm_transformed,
|
||||
check_transformed_graph=combined_graph_check,
|
||||
_get_expected_num_params=_get_expected_num_params,
|
||||
skip_output_assert=skip_output_assert,
|
||||
)
|
||||
|
||||
|
||||
@ -248,6 +381,47 @@ def _run_pattern_detection_job(
|
||||
hidden_size=num_features,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
elif model_cls == NemotronHMamba2Mixer:
|
||||
# Create config for Mamba2
|
||||
mamba_config = SimpleNamespace(
|
||||
hidden_size=num_features,
|
||||
ssm_state_size=16,
|
||||
mamba_num_heads=num_heads,
|
||||
mamba_head_dim=num_features // num_heads,
|
||||
n_groups=1,
|
||||
chunk_size=256,
|
||||
conv_kernel=4,
|
||||
use_conv_bias=bias,
|
||||
use_bias=bias,
|
||||
mamba_hidden_act="silu",
|
||||
layer_norm_epsilon=1e-5,
|
||||
time_step_limit=(0.0, float("inf")),
|
||||
time_step_min=0.001,
|
||||
time_step_max=0.1,
|
||||
time_step_floor=1e-4,
|
||||
initializer_range=0.02,
|
||||
rescale_prenorm_residual=False,
|
||||
residual_in_fp32=False,
|
||||
num_hidden_layers=1,
|
||||
)
|
||||
model = model_cls(mamba_config, layer_idx=0).to(device="cuda", dtype=torch.float16)
|
||||
elif model_cls == MLA_Block:
|
||||
# Create simplified MLA based on DeepSeek-V3 architecture
|
||||
qk_nope_head_dim = 2
|
||||
qk_rope_head_dim = 1
|
||||
v_head_dim = 2
|
||||
kv_lora_rank = 8
|
||||
|
||||
model = model_cls(
|
||||
hidden_size=num_features,
|
||||
num_heads=num_heads,
|
||||
q_lora_rank=kv_lora_rank,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
bias=bias,
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
else:
|
||||
model = model_cls(num_features, num_features, bias=bias).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
@ -344,6 +518,96 @@ def _run_pattern_detection_job(
|
||||
min_local_shape=1,
|
||||
)
|
||||
)
|
||||
elif model_cls == NemotronHMamba2Mixer:
|
||||
for node in gm.graph.nodes:
|
||||
if is_linear_op(node):
|
||||
# in_proj should be sharded column-wise
|
||||
# out_proj should be sharded row-wise with all_reduce
|
||||
if "out_proj" in node.args[1].name:
|
||||
dim = SplitDimension.ROW
|
||||
dist_op = "all_reduce"
|
||||
fused_weight_dims = None
|
||||
else:
|
||||
dim = SplitDimension.COLUMN
|
||||
dist_op = None
|
||||
fused_weight_dims = (num_features, num_features, 16, 16, num_heads)
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=dim,
|
||||
config=config,
|
||||
dist_op=dist_op,
|
||||
min_local_shape=1,
|
||||
layer_type=LayerType.SSM,
|
||||
fused_weight_dims=fused_weight_dims,
|
||||
)
|
||||
)
|
||||
elif is_op(node, torch.ops.auto_deploy.torch_causal_conv1d):
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
config=config,
|
||||
dist_op=None,
|
||||
min_local_shape=1,
|
||||
layer_type=LayerType.SSM,
|
||||
fused_weight_dims=(num_features, 16, 16),
|
||||
)
|
||||
)
|
||||
elif is_op(node, torch.ops.auto_deploy.torch_ssm):
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
config=config,
|
||||
dist_op=None,
|
||||
min_local_shape=1,
|
||||
layer_type=LayerType.SSM,
|
||||
fused_weight_dims=None,
|
||||
)
|
||||
)
|
||||
elif len(node.args) > 1 and "norm_weight" in node.args[0].name:
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
config=config,
|
||||
dist_op=None,
|
||||
min_local_shape=1,
|
||||
layer_type=LayerType.SSM,
|
||||
fused_weight_dims=None,
|
||||
)
|
||||
)
|
||||
elif model_cls == MLA_Block:
|
||||
for node in gm.graph.nodes:
|
||||
if is_linear_op(node):
|
||||
# kv_a_proj_with_mqa: gather (no sharding)
|
||||
# q_b_proj/kv_b_proj: column-wise
|
||||
# o_proj: row-wise with all_reduce
|
||||
min_local_shape = 2
|
||||
if "o_proj" in node.args[1].name:
|
||||
dim = SplitDimension.ROW
|
||||
dist_op = "all_reduce"
|
||||
elif (
|
||||
"kv_a_proj_with_mqa" in node.args[1].name or "q_a_proj" in node.args[1].name
|
||||
):
|
||||
# This is simple-shard gather
|
||||
dim = SplitDimension.COLUMN
|
||||
dist_op = "all_gather"
|
||||
min_local_shape = 1
|
||||
else:
|
||||
dim = SplitDimension.COLUMN
|
||||
dist_op = None
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=dim,
|
||||
config=config,
|
||||
dist_op=dist_op,
|
||||
min_local_shape=min_local_shape,
|
||||
layer_type=LayerType.MLA,
|
||||
)
|
||||
)
|
||||
|
||||
# get detected transformations
|
||||
optimizer = InferenceOptimizer(
|
||||
@ -378,6 +642,8 @@ def _run_pattern_detection_job(
|
||||
(FP8MLP, "torch_dist_all_reduce"),
|
||||
(nn.Linear, "torch_dist_all_gather"),
|
||||
(GQA_Block, "torch_dist_all_reduce"),
|
||||
(NemotronHMamba2Mixer, "torch_dist_all_reduce"),
|
||||
(MLA_Block, "torch_dist_all_reduce"),
|
||||
),
|
||||
)
|
||||
def test_sharding(
|
||||
@ -403,6 +669,8 @@ def test_sharding(
|
||||
(FP8MLP, "torch_dist_all_reduce"),
|
||||
(nn.Linear, "torch_dist_all_gather"),
|
||||
(GQA_Block, "torch_dist_all_reduce"),
|
||||
(NemotronHMamba2Mixer, "torch_dist_all_reduce"),
|
||||
(MLA_Block, "torch_dist_all_reduce"),
|
||||
),
|
||||
)
|
||||
def test_sharding_pattern_detection(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user