From 0434db5bf75dfd01fe575a79c27d9260b597f167 Mon Sep 17 00:00:00 2001 From: Daniil Date: Wed, 21 Jan 2026 22:42:09 -0500 Subject: [PATCH] [None][feat] GLM-4.5-Air support (#10653) Signed-off-by: Daniil Kulko --- tensorrt_llm/_torch/models/modeling_glm.py | 35 +++++- .../defs/accuracy/references/gsm8k.yaml | 8 ++ .../defs/accuracy/test_llm_api_pytorch.py | 106 ++++++++++++++++++ .../test_lists/qa/llm_function_core.txt | 4 + 4 files changed, 149 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_glm.py b/tensorrt_llm/_torch/models/modeling_glm.py index 3c97997352..53d45d3edc 100644 --- a/tensorrt_llm/_torch/models/modeling_glm.py +++ b/tensorrt_llm/_torch/models/modeling_glm.py @@ -24,6 +24,7 @@ from ..distributed import ( MoEAllReduceParams, ) from ..model_config import ModelConfig +from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.fused_moe import MoE, MoEWeightLoadingMode, create_moe @@ -201,6 +202,32 @@ class Glm4Attention(QKNormRoPEAttention): ) +class Glm4AirAttention(Attention): + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: Optional[int] = None, + ): + config = model_config.pretrained_config + pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.yarn, + rope=RopeParams.from_config(config), + ) + + super().__init__( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + bias=config.attention_bias, + pos_embd_params=pos_embd_params, + layer_idx=layer_idx, + dtype=config.torch_dtype, + dense_bias=False, + config=model_config, + ) + + class Glm4MoE(nn.Module): def __init__( self, @@ -436,10 +463,10 @@ class Glm4DecoderLayer(DecoderLayer): # KVCacheManager only support 1 layer for separate draft engine layer_idx_for_attention = layer_idx - model_config.pretrained_config.num_hidden_layers - self.self_attn = Glm4Attention( - model_config, - layer_idx=layer_idx_for_attention, - ) + if getattr(config, "use_qk_norm", False) and config.use_qk_norm: + self.self_attn = Glm4Attention(model_config, layer_idx=layer_idx_for_attention) + else: + self.self_attn = Glm4AirAttention(model_config, layer_idx=layer_idx_for_attention) self.enable_attention_dp = mapping.enable_attention_dp self.mlp_tp_size = mapping.tp_size diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index a4365a5ce9..aea77ba411 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -293,6 +293,14 @@ zai-org/GLM-4.6: - quant_algo: NVFP4 spec_dec_algo: MTP accuracy: 88.0 +zai-org/GLM-4.5-Air: + - accuracy: 91.0 + - spec_dec_algo: MTP + accuracy: 91.2 + - quant_algo: NVFP4 + spec_dec_algo: MTP + kv_cache_quant_algo: FP8 + accuracy: 88.2 bigcode/starcoder2-3b: - accuracy: 20.2 bigcode/starcoder2-7b: diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 3abf542338..ed5ba3f45e 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2971,6 +2971,112 @@ class TestGLM4_6(LlmapiAccuracyTestHarness): task.evaluate(llm) +@skip_pre_blackwell +class TestGLM4_5Air(LlmapiAccuracyTestHarness): + MODEL_NAME = "zai-org/GLM-4.5-Air" + MODEL_PATH = f"{llm_models_root()}/GLM-4.5-Air" + + @pytest.mark.timeout(14400) + @pytest.mark.skip_less_device_memory(80000) + @pytest.mark.skip_less_device(2) + @parametrize_with_ids("mtp_nextn", [0, 2]) + @parametrize_with_ids("overlap_scheduler", [False, True]) + @parametrize_with_ids("tp_size, ep_size", [(2, 2), (2, 1)]) + @parametrize_with_ids("max_batch_size, moe_backend", [(4, "CUTLASS")]) + def test_bfloat16_2gpus(self, tp_size, ep_size, mtp_nextn, + overlap_scheduler, max_batch_size, moe_backend): + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + moe_config=MoeConfig(backend=moe_backend), + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70) + + mtp_config = None + if mtp_nextn > 0: + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn) + + with LLM(self.MODEL_PATH, + max_batch_size=max_batch_size, + tensor_parallel_size=tp_size, + moe_expert_parallel_size=ep_size, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=True, + max_num_tokens=512, + **pytorch_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device(2) + @pytest.mark.parametrize( + "tp_size,pp_size,mtp_nextn,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend", + [ + pytest.param(2, 1, 2, True, True, True, 16, "CUTLASS"), + pytest.param(2, 1, 2, True, True, True, 16, "TRTLLM") + ], + ids=["throughput", "throughput_trtllm"]) + def test_nvfp4_multi_gpus(self, tp_size, pp_size, mtp_nextn, cuda_graph, + overlap_scheduler, chunked_prefill, + max_batch_size, moe_backend): + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70) + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + moe_config=MoeConfig(backend=moe_backend)) + + mtp_config = None + if mtp_nextn > 0: + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn) + + with LLM(f"{llm_models_root()}/glm-4.5-air-fp4", + max_batch_size=max_batch_size, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + kv_cache_config=kv_cache_config, + **pytorch_config, + speculative_config=mtp_config, + enable_chunked_prefill=chunked_prefill) as llm: + + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device(2) + @pytest.mark.parametrize( + "tp_size,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend", + [ + pytest.param(2, True, True, True, 16, "CUTLASS"), + pytest.param(2, True, True, True, 16, "TRTLLM"), + ], + ids=["2model", "2model_trtllm"]) + def test_nvfp4_2_model_mtp(self, tp_size, cuda_graph, overlap_scheduler, + chunked_prefill, max_batch_size, moe_backend): + + model_path = f"{llm_models_root()}/glm-4.5-air-fp4" + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70) + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + moe_config=MoeConfig(backend=moe_backend)) + + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3, + mtp_eagle_one_model=False, + speculative_model_dir=model_path) + + with LLM(model_path, + max_batch_size=max_batch_size, + tensor_parallel_size=tp_size, + kv_cache_config=kv_cache_config, + **pytorch_config, + speculative_config=mtp_config, + enable_chunked_prefill=chunked_prefill) as llm: + + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.threadleak(enabled=False) @pytest.mark.timeout(10800) @pytest.mark.skip_less_device_memory(100000) diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index cf7157dc69..be3744956a 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -226,6 +226,10 @@ accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput] accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput_trtllm] accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model] accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model_trtllm] +accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_multi_gpus[throughput] +accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_multi_gpus[throughput_trtllm] +accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_2_model_mtp[2model] +accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_2_model_mtp[2model_trtllm] accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestKanana_Instruct::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]