mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6342][feat] Support for partial sharding from factory (#7393)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Signed-off-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
parent
8fcd11515d
commit
8adaf0bb78
@ -69,6 +69,7 @@ transforms:
|
||||
stage: sharding
|
||||
simple_shard_only: false
|
||||
use_sharding_from_factory: false
|
||||
support_partial_config: false
|
||||
sharding_dims: ['tp', 'ep', 'bmm']
|
||||
# TODO: (hg) need to ensure run_shape_prop after sharding.
|
||||
sharding_transform_executor:
|
||||
|
||||
@ -164,7 +164,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
)
|
||||
|
||||
sharding_dims: List[str] = Field(
|
||||
default=["tp", "ep", "bmm"],
|
||||
default=["tp", "ep", "dp"],
|
||||
description="The sharding methods to apply by the heuristic sharding stage.",
|
||||
)
|
||||
|
||||
|
||||
@ -127,6 +127,7 @@ class ShardingTransformConfig(TransformConfig):
|
||||
|
||||
simple_shard_only: bool = Field(default=False)
|
||||
use_sharding_from_factory: bool = Field(default=False)
|
||||
support_partial_config: bool = Field(default=False)
|
||||
# Which sharding families to run: any subset of {"tp", "ep", "bmm"}
|
||||
sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"])
|
||||
|
||||
@ -185,6 +186,9 @@ class Sharding(BaseTransform):
|
||||
else ShardingConfigSource.UNKNOWN
|
||||
)
|
||||
shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only
|
||||
shared_config.sharding_config.support_partial_config = self.config.support_partial_config
|
||||
shared_config.sharding_config.sharding_dims = self.config.sharding_dims
|
||||
|
||||
shared_config.sharding_config.use_sharding_from_factory = (
|
||||
self.config.use_sharding_from_factory
|
||||
)
|
||||
@ -200,8 +204,6 @@ class Sharding(BaseTransform):
|
||||
factory_info = detect_sharding_from_factory_config(gm, sharding_config)
|
||||
return gm, factory_info
|
||||
|
||||
shared_config.sharding_config.sharding_dims = self.config.sharding_dims
|
||||
|
||||
ad_logger.info(
|
||||
f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}"
|
||||
)
|
||||
@ -338,8 +340,39 @@ def detect_sharding_from_factory_config(
|
||||
# TODO: Sequence parallelism is not supported yet.
|
||||
ad_logger.warning("Sequence parallelism is not supported yet. Skipping.")
|
||||
elif "local" in config:
|
||||
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
|
||||
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
|
||||
# Check if this applies to shared experts in EP parallelism.
|
||||
# If yes, apply the TP col-row shard.
|
||||
if "shared" in module_name:
|
||||
col_row_action = config.replace("local_", "")
|
||||
if col_row_action == "colwise":
|
||||
sharding_config.tp_transforms.append(
|
||||
TPShardingInfo(
|
||||
target_node=lin_node.name,
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dist_op=None,
|
||||
min_local_shape=min_local_shape,
|
||||
)
|
||||
)
|
||||
elif col_row_action == "rowwise":
|
||||
sharding_config.tp_transforms.append(
|
||||
TPShardingInfo(
|
||||
target_node=lin_node.name,
|
||||
split_dim=SplitDimension.ROW,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dist_op="all_reduce",
|
||||
min_local_shape=min_local_shape,
|
||||
)
|
||||
)
|
||||
num_row_col_shards += 1
|
||||
else:
|
||||
ad_logger.warning("Invalid sharding config. Skipping.")
|
||||
else:
|
||||
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
|
||||
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
|
||||
|
||||
elif "gather" in config:
|
||||
# Simple shard (row + all_gather)
|
||||
sharding_config.tp_transforms.append(
|
||||
@ -362,9 +395,35 @@ def detect_sharding_from_factory_config(
|
||||
f"Applied {num_shards} TP shards (simple: {num_simple_shards}, "
|
||||
f"row-col pattern: {num_row_col_shards})"
|
||||
)
|
||||
|
||||
num_matches = len(sharding_config.tp_transforms)
|
||||
|
||||
if sharding_config.support_partial_config:
|
||||
ad_logger.info(
|
||||
f"Partial factory config applied only for TP. "
|
||||
f"Applying heuristics for {sharding_config.sharding_dims}."
|
||||
)
|
||||
|
||||
# run EP sharding across ranks
|
||||
if "ep" in sharding_config.sharding_dims:
|
||||
ep_info = detect_ep_shard(gm, sharding_config)
|
||||
else:
|
||||
ep_info = TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
# run BMM sharding across ranks
|
||||
if "bmm" in sharding_config.sharding_dims:
|
||||
dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config)
|
||||
else:
|
||||
dp_bmm_info = TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
num_matches += ep_info.num_matches + dp_bmm_info.num_matches
|
||||
|
||||
return TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=len(sharding_config.tp_transforms),
|
||||
num_matches=num_matches,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
@ -737,6 +737,7 @@ class ShardingConfig(BaseModel):
|
||||
predefined_config: Optional[Dict[str, Any]] = None
|
||||
simple_shard_only: bool = Field(default=False)
|
||||
use_sharding_from_factory: bool = False
|
||||
support_partial_config: bool = False
|
||||
sharding_dims: List[str] = Field(default_factory=list)
|
||||
tp_transforms: List[TPShardingInfo] = Field(default_factory=list)
|
||||
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)
|
||||
@ -781,7 +782,7 @@ class ShardingConfig(BaseModel):
|
||||
tp_plan = self.predefined_config["tp_plan"]
|
||||
|
||||
values = set(tp_plan.values())
|
||||
allowed_values = {
|
||||
supported_modes = {
|
||||
"colwise", # row split and no collective
|
||||
"rowwise", # column split and all-reduce
|
||||
"gather", # simple shard (row + all_gather)
|
||||
@ -793,7 +794,7 @@ class ShardingConfig(BaseModel):
|
||||
# "local_packed_rowwise",
|
||||
# "local",
|
||||
}
|
||||
if not values.issubset(allowed_values):
|
||||
if not self.support_partial_config and not values.issubset(supported_modes):
|
||||
ad_logger.warning("Sharding config contains invalid values. Skipping.")
|
||||
# invalidate the config
|
||||
self.predefined_config = {}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user