mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
parent
4fa9284612
commit
d47ac4e3e5
@ -1,7 +1,6 @@
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm
|
||||
@ -43,6 +42,9 @@ class DummyModel(torch.nn.Module):
|
||||
torch_dtype=dtype))
|
||||
self.recorded_position_ids = None
|
||||
|
||||
def infer_max_seq_len(self):
|
||||
return 2048
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.model_config.pretrained_config
|
||||
@ -133,7 +135,6 @@ def create_model_engine_and_kvcache(config: PyTorchConfig = None):
|
||||
return model_engine, kv_cache_manager
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5324248")
|
||||
class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
|
||||
def test_pad_generation_requests(self) -> None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user