diff --git a/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp b/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp index 58bfdb9ac0..68e3e4d600 100644 --- a/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp +++ b/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp @@ -125,7 +125,7 @@ bool FmhaDispatcher::isSupported() { tllmRunnerParams.mSparseMla = true; tllmRunnerParams.mKernelType = FmhaKernelType::Generation; - tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Dense; + tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal; } foundKernels = mTllmGenFMHARunner->isSupported(tllmRunnerParams); @@ -237,7 +237,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams) tllmRunnerParams.mSparseMla = true; tllmRunnerParams.mSparseMlaTopK = runnerParams.sparse_params.sparse_mla_topk; tllmRunnerParams.mKernelType = FmhaKernelType::Generation; - tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Dense; + tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal; tllmRunnerParams.kvPageIdxPtr = reinterpret_cast(runnerParams.sparse_params.sparse_attn_indices); tllmRunnerParams.kvPtr = runnerParams.sparse_params.sparse_mla_kv_cache_pool; diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h index 4324e35838..524682f6d8 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h @@ -71,7 +71,7 @@ private: private: int32_t mDivisor = 1; uint32_t mMultiplier = 0; - [[maybe_unused]] uint32_t mAdd = 0; + __attribute__((unused)) uint32_t mAdd = 0; int32_t mShift = 0; }; diff --git a/tests/unittest/_torch/attention/test_attention_mla.py b/tests/unittest/_torch/attention/test_attention_mla.py index 1e0966e4ff..9f8a941c8b 100644 --- a/tests/unittest/_torch/attention/test_attention_mla.py +++ b/tests/unittest/_torch/attention/test_attention_mla.py @@ -269,7 +269,7 @@ class Scenario: rope_original_max_position_embeddings: int = 4096 rope_type: str = "yarn" model_type: str = "deepseek_v3" - kv_cache_tokens_per_block: int = 64 + kv_cache_tokens_per_block: int = 32 @dataclass(kw_only=True, frozen=True) diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index cad865087b..50b1587690 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -407,7 +407,7 @@ class TestLlama(unittest.TestCase): llama = LlamaForCausalLM(model_config).to(dtype).to(device) llama.load_weights(hf_llama.state_dict()) num_blocks = 1 - tokens_per_block = 64 + tokens_per_block = 32 head_dim = llama.config.hidden_size // llama.config.num_attention_heads num_layers = llama.config.num_hidden_layers num_kv_heads = llama.config.num_key_value_heads