From f32e176d6a38a52f9c491134050172bd1f7eb3f1 Mon Sep 17 00:00:00 2001 From: "Junyan Qin (Chin)" Date: Fri, 29 Aug 2025 14:10:51 +0800 Subject: [PATCH] feat: oauth provider (#24206) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: yessenia --- api/controllers/console/__init__.py | 2 +- api/controllers/console/auth/oauth_server.py | 189 ++++++++++++++++ ...47-8d289573e1da_add_oauth_provider_apps.py | 45 ++++ api/models/model.py | 26 +++ api/services/oauth_server.py | 94 ++++++++ .../account-page/AvatarWithEdit.tsx | 0 .../account-page/email-change-modal.tsx | 0 .../account-page/index.tsx | 0 .../account/{ => (commonLayout)}/avatar.tsx | 0 .../delete-account/components/check-email.tsx | 0 .../delete-account/components/feed-back.tsx | 0 .../components/verify-email.tsx | 0 .../delete-account/index.tsx | 0 .../delete-account/state.tsx | 0 .../account/{ => (commonLayout)}/header.tsx | 4 +- .../account/{ => (commonLayout)}/layout.tsx | 0 web/app/account/{ => (commonLayout)}/page.tsx | 0 web/app/account/oauth/authorize/layout.tsx | 37 ++++ web/app/account/oauth/authorize/page.tsx | 205 ++++++++++++++++++ web/app/components/base/toast/index.tsx | 9 +- web/app/components/swr-initializer.tsx | 7 +- web/app/signin/check-code/page.tsx | 9 +- .../components/mail-and-password-auth.tsx | 4 +- web/app/signin/invite-settings/page.tsx | 4 +- web/app/signin/layout.tsx | 2 +- web/app/signin/normal-form.tsx | 4 +- web/app/signin/utils/post-login-redirect.ts | 36 +++ web/context/app-context.tsx | 18 +- web/i18n-config/i18next-config.ts | 1 + web/i18n/en-US/oauth.ts | 27 +++ web/i18n/zh-Hans/oauth.ts | 27 +++ web/service/use-oauth.ts | 29 +++ 32 files changed, 757 insertions(+), 22 deletions(-) create mode 100644 api/controllers/console/auth/oauth_server.py create mode 100644 api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py create mode 100644 api/services/oauth_server.py rename web/app/account/{ => (commonLayout)}/account-page/AvatarWithEdit.tsx (100%) rename web/app/account/{ => (commonLayout)}/account-page/email-change-modal.tsx (100%) rename web/app/account/{ => (commonLayout)}/account-page/index.tsx (100%) rename web/app/account/{ => (commonLayout)}/avatar.tsx (100%) rename web/app/account/{ => (commonLayout)}/delete-account/components/check-email.tsx (100%) rename web/app/account/{ => (commonLayout)}/delete-account/components/feed-back.tsx (100%) rename web/app/account/{ => (commonLayout)}/delete-account/components/verify-email.tsx (100%) rename web/app/account/{ => (commonLayout)}/delete-account/index.tsx (100%) rename web/app/account/{ => (commonLayout)}/delete-account/state.tsx (100%) rename web/app/account/{ => (commonLayout)}/header.tsx (97%) rename web/app/account/{ => (commonLayout)}/layout.tsx (100%) rename web/app/account/{ => (commonLayout)}/page.tsx (100%) create mode 100644 web/app/account/oauth/authorize/layout.tsx create mode 100644 web/app/account/oauth/authorize/page.tsx create mode 100644 web/app/signin/utils/post-login-redirect.ts create mode 100644 web/i18n/en-US/oauth.ts create mode 100644 web/i18n/zh-Hans/oauth.ts create mode 100644 web/service/use-oauth.ts diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index e25f92399c..5ad7645969 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -70,7 +70,7 @@ from .app import ( ) # Import auth controllers -from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth +from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server # Import billing controllers from .billing import billing, compliance diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py new file mode 100644 index 0000000000..19ca464a79 --- /dev/null +++ b/api/controllers/console/auth/oauth_server.py @@ -0,0 +1,189 @@ +from functools import wraps +from typing import cast + +import flask_login +from flask import request +from flask_restx import Resource, reqparse +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.console.wraps import account_initialization_required, setup_required +from core.model_runtime.utils.encoders import jsonable_encoder +from libs.login import login_required +from models.account import Account +from models.model import OAuthProviderApp +from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService + +from .. import api + + +def oauth_server_client_id_required(view): + @wraps(view) + def decorated(*args, **kwargs): + parser = reqparse.RequestParser() + parser.add_argument("client_id", type=str, required=True, location="json") + parsed_args = parser.parse_args() + client_id = parsed_args.get("client_id") + if not client_id: + raise BadRequest("client_id is required") + + oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id) + if not oauth_provider_app: + raise NotFound("client_id is invalid") + + kwargs["oauth_provider_app"] = oauth_provider_app + + return view(*args, **kwargs) + + return decorated + + +def oauth_server_access_token_required(view): + @wraps(view) + def decorated(*args, **kwargs): + oauth_provider_app = kwargs.get("oauth_provider_app") + if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp): + raise BadRequest("Invalid oauth_provider_app") + + if not request.headers.get("Authorization"): + raise BadRequest("Authorization is required") + + authorization_header = request.headers.get("Authorization") + if not authorization_header: + raise BadRequest("Authorization header is required") + + parts = authorization_header.split(" ") + if len(parts) != 2: + raise BadRequest("Invalid Authorization header format") + + token_type = parts[0] + if token_type != "Bearer": + raise BadRequest("token_type is invalid") + + access_token = parts[1] + if not access_token: + raise BadRequest("access_token is required") + + account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token) + if not account: + raise BadRequest("access_token or client_id is invalid") + + kwargs["account"] = account + + return view(*args, **kwargs) + + return decorated + + +class OAuthServerAppApi(Resource): + @setup_required + @oauth_server_client_id_required + def post(self, oauth_provider_app: OAuthProviderApp): + parser = reqparse.RequestParser() + parser.add_argument("redirect_uri", type=str, required=True, location="json") + parsed_args = parser.parse_args() + redirect_uri = parsed_args.get("redirect_uri") + + # check if redirect_uri is valid + if redirect_uri not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") + + return jsonable_encoder( + { + "app_icon": oauth_provider_app.app_icon, + "app_label": oauth_provider_app.app_label, + "scope": oauth_provider_app.scope, + } + ) + + +class OAuthServerUserAuthorizeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @oauth_server_client_id_required + def post(self, oauth_provider_app: OAuthProviderApp): + account = cast(Account, flask_login.current_user) + user_account_id = account.id + + code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id) + return jsonable_encoder( + { + "code": code, + } + ) + + +class OAuthServerUserTokenApi(Resource): + @setup_required + @oauth_server_client_id_required + def post(self, oauth_provider_app: OAuthProviderApp): + parser = reqparse.RequestParser() + parser.add_argument("grant_type", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=False, location="json") + parser.add_argument("client_secret", type=str, required=False, location="json") + parser.add_argument("redirect_uri", type=str, required=False, location="json") + parser.add_argument("refresh_token", type=str, required=False, location="json") + parsed_args = parser.parse_args() + + grant_type = OAuthGrantType(parsed_args["grant_type"]) + + if grant_type == OAuthGrantType.AUTHORIZATION_CODE: + if not parsed_args["code"]: + raise BadRequest("code is required") + + if parsed_args["client_secret"] != oauth_provider_app.client_secret: + raise BadRequest("client_secret is invalid") + + if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + elif grant_type == OAuthGrantType.REFRESH_TOKEN: + if not parsed_args["refresh_token"]: + raise BadRequest("refresh_token is required") + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + else: + raise BadRequest("invalid grant_type") + + +class OAuthServerUserAccountApi(Resource): + @setup_required + @oauth_server_client_id_required + @oauth_server_access_token_required + def post(self, oauth_provider_app: OAuthProviderApp, account: Account): + return jsonable_encoder( + { + "name": account.name, + "email": account.email, + "avatar": account.avatar, + "interface_language": account.interface_language, + "timezone": account.timezone, + } + ) + + +api.add_resource(OAuthServerAppApi, "/oauth/provider") +api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize") +api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token") +api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account") diff --git a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py new file mode 100644 index 0000000000..5986853f01 --- /dev/null +++ b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py @@ -0,0 +1,45 @@ +"""empty message + +Revision ID: 8d289573e1da +Revises: fa8b0fa6f407 +Create Date: 2025-08-20 17:47:17.015695 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '8d289573e1da' +down_revision = '0e154742a5fa' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('oauth_provider_apps', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_icon', sa.String(length=255), nullable=False), + sa.Column('app_label', sa.JSON(), server_default='{}', nullable=False), + sa.Column('client_id', sa.String(length=255), nullable=False), + sa.Column('client_secret', sa.String(length=255), nullable=False), + sa.Column('redirect_uris', sa.JSON(), server_default='[]', nullable=False), + sa.Column('scope', sa.String(length=255), server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='oauth_provider_app_pkey') + ) + with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op: + batch_op.create_index('oauth_provider_app_client_id_idx', ['client_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op: + batch_op.drop_index('oauth_provider_app_client_id_idx') + + op.drop_table('oauth_provider_apps') + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 53646c0155..6a0e0af482 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -580,6 +580,32 @@ class InstalledApp(Base): return tenant +class OAuthProviderApp(Base): + """ + Globally shared OAuth provider app information. + Only for Dify Cloud. + """ + + __tablename__ = "oauth_provider_apps" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="oauth_provider_app_pkey"), + sa.Index("oauth_provider_app_client_id_idx", "client_id"), + ) + + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + app_icon = mapped_column(String(255), nullable=False) + app_label = mapped_column(sa.JSON, nullable=False, server_default="{}") + client_id = mapped_column(String(255), nullable=False) + client_secret = mapped_column(String(255), nullable=False) + redirect_uris = mapped_column(sa.JSON, nullable=False, server_default="[]") + scope = mapped_column( + String(255), + nullable=False, + server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), + ) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")) + + class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( diff --git a/api/services/oauth_server.py b/api/services/oauth_server.py new file mode 100644 index 0000000000..b722dbee22 --- /dev/null +++ b/api/services/oauth_server.py @@ -0,0 +1,94 @@ +import enum +import uuid + +from sqlalchemy import select +from sqlalchemy.orm import Session +from werkzeug.exceptions import BadRequest + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account +from models.model import OAuthProviderApp +from services.account_service import AccountService + + +class OAuthGrantType(enum.StrEnum): + AUTHORIZATION_CODE = "authorization_code" + REFRESH_TOKEN = "refresh_token" + + +OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}" +OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}" +OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours +OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}" +OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days + + +class OAuthServerService: + @staticmethod + def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None: + query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id) + + with Session(db.engine) as session: + return session.execute(query).scalar_one_or_none() + + @staticmethod + def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str: + code = str(uuid.uuid4()) + redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code) + redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes + return code + + @staticmethod + def sign_oauth_access_token( + grant_type: OAuthGrantType, + code: str = "", + client_id: str = "", + refresh_token: str = "", + ) -> tuple[str, str]: + match grant_type: + case OAuthGrantType.AUTHORIZATION_CODE: + redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code) + user_account_id = redis_client.get(redis_key) + if not user_account_id: + raise BadRequest("invalid code") + + # delete code + redis_client.delete(redis_key) + + access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id) + refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id) + return access_token, refresh_token + case OAuthGrantType.REFRESH_TOKEN: + redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token) + user_account_id = redis_client.get(redis_key) + if not user_account_id: + raise BadRequest("invalid refresh token") + + access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id) + return access_token, refresh_token + + @staticmethod + def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str: + token = str(uuid.uuid4()) + redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token) + redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN) + return token + + @staticmethod + def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str: + token = str(uuid.uuid4()) + redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token) + redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN) + return token + + @staticmethod + def validate_oauth_access_token(client_id: str, token: str) -> Account | None: + redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token) + user_account_id = redis_client.get(redis_key) + if not user_account_id: + return None + + user_id_str = user_account_id.decode("utf-8") + + return AccountService.load_user(user_id_str) diff --git a/web/app/account/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx similarity index 100% rename from web/app/account/account-page/AvatarWithEdit.tsx rename to web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx diff --git a/web/app/account/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx similarity index 100% rename from web/app/account/account-page/email-change-modal.tsx rename to web/app/account/(commonLayout)/account-page/email-change-modal.tsx diff --git a/web/app/account/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx similarity index 100% rename from web/app/account/account-page/index.tsx rename to web/app/account/(commonLayout)/account-page/index.tsx diff --git a/web/app/account/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx similarity index 100% rename from web/app/account/avatar.tsx rename to web/app/account/(commonLayout)/avatar.tsx diff --git a/web/app/account/delete-account/components/check-email.tsx b/web/app/account/(commonLayout)/delete-account/components/check-email.tsx similarity index 100% rename from web/app/account/delete-account/components/check-email.tsx rename to web/app/account/(commonLayout)/delete-account/components/check-email.tsx diff --git a/web/app/account/delete-account/components/feed-back.tsx b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx similarity index 100% rename from web/app/account/delete-account/components/feed-back.tsx rename to web/app/account/(commonLayout)/delete-account/components/feed-back.tsx diff --git a/web/app/account/delete-account/components/verify-email.tsx b/web/app/account/(commonLayout)/delete-account/components/verify-email.tsx similarity index 100% rename from web/app/account/delete-account/components/verify-email.tsx rename to web/app/account/(commonLayout)/delete-account/components/verify-email.tsx diff --git a/web/app/account/delete-account/index.tsx b/web/app/account/(commonLayout)/delete-account/index.tsx similarity index 100% rename from web/app/account/delete-account/index.tsx rename to web/app/account/(commonLayout)/delete-account/index.tsx diff --git a/web/app/account/delete-account/state.tsx b/web/app/account/(commonLayout)/delete-account/state.tsx similarity index 100% rename from web/app/account/delete-account/state.tsx rename to web/app/account/(commonLayout)/delete-account/state.tsx diff --git a/web/app/account/header.tsx b/web/app/account/(commonLayout)/header.tsx similarity index 97% rename from web/app/account/header.tsx rename to web/app/account/(commonLayout)/header.tsx index af09ca1c9c..ce804055b5 100644 --- a/web/app/account/header.tsx +++ b/web/app/account/(commonLayout)/header.tsx @@ -2,11 +2,11 @@ import { useTranslation } from 'react-i18next' import { RiArrowRightUpLine, RiRobot2Line } from '@remixicon/react' import { useRouter } from 'next/navigation' -import Button from '../components/base/button' -import Avatar from './avatar' +import Button from '@/app/components/base/button' import DifyLogo from '@/app/components/base/logo/dify-logo' import { useCallback } from 'react' import { useGlobalPublicStore } from '@/context/global-public-context' +import Avatar from './avatar' const Header = () => { const { t } = useTranslation() diff --git a/web/app/account/layout.tsx b/web/app/account/(commonLayout)/layout.tsx similarity index 100% rename from web/app/account/layout.tsx rename to web/app/account/(commonLayout)/layout.tsx diff --git a/web/app/account/page.tsx b/web/app/account/(commonLayout)/page.tsx similarity index 100% rename from web/app/account/page.tsx rename to web/app/account/(commonLayout)/page.tsx diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx new file mode 100644 index 0000000000..078d23114a --- /dev/null +++ b/web/app/account/oauth/authorize/layout.tsx @@ -0,0 +1,37 @@ +'use client' +import Header from '@/app/signin/_header' + +import cn from '@/utils/classnames' +import { useGlobalPublicStore } from '@/context/global-public-context' +import useDocumentTitle from '@/hooks/use-document-title' +import { AppContextProvider } from '@/context/app-context' +import { useMemo } from 'react' + +export default function SignInLayout({ children }: any) { + const { systemFeatures } = useGlobalPublicStore() + useDocumentTitle('') + const isLoggedIn = useMemo(() => { + try { + return Boolean(localStorage.getItem('console_token') && localStorage.getItem('refresh_token')) + } + catch { return false } + }, []) + return <> +
+
+
+
+
+ {isLoggedIn ? + {children} + + : children} +
+
+ {systemFeatures.branding.enabled === false &&
+ © {new Date().getFullYear()} LangGenius, Inc. All rights reserved. +
} +
+
+ +} diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx new file mode 100644 index 0000000000..6ad63996ae --- /dev/null +++ b/web/app/account/oauth/authorize/page.tsx @@ -0,0 +1,205 @@ +'use client' + +import React, { useEffect, useMemo, useRef } from 'react' +import { useTranslation } from 'react-i18next' +import { useRouter, useSearchParams } from 'next/navigation' +import Button from '@/app/components/base/button' +import Avatar from '@/app/components/base/avatar' +import Loading from '@/app/components/base/loading' +import Toast from '@/app/components/base/toast' +import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useAppContext } from '@/context/app-context' +import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth' +import { + RiAccountCircleLine, + RiGlobalLine, + RiInfoCardLine, + RiMailLine, + RiTranslate2, +} from '@remixicon/react' +import dayjs from 'dayjs' + +export const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending' +export const REDIRECT_URL_KEY = 'oauth_redirect_url' + +const OAUTH_AUTHORIZE_PENDING_TTL = 60 * 3 + +function setItemWithExpiry(key: string, value: string, ttl: number) { + const item = { + value, + expiry: dayjs().add(ttl, 'seconds').unix(), + } + localStorage.setItem(key, JSON.stringify(item)) +} + +function buildReturnUrl(pathname: string, search: string) { + try { + const base = `${globalThis.location.origin}${pathname}${search}` + return base + } + catch { + return pathname + search + } +} + +export default function OAuthAuthorize() { + const { t } = useTranslation() + + const SCOPE_INFO_MAP: Record, label: string }> = { + 'read:name': { + icon: RiInfoCardLine, + label: t('oauth.scopes.name'), + }, + 'read:email': { + icon: RiMailLine, + label: t('oauth.scopes.email'), + }, + 'read:avatar': { + icon: RiAccountCircleLine, + label: t('oauth.scopes.avatar'), + }, + 'read:interface_language': { + icon: RiTranslate2, + label: t('oauth.scopes.languagePreference'), + }, + 'read:timezone': { + icon: RiGlobalLine, + label: t('oauth.scopes.timezone'), + }, + } + + const router = useRouter() + const language = useLanguage() + const searchParams = useSearchParams() + const client_id = decodeURIComponent(searchParams.get('client_id') || '') + const redirect_uri = decodeURIComponent(searchParams.get('redirect_uri') || '') + const { userProfile } = useAppContext() + const { data: authAppInfo, isLoading, isError } = useOAuthAppInfo(client_id, redirect_uri) + const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp() + const hasNotifiedRef = useRef(false) + + const isLoggedIn = useMemo(() => { + try { + return Boolean(localStorage.getItem('console_token') && localStorage.getItem('refresh_token')) + } + catch { return false } + }, []) + + const onLoginSwitchClick = () => { + try { + const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`) + setItemWithExpiry(OAUTH_AUTHORIZE_PENDING_KEY, returnUrl, OAUTH_AUTHORIZE_PENDING_TTL) + router.push(`/signin?${REDIRECT_URL_KEY}=${encodeURIComponent(returnUrl)}`) + } + catch { + router.push('/signin') + } + } + + const onAuthorize = async () => { + if (!client_id || !redirect_uri) + return + try { + const { code } = await authorize({ client_id }) + const url = new URL(redirect_uri) + url.searchParams.set('code', code) + globalThis.location.href = url.toString() + } + catch (err: any) { + Toast.notify({ + type: 'error', + message: `${t('oauth.error.authorizeFailed')}: ${err.message}`, + }) + } + } + + useEffect(() => { + const invalidParams = !client_id || !redirect_uri + if ((invalidParams || isError) && !hasNotifiedRef.current) { + hasNotifiedRef.current = true + Toast.notify({ + type: 'error', + message: invalidParams ? t('oauth.error.invalidParams') : t('oauth.error.authAppInfoFetchFailed'), + duration: 0, + }) + } + }, [client_id, redirect_uri, isError]) + + if (isLoading) { + return ( +
+ +
+ ) + } + + return ( +
+ {authAppInfo?.app_icon && ( +
+ app icon +
+ )} + +
+
+ {isLoggedIn &&
{t('oauth.connect')}
} +
{authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('oauth.unknownApp')}
+ {!isLoggedIn &&
{t('oauth.tips.notLoggedIn')}
} +
+
{isLoggedIn ? `${authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('oauth.unknownApp')} ${t('oauth.tips.loggedIn')}` : t('oauth.tips.needLogin')}
+
+ + {isLoggedIn && userProfile && ( +
+
+ +
+
{userProfile.name}
+
{userProfile.email}
+
+
+ +
+ )} + + {isLoggedIn && Boolean(authAppInfo?.scope) && ( +
+ {authAppInfo!.scope.split(/\s+/).filter(Boolean).map((scope: string) => { + const Icon = SCOPE_INFO_MAP[scope] + return ( +
+ {Icon ? : } + {Icon.label} +
+ ) + })} +
+ )} + +
+ {!isLoggedIn ? ( + + ) : ( + <> + + + + )} +
+
+ + + + + + + + + + +
+
{t('oauth.tips.common')}
+
+ ) +} diff --git a/web/app/components/base/toast/index.tsx b/web/app/components/base/toast/index.tsx index a23a60dbf1..245f709143 100644 --- a/web/app/components/base/toast/index.tsx +++ b/web/app/components/base/toast/index.tsx @@ -56,12 +56,11 @@ const Toast = ({ 'top-0', 'right-0', )}> -
@@ -162,7 +161,9 @@ Toast.notify = ({ , ) document.body.appendChild(holder) - setTimeout(toastHandler.clear, duration || defaultDuring) + const d = duration ?? defaultDuring + if (d > 0) + setTimeout(toastHandler.clear, d) } return toastHandler diff --git a/web/app/components/swr-initializer.tsx b/web/app/components/swr-initializer.tsx index a3f6e011d8..0a873400d6 100644 --- a/web/app/components/swr-initializer.tsx +++ b/web/app/components/swr-initializer.tsx @@ -9,6 +9,7 @@ import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, } from '@/app/education-apply/constants' +import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' type SwrInitializerProps = { children: ReactNode @@ -63,7 +64,11 @@ const SwrInitializer = ({ if (searchParams.has('access_token') || searchParams.has('refresh_token')) { consoleToken && localStorage.setItem('console_token', consoleToken) refreshToken && localStorage.setItem('refresh_token', refreshToken) - router.replace(pathname) + const redirectUrl = resolvePostLoginRedirect(searchParams) + if (redirectUrl) + location.replace(redirectUrl) + else + router.replace(pathname) } setInit(true) diff --git a/web/app/signin/check-code/page.tsx b/web/app/signin/check-code/page.tsx index 9c3f7768f8..8edb12eb7e 100644 --- a/web/app/signin/check-code/page.tsx +++ b/web/app/signin/check-code/page.tsx @@ -10,6 +10,7 @@ import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { emailLoginWithCode, sendEMailLoginCode } from '@/service/common' import I18NContext from '@/context/i18n' +import { resolvePostLoginRedirect } from '../utils/post-login-redirect' export default function CheckCode() { const { t } = useTranslation() @@ -43,7 +44,13 @@ export default function CheckCode() { if (ret.result === 'success') { localStorage.setItem('console_token', ret.data.access_token) localStorage.setItem('refresh_token', ret.data.refresh_token) - router.replace(invite_token ? `/signin/invite-settings?${searchParams.toString()}` : '/apps') + if (invite_token) { + router.replace(`/signin/invite-settings?${searchParams.toString()}`) + } + else { + const redirectUrl = resolvePostLoginRedirect(searchParams) + router.replace(redirectUrl || '/apps') + } } } catch (error) { console.error(error) } diff --git a/web/app/signin/components/mail-and-password-auth.tsx b/web/app/signin/components/mail-and-password-auth.tsx index 7360fdac44..b7e010e2fd 100644 --- a/web/app/signin/components/mail-and-password-auth.tsx +++ b/web/app/signin/components/mail-and-password-auth.tsx @@ -10,6 +10,7 @@ import { login } from '@/service/common' import Input from '@/app/components/base/input' import I18NContext from '@/context/i18n' import { noop } from 'lodash-es' +import { resolvePostLoginRedirect } from '../utils/post-login-redirect' type MailAndPasswordAuthProps = { isInvite: boolean @@ -74,7 +75,8 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis else { localStorage.setItem('console_token', res.data.access_token) localStorage.setItem('refresh_token', res.data.refresh_token) - router.replace('/apps') + const redirectUrl = resolvePostLoginRedirect(searchParams) + router.replace(redirectUrl || '/apps') } } else if (res.code === 'account_not_found') { diff --git a/web/app/signin/invite-settings/page.tsx b/web/app/signin/invite-settings/page.tsx index fae62de530..036edfc478 100644 --- a/web/app/signin/invite-settings/page.tsx +++ b/web/app/signin/invite-settings/page.tsx @@ -18,6 +18,7 @@ import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import { noop } from 'lodash-es' import { useGlobalPublicStore } from '@/context/global-public-context' +import { resolvePostLoginRedirect } from '../utils/post-login-redirect' export default function InviteSettingsPage() { const { t } = useTranslation() @@ -60,7 +61,8 @@ export default function InviteSettingsPage() { localStorage.setItem('console_token', res.data.access_token) localStorage.setItem('refresh_token', res.data.refresh_token) await setLocaleOnClient(language, false) - router.replace('/apps') + const redirectUrl = resolvePostLoginRedirect(searchParams) + router.replace(redirectUrl || '/apps') } } catch { diff --git a/web/app/signin/layout.tsx b/web/app/signin/layout.tsx index 4e9ac7ebf9..7e7280f5b8 100644 --- a/web/app/signin/layout.tsx +++ b/web/app/signin/layout.tsx @@ -10,7 +10,7 @@ export default function SignInLayout({ children }: any) { useDocumentTitle('') return <>
-
+
diff --git a/web/app/signin/normal-form.tsx b/web/app/signin/normal-form.tsx index 51046fbd06..3d20b72c5f 100644 --- a/web/app/signin/normal-form.tsx +++ b/web/app/signin/normal-form.tsx @@ -14,6 +14,7 @@ import { LicenseStatus } from '@/types/feature' import Toast from '@/app/components/base/toast' import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' +import { resolvePostLoginRedirect } from './utils/post-login-redirect' const NormalForm = () => { const { t } = useTranslation() @@ -37,7 +38,8 @@ const NormalForm = () => { if (consoleToken && refreshToken) { localStorage.setItem('console_token', consoleToken) localStorage.setItem('refresh_token', refreshToken) - router.replace('/apps') + const redirectUrl = resolvePostLoginRedirect(searchParams) + router.replace(redirectUrl || '/apps') return } diff --git a/web/app/signin/utils/post-login-redirect.ts b/web/app/signin/utils/post-login-redirect.ts new file mode 100644 index 0000000000..37ab122dfa --- /dev/null +++ b/web/app/signin/utils/post-login-redirect.ts @@ -0,0 +1,36 @@ +import { OAUTH_AUTHORIZE_PENDING_KEY, REDIRECT_URL_KEY } from '@/app/account/oauth/authorize/page' +import dayjs from 'dayjs' +import type { ReadonlyURLSearchParams } from 'next/navigation' + +function getItemWithExpiry(key: string): string | null { + const itemStr = localStorage.getItem(key) + if (!itemStr) + return null + + try { + const item = JSON.parse(itemStr) + localStorage.removeItem(key) + if (!item?.value) return null + + return dayjs().unix() > item.expiry ? null : item.value + } + catch { + return null + } +} + +export const resolvePostLoginRedirect = (searchParams: ReadonlyURLSearchParams) => { + const redirectUrl = searchParams.get(REDIRECT_URL_KEY) + if (redirectUrl) { + try { + localStorage.removeItem(OAUTH_AUTHORIZE_PENDING_KEY) + return decodeURIComponent(redirectUrl) + } + catch (e) { + console.error('Failed to decode redirect URL:', e) + return redirectUrl + } + } + + return getItemWithExpiry(OAUTH_AUTHORIZE_PENDING_KEY) +} diff --git a/web/context/app-context.tsx b/web/context/app-context.tsx index 4ba9e3492d..c033e1dcfa 100644 --- a/web/context/app-context.tsx +++ b/web/context/app-context.tsx @@ -24,13 +24,13 @@ export type AppContextValue = { } const userProfilePlaceholder = { - id: '', - name: '', - email: '', - avatar: '', - avatar_url: '', - is_password_set: false, - } + id: '', + name: '', + email: '', + avatar: '', + avatar_url: '', + is_password_set: false, +} const initialLangGeniusVersionInfo = { current_env: '', @@ -96,13 +96,13 @@ export const AppContextProvider: FC = ({ children }) => const versionData = await fetchLangGeniusVersion({ url: '/version', params: { current_version } }) setLangGeniusVersionInfo({ ...versionData, current_version, latest_version: versionData.version, current_env }) } - catch (error) { + catch (error) { console.error('Failed to update user profile:', error) if (userProfile.id === '') setUserProfile(userProfilePlaceholder) } } - else if (userProfileError && userProfile.id === '') { + else if (userProfileError && userProfile.id === '') { setUserProfile(userProfilePlaceholder) } }, [userProfileResponse, userProfileError, userProfile.id]) diff --git a/web/i18n-config/i18next-config.ts b/web/i18n-config/i18next-config.ts index 19ac59ebb4..da3a2f3425 100644 --- a/web/i18n-config/i18next-config.ts +++ b/web/i18n-config/i18next-config.ts @@ -34,6 +34,7 @@ const NAMESPACES = [ 'explore', 'layout', 'login', + 'oauth', 'plugin-tags', 'plugin', 'register', diff --git a/web/i18n/en-US/oauth.ts b/web/i18n/en-US/oauth.ts new file mode 100644 index 0000000000..ff71487fcd --- /dev/null +++ b/web/i18n/en-US/oauth.ts @@ -0,0 +1,27 @@ +const translation = { + tips: { + loggedIn: 'wants to access the following information from your Dify Cloud account.', + notLoggedIn: 'wants to access your Dify Cloud account', + needLogin: 'Please log in to authorize', + common: 'We respect your privacy and will only use this information to enhance your experience with our developer tools.', + }, + connect: 'Connect to', + continue: 'Continue', + switchAccount: 'Switch Account', + login: 'Login', + scopes: { + name: 'Name', + email: 'Email', + avatar: 'Avatar', + languagePreference: 'Language Preference', + timezone: 'Timezone', + }, + error: { + invalidParams: 'Invalid parameters', + authorizeFailed: 'Authorize failed', + authAppInfoFetchFailed: 'Failed to fetch app info for authorization', + }, + unknownApp: 'Unknown App', +} + +export default translation diff --git a/web/i18n/zh-Hans/oauth.ts b/web/i18n/zh-Hans/oauth.ts new file mode 100644 index 0000000000..2afde687b2 --- /dev/null +++ b/web/i18n/zh-Hans/oauth.ts @@ -0,0 +1,27 @@ +const translation = { + tips: { + loggedIn: '想要访问您的 Dify Cloud 账号中的以下信息。', + notLoggedIn: '想要访问您的 Dify Cloud 账号', + needLogin: '请先登录以授权', + common: '我们尊重您的隐私,并仅使用此信息来增强您对我们开发工具的使用体验。', + }, + connect: '连接到', + continue: '继续', + switchAccount: '切换账号', + login: '登录', + scopes: { + name: '名称', + email: '邮箱', + avatar: '头像', + languagePreference: '语言偏好', + timezone: '时区', + }, + error: { + invalidParams: '无效的参数', + authorizeFailed: '授权失败', + authAppInfoFetchFailed: '获取待授权应用的信息失败', + }, + unknownApp: '未知应用', +} + +export default translation diff --git a/web/service/use-oauth.ts b/web/service/use-oauth.ts new file mode 100644 index 0000000000..d3860fe8d8 --- /dev/null +++ b/web/service/use-oauth.ts @@ -0,0 +1,29 @@ +import { post } from './base' +import { useMutation, useQuery } from '@tanstack/react-query' + +const NAME_SPACE = 'oauth-provider' + +export type OAuthAppInfo = { + app_icon: string + app_label: Record + scope: string +} + +export type OAuthAuthorizeResponse = { + code: string +} + +export const useOAuthAppInfo = (client_id: string, redirect_uri: string) => { + return useQuery({ + queryKey: [NAME_SPACE, 'authAppInfo', client_id, redirect_uri], + queryFn: () => post('/oauth/provider', { body: { client_id, redirect_uri } }, { silent: true }), + enabled: Boolean(client_id && redirect_uri), + }) +} + +export const useAuthorizeOAuthApp = () => { + return useMutation({ + mutationKey: [NAME_SPACE, 'authorize'], + mutationFn: (payload: { client_id: string }) => post('/oauth/provider/authorize', { body: payload }), + }) +}