mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-06-06 00:04:42 +00:00
bug fix
This commit is contained in:
@@ -2260,7 +2260,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):\n",
|
||||
"class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel, GenerationMixin):\n",
|
||||
" def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):\n",
|
||||
" super().__init__(config)\n",
|
||||
"\n",
|
||||
@@ -2273,11 +2273,10 @@
|
||||
" outputs: ModelOutput,\n",
|
||||
" model_kwargs: Dict[str, Any],\n",
|
||||
" is_encoder_decoder: bool = False,\n",
|
||||
" standardize_cache_format: bool = False,\n",
|
||||
" ) -> Dict[str, Any]:\n",
|
||||
" # 更新 past_key_values\n",
|
||||
" model_kwargs[\"past_key_values\"] = self._extract_past_from_model_output(\n",
|
||||
" outputs, standardize_cache_format=standardize_cache_format\n",
|
||||
" _, model_kwargs[\"past_key_values\"] = self._extract_past_from_model_output(\n",
|
||||
" outputs\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # 更新注意力掩码\n",
|
||||
@@ -2856,11 +2855,10 @@
|
||||
" outputs: ModelOutput,\n",
|
||||
" model_kwargs: Dict[str, Any],\n",
|
||||
" is_encoder_decoder: bool = False,\n",
|
||||
" standardize_cache_format: bool = False,\n",
|
||||
") -> Dict[str, Any]:\n",
|
||||
" # 更新 past_key_values\n",
|
||||
" model_kwargs[\"past_key_values\"] = self._extract_past_from_model_output(\n",
|
||||
" outputs, standardize_cache_format=standardize_cache_format\n",
|
||||
" _, model_kwargs[\"past_key_values\"] = self._extract_past_from_model_output(\n",
|
||||
" outputs\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # 更新注意力掩码\n",
|
||||
|
||||
@@ -23,7 +23,7 @@ from transformers.modeling_outputs import (
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput, GenerationMixin
|
||||
|
||||
from configuration_chatglm import ChatGLMConfig
|
||||
|
||||
@@ -793,7 +793,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel, GenerationMixin):
|
||||
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -806,11 +806,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
outputs: ModelOutput,
|
||||
model_kwargs: Dict[str, Any],
|
||||
is_encoder_decoder: bool = False,
|
||||
standardize_cache_format: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
# update past_key_values
|
||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||
outputs, standardize_cache_format=standardize_cache_format
|
||||
_, model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||
outputs
|
||||
)
|
||||
|
||||
# update attention mask
|
||||
|
||||
Reference in New Issue
Block a user