diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index a50d475681..4e860f6abb 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -356,6 +356,7 @@ class RopeParams: mscale_all_dim: float = 0.0 short_factor: Optional[Tuple[float]] = None long_factor: Optional[Tuple[float]] = None + max_seq_len: Optional[int] = None @staticmethod def from_config(config) -> "RopeParams": @@ -406,6 +407,8 @@ class RopeParams: # Workaround for DeepSeek V3 Lite since its rope_scaling is null in config.json. elif config.model_type == "deepseek_v3": rope_params.scale_type = RotaryScalingType.yarn + # Other metdadata for RoPE. + rope_params.max_seq_len = getattr(config, 'max_seq_len', None) return rope_params @@ -439,13 +442,14 @@ class RopeParams: self.mscale_all_dim, ) elif self.scale_type == RotaryScalingType.longrope: - rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_long_rope_for_attention_plugin( + rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_long_rope( num_pos=self.max_positions, dim=self.dim, theta=self.theta, original_max_pos=self.original_max_positions, short_factor=self.short_factor, long_factor=self.long_factor, + max_seq_len=self.max_seq_len, ) else: rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 3d0175a3c2..b84b345f82 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -75,6 +75,7 @@ class ModelConfig(Generic[TConfig]): is_generation: bool = True max_num_tokens: int = 8192 + max_seq_len: Optional[int] = None moe_max_num_tokens: Optional[int] = None moe_load_balancer: Optional[MoeLoadBalancerConfig] = None diff --git a/tensorrt_llm/_torch/models/modeling_phi3.py b/tensorrt_llm/_torch/models/modeling_phi3.py index 5e4221dd71..d778869a70 100644 --- a/tensorrt_llm/_torch/models/modeling_phi3.py +++ b/tensorrt_llm/_torch/models/modeling_phi3.py @@ -29,7 +29,8 @@ class Phi3Attention(Attention): layer_idx: Optional[int] = None, ): config = model_config.pretrained_config - + # Pass max_seq_len to config for LongRoPE. + config.max_seq_len = model_config.max_seq_len rope_params = RopeParams.from_config(config) super().__init__( hidden_size=config.hidden_size, @@ -66,6 +67,7 @@ class Phi3DecoderLayer(DecoderLayer): bias=False, dtype=config.torch_dtype, config=model_config, + layer_idx=layer_idx, ) self.input_layernorm = RMSNorm( @@ -107,7 +109,11 @@ class Phi3DecoderLayer(DecoderLayer): # Fully connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = self.mlp( + hidden_states, + lora_params=lora_params, + **kwargs, + ) return hidden_states, residual diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 2ba4cafeda..3ac132767f 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -978,6 +978,7 @@ class PyTorchModelEngine(ModelEngine): force_dynamic_quantization, spec_config=self.spec_config, max_num_tokens=max_num_tokens, + max_seq_len=self.max_seq_len, moe_max_num_tokens=moe_max_num_tokens, moe_load_balancer=moe_load_balancer, lora_config=lora_config, diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 0532b995c5..06880bc430 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -4779,7 +4779,7 @@ class RopeEmbeddingUtils: return inv_freq, concat.reshape(1, -1).astype(dtype) - def create_sinusoidal_positions_long_rope( + def create_sinusoidal_positions_long_rope_for_attention_plugin( num_pos: int, num_orig_pos: int, dim: int, @@ -4835,37 +4835,34 @@ class RopeEmbeddingUtils: scaling_long_factors, False, True), short_mscale @staticmethod - def create_sinusoidal_positions_long_rope_for_attention_plugin( + def create_sinusoidal_positions_long_rope( num_pos: int, dim: int, theta: float, original_max_pos: int, short_factor: List[float], long_factor: List[float], - dtype=np.float32): + dtype=np.float32, + max_seq_len: Optional[int] = None): short_factor = np.array(short_factor, dtype=np.float32) long_factor = np.array(long_factor, dtype=np.float32) inv_freq = 1.0 / (theta**(np.arange(0, dim, 2, dtype=np.float32) / dim)) + t_pos = np.arange(np.max([num_pos, original_max_pos]), dtype=np.float32) - # Short part - inv_freq_short = inv_freq / short_factor - t_short = np.arange(np.min([num_pos, original_max_pos]), - dtype=np.float32) - freqs_short = np.einsum("i,j->ij", t_short, inv_freq_short) - - # Long part - inv_freq_long = inv_freq / long_factor - t_long = np.arange(np.max([0, num_pos - original_max_pos]), - dtype=np.float32) + original_max_pos - freqs_long = np.einsum("i,j->ij", t_long, inv_freq_long) - - freqs = np.concatenate([freqs_short, freqs_long], axis=0) + # Choose proper freqs based on max_seq_len. + factor = long_factor if max_seq_len is None or max_seq_len > original_max_pos else short_factor + inv_freq = inv_freq / factor + freqs = np.einsum("i,j->ij", t_pos, inv_freq) sinusoid_inp = freqs.astype(np.float32)[..., np.newaxis] # Apply scaling scale = num_pos / original_max_pos - scaling_factor = np.sqrt(1.0 + np.log(scale) / np.log(original_max_pos)) + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = np.sqrt(1.0 + + np.log(scale) / np.log(original_max_pos)) # fuse cos/sin into float2 (cos, sin). concat = np.concatenate( diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index ffe0759078..ebfaa8fdea 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -631,7 +631,7 @@ class Attention(Module): embed_positions, long_rope_embed_positions, \ (rotary_inv_freq, embed_positions_for_gpt_attention), \ (long_rope_rotary_inv_freq, long_rope_embed_positions_for_gpt_attention), mscale \ - = RopeEmbeddingUtils.create_sinusoidal_positions_long_rope( + = RopeEmbeddingUtils.create_sinusoidal_positions_long_rope_for_attention_plugin( max_position_embeddings, original_max_position_embeddings, rotary_embedding_dim, rotary_embedding_base, rope_scaling_short_factors, diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index a0086fc2a4..2694c3ea1e 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -136,5 +136,7 @@ mistralai/Mistral-Small-3.1-24B-Instruct-2503: - accuracy: 89.23 microsoft/Phi-4-multimodal-instruct: - accuracy: 81.19 +microsoft/Phi-4-multimodal-instruct-long-rope: + - accuracy: 75.85 microsoft/Phi-4-mini-instruct: - accuracy: 82.30 diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 4e91d222bf..5c44fe18d5 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -220,3 +220,5 @@ mistralai/Ministral-8B-Instruct-2410: accuracy: 65.96 microsoft/Phi-4-multimodal-instruct: - accuracy: 69.69 +microsoft/Phi-4-multimodal-instruct-long-rope: + - accuracy: 65.98 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 0dae6f7e97..8fc8dfe24b 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2055,8 +2055,19 @@ class TestPhi4MM(LlmapiAccuracyTestHarness): MODEL_PATH = f"{llm_models_root()}/multimodals/Phi-4-multimodal-instruct" def test_auto_dtype(self): - with LLM(self.MODEL_PATH) as llm: - task = MMLU(self.MODEL_NAME) + # Set max_seq_len to 4096 to use short rope factor. + model_name = "microsoft/Phi-4-multimodal-instruct" + with LLM(self.MODEL_PATH, max_seq_len=4096) as llm: + task = MMLU(model_name) task.evaluate(llm) - task = GSM8K(self.MODEL_NAME) + task = GSM8K(model_name) + task.evaluate(llm) + + def test_auto_dtype_long_rope(self): + # Set max_seq_len larger than 4096 to use long rope factor. + model_name = "microsoft/Phi-4-multimodal-instruct-long-rope" + with LLM(self.MODEL_PATH, max_seq_len=8192) as llm: + task = MMLU(model_name) + task.evaluate(llm) + task = GSM8K(model_name) task.evaluate(llm) diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index b7aa5821a5..58d98d8604 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -2207,15 +2207,15 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality): } expected_keywords = { "image": [ - ["clear", "sunny", "sky", "image", "object"], - ["road", "car", "lane", "strip", "bus"], + ["image", "depicts", "mountain", "half", "rock"], + ["road", "car", "lane", "traffic", "bus"], ], "audio": [ ["what", "is", "the", "traffic", "sign", "in", "image"], ["what", "is", "shown", "in", "this", "image"], ], "image_audio": [ - ["Half", "Dome", "Park", "natural", "image"], + ["image", "depicts", "Grand", "rock", "scene"], ], } @@ -2229,6 +2229,8 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality): *accuracy_inputs[modality]["prompt"], "--media", *accuracy_inputs[modality]["media"], + # Set max_seq_len to 4096 to use short rope factor. + "--max_seq_len=4096", "--load_lora", "--auto_model_name", "Phi4MMForCausalLM", diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index eaebfb67c5..c485465e16 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -495,6 +495,7 @@ accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_fp8 accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8 accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype +accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype_long_rope accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype test_e2e.py::test_llama_e2e[use_cpp_session-remove_input_padding-]