mirror of
https://github.com/langgenius/dify.git
synced 2026-01-26 05:32:12 +08:00
110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
import logging
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import update
|
|
|
|
from configs import dify_config
|
|
from core.errors.error import QuotaExceededError
|
|
from extensions.ext_database import db
|
|
from models import TenantCreditPool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CreditPoolService:
|
|
@classmethod
|
|
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
|
|
"""create default credit pool for new tenant"""
|
|
credit_pool = TenantCreditPool(
|
|
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
|
|
)
|
|
db.session.add(credit_pool)
|
|
db.session.commit()
|
|
return credit_pool
|
|
|
|
@classmethod
|
|
def get_pool(cls, tenant_id: str) -> Optional[TenantCreditPool]:
|
|
"""get tenant credit pool"""
|
|
return (
|
|
db.session.query(TenantCreditPool)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
@classmethod
|
|
def get_or_create_pool(cls, tenant_id: str) -> TenantCreditPool:
|
|
"""get or create credit pool"""
|
|
# First try to get existing pool
|
|
pool = cls.get_pool(tenant_id)
|
|
if pool:
|
|
return pool
|
|
|
|
# Create new pool if not exists, handle race condition
|
|
try:
|
|
# Double-check in case another thread created it
|
|
pool = (
|
|
db.session.query(TenantCreditPool)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
)
|
|
.first()
|
|
)
|
|
if pool:
|
|
return pool
|
|
|
|
# Create new pool
|
|
pool = TenantCreditPool(
|
|
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
|
|
)
|
|
db.session.add(pool)
|
|
db.session.commit()
|
|
|
|
except Exception:
|
|
# If creation fails (e.g., due to race condition), rollback and try to get existing one
|
|
db.session.rollback()
|
|
pool = cls.get_pool(tenant_id)
|
|
if not pool:
|
|
raise
|
|
|
|
return pool
|
|
|
|
@classmethod
|
|
def check_and_deduct_credits(
|
|
cls,
|
|
tenant_id: str,
|
|
credits_required: int,
|
|
):
|
|
"""check and deduct credits"""
|
|
logger.info("check and deduct credits")
|
|
pool = cls.get_pool(tenant_id)
|
|
if not pool:
|
|
raise QuotaExceededError("Credit pool not found")
|
|
|
|
if pool.remaining_credits < credits_required:
|
|
raise QuotaExceededError(
|
|
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
|
|
)
|
|
|
|
with db.session.begin():
|
|
update_values = {"quota_used": pool.quota_used + credits_required}
|
|
|
|
where_conditions = [
|
|
TenantCreditPool.tenant_id == tenant_id,
|
|
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
|
|
]
|
|
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
|
|
db.session.execute(stmt)
|
|
|
|
@classmethod
|
|
def check_deduct_credits(cls, tenant_id: str, credits_required: int) -> bool:
|
|
"""check and deduct credits"""
|
|
pool = cls.get_pool(tenant_id)
|
|
if not pool:
|
|
return False
|
|
|
|
if pool.remaining_credits < credits_required:
|
|
return False
|
|
return True
|