mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[nvbug 5273941] fix: broken cyclic reference detect (#5417)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
parent
be5ddb0533
commit
ee7fcbf20e
80
tests/unittest/gc_utils.py
Normal file
80
tests/unittest/gc_utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
import gc
|
||||
import inspect
|
||||
import types
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
@contextmanager
|
||||
def assert_resource_freed(object_creation_func, *args, **kwargs):
|
||||
"""
|
||||
Create a resource via object_creation_func(*args, **kwargs),
|
||||
disable the generational GC to force pure refcount freeing,
|
||||
yield the resource, then assert that it was destroyed when the context exits.
|
||||
If it wasn’t freed, collect and report all remaining referrers.
|
||||
"""
|
||||
# Ensure a clean start
|
||||
gc.collect()
|
||||
gc_was_enabled = gc.isenabled()
|
||||
gc.disable()
|
||||
|
||||
resource = object_creation_func(*args, **kwargs)
|
||||
alive_ref = weakref.ref(resource)
|
||||
|
||||
try:
|
||||
yield resource
|
||||
|
||||
finally:
|
||||
# Drop our own strong reference
|
||||
try:
|
||||
del resource
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
# If still alive, diagnose
|
||||
leaked = alive_ref()
|
||||
if leaked is not None:
|
||||
# Restore GC so we can introspect
|
||||
if gc_was_enabled:
|
||||
gc.enable()
|
||||
# Give GC a chance to finalize anything pending
|
||||
gc.collect()
|
||||
|
||||
# Find all objects still referring to our instance
|
||||
refs = gc.get_referrers(leaked)
|
||||
# Filter out inspection internals (frames, tracebacks, the weakref itself, etc.)
|
||||
filtered = []
|
||||
for r in refs:
|
||||
# skip the weakref container itself
|
||||
if isinstance(r, dict) and any(
|
||||
isinstance(v, weakref.ref) and v() is leaked
|
||||
for v in r.values()):
|
||||
continue
|
||||
# skip our own local variables frame
|
||||
if inspect.isframe(r):
|
||||
continue
|
||||
# skip the generator’s internal cell
|
||||
if isinstance(r, types.CellType):
|
||||
continue
|
||||
filtered.append(r)
|
||||
|
||||
# Build a human‐readable report
|
||||
report_lines = [
|
||||
f" - {type(r).__name__} at 0x{id(r):x}: {repr(r)[:200]!r}"
|
||||
for r in filtered
|
||||
]
|
||||
report = "\n".join(
|
||||
report_lines) or " <no non‐internal referrers found>"
|
||||
|
||||
if filtered:
|
||||
raise AssertionError(
|
||||
"Resource was NOT freed upon context exit!\n"
|
||||
f"{len(filtered)} referrer(s) still alive:\n{report}\n")
|
||||
else:
|
||||
logger.info("Resource was freed upon context exit.")
|
||||
|
||||
# Otherwise, restore GC state
|
||||
if gc_was_enabled:
|
||||
gc.enable()
|
||||
86
tests/unittest/llmapi/test_gc_utils.py
Normal file
86
tests/unittest/llmapi/test_gc_utils.py
Normal file
@ -0,0 +1,86 @@
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
|
||||
from gc_utils import assert_resource_freed
|
||||
|
||||
# A global list to simulate a leak
|
||||
LEAKY_HOLD = []
|
||||
|
||||
|
||||
class TestObject:
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
class LeakyObject(TestObject):
|
||||
|
||||
def __init__(self, value):
|
||||
super().__init__(value)
|
||||
LEAKY_HOLD.append(self)
|
||||
|
||||
|
||||
class TestAssertResourceFreed(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Clear any previous leaks and force a clean GC
|
||||
LEAKY_HOLD.clear()
|
||||
gc.collect()
|
||||
|
||||
def test_simple_object_freed(self):
|
||||
"""A plain TestObject should be freed without error."""
|
||||
|
||||
def factory():
|
||||
return TestObject("foo")
|
||||
|
||||
# generator‐based
|
||||
with assert_resource_freed(factory) as obj:
|
||||
self.assertEqual(obj.value, "foo")
|
||||
|
||||
# class‐based
|
||||
with assert_resource_freed(factory) as obj2:
|
||||
self.assertEqual(obj2.value, "foo")
|
||||
|
||||
def test_leaky_object_raises(self):
|
||||
"""LeakyObject holds itself in a global list → should raise."""
|
||||
|
||||
def factory():
|
||||
return LeakyObject("bar")
|
||||
|
||||
with self.assertRaises(AssertionError) as cm:
|
||||
with assert_resource_freed(factory):
|
||||
pass
|
||||
|
||||
def test_diagnostic_message(self):
|
||||
"""Check that the AssertionError message includes a count and type."""
|
||||
|
||||
def factory():
|
||||
return LeakyObject("baz")
|
||||
|
||||
with self.assertRaises(AssertionError) as cm:
|
||||
with assert_resource_freed(factory):
|
||||
pass
|
||||
|
||||
msg = str(cm.exception)
|
||||
# e.g. "1 referrer(s) still alive"
|
||||
self.assertRegex(msg, r"\d+\s+referrer")
|
||||
# and something like "- list at 0x"
|
||||
self.assertIn("list", msg)
|
||||
|
||||
def test_no_false_positive_from_generator_cell(self):
|
||||
"""Ensure that our filter skips the internal cell, so no leak is reported."""
|
||||
|
||||
def factory():
|
||||
return TestObject("qux")
|
||||
|
||||
# If the internal cell weren’t filtered, this would raise—
|
||||
# so no exception means our filter worked.
|
||||
with assert_resource_freed(factory):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,8 +1,11 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
|
||||
from gc_utils import assert_resource_freed
|
||||
|
||||
# Required for test_generate_with_seed to pass.
|
||||
# See the discussion in https://github.com/NVIDIA/TensorRT-LLM/pull/4264#issuecomment-2943269891
|
||||
@ -120,17 +123,13 @@ def llm_test_harness(model_dir: str,
|
||||
if tokenizer is None:
|
||||
tokenizer = model_dir
|
||||
|
||||
if backend == "pytorch":
|
||||
llm = LLM_torch(model_dir, tokenizer=tokenizer, **llm_kwargs)
|
||||
else:
|
||||
llm = LLM(model_dir, tokenizer=tokenizer, **llm_kwargs)
|
||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
||||
print(outputs)
|
||||
check_output(outputs, references, similar_threshold=similar_threshold)
|
||||
llm_cls = LLM_torch if backend == "pytorch" else LLM
|
||||
|
||||
assert gc.is_tracked(llm)
|
||||
assert len(
|
||||
gc.get_referrers(llm)) == 0, f"the references: {gc.get_referrers(llm)}"
|
||||
with assert_resource_freed(llm_cls, model_dir, tokenizer,
|
||||
**llm_kwargs) as llm:
|
||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
||||
print(outputs)
|
||||
check_output(outputs, references, similar_threshold=similar_threshold)
|
||||
|
||||
|
||||
def llm_check_output(llm: LLM,
|
||||
@ -237,7 +236,6 @@ def test_llm_loading_from_hf():
|
||||
kv_cache_config=global_kvcache_config)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5266240")
|
||||
@force_ampere
|
||||
@pytest.mark.part0
|
||||
def test_llm_loading_from_ckpt():
|
||||
|
||||
@ -73,7 +73,6 @@ def engine_from_checkpoint() -> tempfile.TemporaryDirectory:
|
||||
return tmpdir
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5266240")
|
||||
@pytest.mark.gpu2
|
||||
@pytest.mark.part0
|
||||
def test_llm_loading_from_ckpt_for_tp2(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user