fix some failing tests

Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
This commit is contained in:
Perkz Zheng 2025-12-24 13:06:47 +00:00
parent fe96fd7524
commit 55a7b4db1d
4 changed files with 5 additions and 5 deletions

View File

@ -125,7 +125,7 @@ bool FmhaDispatcher::isSupported()
{
tllmRunnerParams.mSparseMla = true;
tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Dense;
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
}
foundKernels = mTllmGenFMHARunner->isSupported(tllmRunnerParams);
@ -237,7 +237,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
tllmRunnerParams.mSparseMla = true;
tllmRunnerParams.mSparseMlaTopK = runnerParams.sparse_params.sparse_mla_topk;
tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Dense;
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
tllmRunnerParams.kvPageIdxPtr
= reinterpret_cast<int const*>(runnerParams.sparse_params.sparse_attn_indices);
tllmRunnerParams.kvPtr = runnerParams.sparse_params.sparse_mla_kv_cache_pool;

View File

@ -71,7 +71,7 @@ private:
private:
int32_t mDivisor = 1;
uint32_t mMultiplier = 0;
[[maybe_unused]] uint32_t mAdd = 0;
__attribute__((unused)) uint32_t mAdd = 0;
int32_t mShift = 0;
};

View File

@ -269,7 +269,7 @@ class Scenario:
rope_original_max_position_embeddings: int = 4096
rope_type: str = "yarn"
model_type: str = "deepseek_v3"
kv_cache_tokens_per_block: int = 64
kv_cache_tokens_per_block: int = 32
@dataclass(kw_only=True, frozen=True)

View File

@ -407,7 +407,7 @@ class TestLlama(unittest.TestCase):
llama = LlamaForCausalLM(model_config).to(dtype).to(device)
llama.load_weights(hf_llama.state_dict())
num_blocks = 1
tokens_per_block = 64
tokens_per_block = 32
head_dim = llama.config.hidden_size // llama.config.num_attention_heads
num_layers = llama.config.num_hidden_layers
num_kv_heads = llama.config.num_key_value_heads