mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Fix test Pytorch model engine (#5416)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
parent
d93a5e04b5
commit
846bbf1edc
@ -1,7 +1,6 @@
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm
|
||||
@ -43,6 +42,9 @@ class DummyModel(torch.nn.Module):
|
||||
torch_dtype=dtype))
|
||||
self.recorded_position_ids = None
|
||||
|
||||
def infer_max_seq_len(self):
|
||||
return 2048
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.model_config.pretrained_config
|
||||
@ -133,7 +135,6 @@ def create_model_engine_and_kvcache(config: PyTorchConfig = None):
|
||||
return model_engine, kv_cache_manager
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5324248")
|
||||
class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
|
||||
def test_pad_generation_requests(self) -> None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user