TensorRT-LLMs/examples/enc_dec/download.py
Kaiyu Xie 75b6210ff4
Kaiyu/update main (#5)
* Update

* Update
2023-10-18 22:38:53 +08:00

19 lines
632 B
Python

import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer("translate English to German: The house is wonderful.",
return_tensors="pt").input_ids
outputs = model.generate(input_ids, decoder_input_ids=torch.IntTensor([[
0,
]]))
print("input", input_ids, "\noutput", outputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
torch.save(model.state_dict(), './models/t5_small.ckpt')
for k, v in model.state_dict().items():
print(k)