mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
parent
4fa9284612
commit
d47ac4e3e5
@ -1,7 +1,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import tensorrt_llm
|
import tensorrt_llm
|
||||||
@ -43,6 +42,9 @@ class DummyModel(torch.nn.Module):
|
|||||||
torch_dtype=dtype))
|
torch_dtype=dtype))
|
||||||
self.recorded_position_ids = None
|
self.recorded_position_ids = None
|
||||||
|
|
||||||
|
def infer_max_seq_len(self):
|
||||||
|
return 2048
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self):
|
def config(self):
|
||||||
return self.model_config.pretrained_config
|
return self.model_config.pretrained_config
|
||||||
@ -133,7 +135,6 @@ def create_model_engine_and_kvcache(config: PyTorchConfig = None):
|
|||||||
return model_engine, kv_cache_manager
|
return model_engine, kv_cache_manager
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="https://nvbugs/5324248")
|
|
||||||
class PyTorchModelEngineTestCase(unittest.TestCase):
|
class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def test_pad_generation_requests(self) -> None:
|
def test_pad_generation_requests(self) -> None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user