mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
import json
|
|
import logging
|
|
import uuid
|
|
from datetime import UTC, datetime
|
|
from typing import Any, Generic, Protocol, TypeVar
|
|
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
from extensions.ext_redis import redis_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SessionStorage(Protocol):
|
|
"""Session storage interface."""
|
|
|
|
def get(self, key: str) -> str | None: ...
|
|
def set(self, key: str, value: str, ttl: int) -> None: ...
|
|
def delete(self, key: str) -> bool: ...
|
|
def exists(self, key: str) -> bool: ...
|
|
def refresh_ttl(self, key: str, ttl: int) -> bool: ...
|
|
|
|
|
|
class RedisSessionStorage:
|
|
"""Redis storage implementation (default)."""
|
|
|
|
def get(self, key: str) -> str | None:
|
|
result = redis_client.get(key)
|
|
if result is None:
|
|
return None
|
|
return result.decode() if isinstance(result, bytes) else result
|
|
|
|
def set(self, key: str, value: str, ttl: int) -> None:
|
|
redis_client.setex(key, ttl, value)
|
|
|
|
def delete(self, key: str) -> bool:
|
|
return redis_client.delete(key) > 0
|
|
|
|
def exists(self, key: str) -> bool:
|
|
return redis_client.exists(key) > 0
|
|
|
|
def refresh_ttl(self, key: str, ttl: int) -> bool:
|
|
return bool(redis_client.expire(key, ttl))
|
|
|
|
|
|
class BaseSession(BaseModel):
|
|
"""Base session model."""
|
|
|
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
tenant_id: str
|
|
user_id: str
|
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
|
context: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
def update_timestamp(self) -> None:
|
|
self.updated_at = datetime.now(UTC)
|
|
|
|
|
|
T = TypeVar("T", bound=BaseSession)
|
|
|
|
|
|
class SessionManager(Generic[T]):
|
|
"""Generic session manager."""
|
|
|
|
DEFAULT_TTL = 7200 # 2 hours
|
|
|
|
def __init__(
|
|
self,
|
|
key_prefix: str,
|
|
session_class: type[T],
|
|
storage: SessionStorage | None = None,
|
|
ttl: int | None = None,
|
|
):
|
|
self._key_prefix = key_prefix
|
|
self._session_class = session_class
|
|
self._storage = storage or RedisSessionStorage()
|
|
self._ttl = ttl or self.DEFAULT_TTL
|
|
|
|
def _get_key(self, session_id: str) -> str:
|
|
return f"{self._key_prefix}:{session_id}"
|
|
|
|
def save(self, session: T) -> None:
|
|
session.update_timestamp()
|
|
key = self._get_key(session.id)
|
|
self._storage.set(key, session.model_dump_json(), self._ttl)
|
|
|
|
def get(self, session_id: str) -> T | None:
|
|
key = self._get_key(session_id)
|
|
data = self._storage.get(key)
|
|
if data is None:
|
|
return None
|
|
try:
|
|
return self._session_class.model_validate(json.loads(data))
|
|
except (json.JSONDecodeError, ValidationError) as e:
|
|
logger.warning("Failed to deserialize session %s: %s", session_id, e)
|
|
return None
|
|
|
|
def delete(self, session_id: str) -> bool:
|
|
return self._storage.delete(self._get_key(session_id))
|
|
|
|
def exists(self, session_id: str) -> bool:
|
|
return self._storage.exists(self._get_key(session_id))
|
|
|
|
def refresh_ttl(self, session_id: str) -> bool:
|
|
return self._storage.refresh_ttl(self._get_key(session_id), self._ttl)
|