mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
This commit is contained in:
parent
9ce0511d86
commit
6e72aff866
@ -38,6 +38,7 @@ from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import add_graph_input
|
||||
from ...utils.cuda_mem_tracker import get_mem_info_in_mb
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import is_op
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
@ -342,7 +343,7 @@ class ResizeKVCache(BaseTransform):
|
||||
try:
|
||||
mod(**cm.named_args)
|
||||
except torch.OutOfMemoryError as e:
|
||||
self.ad_logger.error(
|
||||
ad_logger.error(
|
||||
f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
@ -1532,7 +1532,9 @@ def _insert_sharded_mxfp4_mlp_ep(
|
||||
|
||||
# Add a dist all-reduce after the op (sum partial results across EP ranks)
|
||||
with gm.graph.inserting_after(node):
|
||||
red = gm.graph.call_function(torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,))
|
||||
red = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.torch_dist_all_reduce, args=(node, config.allreduce_strategy.name)
|
||||
)
|
||||
node.replace_all_uses_with(red)
|
||||
# keep dataflow: red(input=node)
|
||||
red.replace_input_with(red, node)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user