mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6342][fix] Fixed triggering BMM sharding (#7389)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
parent
c622f61609
commit
3755f8ab7d
@ -56,7 +56,7 @@ transforms:
|
||||
stage: sharding
|
||||
simple_shard_only: false
|
||||
use_sharding_from_factory: false
|
||||
sharding_dims: ['tp', 'ep', 'dp']
|
||||
sharding_dims: ['tp', 'ep', 'bmm']
|
||||
# TODO: (hg) need to ensure run_shape_prop after sharding.
|
||||
sharding_transform_executor:
|
||||
stage: sharding
|
||||
|
||||
@ -166,7 +166,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
)
|
||||
|
||||
sharding_dims: List[str] = Field(
|
||||
default=["tp", "ep", "dp"],
|
||||
default=["tp", "ep", "bmm"],
|
||||
description="The sharding methods to apply by the heuristic sharding stage.",
|
||||
)
|
||||
|
||||
|
||||
@ -67,6 +67,7 @@ def _run_job(
|
||||
"detect_sharding": {
|
||||
"stage": "sharding",
|
||||
"use_sharding_from_factory": False,
|
||||
"sharding_dims": ["bmm"],
|
||||
},
|
||||
"sharding_transform_executor": {
|
||||
"stage": "sharding",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user