From 5e50570739060640f89e9cc0817320bad64b0c40 Mon Sep 17 00:00:00 2001 From: GareArc Date: Mon, 7 Apr 2025 18:41:02 -0400 Subject: [PATCH] fix: update webapp jwt claim and add user accessibility support --- api/controllers/web/login.py | 19 +++++++++--- api/services/webapp_auth_service.py | 48 +++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 235fcaf8cc..955c781989 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -5,6 +5,7 @@ from flask import request from flask_restful import Resource, reqparse from jwt import InvalidTokenError # type: ignore from web import api +from werkzeug.exceptions import BadRequest import services from controllers.console.auth.error import (EmailCodeError, @@ -16,7 +17,7 @@ from libs.helper import email from libs.password import valid_password from models.account import Account from services.account_service import AccountService -from services.webapp_auth_service import Unauthorized, WebAppAuthService +from services.webapp_auth_service import WebAppAuthService class LoginApi(Resource): @@ -31,7 +32,7 @@ class LoginApi(Resource): app_code = request.headers.get("X-App-Code") if app_code is None: - raise Unauthorized("X-App-Code header is missing.") + raise BadRequest("X-App-Code header is missing.") try: account = WebAppAuthService.authenticate(args["email"], args["password"]) @@ -42,7 +43,11 @@ class LoginApi(Resource): except services.errors.account.AccountNotFoundError: raise AccountNotFound() - token = WebAppAuthService.login(account=account, app_code=app_code) + WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) + + end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code) + + token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) return {"result": "success", "token": token} @@ -90,7 +95,7 @@ class EmailCodeLoginApi(Resource): user_email = args["email"] app_code = request.headers.get("X-App-Code") if app_code is None: - raise Unauthorized("X-App-Code header is missing.") + raise BadRequest("X-App-Code header is missing.") token_data = WebAppAuthService.get_email_code_login_data(args["token"]) if token_data is None: @@ -107,7 +112,11 @@ class EmailCodeLoginApi(Resource): if not account: raise AccountNotFound() - token = WebAppAuthService.login(account=account, app_code=app_code) + WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) + + end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code) + + token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) AccountService.reset_login_error_rate_limit(args["email"]) return {"result": "success", "token": token} diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 65501bbffa..24d1177d87 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -5,14 +5,18 @@ from typing import Any, Optional, cast from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config +from controllers.web.error import (WebAppAuthFailedError, + WebAppAuthRequiredError) from extensions.ext_database import db from libs.helper import TokenManager from libs.passport import PassportService from libs.password import compare_password from models.account import Account, AccountStatus -from models.model import Site +from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.errors.account import (AccountLoginError, AccountNotFoundError, AccountPasswordError) +from services.feature_service import FeatureService from tasks.mail_email_code_login import send_email_code_login_mail_task @@ -35,13 +39,13 @@ class WebAppAuthService: return cast(Account, account) - @staticmethod - def login(account: Account, app_code: str) -> str: + @classmethod + def login(cls, account: Account, app_code: str, end_user_id: str) -> str: site = db.session.query(Site).filter(Site.code == app_code).first() if not site: raise NotFound("Site not found.") - access_token = WebAppAuthService._get_account_jwt_token(account=account, site=site) + access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id) return access_token @@ -84,8 +88,39 @@ class WebAppAuthService: def revoke_email_code_login_token(cls, token: str): TokenManager.revoke_token(token, "webapp_email_code_login") - @staticmethod - def _get_account_jwt_token(account: Account, site: Site) -> str: + @classmethod + def create_end_user(cls, app_code, email) -> EndUser: + site = db.session.query(Site).filter(Site.code == app_code).first() + app_model = db.session.query(App).filter(App.id == site.app_id).first() + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type="browser", + is_anonymous=False, + session_id=email, + name="enterpriseuser", + external_user_id="enterpriseuser" + ) + db.session.add(end_user) + db.session.commit() + + return end_user + + @classmethod + def _validate_user_accessibility(cls, account: Account, app_code: str): + """Check if the user is allowed to access the app.""" + system_features = FeatureService.get_system_features() + if system_features.webapp_auth.enabled: + app_settings = EnterpriseService.get_web_app_settings(app_code=app_code) + if not app_settings or not app_settings.access_mode == "public": + raise WebAppAuthRequiredError() + if app_settings.access_mode == "private" and not EnterpriseService.is_user_allowed_to_access_webapp( + account.id, app_code=app_code + ): + raise WebAppAuthFailedError() + + @classmethod + def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str: exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.WebAppSessionTimeoutInHours * 24) exp = int(exp_dt.timestamp()) @@ -95,6 +130,7 @@ class WebAppAuthService: "app_id": site.app_id, "app_code": site.code, "user_id": account.id, + "end_user_id": end_user_id, "token_source": "webapp", "exp": exp, }