[TRTLLM-6342][feat] Support for partial sharding from factory (#7393)

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
Grzegorz Kwasniewski 2025-09-19 18:07:42 +02:00 committed by GitHub
parent 8fcd11515d
commit 8adaf0bb78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 69 additions and 8 deletions

View File

@ -69,6 +69,7 @@ transforms:
stage: sharding
simple_shard_only: false
use_sharding_from_factory: false
support_partial_config: false
sharding_dims: ['tp', 'ep', 'bmm']
# TODO: (hg) need to ensure run_shape_prop after sharding.
sharding_transform_executor:

View File

@ -164,7 +164,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
)
sharding_dims: List[str] = Field(
default=["tp", "ep", "bmm"],
default=["tp", "ep", "dp"],
description="The sharding methods to apply by the heuristic sharding stage.",
)

View File

@ -127,6 +127,7 @@ class ShardingTransformConfig(TransformConfig):
simple_shard_only: bool = Field(default=False)
use_sharding_from_factory: bool = Field(default=False)
support_partial_config: bool = Field(default=False)
# Which sharding families to run: any subset of {"tp", "ep", "bmm"}
sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"])
@ -185,6 +186,9 @@ class Sharding(BaseTransform):
else ShardingConfigSource.UNKNOWN
)
shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only
shared_config.sharding_config.support_partial_config = self.config.support_partial_config
shared_config.sharding_config.sharding_dims = self.config.sharding_dims
shared_config.sharding_config.use_sharding_from_factory = (
self.config.use_sharding_from_factory
)
@ -200,8 +204,6 @@ class Sharding(BaseTransform):
factory_info = detect_sharding_from_factory_config(gm, sharding_config)
return gm, factory_info
shared_config.sharding_config.sharding_dims = self.config.sharding_dims
ad_logger.info(
f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}"
)
@ -338,8 +340,39 @@ def detect_sharding_from_factory_config(
# TODO: Sequence parallelism is not supported yet.
ad_logger.warning("Sequence parallelism is not supported yet. Skipping.")
elif "local" in config:
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
# Check if this applies to shared experts in EP parallelism.
# If yes, apply the TP col-row shard.
if "shared" in module_name:
col_row_action = config.replace("local_", "")
if col_row_action == "colwise":
sharding_config.tp_transforms.append(
TPShardingInfo(
target_node=lin_node.name,
split_dim=SplitDimension.COLUMN,
rank=rank,
world_size=world_size,
dist_op=None,
min_local_shape=min_local_shape,
)
)
elif col_row_action == "rowwise":
sharding_config.tp_transforms.append(
TPShardingInfo(
target_node=lin_node.name,
split_dim=SplitDimension.ROW,
rank=rank,
world_size=world_size,
dist_op="all_reduce",
min_local_shape=min_local_shape,
)
)
num_row_col_shards += 1
else:
ad_logger.warning("Invalid sharding config. Skipping.")
else:
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
elif "gather" in config:
# Simple shard (row + all_gather)
sharding_config.tp_transforms.append(
@ -362,9 +395,35 @@ def detect_sharding_from_factory_config(
f"Applied {num_shards} TP shards (simple: {num_simple_shards}, "
f"row-col pattern: {num_row_col_shards})"
)
num_matches = len(sharding_config.tp_transforms)
if sharding_config.support_partial_config:
ad_logger.info(
f"Partial factory config applied only for TP. "
f"Applying heuristics for {sharding_config.sharding_dims}."
)
# run EP sharding across ranks
if "ep" in sharding_config.sharding_dims:
ep_info = detect_ep_shard(gm, sharding_config)
else:
ep_info = TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
# run BMM sharding across ranks
if "bmm" in sharding_config.sharding_dims:
dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config)
else:
dp_bmm_info = TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
num_matches += ep_info.num_matches + dp_bmm_info.num_matches
return TransformInfo(
skipped=False,
num_matches=len(sharding_config.tp_transforms),
num_matches=num_matches,
is_clean=False,
has_valid_shapes=False,
)

View File

@ -737,6 +737,7 @@ class ShardingConfig(BaseModel):
predefined_config: Optional[Dict[str, Any]] = None
simple_shard_only: bool = Field(default=False)
use_sharding_from_factory: bool = False
support_partial_config: bool = False
sharding_dims: List[str] = Field(default_factory=list)
tp_transforms: List[TPShardingInfo] = Field(default_factory=list)
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)
@ -781,7 +782,7 @@ class ShardingConfig(BaseModel):
tp_plan = self.predefined_config["tp_plan"]
values = set(tp_plan.values())
allowed_values = {
supported_modes = {
"colwise", # row split and no collective
"rowwise", # column split and all-reduce
"gather", # simple shard (row + all_gather)
@ -793,7 +794,7 @@ class ShardingConfig(BaseModel):
# "local_packed_rowwise",
# "local",
}
if not values.issubset(allowed_values):
if not self.support_partial_config and not values.issubset(supported_modes):
ad_logger.warning("Sharding config contains invalid values. Skipping.")
# invalidate the config
self.predefined_config = {}