mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix some failing tests
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
This commit is contained in:
parent
fe96fd7524
commit
55a7b4db1d
@ -125,7 +125,7 @@ bool FmhaDispatcher::isSupported()
|
||||
{
|
||||
tllmRunnerParams.mSparseMla = true;
|
||||
tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
|
||||
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Dense;
|
||||
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
|
||||
}
|
||||
|
||||
foundKernels = mTllmGenFMHARunner->isSupported(tllmRunnerParams);
|
||||
@ -237,7 +237,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
|
||||
tllmRunnerParams.mSparseMla = true;
|
||||
tllmRunnerParams.mSparseMlaTopK = runnerParams.sparse_params.sparse_mla_topk;
|
||||
tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
|
||||
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Dense;
|
||||
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
|
||||
tllmRunnerParams.kvPageIdxPtr
|
||||
= reinterpret_cast<int const*>(runnerParams.sparse_params.sparse_attn_indices);
|
||||
tllmRunnerParams.kvPtr = runnerParams.sparse_params.sparse_mla_kv_cache_pool;
|
||||
|
||||
@ -71,7 +71,7 @@ private:
|
||||
private:
|
||||
int32_t mDivisor = 1;
|
||||
uint32_t mMultiplier = 0;
|
||||
[[maybe_unused]] uint32_t mAdd = 0;
|
||||
__attribute__((unused)) uint32_t mAdd = 0;
|
||||
int32_t mShift = 0;
|
||||
};
|
||||
|
||||
|
||||
@ -269,7 +269,7 @@ class Scenario:
|
||||
rope_original_max_position_embeddings: int = 4096
|
||||
rope_type: str = "yarn"
|
||||
model_type: str = "deepseek_v3"
|
||||
kv_cache_tokens_per_block: int = 64
|
||||
kv_cache_tokens_per_block: int = 32
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
|
||||
@ -407,7 +407,7 @@ class TestLlama(unittest.TestCase):
|
||||
llama = LlamaForCausalLM(model_config).to(dtype).to(device)
|
||||
llama.load_weights(hf_llama.state_dict())
|
||||
num_blocks = 1
|
||||
tokens_per_block = 64
|
||||
tokens_per_block = 32
|
||||
head_dim = llama.config.hidden_size // llama.config.num_attention_heads
|
||||
num_layers = llama.config.num_hidden_layers
|
||||
num_kv_heads = llama.config.num_key_value_heads
|
||||
|
||||
Loading…
Reference in New Issue
Block a user