mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[None][feat] GLM-4.5-Air support (#10653)
Signed-off-by: Daniil Kulko <kulkodaniil@gmail.com>
This commit is contained in:
parent
bd56b4e1e3
commit
0434db5bf7
@ -24,6 +24,7 @@ from ..distributed import (
|
||||
MoEAllReduceParams,
|
||||
)
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import MoE, MoEWeightLoadingMode, create_moe
|
||||
@ -201,6 +202,32 @@ class Glm4Attention(QKNormRoPEAttention):
|
||||
)
|
||||
|
||||
|
||||
class Glm4AirAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.yarn,
|
||||
rope=RopeParams.from_config(config),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=False,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
|
||||
class Glm4MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -436,10 +463,10 @@ class Glm4DecoderLayer(DecoderLayer):
|
||||
# KVCacheManager only support 1 layer for separate draft engine
|
||||
layer_idx_for_attention = layer_idx - model_config.pretrained_config.num_hidden_layers
|
||||
|
||||
self.self_attn = Glm4Attention(
|
||||
model_config,
|
||||
layer_idx=layer_idx_for_attention,
|
||||
)
|
||||
if getattr(config, "use_qk_norm", False) and config.use_qk_norm:
|
||||
self.self_attn = Glm4Attention(model_config, layer_idx=layer_idx_for_attention)
|
||||
else:
|
||||
self.self_attn = Glm4AirAttention(model_config, layer_idx=layer_idx_for_attention)
|
||||
self.enable_attention_dp = mapping.enable_attention_dp
|
||||
|
||||
self.mlp_tp_size = mapping.tp_size
|
||||
|
||||
@ -293,6 +293,14 @@ zai-org/GLM-4.6:
|
||||
- quant_algo: NVFP4
|
||||
spec_dec_algo: MTP
|
||||
accuracy: 88.0
|
||||
zai-org/GLM-4.5-Air:
|
||||
- accuracy: 91.0
|
||||
- spec_dec_algo: MTP
|
||||
accuracy: 91.2
|
||||
- quant_algo: NVFP4
|
||||
spec_dec_algo: MTP
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 88.2
|
||||
bigcode/starcoder2-3b:
|
||||
- accuracy: 20.2
|
||||
bigcode/starcoder2-7b:
|
||||
|
||||
@ -2971,6 +2971,112 @@ class TestGLM4_6(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
class TestGLM4_5Air(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "zai-org/GLM-4.5-Air"
|
||||
MODEL_PATH = f"{llm_models_root()}/GLM-4.5-Air"
|
||||
|
||||
@pytest.mark.timeout(14400)
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@parametrize_with_ids("mtp_nextn", [0, 2])
|
||||
@parametrize_with_ids("overlap_scheduler", [False, True])
|
||||
@parametrize_with_ids("tp_size, ep_size", [(2, 2), (2, 1)])
|
||||
@parametrize_with_ids("max_batch_size, moe_backend", [(4, "CUTLASS")])
|
||||
def test_bfloat16_2gpus(self, tp_size, ep_size, mtp_nextn,
|
||||
overlap_scheduler, max_batch_size, moe_backend):
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
moe_config=MoeConfig(backend=moe_backend),
|
||||
)
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70)
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
|
||||
|
||||
with LLM(self.MODEL_PATH,
|
||||
max_batch_size=max_batch_size,
|
||||
tensor_parallel_size=tp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_tokens=512,
|
||||
**pytorch_config,
|
||||
speculative_config=mtp_config) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,mtp_nextn,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend",
|
||||
[
|
||||
pytest.param(2, 1, 2, True, True, True, 16, "CUTLASS"),
|
||||
pytest.param(2, 1, 2, True, True, True, 16, "TRTLLM")
|
||||
],
|
||||
ids=["throughput", "throughput_trtllm"])
|
||||
def test_nvfp4_multi_gpus(self, tp_size, pp_size, mtp_nextn, cuda_graph,
|
||||
overlap_scheduler, chunked_prefill,
|
||||
max_batch_size, moe_backend):
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70)
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
moe_config=MoeConfig(backend=moe_backend))
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
|
||||
|
||||
with LLM(f"{llm_models_root()}/glm-4.5-air-fp4",
|
||||
max_batch_size=max_batch_size,
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
speculative_config=mtp_config,
|
||||
enable_chunked_prefill=chunked_prefill) as llm:
|
||||
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend",
|
||||
[
|
||||
pytest.param(2, True, True, True, 16, "CUTLASS"),
|
||||
pytest.param(2, True, True, True, 16, "TRTLLM"),
|
||||
],
|
||||
ids=["2model", "2model_trtllm"])
|
||||
def test_nvfp4_2_model_mtp(self, tp_size, cuda_graph, overlap_scheduler,
|
||||
chunked_prefill, max_batch_size, moe_backend):
|
||||
|
||||
model_path = f"{llm_models_root()}/glm-4.5-air-fp4"
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70)
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
moe_config=MoeConfig(backend=moe_backend))
|
||||
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3,
|
||||
mtp_eagle_one_model=False,
|
||||
speculative_model_dir=model_path)
|
||||
|
||||
with LLM(model_path,
|
||||
max_batch_size=max_batch_size,
|
||||
tensor_parallel_size=tp_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
speculative_config=mtp_config,
|
||||
enable_chunked_prefill=chunked_prefill) as llm:
|
||||
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
@pytest.mark.timeout(10800)
|
||||
@pytest.mark.skip_less_device_memory(100000)
|
||||
|
||||
@ -226,6 +226,10 @@ accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput_trtllm]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model_trtllm]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_multi_gpus[throughput]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_multi_gpus[throughput_trtllm]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_2_model_mtp[2model]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_2_model_mtp[2model_trtllm]
|
||||
accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestKanana_Instruct::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user