llms-from-scratch-cn/Model_Architecture_Discussions/mamba/demo.ipynb
2024-05-31 17:04:42 +08:00

366 lines
10 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "531467a2-5160-4073-a990-0d81d574b014",
"metadata": {},
"source": [
"## (1) Load model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "621ebeea-475a-4917-af01-22b1ec948075",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: HF_ENDPOINT=https://hf-mirror.com\n"
]
}
],
"source": [
"%env HF_ENDPOINT=https://hf-mirror.com"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cac008ac-d314-4ebe-a3d4-f9454d7e3614",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Collecting einops\n",
" Downloading https://pypi.tuna.tsinghua.edu.cn/packages/44/5a/f0b9ad6c0a9017e62d4735daaeb11ba3b6c009d69a26141b258cd37b5588/einops-0.8.0-py3-none-any.whl (43 kB)\n",
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m395.8 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m-:--:--\u001b[0m\n",
"\u001b[?25hInstalling collected packages: einops\n",
"Successfully installed einops-0.8.0\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install einops"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d9337043-4e7a-4b20-9d89-6c6257245334",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "570f0993fa46465d8866e3190fde4a0e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/200 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "76b85f2e1c84414290c77815e848a03c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"pytorch_model.bin: 0%| | 0.00/1.49G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" return self.fget.__get__(instance, owner)()\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4851f6fb0db0428c8e1ff0dc9d147155",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/156 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "27ca8edc7934436a9350a3ee95d5590a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"vocab.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1c27d4a2251841cfbdce227ef4069ad2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"merges.txt: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a3ceb33a66c64dc9a7bb3a8ebc20463e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "407e741d1bb544bdb9540ac77c71dcfe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"special_tokens_map.json: 0%| | 0.00/90.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from model import Mamba, ModelArgs\n",
"from transformers import AutoTokenizer\n",
"\n",
"# One of:\n",
"# 'state-spaces/mamba-2.8b-slimpj'\n",
"# 'state-spaces/mamba-2.8b'\n",
"# 'state-spaces/mamba-1.4b'\n",
"# 'state-spaces/mamba-790m'\n",
"# 'state-spaces/mamba-370m'\n",
"# 'state-spaces/mamba-130m'\n",
"pretrained_model_name = 'state-spaces/mamba-370m'\n",
"\n",
"model = Mamba.from_pretrained(pretrained_model_name)\n",
"tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')"
]
},
{
"cell_type": "markdown",
"id": "0b2efb17-37ad-472b-b029-9567acf17629",
"metadata": {},
"source": [
"## (2) Generate Text"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c4b2d62d-0d95-4a3f-bd98-aa37e3f26b39",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"def generate(model,\n",
" tokenizer,\n",
" prompt: str,\n",
" n_tokens_to_gen: int = 50,\n",
" sample: bool = True,\n",
" top_k: int = 40):\n",
" model.eval()\n",
" \n",
" input_ids = tokenizer(prompt, return_tensors='pt').input_ids\n",
" \n",
" for token_n in range(n_tokens_to_gen):\n",
" with torch.no_grad():\n",
" indices_to_input = input_ids\n",
" next_token_logits = model(indices_to_input)[:, -1]\n",
" \n",
" probs = F.softmax(next_token_logits, dim=-1)\n",
" (batch, vocab_size) = probs.shape\n",
" \n",
" if top_k is not None:\n",
" (values, indices) = torch.topk(probs, k=top_k)\n",
" probs[probs < values[:, -1, None]] = 0\n",
" probs = probs / probs.sum(axis=1, keepdims=True)\n",
" \n",
" if sample:\n",
" next_indices = torch.multinomial(probs, num_samples=1)\n",
" else:\n",
" next_indices = torch.argmax(probs, dim=-1)[:, None]\n",
" \n",
" input_ids = torch.cat([input_ids, next_indices], dim=1)\n",
"\n",
" output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n",
" \n",
" return output_completions"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ee877143-2042-4579-8042-a96db6200517",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mamba is the new bestselling game from the popular mobile developers. This addictive and immersive platformer is a fun and addictive mobile game created by Nintendo that requires players to take control of a large robot monster in an environment where you control the movement of its legs and\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'Mamba is the'))"
]
},
{
"cell_type": "raw",
"id": "2c7642e7-6702-4687-8c4c-95f858625c8d",
"metadata": {},
"source": [
"上述这一段在cpu上用了114s"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "65d70549-597f-49ca-9185-2184d2576f7d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The organization helps students gain expertise to be able to lead data mining projects in companies. The organization has developed various applications to help a learning process of data mining and artificial intelligence. Moreover, data management and the use of Big Data can also be used as\n"
]
}
],
"source": [
"print(generate(model, tokenizer, '\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6d419fc9-066b-4818-812c-2f1952528bc6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The meaning of life is \n",
"just this: It is the best you can do.\n",
"\n",
"--K.J.\n",
"\n",
"And finally: How to handle your emotions. \n",
"\n",
"<|endoftext|>Q:\n",
"\n",
"Error creating an EntityManager instance in JavaEE 7\n",
"\n",
"This is\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'The meaning of life is '))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2b189e6e-6a96-4770-88cf-7c5de22cb321",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def reverse_string(text, result):\n",
" # find the position of the start of the string.\n",
" start = text.index(text[0:-1])\n",
" # find the position where the string begins changing.\n",
" end = text.index\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'def reverse_string('))"
]
},
{
"cell_type": "raw",
"id": "be40cccb-e0e3-40ca-ac87-90249f1bbe3f",
"metadata": {},
"source": [
"我感觉mamba速度还是赶rwkv差了不少340m参数的模型不及rwkv-1b6参数量的模型快"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}