From 6e72aff86608ee54cf1a2da7fbd14d381f510a87 Mon Sep 17 00:00:00 2001 From: tcherckez-nvidia <127761168+tcherckez-nvidia@users.noreply.github.com> Date: Thu, 22 Jan 2026 10:38:31 +0200 Subject: [PATCH] =?UTF-8?q?[#10838][fix]=20Add=20missing=20dist=20strategy?= =?UTF-8?q?=20param.=20fix=20typo=20for=20ad=5Flogger=E2=80=A6=20(#10892)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py | 3 ++- tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 455e57b491..43ea2c9e01 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -38,6 +38,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils._graph import add_graph_input from ...utils.cuda_mem_tracker import get_mem_info_in_mb +from ...utils.logger import ad_logger from ...utils.node_utils import is_op from ..interface import ( BaseTransform, @@ -342,7 +343,7 @@ class ResizeKVCache(BaseTransform): try: mod(**cm.named_args) except torch.OutOfMemoryError as e: - self.ad_logger.error( + ad_logger.error( f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}" ) raise e diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index a0a1734fcc..85f0e24a4e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -1532,7 +1532,9 @@ def _insert_sharded_mxfp4_mlp_ep( # Add a dist all-reduce after the op (sum partial results across EP ranks) with gm.graph.inserting_after(node): - red = gm.graph.call_function(torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,)) + red = gm.graph.call_function( + torch.ops.auto_deploy.torch_dist_all_reduce, args=(node, config.allreduce_strategy.name) + ) node.replace_all_uses_with(red) # keep dataflow: red(input=node) red.replace_input_with(red, node)