Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Bo Deng <deemod@nvidia.com>
Co-authored-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Bo Deng 2026-02-04 11:24:21 +08:00 committed by GitHub
parent d248aef751
commit 910c070e88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 2 deletions

View File

@ -120,6 +120,11 @@ __global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues
auto warp = cg::tiled_partition<WARP_SIZE>(block);
BaseType minScore = BaseType{-INFINITY};
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaGridDependencySynchronize();
#endif
for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum)
{
auto scoreOffset = tokenId * numExperts;
@ -168,6 +173,10 @@ __global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues
}
}
} // end for tokenId
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
int nextPowerOfTwo(int num)

View File

@ -39,6 +39,8 @@ from .modeling_utils import (DecoderModel, duplicate_kv_weight, filter_weights,
# Use TinyGEMM when the number of tokens is not larger than this threshold
MIN_LATENCY_TINYGEMM_NUM_TOKENS = 128
# Enable TinyGEMM optimization (disabled by default, set ENABLE_TINYGEMM=1 to enable)
ENABLE_TINYGEMM = os.environ.get('ENABLE_TINYGEMM', '0') == '1'
class AttentionBlock(Attention):
@ -226,7 +228,7 @@ class MLPBlock(torch.nn.Module):
dtype=pretrained_config.torch_dtype)
def compute_gate_output(self, x: torch.Tensor) -> torch.Tensor:
if get_sm_version() in [
if ENABLE_TINYGEMM and get_sm_version() in [
90, 100, 103
] and x.shape[0] <= MIN_LATENCY_TINYGEMM_NUM_TOKENS:
weight = self.gate.weight

View File

@ -1530,7 +1530,8 @@ class PyTorchModelEngine(ModelEngine):
num_draft_tokens = len(draft_tokens)
total_num_tokens = len(position_ids)
assert total_num_tokens <= self.max_num_tokens, (
"total_num_tokens should be less than or equal to max_num_tokens")
f"total_num_tokens ({total_num_tokens}) should be less than or equal to max_num_tokens ({self.max_num_tokens})"
)
# if exist requests that do not have previous batch, copy input_ids and draft_tokens
if num_tokens > 0:
input_ids = torch.tensor(input_ids,