[#10838][fix] Add missing dist strategy param. fix typo for ad_logger… (#10892)

Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
This commit is contained in:
tcherckez-nvidia 2026-01-22 10:38:31 +02:00 committed by GitHub
parent 9ce0511d86
commit 6e72aff866
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 2 deletions

View File

@ -38,6 +38,7 @@ from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import add_graph_input
from ...utils.cuda_mem_tracker import get_mem_info_in_mb
from ...utils.logger import ad_logger
from ...utils.node_utils import is_op
from ..interface import (
BaseTransform,
@ -342,7 +343,7 @@ class ResizeKVCache(BaseTransform):
try:
mod(**cm.named_args)
except torch.OutOfMemoryError as e:
self.ad_logger.error(
ad_logger.error(
f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}"
)
raise e

View File

@ -1532,7 +1532,9 @@ def _insert_sharded_mxfp4_mlp_ep(
# Add a dist all-reduce after the op (sum partial results across EP ranks)
with gm.graph.inserting_after(node):
red = gm.graph.call_function(torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,))
red = gm.graph.call_function(
torch.ops.auto_deploy.torch_dist_all_reduce, args=(node, config.allreduce_strategy.name)
)
node.replace_all_uses_with(red)
# keep dataflow: red(input=node)
red.replace_input_with(red, node)