dify/api/core/session/session.py

107 lines
3.2 KiB
Python

import json
import logging
import uuid
from datetime import UTC, datetime
from typing import Any, Generic, Protocol, TypeVar
from pydantic import BaseModel, Field, ValidationError
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class SessionStorage(Protocol):
"""Session storage interface."""
def get(self, key: str) -> str | None: ...
def set(self, key: str, value: str, ttl: int) -> None: ...
def delete(self, key: str) -> bool: ...
def exists(self, key: str) -> bool: ...
def refresh_ttl(self, key: str, ttl: int) -> bool: ...
class RedisSessionStorage:
"""Redis storage implementation (default)."""
def get(self, key: str) -> str | None:
result = redis_client.get(key)
if result is None:
return None
return result.decode() if isinstance(result, bytes) else result
def set(self, key: str, value: str, ttl: int) -> None:
redis_client.setex(key, ttl, value)
def delete(self, key: str) -> bool:
return redis_client.delete(key) > 0
def exists(self, key: str) -> bool:
return redis_client.exists(key) > 0
def refresh_ttl(self, key: str, ttl: int) -> bool:
return bool(redis_client.expire(key, ttl))
class BaseSession(BaseModel):
"""Base session model."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
tenant_id: str
user_id: str
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
context: dict[str, Any] = Field(default_factory=dict)
def update_timestamp(self) -> None:
self.updated_at = datetime.now(UTC)
T = TypeVar("T", bound=BaseSession)
class SessionManager(Generic[T]):
"""Generic session manager."""
DEFAULT_TTL = 7200 # 2 hours
def __init__(
self,
key_prefix: str,
session_class: type[T],
storage: SessionStorage | None = None,
ttl: int | None = None,
):
self._key_prefix = key_prefix
self._session_class = session_class
self._storage = storage or RedisSessionStorage()
self._ttl = ttl or self.DEFAULT_TTL
def _get_key(self, session_id: str) -> str:
return f"{self._key_prefix}:{session_id}"
def save(self, session: T) -> None:
session.update_timestamp()
key = self._get_key(session.id)
self._storage.set(key, session.model_dump_json(), self._ttl)
def get(self, session_id: str) -> T | None:
key = self._get_key(session_id)
data = self._storage.get(key)
if data is None:
return None
try:
return self._session_class.model_validate(json.loads(data))
except (json.JSONDecodeError, ValidationError) as e:
logger.warning("Failed to deserialize session %s: %s", session_id, e)
return None
def delete(self, session_id: str) -> bool:
return self._storage.delete(self._get_key(session_id))
def exists(self, session_id: str) -> bool:
return self._storage.exists(self._get_key(session_id))
def refresh_ttl(self, session_id: str) -> bool:
return self._storage.refresh_ttl(self._get_key(session_id), self._ttl)