[nvbug 5273941] fix: broken cyclic reference detect (#5417)

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
Yan Chunwei 2025-06-26 07:35:35 +08:00 committed by Zhenhuan Chen
parent be5ddb0533
commit ee7fcbf20e
4 changed files with 176 additions and 13 deletions

View File

@ -0,0 +1,80 @@
import gc
import inspect
import types
import weakref
from contextlib import contextmanager
from tensorrt_llm.logger import logger
@contextmanager
def assert_resource_freed(object_creation_func, *args, **kwargs):
"""
Create a resource via object_creation_func(*args, **kwargs),
disable the generational GC to force pure refcount freeing,
yield the resource, then assert that it was destroyed when the context exits.
If it wasnt freed, collect and report all remaining referrers.
"""
# Ensure a clean start
gc.collect()
gc_was_enabled = gc.isenabled()
gc.disable()
resource = object_creation_func(*args, **kwargs)
alive_ref = weakref.ref(resource)
try:
yield resource
finally:
# Drop our own strong reference
try:
del resource
except NameError:
pass
# If still alive, diagnose
leaked = alive_ref()
if leaked is not None:
# Restore GC so we can introspect
if gc_was_enabled:
gc.enable()
# Give GC a chance to finalize anything pending
gc.collect()
# Find all objects still referring to our instance
refs = gc.get_referrers(leaked)
# Filter out inspection internals (frames, tracebacks, the weakref itself, etc.)
filtered = []
for r in refs:
# skip the weakref container itself
if isinstance(r, dict) and any(
isinstance(v, weakref.ref) and v() is leaked
for v in r.values()):
continue
# skip our own local variables frame
if inspect.isframe(r):
continue
# skip the generators internal cell
if isinstance(r, types.CellType):
continue
filtered.append(r)
# Build a humanreadable report
report_lines = [
f" - {type(r).__name__} at 0x{id(r):x}: {repr(r)[:200]!r}"
for r in filtered
]
report = "\n".join(
report_lines) or " <no noninternal referrers found>"
if filtered:
raise AssertionError(
"Resource was NOT freed upon context exit!\n"
f"{len(filtered)} referrer(s) still alive:\n{report}\n")
else:
logger.info("Resource was freed upon context exit.")
# Otherwise, restore GC state
if gc_was_enabled:
gc.enable()

View File

@ -0,0 +1,86 @@
import gc
import os
import sys
import unittest
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
from gc_utils import assert_resource_freed
# A global list to simulate a leak
LEAKY_HOLD = []
class TestObject:
def __init__(self, value):
self.value = value
class LeakyObject(TestObject):
def __init__(self, value):
super().__init__(value)
LEAKY_HOLD.append(self)
class TestAssertResourceFreed(unittest.TestCase):
def setUp(self):
# Clear any previous leaks and force a clean GC
LEAKY_HOLD.clear()
gc.collect()
def test_simple_object_freed(self):
"""A plain TestObject should be freed without error."""
def factory():
return TestObject("foo")
# generatorbased
with assert_resource_freed(factory) as obj:
self.assertEqual(obj.value, "foo")
# classbased
with assert_resource_freed(factory) as obj2:
self.assertEqual(obj2.value, "foo")
def test_leaky_object_raises(self):
"""LeakyObject holds itself in a global list → should raise."""
def factory():
return LeakyObject("bar")
with self.assertRaises(AssertionError) as cm:
with assert_resource_freed(factory):
pass
def test_diagnostic_message(self):
"""Check that the AssertionError message includes a count and type."""
def factory():
return LeakyObject("baz")
with self.assertRaises(AssertionError) as cm:
with assert_resource_freed(factory):
pass
msg = str(cm.exception)
# e.g. "1 referrer(s) still alive"
self.assertRegex(msg, r"\d+\s+referrer")
# and something like "- list at 0x"
self.assertIn("list", msg)
def test_no_false_positive_from_generator_cell(self):
"""Ensure that our filter skips the internal cell, so no leak is reported."""
def factory():
return TestObject("qux")
# If the internal cell werent filtered, this would raise—
# so no exception means our filter worked.
with assert_resource_freed(factory):
pass
if __name__ == "__main__":
unittest.main()

View File

@ -1,8 +1,11 @@
import asyncio
import datetime
import gc
import json
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
from gc_utils import assert_resource_freed
# Required for test_generate_with_seed to pass.
# See the discussion in https://github.com/NVIDIA/TensorRT-LLM/pull/4264#issuecomment-2943269891
@ -120,17 +123,13 @@ def llm_test_harness(model_dir: str,
if tokenizer is None:
tokenizer = model_dir
if backend == "pytorch":
llm = LLM_torch(model_dir, tokenizer=tokenizer, **llm_kwargs)
else:
llm = LLM(model_dir, tokenizer=tokenizer, **llm_kwargs)
outputs = llm.generate(inputs, sampling_params=sampling_params)
print(outputs)
check_output(outputs, references, similar_threshold=similar_threshold)
llm_cls = LLM_torch if backend == "pytorch" else LLM
assert gc.is_tracked(llm)
assert len(
gc.get_referrers(llm)) == 0, f"the references: {gc.get_referrers(llm)}"
with assert_resource_freed(llm_cls, model_dir, tokenizer,
**llm_kwargs) as llm:
outputs = llm.generate(inputs, sampling_params=sampling_params)
print(outputs)
check_output(outputs, references, similar_threshold=similar_threshold)
def llm_check_output(llm: LLM,
@ -237,7 +236,6 @@ def test_llm_loading_from_hf():
kv_cache_config=global_kvcache_config)
@pytest.mark.skip(reason="https://nvbugs/5266240")
@force_ampere
@pytest.mark.part0
def test_llm_loading_from_ckpt():

View File

@ -73,7 +73,6 @@ def engine_from_checkpoint() -> tempfile.TemporaryDirectory:
return tmpdir
@pytest.mark.skip(reason="https://nvbugs/5266240")
@pytest.mark.gpu2
@pytest.mark.part0
def test_llm_loading_from_ckpt_for_tp2(