[TRTLLM-6342][fix] Fixed triggering BMM sharding (#7389)

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
Grzegorz Kwasniewski 2025-09-04 08:01:27 +02:00 committed by GitHub
parent c622f61609
commit 3755f8ab7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 3 additions and 2 deletions

View File

@ -56,7 +56,7 @@ transforms:
stage: sharding
simple_shard_only: false
use_sharding_from_factory: false
sharding_dims: ['tp', 'ep', 'dp']
sharding_dims: ['tp', 'ep', 'bmm']
# TODO: (hg) need to ensure run_shape_prop after sharding.
sharding_transform_executor:
stage: sharding

View File

@ -166,7 +166,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
)
sharding_dims: List[str] = Field(
default=["tp", "ep", "dp"],
default=["tp", "ep", "bmm"],
description="The sharding methods to apply by the heuristic sharding stage.",
)

View File

@ -67,6 +67,7 @@ def _run_job(
"detect_sharding": {
"stage": "sharding",
"use_sharding_from_factory": False,
"sharding_dims": ["bmm"],
},
"sharding_transform_executor": {
"stage": "sharding",