diff --git a/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch b/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch deleted file mode 100644 index ef9e74c73a..0000000000 --- a/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch +++ /dev/null @@ -1,279 +0,0 @@ -diff --git a/client.js b/client.js -index 33b4ff6309d5f29187dab4e285d07dac20340bab..8f568637ee9e4677585931fb0284c8165a933f69 100644 ---- a/client.js -+++ b/client.js -@@ -433,7 +433,7 @@ class OpenAI { - 'User-Agent': this.getUserAgent(), - 'X-Stainless-Retry-Count': String(retryCount), - ...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}), -- ...(0, detect_platform_1.getPlatformHeaders)(), -+ // ...(0, detect_platform_1.getPlatformHeaders)(), - 'OpenAI-Organization': this.organization, - 'OpenAI-Project': this.project, - }, -diff --git a/client.mjs b/client.mjs -index c34c18213073540ebb296ea540b1d1ad39527906..1ce1a98256d7e90e26ca963582f235b23e996e73 100644 ---- a/client.mjs -+++ b/client.mjs -@@ -430,7 +430,7 @@ export class OpenAI { - 'User-Agent': this.getUserAgent(), - 'X-Stainless-Retry-Count': String(retryCount), - ...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}), -- ...getPlatformHeaders(), -+ // ...getPlatformHeaders(), - 'OpenAI-Organization': this.organization, - 'OpenAI-Project': this.project, - }, -diff --git a/core/error.js b/core/error.js -index a12d9d9ccd242050161adeb0f82e1b98d9e78e20..fe3a5462480558bc426deea147f864f12b36f9bd 100644 ---- a/core/error.js -+++ b/core/error.js -@@ -40,7 +40,7 @@ class APIError extends OpenAIError { - if (!status || !headers) { - return new APIConnectionError({ message, cause: (0, errors_1.castToError)(errorResponse) }); - } -- const error = errorResponse?.['error']; -+ const error = errorResponse?.['error'] || errorResponse; - if (status === 400) { - return new BadRequestError(status, error, message, headers); - } -diff --git a/core/error.mjs b/core/error.mjs -index 83cefbaffeb8c657536347322d8de9516af479a2..63334b7972ec04882aa4a0800c1ead5982345045 100644 ---- a/core/error.mjs -+++ b/core/error.mjs -@@ -36,7 +36,7 @@ export class APIError extends OpenAIError { - if (!status || !headers) { - return new APIConnectionError({ message, cause: castToError(errorResponse) }); - } -- const error = errorResponse?.['error']; -+ const error = errorResponse?.['error'] || errorResponse; - if (status === 400) { - return new BadRequestError(status, error, message, headers); - } -diff --git a/resources/embeddings.js b/resources/embeddings.js -index 2404264d4ba0204322548945ebb7eab3bea82173..8f1bc45cc45e0797d50989d96b51147b90ae6790 100644 ---- a/resources/embeddings.js -+++ b/resources/embeddings.js -@@ -5,52 +5,64 @@ exports.Embeddings = void 0; - const resource_1 = require("../core/resource.js"); - const utils_1 = require("../internal/utils.js"); - class Embeddings extends resource_1.APIResource { -- /** -- * Creates an embedding vector representing the input text. -- * -- * @example -- * ```ts -- * const createEmbeddingResponse = -- * await client.embeddings.create({ -- * input: 'The quick brown fox jumped over the lazy dog', -- * model: 'text-embedding-3-small', -- * }); -- * ``` -- */ -- create(body, options) { -- const hasUserProvidedEncodingFormat = !!body.encoding_format; -- // No encoding_format specified, defaulting to base64 for performance reasons -- // See https://github.com/openai/openai-node/pull/1312 -- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64'; -- if (hasUserProvidedEncodingFormat) { -- (0, utils_1.loggerFor)(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format); -- } -- const response = this._client.post('/embeddings', { -- body: { -- ...body, -- encoding_format: encoding_format, -- }, -- ...options, -- }); -- // if the user specified an encoding_format, return the response as-is -- if (hasUserProvidedEncodingFormat) { -- return response; -- } -- // in this stage, we are sure the user did not specify an encoding_format -- // and we defaulted to base64 for performance reasons -- // we are sure then that the response is base64 encoded, let's decode it -- // the returned result will be a float32 array since this is OpenAI API's default encoding -- (0, utils_1.loggerFor)(this._client).debug('embeddings/decoding base64 embeddings from base64'); -- return response._thenUnwrap((response) => { -- if (response && response.data) { -- response.data.forEach((embeddingBase64Obj) => { -- const embeddingBase64Str = embeddingBase64Obj.embedding; -- embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)(embeddingBase64Str); -- }); -- } -- return response; -- }); -- } -+ /** -+ * Creates an embedding vector representing the input text. -+ * -+ * @example -+ * ```ts -+ * const createEmbeddingResponse = -+ * await client.embeddings.create({ -+ * input: 'The quick brown fox jumped over the lazy dog', -+ * model: 'text-embedding-3-small', -+ * }); -+ * ``` -+ */ -+ create(body, options) { -+ const hasUserProvidedEncodingFormat = !!body.encoding_format; -+ // No encoding_format specified, defaulting to base64 for performance reasons -+ // See https://github.com/openai/openai-node/pull/1312 -+ let encoding_format = hasUserProvidedEncodingFormat -+ ? body.encoding_format -+ : "base64"; -+ if (body.model.includes("jina")) { -+ encoding_format = undefined; -+ } -+ if (hasUserProvidedEncodingFormat) { -+ (0, utils_1.loggerFor)(this._client).debug( -+ "embeddings/user defined encoding_format:", -+ body.encoding_format -+ ); -+ } -+ const response = this._client.post("/embeddings", { -+ body: { -+ ...body, -+ encoding_format: encoding_format, -+ }, -+ ...options, -+ }); -+ // if the user specified an encoding_format, return the response as-is -+ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) { -+ return response; -+ } -+ // in this stage, we are sure the user did not specify an encoding_format -+ // and we defaulted to base64 for performance reasons -+ // we are sure then that the response is base64 encoded, let's decode it -+ // the returned result will be a float32 array since this is OpenAI API's default encoding -+ (0, utils_1.loggerFor)(this._client).debug( -+ "embeddings/decoding base64 embeddings from base64" -+ ); -+ return response._thenUnwrap((response) => { -+ if (response && response.data && typeof response.data[0]?.embedding === 'string') { -+ response.data.forEach((embeddingBase64Obj) => { -+ const embeddingBase64Str = embeddingBase64Obj.embedding; -+ embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)( -+ embeddingBase64Str -+ ); -+ }); -+ } -+ return response; -+ }); -+ } - } - exports.Embeddings = Embeddings; - //# sourceMappingURL=embeddings.js.map -diff --git a/resources/embeddings.mjs b/resources/embeddings.mjs -index 19dcaef578c194a89759c4360073cfd4f7dd2cbf..0284e9cc615c900eff508eb595f7360a74bd9200 100644 ---- a/resources/embeddings.mjs -+++ b/resources/embeddings.mjs -@@ -2,51 +2,61 @@ - import { APIResource } from "../core/resource.mjs"; - import { loggerFor, toFloat32Array } from "../internal/utils.mjs"; - export class Embeddings extends APIResource { -- /** -- * Creates an embedding vector representing the input text. -- * -- * @example -- * ```ts -- * const createEmbeddingResponse = -- * await client.embeddings.create({ -- * input: 'The quick brown fox jumped over the lazy dog', -- * model: 'text-embedding-3-small', -- * }); -- * ``` -- */ -- create(body, options) { -- const hasUserProvidedEncodingFormat = !!body.encoding_format; -- // No encoding_format specified, defaulting to base64 for performance reasons -- // See https://github.com/openai/openai-node/pull/1312 -- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64'; -- if (hasUserProvidedEncodingFormat) { -- loggerFor(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format); -- } -- const response = this._client.post('/embeddings', { -- body: { -- ...body, -- encoding_format: encoding_format, -- }, -- ...options, -- }); -- // if the user specified an encoding_format, return the response as-is -- if (hasUserProvidedEncodingFormat) { -- return response; -- } -- // in this stage, we are sure the user did not specify an encoding_format -- // and we defaulted to base64 for performance reasons -- // we are sure then that the response is base64 encoded, let's decode it -- // the returned result will be a float32 array since this is OpenAI API's default encoding -- loggerFor(this._client).debug('embeddings/decoding base64 embeddings from base64'); -- return response._thenUnwrap((response) => { -- if (response && response.data) { -- response.data.forEach((embeddingBase64Obj) => { -- const embeddingBase64Str = embeddingBase64Obj.embedding; -- embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str); -- }); -- } -- return response; -- }); -- } -+ /** -+ * Creates an embedding vector representing the input text. -+ * -+ * @example -+ * ```ts -+ * const createEmbeddingResponse = -+ * await client.embeddings.create({ -+ * input: 'The quick brown fox jumped over the lazy dog', -+ * model: 'text-embedding-3-small', -+ * }); -+ * ``` -+ */ -+ create(body, options) { -+ const hasUserProvidedEncodingFormat = !!body.encoding_format; -+ // No encoding_format specified, defaulting to base64 for performance reasons -+ // See https://github.com/openai/openai-node/pull/1312 -+ let encoding_format = hasUserProvidedEncodingFormat -+ ? body.encoding_format -+ : "base64"; -+ if (body.model.includes("jina")) { -+ encoding_format = undefined; -+ } -+ if (hasUserProvidedEncodingFormat) { -+ loggerFor(this._client).debug( -+ "embeddings/user defined encoding_format:", -+ body.encoding_format -+ ); -+ } -+ const response = this._client.post("/embeddings", { -+ body: { -+ ...body, -+ encoding_format: encoding_format, -+ }, -+ ...options, -+ }); -+ // if the user specified an encoding_format, return the response as-is -+ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) { -+ return response; -+ } -+ // in this stage, we are sure the user did not specify an encoding_format -+ // and we defaulted to base64 for performance reasons -+ // we are sure then that the response is base64 encoded, let's decode it -+ // the returned result will be a float32 array since this is OpenAI API's default encoding -+ loggerFor(this._client).debug( -+ "embeddings/decoding base64 embeddings from base64" -+ ); -+ return response._thenUnwrap((response) => { -+ if (response && response.data && typeof response.data[0]?.embedding === 'string') { -+ response.data.forEach((embeddingBase64Obj) => { -+ const embeddingBase64Str = embeddingBase64Obj.embedding; -+ embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str); -+ }); -+ } -+ return response; -+ }); -+ } - } - //# sourceMappingURL=embeddings.mjs.map diff --git a/.yarn/patches/openai-npm-5.12.0-a06a6369b2.patch b/.yarn/patches/openai-npm-5.12.0-a06a6369b2.patch new file mode 100644 index 0000000000..39f0c9b7da --- /dev/null +++ b/.yarn/patches/openai-npm-5.12.0-a06a6369b2.patch @@ -0,0 +1,344 @@ +diff --git a/client.js b/client.js +index 22cc08d77ce849842a28f684c20dd5738152efa4..0c20f96405edbe7724b87517115fa2a61934b343 100644 +--- a/client.js ++++ b/client.js +@@ -444,7 +444,7 @@ class OpenAI { + 'User-Agent': this.getUserAgent(), + 'X-Stainless-Retry-Count': String(retryCount), + ...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}), +- ...(0, detect_platform_1.getPlatformHeaders)(), ++ // ...(0, detect_platform_1.getPlatformHeaders)(), + 'OpenAI-Organization': this.organization, + 'OpenAI-Project': this.project, + }, +diff --git a/client.mjs b/client.mjs +index 7f1af99fb30d2cae03eea6687b53e6c7828faceb..fd66373a5eff31a5846084387a3fd97956c9ad48 100644 +--- a/client.mjs ++++ b/client.mjs +@@ -1,43 +1,41 @@ + // File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + var _OpenAI_instances, _a, _OpenAI_encoder, _OpenAI_baseURLOverridden; +-import { __classPrivateFieldGet, __classPrivateFieldSet } from "./internal/tslib.mjs"; +-import { uuid4 } from "./internal/utils/uuid.mjs"; +-import { validatePositiveInteger, isAbsoluteURL, safeJSON } from "./internal/utils/values.mjs"; +-import { sleep } from "./internal/utils/sleep.mjs"; +-import { castToError, isAbortError } from "./internal/errors.mjs"; +-import { getPlatformHeaders } from "./internal/detect-platform.mjs"; +-import * as Shims from "./internal/shims.mjs"; +-import * as Opts from "./internal/request-options.mjs"; +-import * as qs from "./internal/qs/index.mjs"; +-import { VERSION } from "./version.mjs"; ++import { APIPromise } from "./core/api-promise.mjs"; + import * as Errors from "./core/error.mjs"; + import * as Pagination from "./core/pagination.mjs"; + import * as Uploads from "./core/uploads.mjs"; +-import * as API from "./resources/index.mjs"; +-import { APIPromise } from "./core/api-promise.mjs"; +-import { Batches, } from "./resources/batches.mjs"; +-import { Completions, } from "./resources/completions.mjs"; +-import { Embeddings, } from "./resources/embeddings.mjs"; +-import { Files, } from "./resources/files.mjs"; +-import { Images, } from "./resources/images.mjs"; +-import { Models } from "./resources/models.mjs"; +-import { Moderations, } from "./resources/moderations.mjs"; +-import { Webhooks } from "./resources/webhooks.mjs"; ++import { isRunningInBrowser } from "./internal/detect-platform.mjs"; ++import { castToError, isAbortError } from "./internal/errors.mjs"; ++import { buildHeaders } from "./internal/headers.mjs"; ++import * as qs from "./internal/qs/index.mjs"; ++import * as Opts from "./internal/request-options.mjs"; ++import * as Shims from "./internal/shims.mjs"; ++import { __classPrivateFieldGet, __classPrivateFieldSet } from "./internal/tslib.mjs"; ++import { readEnv } from "./internal/utils/env.mjs"; ++import { formatRequestDetails, loggerFor, parseLogLevel, } from "./internal/utils/log.mjs"; ++import { sleep } from "./internal/utils/sleep.mjs"; ++import { uuid4 } from "./internal/utils/uuid.mjs"; ++import { isAbsoluteURL, isEmptyObj, safeJSON, validatePositiveInteger } from "./internal/utils/values.mjs"; + import { Audio } from "./resources/audio/audio.mjs"; ++import { Batches, } from "./resources/batches.mjs"; + import { Beta } from "./resources/beta/beta.mjs"; + import { Chat } from "./resources/chat/chat.mjs"; ++import { Completions, } from "./resources/completions.mjs"; + import { Containers, } from "./resources/containers/containers.mjs"; ++import { Embeddings, } from "./resources/embeddings.mjs"; + import { Evals, } from "./resources/evals/evals.mjs"; ++import { Files, } from "./resources/files.mjs"; + import { FineTuning } from "./resources/fine-tuning/fine-tuning.mjs"; + import { Graders } from "./resources/graders/graders.mjs"; ++import { Images, } from "./resources/images.mjs"; ++import * as API from "./resources/index.mjs"; ++import { Models } from "./resources/models.mjs"; ++import { Moderations, } from "./resources/moderations.mjs"; + import { Responses } from "./resources/responses/responses.mjs"; + import { Uploads as UploadsAPIUploads, } from "./resources/uploads/uploads.mjs"; + import { VectorStores, } from "./resources/vector-stores/vector-stores.mjs"; +-import { isRunningInBrowser } from "./internal/detect-platform.mjs"; +-import { buildHeaders } from "./internal/headers.mjs"; +-import { readEnv } from "./internal/utils/env.mjs"; +-import { formatRequestDetails, loggerFor, parseLogLevel, } from "./internal/utils/log.mjs"; +-import { isEmptyObj } from "./internal/utils/values.mjs"; ++import { Webhooks } from "./resources/webhooks.mjs"; ++import { VERSION } from "./version.mjs"; + /** + * API Client for interfacing with the OpenAI API. + */ +@@ -441,7 +439,7 @@ export class OpenAI { + 'User-Agent': this.getUserAgent(), + 'X-Stainless-Retry-Count': String(retryCount), + ...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}), +- ...getPlatformHeaders(), ++ // ...getPlatformHeaders(), + 'OpenAI-Organization': this.organization, + 'OpenAI-Project': this.project, + }, +diff --git a/core/error.js b/core/error.js +index c302cc356f0f24b50c3f5a0aa3ea0b79ae1e9a8d..164ee2ee31cd7eea8f70139e25d140b763e91d36 100644 +--- a/core/error.js ++++ b/core/error.js +@@ -40,7 +40,7 @@ class APIError extends OpenAIError { + if (!status || !headers) { + return new APIConnectionError({ message, cause: (0, errors_1.castToError)(errorResponse) }); + } +- const error = errorResponse?.['error']; ++ const error = errorResponse?.['error'] || errorResponse; + if (status === 400) { + return new BadRequestError(status, error, message, headers); + } +diff --git a/core/error.mjs b/core/error.mjs +index 75f5b0c328cc4894478f3490a00dbf6abd96fc12..269f46f96e9fad1f7a1649a3810562abc7fae37f 100644 +--- a/core/error.mjs ++++ b/core/error.mjs +@@ -36,7 +36,7 @@ export class APIError extends OpenAIError { + if (!status || !headers) { + return new APIConnectionError({ message, cause: castToError(errorResponse) }); + } +- const error = errorResponse?.['error']; ++ const error = errorResponse?.['error'] || errorResponse; + if (status === 400) { + return new BadRequestError(status, error, message, headers); + } +diff --git a/resources/embeddings.js b/resources/embeddings.js +index 2404264d4ba0204322548945ebb7eab3bea82173..93b9e286f62101b5aa7532e96ddba61f682ece3f 100644 +--- a/resources/embeddings.js ++++ b/resources/embeddings.js +@@ -5,52 +5,64 @@ exports.Embeddings = void 0; + const resource_1 = require("../core/resource.js"); + const utils_1 = require("../internal/utils.js"); + class Embeddings extends resource_1.APIResource { +- /** +- * Creates an embedding vector representing the input text. +- * +- * @example +- * ```ts +- * const createEmbeddingResponse = +- * await client.embeddings.create({ +- * input: 'The quick brown fox jumped over the lazy dog', +- * model: 'text-embedding-3-small', +- * }); +- * ``` +- */ +- create(body, options) { +- const hasUserProvidedEncodingFormat = !!body.encoding_format; +- // No encoding_format specified, defaulting to base64 for performance reasons +- // See https://github.com/openai/openai-node/pull/1312 +- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64'; +- if (hasUserProvidedEncodingFormat) { +- (0, utils_1.loggerFor)(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format); +- } +- const response = this._client.post('/embeddings', { +- body: { +- ...body, +- encoding_format: encoding_format, +- }, +- ...options, +- }); +- // if the user specified an encoding_format, return the response as-is +- if (hasUserProvidedEncodingFormat) { +- return response; +- } +- // in this stage, we are sure the user did not specify an encoding_format +- // and we defaulted to base64 for performance reasons +- // we are sure then that the response is base64 encoded, let's decode it +- // the returned result will be a float32 array since this is OpenAI API's default encoding +- (0, utils_1.loggerFor)(this._client).debug('embeddings/decoding base64 embeddings from base64'); +- return response._thenUnwrap((response) => { +- if (response && response.data) { +- response.data.forEach((embeddingBase64Obj) => { +- const embeddingBase64Str = embeddingBase64Obj.embedding; +- embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)(embeddingBase64Str); +- }); +- } +- return response; +- }); ++ /** ++ * Creates an embedding vector representing the input text. ++ * ++ * @example ++ * ```ts ++ * const createEmbeddingResponse = ++ * await client.embeddings.create({ ++ * input: 'The quick brown fox jumped over the lazy dog', ++ * model: 'text-embedding-3-small', ++ * }); ++ * ``` ++ */ ++ create(body, options) { ++ const hasUserProvidedEncodingFormat = !!body.encoding_format; ++ // No encoding_format specified, defaulting to base64 for performance reasons ++ // See https://github.com/openai/openai-node/pull/1312 ++ let encoding_format = hasUserProvidedEncodingFormat ++ ? body.encoding_format ++ : "base64"; ++ if (body.model.includes("jina")) { ++ encoding_format = undefined; ++ } ++ if (hasUserProvidedEncodingFormat) { ++ (0, utils_1.loggerFor)(this._client).debug( ++ "embeddings/user defined encoding_format:", ++ body.encoding_format ++ ); + } ++ const response = this._client.post("/embeddings", { ++ body: { ++ ...body, ++ encoding_format: encoding_format, ++ }, ++ ...options, ++ }); ++ // if the user specified an encoding_format, return the response as-is ++ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) { ++ return response; ++ } ++ // in this stage, we are sure the user did not specify an encoding_format ++ // and we defaulted to base64 for performance reasons ++ // we are sure then that the response is base64 encoded, let's decode it ++ // the returned result will be a float32 array since this is OpenAI API's default encoding ++ (0, utils_1.loggerFor)(this._client).debug( ++ "embeddings/decoding base64 embeddings from base64" ++ ); ++ return response._thenUnwrap((response) => { ++ if (response && response.data && typeof response.data[0]?.embedding === 'string') { ++ response.data.forEach((embeddingBase64Obj) => { ++ const embeddingBase64Str = embeddingBase64Obj.embedding; ++ embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)( ++ embeddingBase64Str ++ ); ++ }); ++ } ++ return response; ++ }); ++ } + } + exports.Embeddings = Embeddings; + //# sourceMappingURL=embeddings.js.map +diff --git a/resources/embeddings.mjs b/resources/embeddings.mjs +index 19dcaef578c194a89759c4360073cfd4f7dd2cbf..42c903fadb03c707356a983603ff09e4152ecf11 100644 +--- a/resources/embeddings.mjs ++++ b/resources/embeddings.mjs +@@ -2,51 +2,61 @@ + import { APIResource } from "../core/resource.mjs"; + import { loggerFor, toFloat32Array } from "../internal/utils.mjs"; + export class Embeddings extends APIResource { +- /** +- * Creates an embedding vector representing the input text. +- * +- * @example +- * ```ts +- * const createEmbeddingResponse = +- * await client.embeddings.create({ +- * input: 'The quick brown fox jumped over the lazy dog', +- * model: 'text-embedding-3-small', +- * }); +- * ``` +- */ +- create(body, options) { +- const hasUserProvidedEncodingFormat = !!body.encoding_format; +- // No encoding_format specified, defaulting to base64 for performance reasons +- // See https://github.com/openai/openai-node/pull/1312 +- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64'; +- if (hasUserProvidedEncodingFormat) { +- loggerFor(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format); +- } +- const response = this._client.post('/embeddings', { +- body: { +- ...body, +- encoding_format: encoding_format, +- }, +- ...options, +- }); +- // if the user specified an encoding_format, return the response as-is +- if (hasUserProvidedEncodingFormat) { +- return response; +- } +- // in this stage, we are sure the user did not specify an encoding_format +- // and we defaulted to base64 for performance reasons +- // we are sure then that the response is base64 encoded, let's decode it +- // the returned result will be a float32 array since this is OpenAI API's default encoding +- loggerFor(this._client).debug('embeddings/decoding base64 embeddings from base64'); +- return response._thenUnwrap((response) => { +- if (response && response.data) { +- response.data.forEach((embeddingBase64Obj) => { +- const embeddingBase64Str = embeddingBase64Obj.embedding; +- embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str); +- }); +- } +- return response; +- }); ++ /** ++ * Creates an embedding vector representing the input text. ++ * ++ * @example ++ * ```ts ++ * const createEmbeddingResponse = ++ * await client.embeddings.create({ ++ * input: 'The quick brown fox jumped over the lazy dog', ++ * model: 'text-embedding-3-small', ++ * }); ++ * ``` ++ */ ++ create(body, options) { ++ const hasUserProvidedEncodingFormat = !!body.encoding_format; ++ // No encoding_format specified, defaulting to base64 for performance reasons ++ // See https://github.com/openai/openai-node/pull/1312 ++ let encoding_format = hasUserProvidedEncodingFormat ++ ? body.encoding_format ++ : "base64"; ++ if (body.model.includes("jina")) { ++ encoding_format = undefined; ++ } ++ if (hasUserProvidedEncodingFormat) { ++ loggerFor(this._client).debug( ++ "embeddings/user defined encoding_format:", ++ body.encoding_format ++ ); + } ++ const response = this._client.post("/embeddings", { ++ body: { ++ ...body, ++ encoding_format: encoding_format, ++ }, ++ ...options, ++ }); ++ // if the user specified an encoding_format, return the response as-is ++ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) { ++ return response; ++ } ++ // in this stage, we are sure the user did not specify an encoding_format ++ // and we defaulted to base64 for performance reasons ++ // we are sure then that the response is base64 encoded, let's decode it ++ // the returned result will be a float32 array since this is OpenAI API's default encoding ++ loggerFor(this._client).debug( ++ "embeddings/decoding base64 embeddings from base64" ++ ); ++ return response._thenUnwrap((response) => { ++ if (response && response.data && typeof response.data[0]?.embedding === 'string') { ++ response.data.forEach((embeddingBase64Obj) => { ++ const embeddingBase64Str = embeddingBase64Obj.embedding; ++ embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str); ++ }); ++ } ++ return response; ++ }); ++ } + } + //# sourceMappingURL=embeddings.mjs.map diff --git a/.yarn/patches/windows-system-proxy-npm-1.0.0-ff2a828eec.patch b/.yarn/patches/windows-system-proxy-npm-1.0.0-ff2a828eec.patch new file mode 100644 index 0000000000..354f806148 --- /dev/null +++ b/.yarn/patches/windows-system-proxy-npm-1.0.0-ff2a828eec.patch @@ -0,0 +1,23 @@ +diff --git a/dist/index.js b/dist/index.js +index b54962b2d332c1a3affadbdb37d39fdf90ab9f82..7906b4ea3bf9dffe60d74c279e9cfe885489c9f9 100644 +--- a/dist/index.js ++++ b/dist/index.js +@@ -36,12 +36,12 @@ async function getWindowsSystemProxy() { + const proxies = Object.fromEntries(proxyConfigString + .split(';') + .map((proxyPair) => proxyPair.split('='))); +- const proxyUrl = proxies['https'] +- ? `https://${proxies['https']}` +- : proxies['http'] +- ? `http://${proxies['http']}` +- : proxies['socks'] +- ? `socks://${proxies['socks']}` ++ const proxyUrl = proxies['http'] ++ ? `http://${proxies['http']}` ++ : proxies['socks'] ++ ? `socks://${proxies['socks']}` ++ : proxies['https'] ++ ? `https://${proxies['https']}` + : undefined; + if (!proxyUrl) { + throw new Error(`Could not get usable proxy URL from ${proxyConfigString}`); diff --git a/docs/technical/CodeBlockView-en.md b/docs/technical/CodeBlockView-en.md new file mode 100644 index 0000000000..786d7aa029 --- /dev/null +++ b/docs/technical/CodeBlockView-en.md @@ -0,0 +1,180 @@ +# CodeBlockView Component Structure + +## Overview + +CodeBlockView is the core component in Cherry Studio for displaying and manipulating code blocks. It supports multiple view modes and visual previews for special languages, providing rich interactive tools. + +## Component Structure + +```mermaid +graph TD + A[CodeBlockView] --> B[CodeToolbar] + A --> C[SourceView] + A --> D[SpecialView] + A --> E[StatusBar] + + B --> F[CodeToolButton] + + C --> G[CodeEditor / CodeViewer] + + D --> H[MermaidPreview] + D --> I[PlantUmlPreview] + D --> J[SvgPreview] + D --> K[GraphvizPreview] + + F --> L[useCopyTool] + F --> M[useDownloadTool] + F --> N[useViewSourceTool] + F --> O[useSplitViewTool] + F --> P[useRunTool] + F --> Q[useExpandTool] + F --> R[useWrapTool] + F --> S[useSaveTool] +``` + +## Core Concepts + +### View Types + +- **preview**: Preview view, where non-source code is displayed as special views +- **edit**: Edit view + +### View Modes + +- **source**: Source code view mode +- **special**: Special view mode (Mermaid, PlantUML, SVG) +- **split**: Split view mode (source code and special view displayed side by side) + +### Special View Languages + +- mermaid +- plantuml +- svg +- dot +- graphviz + +## Component Details + +### CodeBlockView Main Component + +Main responsibilities: + +1. Managing view mode state +2. Coordinating the display of source code view and special view +3. Managing toolbar tools +4. Handling code execution state + +### Subcomponents + +#### CodeToolbar + +- Toolbar displayed at the top-right corner of the code block +- Contains core and quick tools +- Dynamically displays relevant tools based on context + +#### CodeEditor/CodeViewer Source View + +- Editable code editor or read-only code viewer +- Uses either component based on settings +- Supports syntax highlighting for multiple programming languages + +#### Special View Components + +- **MermaidPreview**: Mermaid diagram preview +- **PlantUmlPreview**: PlantUML diagram preview +- **SvgPreview**: SVG image preview +- **GraphvizPreview**: Graphviz diagram preview + +All special view components share a common architecture for consistent user experience and functionality. For detailed information about these components and their implementation, see [Image Preview Components Documentation](./ImagePreview-en.md). + +#### StatusBar + +- Displays Python code execution results +- Can show both text and image results + +## Tool System + +CodeBlockView uses a hook-based tool system: + +```mermaid +graph TD + A[CodeBlockView] --> B[useCopyTool] + A --> C[useDownloadTool] + A --> D[useViewSourceTool] + A --> E[useSplitViewTool] + A --> F[useRunTool] + A --> G[useExpandTool] + A --> H[useWrapTool] + A --> I[useSaveTool] + + B --> J[ToolManager] + C --> J + D --> J + E --> J + F --> J + G --> J + H --> J + I --> J + + J --> K[CodeToolbar] +``` + +Each tool hook is responsible for registering specific function tool buttons to the tool manager, which then passes these tools to the CodeToolbar component for rendering. + +### Tool Types + +- **core**: Core tools, always displayed in the toolbar +- **quick**: Quick tools, displayed in a dropdown menu when there are more than one + +### Tool List + +1. **Copy**: Copy code or image +2. **Download**: Download code or image +3. **View Source**: Switch between special view and source code view +4. **Split View**: Toggle split view mode +5. **Run**: Run Python code +6. **Expand/Collapse**: Control code block expansion/collapse +7. **Wrap**: Control automatic line wrapping +8. **Save**: Save edited code + +## State Management + +CodeBlockView manages the following states through React hooks: + +1. **viewMode**: Current view mode ('source' | 'special' | 'split') +2. **isRunning**: Python code execution status +3. **executionResult**: Python code execution result +4. **tools**: Toolbar tool list +5. **expandOverride/unwrapOverride**: User override settings for expand/wrap +6. **sourceScrollHeight**: Source code view scroll height + +## Interaction Flow + +```mermaid +sequenceDiagram + participant U as User + participant CB as CodeBlockView + participant CT as CodeToolbar + participant SV as SpecialView + participant SE as SourceEditor + + U->>CB: View code block + CB->>CB: Initialize state + CB->>CT: Register tools + CB->>SV: Render special view (if applicable) + CB->>SE: Render source view + U->>CT: Click tool button + CT->>CB: Trigger tool callback + CB->>CB: Update state + CB->>CT: Re-register tools (if needed) +``` + +## Special Handling + +### HTML Code Blocks + +HTML code blocks are specially handled using the HtmlArtifactsCard component. + +### Python Code Execution + +Supports executing Python code and displaying results using Pyodide to run Python code in the browser. diff --git a/docs/technical/CodeBlockView-zh.md b/docs/technical/CodeBlockView-zh.md new file mode 100644 index 0000000000..a817e99361 --- /dev/null +++ b/docs/technical/CodeBlockView-zh.md @@ -0,0 +1,180 @@ +# CodeBlockView 组件结构说明 + +## 概述 + +CodeBlockView 是 Cherry Studio 中用于显示和操作代码块的核心组件。它支持多种视图模式和特殊语言的可视化预览,提供丰富的交互工具。 + +## 组件结构 + +```mermaid +graph TD + A[CodeBlockView] --> B[CodeToolbar] + A --> C[SourceView] + A --> D[SpecialView] + A --> E[StatusBar] + + B --> F[CodeToolButton] + + C --> G[CodeEditor / CodeViewer] + + D --> H[MermaidPreview] + D --> I[PlantUmlPreview] + D --> J[SvgPreview] + D --> K[GraphvizPreview] + + F --> L[useCopyTool] + F --> M[useDownloadTool] + F --> N[useViewSourceTool] + F --> O[useSplitViewTool] + F --> P[useRunTool] + F --> Q[useExpandTool] + F --> R[useWrapTool] + F --> S[useSaveTool] +``` + +## 核心概念 + +### 视图类型 + +- **preview**: 预览视图,非源代码的是特殊视图 +- **edit**: 编辑视图 + +### 视图模式 + +- **source**: 源代码视图模式 +- **special**: 特殊视图模式(Mermaid、PlantUML、SVG) +- **split**: 分屏模式(源代码和特殊视图并排显示) + +### 特殊视图语言 + +- mermaid +- plantuml +- svg +- dot +- graphviz + +## 组件详细说明 + +### CodeBlockView 主组件 + +主要负责: + +1. 管理视图模式状态 +2. 协调源代码视图和特殊视图的显示 +3. 管理工具栏工具 +4. 处理代码执行状态 + +### 子组件 + +#### CodeToolbar 工具栏 + +- 显示在代码块右上角的工具栏 +- 包含核心(core)和快捷(quick)两类工具 +- 根据上下文动态显示相关工具 + +#### CodeEditor/CodeViewer 源代码视图 + +- 可编辑的代码编辑器或只读的代码查看器 +- 根据设置决定使用哪个组件 +- 支持多种编程语言高亮 + +#### 特殊视图组件 + +- **MermaidPreview**: Mermaid 图表预览 +- **PlantUmlPreview**: PlantUML 图表预览 +- **SvgPreview**: SVG 图像预览 +- **GraphvizPreview**: Graphviz 图表预览 + +所有特殊视图组件共享通用架构,以确保一致的用户体验和功能。有关这些组件及其实现的详细信息,请参阅 [图像预览组件文档](./ImagePreview-zh.md)。 + +#### StatusBar 状态栏 + +- 显示 Python 代码执行结果 +- 可显示文本和图像结果 + +## 工具系统 + +CodeBlockView 使用基于 hooks 的工具系统: + +```mermaid +graph TD + A[CodeBlockView] --> B[useCopyTool] + A --> C[useDownloadTool] + A --> D[useViewSourceTool] + A --> E[useSplitViewTool] + A --> F[useRunTool] + A --> G[useExpandTool] + A --> H[useWrapTool] + A --> I[useSaveTool] + + B --> J[ToolManager] + C --> J + D --> J + E --> J + F --> J + G --> J + H --> J + I --> J + + J --> K[CodeToolbar] +``` + +每个工具 hook 负责注册特定功能的工具按钮到工具管理器,工具管理器再将这些工具传递给 CodeToolbar 组件进行渲染。 + +### 工具类型 + +- **core**: 核心工具,始终显示在工具栏 +- **quick**: 快捷工具,当数量大于1时通过下拉菜单显示 + +### 工具列表 + +1. **复制(copy)**: 复制代码或图像 +2. **下载(download)**: 下载代码或图像 +3. **查看源码(view-source)**: 在特殊视图和源码视图间切换 +4. **分屏(split-view)**: 切换分屏模式 +5. **运行(run)**: 运行 Python 代码 +6. **展开/折叠(expand)**: 控制代码块的展开/折叠 +7. **换行(wrap)**: 控制代码的自动换行 +8. **保存(save)**: 保存编辑的代码 + +## 状态管理 + +CodeBlockView 通过 React hooks 管理以下状态: + +1. **viewMode**: 当前视图模式 ('source' | 'special' | 'split') +2. **isRunning**: Python 代码执行状态 +3. **executionResult**: Python 代码执行结果 +4. **tools**: 工具栏工具列表 +5. **expandOverride/unwrapOverride**: 用户展开/换行的覆盖设置 +6. **sourceScrollHeight**: 源代码视图滚动高度 + +## 交互流程 + +```mermaid +sequenceDiagram + participant U as User + participant CB as CodeBlockView + participant CT as CodeToolbar + participant SV as SpecialView + participant SE as SourceEditor + + U->>CB: 查看代码块 + CB->>CB: 初始化状态 + CB->>CT: 注册工具 + CB->>SV: 渲染特殊视图(如果适用) + CB->>SE: 渲染源码视图 + U->>CT: 点击工具按钮 + CT->>CB: 触发工具回调 + CB->>CB: 更新状态 + CB->>CT: 重新注册工具(如果需要) +``` + +## 特殊处理 + +### HTML 代码块 + +HTML 代码块会被特殊处理,使用 HtmlArtifactsCard 组件显示。 + +### Python 代码执行 + +支持执行 Python 代码并显示结果,使用 Pyodide 在浏览器中运行 Python 代码。 diff --git a/docs/technical/ImagePreview-en.md b/docs/technical/ImagePreview-en.md new file mode 100644 index 0000000000..383bf5c664 --- /dev/null +++ b/docs/technical/ImagePreview-en.md @@ -0,0 +1,195 @@ +# Image Preview Components + +## Overview + +Image Preview Components are a set of specialized components in Cherry Studio for rendering and displaying various diagram and image formats. They provide a consistent user experience across different preview types with shared functionality for loading states, error handling, and interactive controls. + +## Supported Formats + +- **Mermaid**: Interactive diagrams and flowcharts +- **PlantUML**: UML diagrams and system architecture +- **SVG**: Scalable vector graphics +- **Graphviz/DOT**: Graph visualization and network diagrams + +## Architecture + +```mermaid +graph TD + A[MermaidPreview] --> D[ImagePreviewLayout] + B[PlantUmlPreview] --> D + C[SvgPreview] --> D + E[GraphvizPreview] --> D + + D --> F[ImageToolbar] + D --> G[useDebouncedRender] + + F --> H[Pan Controls] + F --> I[Zoom Controls] + F --> J[Reset Function] + F --> K[Dialog Control] + + G --> L[Debounced Rendering] + G --> M[Error Handling] + G --> N[Loading State] + G --> O[Dependency Management] +``` + +## Core Components + +### ImagePreviewLayout + +A common layout wrapper that provides the foundation for all image preview components. + +**Features:** + +- **Loading State Management**: Shows loading spinner during rendering +- **Error Display**: Displays error messages when rendering fails +- **Toolbar Integration**: Conditionally renders ImageToolbar when enabled +- **Container Management**: Wraps preview content with consistent styling +- **Responsive Design**: Adapts to different container sizes + +**Props:** + +- `children`: The preview content to be displayed +- `loading`: Boolean indicating if content is being rendered +- `error`: Error message to display if rendering fails +- `enableToolbar`: Whether to show the interactive toolbar +- `imageRef`: Reference to the container element for image manipulation + +### ImageToolbar + +Interactive toolbar component providing image manipulation controls. + +**Features:** + +- **Pan Controls**: 4-directional pan buttons (up, down, left, right) +- **Zoom Controls**: Zoom in/out functionality with configurable increments +- **Reset Function**: Restore original pan and zoom state +- **Dialog Control**: Open preview in expanded dialog view +- **Accessible Design**: Full keyboard navigation and screen reader support + +**Layout:** + +- 3x3 grid layout positioned at bottom-right of preview +- Responsive button sizing +- Tooltip support for all controls + +### useDebouncedRender Hook + +A specialized React hook for managing preview rendering with performance optimizations. + +**Features:** + +- **Debounced Rendering**: Prevents excessive re-renders during rapid content changes (default 300ms delay) +- **Automatic Dependency Management**: Handles dependencies for render and condition functions +- **Error Handling**: Catches and manages rendering errors with detailed error messages +- **Loading State**: Tracks rendering progress with automatic state updates +- **Conditional Rendering**: Supports pre-render condition checks +- **Manual Controls**: Provides trigger, cancel, and state management functions + +**API:** + +```typescript +const { containerRef, error, isLoading, triggerRender, cancelRender, clearError, setLoading } = useDebouncedRender( + value, + renderFunction, + options +) +``` + +**Options:** + +- `debounceDelay`: Customize debounce timing +- `shouldRender`: Function for conditional rendering logic + +## Component Implementations + +### MermaidPreview + +Renders Mermaid diagrams with special handling for visibility detection. + +**Special Features:** + +- Syntax validation before rendering +- Visibility detection to handle collapsed containers +- SVG coordinate fixing for edge cases +- Integration with mermaid.js library + +### PlantUmlPreview + +Renders PlantUML diagrams using the online PlantUML server. + +**Special Features:** + +- Network error handling and retry logic +- Diagram encoding using deflate compression +- Support for light/dark themes +- Server status monitoring + +### SvgPreview + +Renders SVG content using Shadow DOM for isolation. + +**Special Features:** + +- Shadow DOM rendering for style isolation +- Direct SVG content injection +- Minimal processing overhead +- Cross-browser compatibility + +### GraphvizPreview + +Renders Graphviz/DOT diagrams using the viz.js library. + +**Special Features:** + +- Client-side rendering with viz.js +- Lazy loading of viz.js library +- SVG element generation +- Memory-efficient processing + +## Shared Functionality + +### Error Handling + +All preview components provide consistent error handling: + +- Network errors (connection failures) +- Syntax errors (invalid diagram code) +- Server errors (external service failures) +- Rendering errors (library failures) + +### Loading States + +Standardized loading indicators across all components: + +- Spinner animation during processing +- Progress feedback for long operations +- Smooth transitions between states + +### Interactive Controls + +Common interaction patterns: + +- Pan and zoom functionality +- Reset to original view +- Full-screen dialog mode +- Keyboard accessibility + +### Performance Optimizations + +- Debounced rendering to prevent excessive updates +- Lazy loading of heavy libraries +- Memory management for large diagrams +- Efficient re-rendering strategies + +## Integration with CodeBlockView + +Image Preview Components integrate seamlessly with CodeBlockView: + +- Automatic format detection based on language tags +- Consistent toolbar integration +- Shared state management +- Responsive layout adaptation + +For more information about the overall CodeBlockView architecture, see [CodeBlockView Documentation](./CodeBlockView-en.md). diff --git a/docs/technical/ImagePreview-zh.md b/docs/technical/ImagePreview-zh.md new file mode 100644 index 0000000000..8a68b84312 --- /dev/null +++ b/docs/technical/ImagePreview-zh.md @@ -0,0 +1,195 @@ +# 图像预览组件 + +## 概述 + +图像预览组件是 Cherry Studio 中用于渲染和显示各种图表和图像格式的专用组件集合。它们为不同预览类型提供一致的用户体验,具有共享的加载状态、错误处理和交互控制功能。 + +## 支持格式 + +- **Mermaid**: 交互式图表和流程图 +- **PlantUML**: UML 图表和系统架构 +- **SVG**: 可缩放矢量图形 +- **Graphviz/DOT**: 图形可视化和网络图表 + +## 架构 + +```mermaid +graph TD + A[MermaidPreview] --> D[ImagePreviewLayout] + B[PlantUmlPreview] --> D + C[SvgPreview] --> D + E[GraphvizPreview] --> D + + D --> F[ImageToolbar] + D --> G[useDebouncedRender] + + F --> H[平移控制] + F --> I[缩放控制] + F --> J[重置功能] + F --> K[对话框控制] + + G --> L[防抖渲染] + G --> M[错误处理] + G --> N[加载状态] + G --> O[依赖管理] +``` + +## 核心组件 + +### ImagePreviewLayout 图像预览布局 + +为所有图像预览组件提供基础的通用布局包装器。 + +**功能特性:** + +- **加载状态管理**: 在渲染期间显示加载动画 +- **错误显示**: 渲染失败时显示错误信息 +- **工具栏集成**: 启用时有条件地渲染 ImageToolbar +- **容器管理**: 使用一致的样式包装预览内容 +- **响应式设计**: 适应不同的容器尺寸 + +**属性:** + +- `children`: 要显示的预览内容 +- `loading`: 指示内容是否正在渲染的布尔值 +- `error`: 渲染失败时显示的错误信息 +- `enableToolbar`: 是否显示交互式工具栏 +- `imageRef`: 用于图像操作的容器元素引用 + +### ImageToolbar 图像工具栏 + +提供图像操作控制的交互式工具栏组件。 + +**功能特性:** + +- **平移控制**: 4方向平移按钮(上、下、左、右) +- **缩放控制**: 放大/缩小功能,支持可配置的增量 +- **重置功能**: 恢复原始平移和缩放状态 +- **对话框控制**: 在展开对话框中打开预览 +- **无障碍设计**: 完整的键盘导航和屏幕阅读器支持 + +**布局:** + +- 3x3 网格布局,位于预览右下角 +- 响应式按钮尺寸 +- 所有控件的工具提示支持 + +### useDebouncedRender Hook 防抖渲染钩子 + +用于管理预览渲染的专用 React Hook,具有性能优化功能。 + +**功能特性:** + +- **防抖渲染**: 防止内容快速变化时的过度重新渲染(默认 300ms 延迟) +- **自动依赖管理**: 处理渲染和条件函数的依赖项 +- **错误处理**: 捕获和管理渲染错误,提供详细的错误信息 +- **加载状态**: 跟踪渲染进度并自动更新状态 +- **条件渲染**: 支持预渲染条件检查 +- **手动控制**: 提供触发、取消和状态管理功能 + +**API:** + +```typescript +const { containerRef, error, isLoading, triggerRender, cancelRender, clearError, setLoading } = useDebouncedRender( + value, + renderFunction, + options +) +``` + +**选项:** + +- `debounceDelay`: 自定义防抖时间 +- `shouldRender`: 条件渲染逻辑函数 + +## 组件实现 + +### MermaidPreview Mermaid 预览 + +渲染 Mermaid 图表,具有可见性检测的特殊处理。 + +**特殊功能:** + +- 渲染前语法验证 +- 可见性检测以处理折叠的容器 +- 边缘情况的 SVG 坐标修复 +- 与 mermaid.js 库集成 + +### PlantUmlPreview PlantUML 预览 + +使用在线 PlantUML 服务器渲染 PlantUML 图表。 + +**特殊功能:** + +- 网络错误处理和重试逻辑 +- 使用 deflate 压缩的图表编码 +- 支持明/暗主题 +- 服务器状态监控 + +### SvgPreview SVG 预览 + +使用 Shadow DOM 隔离渲染 SVG 内容。 + +**特殊功能:** + +- Shadow DOM 渲染实现样式隔离 +- 直接 SVG 内容注入 +- 最小化处理开销 +- 跨浏览器兼容性 + +### GraphvizPreview Graphviz 预览 + +使用 viz.js 库渲染 Graphviz/DOT 图表。 + +**特殊功能:** + +- 使用 viz.js 进行客户端渲染 +- viz.js 库的懒加载 +- SVG 元素生成 +- 内存高效处理 + +## 共享功能 + +### 错误处理 + +所有预览组件提供一致的错误处理: + +- 网络错误(连接失败) +- 语法错误(无效的图表代码) +- 服务器错误(外部服务失败) +- 渲染错误(库失败) + +### 加载状态 + +所有组件的标准化加载指示器: + +- 处理期间的动画 +- 长时间操作的进度反馈 +- 状态间的平滑过渡 + +### 交互控制 + +通用交互模式: + +- 平移和缩放功能 +- 重置到原始视图 +- 全屏对话框模式 +- 键盘无障碍访问 + +### 性能优化 + +- 防抖渲染以防止过度更新 +- 重型库的懒加载 +- 大型图表的内存管理 +- 高效的重新渲染策略 + +## 与 CodeBlockView 的集成 + +图像预览组件与 CodeBlockView 无缝集成: + +- 基于语言标签的自动格式检测 +- 一致的工具栏集成 +- 共享状态管理 +- 响应式布局适应 + +有关整体 CodeBlockView 架构的更多信息,请参阅 [CodeBlockView 文档](./CodeBlockView-zh.md)。 diff --git a/electron-builder.yml b/electron-builder.yml index 180b38fc68..4fc42854a3 100644 --- a/electron-builder.yml +++ b/electron-builder.yml @@ -50,6 +50,7 @@ files: - '!node_modules/rollup-plugin-visualizer' - '!node_modules/js-tiktoken' - '!node_modules/@tavily/core/node_modules/js-tiktoken' + - '!node_modules/pdf-parse/lib/pdf.js/{v1.9.426,v1.10.88,v2.0.550}' - '!node_modules/mammoth/{mammoth.browser.js,mammoth.browser.min.js}' - '!node_modules/selection-hook/prebuilds/**/*' # we rebuild .node, don't use prebuilds - '!node_modules/selection-hook/node_modules' # we don't need what in the node_modules dir @@ -117,18 +118,4 @@ afterSign: scripts/notarize.js artifactBuildCompleted: scripts/artifact-build-completed.js releaseInfo: releaseNotes: | - 新增服务商:AWS Bedrock - 富文本编辑器支持:提升提示词编辑体验,支持更丰富的格式调整 - 拖拽输入优化:支持从其他软件直接拖拽文本至输入框,简化内容输入流程 - 参数调节增强:新增 Top-P 和 Temperature 开关设置,提供更灵活的模型调控选项 - 翻译任务后台执行:翻译任务支持后台运行,提升多任务处理效率 - 新模型支持:新增 Qwen-MT、Qwen3235BA22Bthinking 和 sonar-deep-research 模型,扩展推理能力 - 推理稳定性提升:修复部分模型思考内容无法输出的问题,确保推理结果完整 - Mistral 模型修复:解决 Mistral 模型无法使用的问题,恢复其推理功能 - 备份目录优化:支持相对路径输入,提升备份配置灵活性 - 数据导出调整:新增引用内容导出开关,提供更精细的导出控制 - 文本流完整性:修复文本流末尾文字丢失问题,确保输出内容完整 - 内存泄漏修复:优化代码逻辑,解决内存泄漏问题,提升运行稳定性 - 嵌入模型简化:降低嵌入模型配置复杂度,提高易用性 - MCP Tool 长时间运行:增强 MCP 工具的稳定性,支持长时间任务执行 - 设置页面优化:优化设置页面布局,提升用户体验 + 稳定性改进和错误修复 diff --git a/package.json b/package.json index 8e3df2c693..4b70697842 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "CherryStudio", - "version": "1.5.4", + "version": "1.5.5", "private": true, "description": "A powerful AI assistant for producer.", "main": "./out/main/index.js", @@ -219,7 +219,7 @@ "motion": "^12.10.5", "notion-helper": "^1.3.22", "npx-scope-finder": "^1.2.0", - "openai": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch", + "openai": "patch:openai@npm%3A5.12.0#~/.yarn/patches/openai-npm-5.12.0-a06a6369b2.patch", "p-queue": "^8.1.0", "pdf-lib": "^1.17.1", "playwright": "^1.52.0", @@ -273,20 +273,22 @@ "zod": "^3.25.74" }, "resolutions": { + "pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch", "@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch", "@langchain/openai@npm:>=0.1.0 <0.4.0": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch", "libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch", - "openai@npm:^4.77.0": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch", + "openai@npm:^4.77.0": "patch:openai@npm%3A5.12.0#~/.yarn/patches/openai-npm-5.12.0-a06a6369b2.patch", "pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch", "app-builder-lib@npm:26.0.13": "patch:app-builder-lib@npm%3A26.0.13#~/.yarn/patches/app-builder-lib-npm-26.0.13-a064c9e1d0.patch", - "openai@npm:^4.87.3": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch", + "openai@npm:^4.87.3": "patch:openai@npm%3A5.12.0#~/.yarn/patches/openai-npm-5.12.0-a06a6369b2.patch", "app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch", "@langchain/core@npm:^0.3.26": "patch:@langchain/core@npm%3A0.3.44#~/.yarn/patches/@langchain-core-npm-0.3.44-41d5c3cb0a.patch", "node-abi": "4.12.0", "undici": "6.21.2", "vite": "npm:rolldown-vite@latest", "atomically@npm:^1.7.0": "patch:atomically@npm%3A1.7.0#~/.yarn/patches/atomically-npm-1.7.0-e742e5293b.patch", - "file-stream-rotator@npm:^0.6.1": "patch:file-stream-rotator@npm%3A0.6.1#~/.yarn/patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch" + "file-stream-rotator@npm:^0.6.1": "patch:file-stream-rotator@npm%3A0.6.1#~/.yarn/patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch", + "windows-system-proxy@npm:^1.0.0": "patch:windows-system-proxy@npm%3A1.0.0#~/.yarn/patches/windows-system-proxy-npm-1.0.0-ff2a828eec.patch" }, "packageManager": "yarn@4.9.1", "lint-staged": { diff --git a/src/main/ipc.ts b/src/main/ipc.ts index e4db5ec210..e337d0d247 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -94,17 +94,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { let proxyConfig: ProxyConfig if (proxy === 'system') { + // system proxy will use the system filter by themselves proxyConfig = { mode: 'system' } } else if (proxy) { - proxyConfig = { mode: 'fixed_servers', proxyRules: proxy } + proxyConfig = { mode: 'fixed_servers', proxyRules: proxy, proxyBypassRules: bypassRules } } else { proxyConfig = { mode: 'direct' } } - if (bypassRules) { - proxyConfig.proxyBypassRules = bypassRules - } - await proxyManager.configureProxy(proxyConfig) }) diff --git a/src/main/services/ProxyManager.ts b/src/main/services/ProxyManager.ts index 48b6da6fa7..620a6a5fef 100644 --- a/src/main/services/ProxyManager.ts +++ b/src/main/services/ProxyManager.ts @@ -1,5 +1,4 @@ import { loggerService } from '@logger' -import { defaultByPassRules } from '@shared/config/constant' import axios from 'axios' import { app, ProxyConfig, session } from 'electron' import { socksDispatcher } from 'fetch-socks' @@ -10,9 +9,13 @@ import { ProxyAgent } from 'proxy-agent' import { Dispatcher, EnvHttpProxyAgent, getGlobalDispatcher, setGlobalDispatcher } from 'undici' const logger = loggerService.withContext('ProxyManager') -let byPassRules = defaultByPassRules.split(',') +let byPassRules: string[] = [] const isByPass = (hostname: string) => { + if (byPassRules.length === 0) { + return false + } + return byPassRules.includes(hostname) } @@ -98,7 +101,7 @@ export class ProxyManager { await this.configureProxy({ mode: 'system', proxyRules: currentProxy?.proxyUrl.toLowerCase(), - proxyBypassRules: this.config.proxyBypassRules + proxyBypassRules: undefined }) }, 1000 * 60) } @@ -131,7 +134,7 @@ export class ProxyManager { this.monitorSystemProxy() } - byPassRules = config.proxyBypassRules?.split(',') || defaultByPassRules.split(',') + byPassRules = config.proxyBypassRules?.split(',') || [] this.setGlobalProxy(this.config) } catch (error) { logger.error('Failed to config proxy:', error as Error) diff --git a/src/main/services/WindowService.ts b/src/main/services/WindowService.ts index 7e7c1466e2..8b410323b1 100644 --- a/src/main/services/WindowService.ts +++ b/src/main/services/WindowService.ts @@ -252,7 +252,9 @@ export class WindowService { 'https://cloud.siliconflow.cn/expensebill', 'https://aihubmix.com/token', 'https://aihubmix.com/topup', - 'https://aihubmix.com/statistics' + 'https://aihubmix.com/statistics', + 'https://dash.302.ai/sso/login', + 'https://dash.302.ai/charge' ] if (oauthProviderUrls.some((link) => url.startsWith(link))) { diff --git a/src/renderer/src/aiCore/clients/BaseApiClient.ts b/src/renderer/src/aiCore/clients/BaseApiClient.ts index 9bb0f92789..ff91259143 100644 --- a/src/renderer/src/aiCore/clients/BaseApiClient.ts +++ b/src/renderer/src/aiCore/clients/BaseApiClient.ts @@ -3,25 +3,28 @@ import { isFunctionCallingModel, isNotSupportTemperatureAndTopP, isOpenAIModel, - isSupportedFlexServiceTier + isSupportFlexServiceTierModel } from '@renderer/config/models' import { REFERENCE_PROMPT } from '@renderer/config/prompts' +import { isSupportServiceTierProviders } from '@renderer/config/providers' import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio' -import { getStoreSetting } from '@renderer/hooks/useSettings' import { getAssistantSettings } from '@renderer/services/AssistantService' -import { SettingsState } from '@renderer/store/settings' import { Assistant, FileTypes, GenerateImageParams, + GroqServiceTiers, + isGroqServiceTier, + isOpenAIServiceTier, KnowledgeReference, MCPCallToolResponse, MCPTool, MCPToolResponse, MemoryItem, Model, - OpenAIServiceTier, + OpenAIServiceTiers, Provider, + SystemProviderIds, ToolCallResponse, WebSearchProviderResponse, WebSearchResponse @@ -201,29 +204,37 @@ export abstract class BaseApiClient< return assistantSettings?.enableTopP ? assistantSettings?.topP : undefined } + // NOTE: 这个也许可以迁移到OpenAIBaseClient protected getServiceTier(model: Model) { - if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') { + const serviceTierSetting = this.provider.serviceTier + + if (!isSupportServiceTierProviders(this.provider) || !isOpenAIModel(model) || !serviceTierSetting) { return undefined } - const openAI = getStoreSetting('openAI') as SettingsState['openAI'] - let serviceTier = 'auto' as OpenAIServiceTier - - if (openAI && openAI?.serviceTier === 'flex') { - if (isSupportedFlexServiceTier(model)) { - serviceTier = 'flex' - } else { - serviceTier = 'auto' + // 处理不同供应商需要 fallback 到默认值的情况 + if (this.provider.id === SystemProviderIds.groq) { + if ( + !isGroqServiceTier(serviceTierSetting) || + (serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model)) + ) { + return undefined } } else { - serviceTier = openAI.serviceTier + // 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同 + if ( + !isOpenAIServiceTier(serviceTierSetting) || + (serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model)) + ) { + return undefined + } } - return serviceTier + return serviceTierSetting } protected getTimeout(model: Model) { - if (isSupportedFlexServiceTier(model)) { + if (isSupportFlexServiceTierModel(model)) { return 15 * 1000 * 60 } return defaultTimeout diff --git a/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts b/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts index 6a73bf47ce..f286b40d59 100644 --- a/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts +++ b/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts @@ -11,7 +11,6 @@ import { import { ContentBlock, ContentBlockParam, - MessageCreateParams, MessageCreateParamsBase, RedactedThinkingBlockParam, ServerToolUseBlockParam, @@ -70,6 +69,7 @@ import { mcpToolsToAnthropicTools } from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' +import { t } from 'i18next' import { BaseApiClient } from '../BaseApiClient' import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types' @@ -494,22 +494,14 @@ export class AnthropicAPIClient extends BaseApiClient< system: systemMessage ? [systemMessage] : undefined, thinking: this.getBudgetToken(assistant, model), tools: tools.length > 0 ? tools : undefined, + stream: streamOutput, // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 + // 注意:用户自定义参数总是应该覆盖其他参数 ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) } - const finalParams: MessageCreateParams = streamOutput - ? { - ...commonParams, - stream: true - } - : { - ...commonParams, - stream: false - } - const timeout = this.getTimeout(model) - return { payload: finalParams, messages: sdkMessages, metadata: { timeout } } + return { payload: commonParams, messages: sdkMessages, metadata: { timeout } } } } } @@ -520,6 +512,14 @@ export class AnthropicAPIClient extends BaseApiClient< const toolCalls: Record = {} return { async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController) { + if (typeof rawChunk === 'string') { + try { + rawChunk = JSON.parse(rawChunk) + } catch (error) { + logger.error('invalid chunk', { rawChunk, error }) + throw new Error(t('error.chat.chunk.non_json')) + } + } switch (rawChunk.type) { case 'message': { let i = 0 diff --git a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts b/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts index 6117dffa18..de9c7c2c17 100644 --- a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts +++ b/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts @@ -42,6 +42,7 @@ import { mcpToolsToAwsBedrockTools } from '@renderer/utils/mcp-tools' import { findImageBlocks } from '@renderer/utils/messageUtils/find' +import { t } from 'i18next' import { BaseApiClient } from '../BaseApiClient' import { RequestTransformer, ResponseChunkTransformer } from '../types' @@ -417,7 +418,10 @@ export class AwsBedrockAPIClient extends BaseApiClient< temperature: this.getTemperature(assistant, model), topP: this.getTopP(assistant, model), stream: streamOutput !== false, - tools: tools.length > 0 ? tools : undefined + tools: tools.length > 0 ? tools : undefined, + // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 + // 注意:用户自定义参数总是应该覆盖其他参数 + ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) } const timeout = this.getTimeout(model) @@ -436,6 +440,15 @@ export class AwsBedrockAPIClient extends BaseApiClient< async transform(rawChunk: AwsBedrockSdkRawChunk, controller: TransformStreamDefaultController) { logger.silly('Processing AWS Bedrock chunk:', rawChunk) + if (typeof rawChunk === 'string') { + try { + rawChunk = JSON.parse(rawChunk) + } catch (error) { + logger.error('invalid chunk', { rawChunk, error }) + throw new Error(t('error.chat.chunk.non_json')) + } + } + // 处理消息开始事件 if (rawChunk.messageStart) { controller.enqueue({ diff --git a/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts b/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts index edc8a1190a..bdd7689d6f 100644 --- a/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts +++ b/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts @@ -60,6 +60,7 @@ import { } from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { defaultTimeout, MB } from '@shared/config/constant' +import { t } from 'i18next' import { BaseApiClient } from '../BaseApiClient' import { RequestTransformer, ResponseChunkTransformer } from '../types' @@ -531,6 +532,7 @@ export class GeminiAPIClient extends BaseApiClient< ...(enableGenerateImage ? this.getGenerateImageParameter() : {}), ...this.getBudgetToken(assistant, model), // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 + // 注意:用户自定义参数总是应该覆盖其他参数 ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) } @@ -557,6 +559,14 @@ export class GeminiAPIClient extends BaseApiClient< return () => ({ async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController) { logger.silly('chunk', chunk) + if (typeof chunk === 'string') { + try { + chunk = JSON.parse(chunk) + } catch (error) { + logger.error('invalid chunk', { chunk, error }) + throw new Error(t('error.chat.chunk.non_json')) + } + } if (chunk.candidates && chunk.candidates.length > 0) { for (const candidate of chunk.candidates) { if (candidate.content) { diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts index 1faff88983..617637a7e1 100644 --- a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts +++ b/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts @@ -4,10 +4,11 @@ import { findTokenLimit, GEMINI_FLASH_MODEL_REGEX, getOpenAIWebSearchParams, + getThinkModelType, isDoubaoThinkingAutoModel, isGrokReasoningModel, isNotSupportSystemMessageModel, - isQwen3235BA22BThinkingModel, + isQwenAlwaysThinkModel, isQwenMTModel, isQwenReasoningModel, isReasoningModel, @@ -20,12 +21,13 @@ import { isSupportedThinkingTokenModel, isSupportedThinkingTokenQwenModel, isSupportedThinkingTokenZhipuModel, - isVisionModel + isVisionModel, + MODEL_SUPPORTED_REASONING_EFFORT } from '@renderer/config/models' import { isSupportArrayContentProvider, isSupportDeveloperRoleProvider, - isSupportQwen3EnableThinkingProvider, + isSupportEnableThinkingProvider, isSupportStreamOptionsProvider } from '@renderer/config/providers' import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService' @@ -39,6 +41,7 @@ import { MCPTool, MCPToolResponse, Model, + OpenAIServiceTier, Provider, ToolCallResponse, TranslateAssistant, @@ -63,6 +66,7 @@ import { openAIToolsToMcpTool } from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' +import { t } from 'i18next' import OpenAI, { AzureOpenAI } from 'openai' import { ChatCompletionContentPart, ChatCompletionContentPartRefusal, ChatCompletionTool } from 'openai/resources' @@ -146,10 +150,11 @@ export class OpenAIAPIClient extends OpenAIBaseClient< } return { reasoning: { enabled: false, exclude: true } } } - if (isSupportedThinkingTokenQwenModel(model) || isSupportedThinkingTokenHunyuanModel(model)) { - if (isQwen3235BA22BThinkingModel(model)) { - return {} - } + + if ( + isSupportEnableThinkingProvider(this.provider) && + (isSupportedThinkingTokenQwenModel(model) || isSupportedThinkingTokenHunyuanModel(model)) + ) { return { enable_thinking: false } } @@ -178,6 +183,8 @@ export class OpenAIAPIClient extends OpenAIBaseClient< return {} } + + // reasoningEffort有效的情况 const effortRatio = EFFORT_RATIO[reasoningEffort] const budgetTokens = Math.floor( (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min! @@ -195,9 +202,10 @@ export class OpenAIAPIClient extends OpenAIBaseClient< } // Qwen models - if (isSupportedThinkingTokenQwenModel(model)) { + if (isQwenReasoningModel(model)) { const thinkConfig = { - enable_thinking: isQwen3235BA22BThinkingModel(model) ? undefined : true, + enable_thinking: + isQwenAlwaysThinkModel(model) || !isSupportEnableThinkingProvider(this.provider) ? undefined : true, thinking_budget: budgetTokens } if (this.provider.id === 'dashscope') { @@ -210,7 +218,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient< } // Hunyuan models - if (isSupportedThinkingTokenHunyuanModel(model)) { + if (isSupportedThinkingTokenHunyuanModel(model) && isSupportEnableThinkingProvider(this.provider)) { return { enable_thinking: true } @@ -218,8 +226,18 @@ export class OpenAIAPIClient extends OpenAIBaseClient< // Grok models/Perplexity models/OpenAI models if (isSupportedReasoningEffortModel(model)) { - return { - reasoning_effort: reasoningEffort + // 检查模型是否支持所选选项 + const modelType = getThinkModelType(model) + const supportedOptions = MODEL_SUPPORTED_REASONING_EFFORT[modelType] + if (supportedOptions.includes(reasoningEffort)) { + return { + reasoning_effort: reasoningEffort + } + } else { + // 如果不支持,fallback到第一个支持的值 + return { + reasoning_effort: supportedOptions[0] + } } } @@ -530,7 +548,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient< if ( lastUserMsg && isSupportedThinkingTokenQwenModel(model) && - !isSupportQwen3EnableThinkingProvider(this.provider) + !isSupportEnableThinkingProvider(this.provider) ) { const postsuffix = '/no_think' const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true @@ -550,7 +568,11 @@ export class OpenAIAPIClient extends OpenAIBaseClient< reqMessages = processReqMessages(model, reqMessages) // 5. 创建通用参数 - const commonParams = { + // Create the appropriate parameters object based on whether streaming is enabled + // Note: Some providers like Mistral don't support stream_options + const shouldIncludeStreamOptions = streamOutput && isSupportStreamOptionsProvider(this.provider) + + const commonParams: OpenAISdkParams = { model: model.id, messages: isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 @@ -560,35 +582,24 @@ export class OpenAIAPIClient extends OpenAIBaseClient< top_p: this.getTopP(assistant, model), max_tokens: maxTokens, tools: tools.length > 0 ? tools : undefined, - service_tier: this.getServiceTier(model), + stream: streamOutput, + ...(shouldIncludeStreamOptions ? { stream_options: { include_usage: true } } : {}), + // groq 有不同的 service tier 配置,不符合 openai 接口类型 + service_tier: this.getServiceTier(model) as OpenAIServiceTier, ...this.getProviderSpecificParameters(assistant, model), ...this.getReasoningEffort(assistant, model), ...getOpenAIWebSearchParams(model, enableWebSearch), - // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 - ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}), // OpenRouter usage tracking ...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}), - ...(isQwenMTModel(model) ? extra_body : {}) + ...(isQwenMTModel(model) ? extra_body : {}), + // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 + // 注意:用户自定义参数总是应该覆盖其他参数 + ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) } - // Create the appropriate parameters object based on whether streaming is enabled - // Note: Some providers like Mistral don't support stream_options - const shouldIncludeStreamOptions = streamOutput && isSupportStreamOptionsProvider(this.provider) - - const sdkParams: OpenAISdkParams = streamOutput - ? { - ...commonParams, - stream: true, - ...(shouldIncludeStreamOptions ? { stream_options: { include_usage: true } } : {}) - } - : { - ...commonParams, - stream: false - } - const timeout = this.getTimeout(model) - return { payload: sdkParams, messages: reqMessages, metadata: { timeout } } + return { payload: commonParams, messages: reqMessages, metadata: { timeout } } } } } @@ -758,6 +769,15 @@ export class OpenAIAPIClient extends OpenAIBaseClient< return } + if (typeof chunk === 'string') { + try { + chunk = JSON.parse(chunk) + } catch (error) { + logger.error('invalid chunk', { chunk, error }) + throw new Error(t('error.chat.chunk.non_json')) + } + } + // 处理chunk if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) { for (const choice of chunk.choices) { diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts index 9e4042fa3c..f2ee0f58f4 100644 --- a/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts +++ b/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts @@ -99,8 +99,12 @@ export abstract class OpenAIBaseClient< override async listModels(): Promise { try { const sdk = await this.getSdkInstance() - const response = await sdk.models.list() if (this.provider.id === 'github') { + // GitHub Models 其 models 和 chat completions 两个接口的 baseUrl 不一样 + const baseUrl = 'https://models.github.ai/catalog/' + const newSdk = sdk.withOptions({ baseURL: baseUrl }) + const response = await newSdk.models.list() + // @ts-ignore key is not typed return response?.body .map((model) => ({ @@ -111,6 +115,7 @@ export abstract class OpenAIBaseClient< })) .filter(isSupportedModel) } + const response = await sdk.models.list() if (this.provider.id === 'together') { // @ts-ignore key is not typed return response?.body.map((model) => ({ diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts index 970dd1399f..f740c5bdcf 100644 --- a/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts +++ b/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts @@ -1,3 +1,4 @@ +import { loggerService } from '@logger' import { GenericChunk } from '@renderer/aiCore/middleware/schemas' import { CompletionsContext } from '@renderer/aiCore/middleware/types' import { @@ -6,7 +7,7 @@ import { isSupportedReasoningEffortOpenAIModel, isVisionModel } from '@renderer/config/models' -import { isSupportDeveloperRoleProvider } from '@renderer/config/providers' +import { isSupportDeveloperRoleProvider, isSupportStreamOptionsProvider } from '@renderer/config/providers' import { estimateTextTokens } from '@renderer/services/TokenService' import { FileMetadata, @@ -15,6 +16,7 @@ import { MCPTool, MCPToolResponse, Model, + OpenAIServiceTier, Provider, ToolCallResponse, WebSearchSource @@ -38,6 +40,7 @@ import { } from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' import { MB } from '@shared/config/constant' +import { t } from 'i18next' import { isEmpty } from 'lodash' import OpenAI, { AzureOpenAI } from 'openai' import { ResponseInput } from 'openai/resources/responses/responses' @@ -46,6 +49,7 @@ import { RequestTransformer, ResponseChunkTransformer } from '../types' import { OpenAIAPIClient } from './OpenAIApiClient' import { OpenAIBaseClient } from './OpenAIBaseClient' +const logger = loggerService.withContext('OpenAIResponseAPIClient') export class OpenAIResponseAPIClient extends OpenAIBaseClient< OpenAI, OpenAIResponseSdkParams, @@ -338,8 +342,8 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient< } public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] { - if (typeof sdkPayload.input === 'string') { - return [{ role: 'user', content: sdkPayload.input }] + if (!sdkPayload.input || typeof sdkPayload.input === 'string') { + return [{ role: 'user', content: sdkPayload.input ?? '' }] } return sdkPayload.input } @@ -437,7 +441,10 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient< } tools = tools.concat(extraTools) - const commonParams = { + + const shouldIncludeStreamOptions = streamOutput && isSupportStreamOptionsProvider(this.provider) + + const commonParams: OpenAIResponseSdkParams = { model: model.id, input: isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 @@ -447,23 +454,17 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient< top_p: this.getTopP(assistant, model), max_output_tokens: maxTokens, stream: streamOutput, + ...(shouldIncludeStreamOptions ? { stream_options: { include_usage: true } } : {}), tools: !isEmpty(tools) ? tools : undefined, - service_tier: this.getServiceTier(model), + // groq 有不同的 service tier 配置,不符合 openai 接口类型 + service_tier: this.getServiceTier(model) as OpenAIServiceTier, ...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning), // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 + // 注意:用户自定义参数总是应该覆盖其他参数 ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) } - const sdkParams: OpenAIResponseSdkParams = streamOutput - ? { - ...commonParams, - stream: true - } - : { - ...commonParams, - stream: false - } const timeout = this.getTimeout(model) - return { payload: sdkParams, messages: reqMessages, metadata: { timeout } } + return { payload: commonParams, messages: reqMessages, metadata: { timeout } } } } } @@ -477,6 +478,14 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient< let isFirstTextChunk = true return () => ({ async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController) { + if (typeof chunk === 'string') { + try { + chunk = JSON.parse(chunk) + } catch (error) { + logger.error('invalid chunk', { chunk, error }) + throw new Error(t('error.chat.chunk.non_json')) + } + } // 处理chunk if ('output' in chunk) { if (ctx._internal?.toolProcessingState) { diff --git a/src/renderer/src/aiCore/index.ts b/src/renderer/src/aiCore/index.ts index 47fb4cd707..16c8949cfc 100644 --- a/src/renderer/src/aiCore/index.ts +++ b/src/renderer/src/aiCore/index.ts @@ -123,7 +123,10 @@ export default class AiProvider { } const middlewares = builder.build() - logger.silly('middlewares', middlewares) + logger.silly( + 'middlewares', + middlewares.map((m) => m.name) + ) // 3. Create the wrapped SDK method with middlewares const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares) diff --git a/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts b/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts index e36e45807a..57498b97fb 100644 --- a/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts +++ b/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts @@ -85,9 +85,15 @@ const FinalChunkConsumerMiddleware: CompletionsMiddleware = logger.warn(`Received undefined chunk before stream was done.`) } } - } catch (error) { + } catch (error: any) { logger.error(`Error consuming stream:`, error as Error) - throw error + // FIXME: 临时解决方案。该中间件的异常无法被 ErrorHandlerMiddleware捕获。 + if (params.onError) { + params.onError(error) + } + if (params.shouldThrow) { + throw error + } } finally { if (params.onChunk && !isRecursiveCall) { params.onChunk({ diff --git a/src/renderer/src/components/ActionTools/__tests__/useImageTools.test.tsx b/src/renderer/src/components/ActionTools/__tests__/useImageTools.test.tsx new file mode 100644 index 0000000000..6083a508df --- /dev/null +++ b/src/renderer/src/components/ActionTools/__tests__/useImageTools.test.tsx @@ -0,0 +1,555 @@ +import { useImageTools } from '@renderer/components/ActionTools' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies +const mocks = vi.hoisted(() => ({ + i18n: { + t: (key: string) => key + }, + svgToPngBlob: vi.fn(), + svgToSvgBlob: vi.fn(), + download: vi.fn(), + ImagePreviewService: { + show: vi.fn() + } +})) + +vi.mock('@renderer/utils/image', () => ({ + svgToPngBlob: mocks.svgToPngBlob, + svgToSvgBlob: mocks.svgToSvgBlob +})) + +vi.mock('@renderer/utils/download', () => ({ + download: mocks.download +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: mocks.i18n.t + }) +})) + +vi.mock('@renderer/services/ImagePreviewService', () => ({ + ImagePreviewService: mocks.ImagePreviewService +})) + +vi.mock('@renderer/context/ThemeProvider', () => ({ + useTheme: () => ({ + theme: 'light' + }) +})) + +// Mock navigator.clipboard +const mockWrite = vi.fn() + +// Mock window.message +const mockMessage = { + success: vi.fn(), + error: vi.fn() +} + +// Mock ClipboardItem +class MockClipboardItem { + constructor(items: any) { + return items + } +} + +// Mock URL +const mockCreateObjectURL = vi.fn(() => 'blob:test-url') +const mockRevokeObjectURL = vi.fn() + +describe('useImageTools', () => { + beforeEach(() => { + // Setup global mocks + Object.defineProperty(global.navigator, 'clipboard', { + value: { write: mockWrite }, + writable: true + }) + + Object.defineProperty(global.window, 'message', { + value: mockMessage, + writable: true + }) + + // Mock ClipboardItem + global.ClipboardItem = MockClipboardItem as any + + // Mock URL + global.URL = { + createObjectURL: mockCreateObjectURL, + revokeObjectURL: mockRevokeObjectURL + } as any + + // Mock DOMMatrix + global.DOMMatrix = class DOMMatrix { + m41 = 0 + m42 = 0 + a = 1 + d = 1 + + constructor(transform?: string) { + if (transform) { + // 简单解析 translate(x, y) + const translateMatch = transform.match(/translate\(([^,]+),\s*([^)]+)\)/) + if (translateMatch) { + this.m41 = parseFloat(translateMatch[1]) + this.m42 = parseFloat(translateMatch[2]) + } + + // 解析 scale(s) + const scaleMatch = transform.match(/scale\(([^)]+)\)/) + if (scaleMatch) { + const scaleValue = parseFloat(scaleMatch[1]) + this.a = scaleValue + this.d = scaleValue + } + } + } + + static fromMatrix() { + return new DOMMatrix() + } + } as any + + vi.clearAllMocks() + }) + + // 创建模拟的 DOM 环境 + const createMockContainer = () => { + const mockContainer = { + addEventListener: vi.fn(), + removeEventListener: vi.fn(), + contains: vi.fn().mockReturnValue(true), + style: { + cursor: '' + }, + querySelector: vi.fn(), + shadowRoot: null + } as unknown as HTMLDivElement + + return mockContainer + } + + const createMockSvgElement = () => { + const mockSvg = { + style: { + transform: '', + transformOrigin: '' + }, + cloneNode: vi.fn().mockReturnThis() + } as unknown as SVGElement + + return mockSvg + } + + describe('initialization', () => { + it('should initialize with default scale', () => { + const mockContainer = createMockContainer() + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + const transform = result.current.getCurrentTransform() + expect(transform.scale).toBe(1) + }) + }) + + describe('pan function', () => { + it('should pan with relative and absolute coordinates', () => { + const mockContainer = createMockContainer() + const mockSvg = createMockSvgElement() + mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg) + + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + // 相对坐标平移 + act(() => { + result.current.pan(10, 20) + }) + expect(mockSvg.style.transform).toContain('translate(10px, 20px)') + + // 绝对坐标平移 + act(() => { + result.current.pan(50, 60, true) + }) + expect(mockSvg.style.transform).toContain('translate(50px, 60px)') + }) + }) + + describe('zoom function', () => { + it('should zoom in/out and set absolute zoom level', () => { + const mockContainer = createMockContainer() + const mockSvg = createMockSvgElement() + mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg) + + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + // 放大 + act(() => { + result.current.zoom(0.5) + }) + expect(result.current.getCurrentTransform().scale).toBe(1.5) + expect(mockSvg.style.transform).toContain('scale(1.5)') + + // 缩小 + act(() => { + result.current.zoom(-0.3) + }) + expect(result.current.getCurrentTransform().scale).toBe(1.2) + expect(mockSvg.style.transform).toContain('scale(1.2)') + + // 设置绝对缩放级别 + act(() => { + result.current.zoom(2.5, true) + }) + expect(result.current.getCurrentTransform().scale).toBe(2.5) + }) + + it('should constrain zoom between 0.1 and 3', () => { + const mockContainer = createMockContainer() + const mockSvg = createMockSvgElement() + mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg) + + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + // 尝试过度缩小 + act(() => { + result.current.zoom(-10) + }) + expect(result.current.getCurrentTransform().scale).toBe(0.1) + + // 尝试过度放大 + act(() => { + result.current.zoom(10) + }) + expect(result.current.getCurrentTransform().scale).toBe(3) + }) + }) + + describe('copy and download functions', () => { + it('should copy image to clipboard successfully', async () => { + const mockContainer = createMockContainer() + const mockSvg = createMockSvgElement() + mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg) + + // Mock svgToPngBlob to return a blob + const mockBlob = new Blob(['test'], { type: 'image/png' }) + mocks.svgToPngBlob.mockResolvedValue(mockBlob) + + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + await act(async () => { + await result.current.copy() + }) + + expect(mocks.svgToPngBlob).toHaveBeenCalledWith(mockSvg) + expect(mockWrite).toHaveBeenCalled() + expect(mockMessage.success).toHaveBeenCalledWith('message.copy.success') + }) + + it('should download image as PNG and SVG', async () => { + const mockContainer = createMockContainer() + const mockSvg = createMockSvgElement() + mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg) + + // Mock svgToPngBlob to return a blob + const pngBlob = new Blob(['test'], { type: 'image/png' }) + mocks.svgToPngBlob.mockResolvedValue(pngBlob) + + // Mock svgToSvgBlob to return a blob + const svgBlob = new Blob([''], { type: 'image/svg+xml' }) + mocks.svgToSvgBlob.mockReturnValue(svgBlob) + + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + // 下载 PNG + await act(async () => { + await result.current.download('png') + }) + expect(mocks.svgToPngBlob).toHaveBeenCalledWith(mockSvg) + + // 下载 SVG + await act(async () => { + await result.current.download('svg') + }) + expect(mocks.svgToSvgBlob).toHaveBeenCalledWith(mockSvg) + + // 验证通用的下载流程 + expect(mockCreateObjectURL).toHaveBeenCalledTimes(2) + expect(mocks.download).toHaveBeenCalledTimes(2) + expect(mockRevokeObjectURL).toHaveBeenCalledTimes(2) + }) + + it('should handle copy/download failures and missing elements', async () => { + const mockContainer = createMockContainer() + const mockSvg = createMockSvgElement() + + // 测试无元素情况 + mockContainer.querySelector = vi.fn().mockReturnValue(null) + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + // 复制无元素 + await act(async () => { + await result.current.copy() + }) + expect(mocks.svgToPngBlob).not.toHaveBeenCalled() + + // 下载无元素 + await act(async () => { + await result.current.download('png') + }) + expect(mocks.svgToPngBlob).not.toHaveBeenCalled() + + // 测试失败情况 + mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg) + mocks.svgToPngBlob.mockRejectedValue(new Error('Conversion failed')) + + // 复制失败 + await act(async () => { + await result.current.copy() + }) + expect(mockMessage.error).toHaveBeenCalledWith('message.copy.failed') + + // 下载失败 + await act(async () => { + await result.current.download('png') + }) + expect(mockMessage.error).toHaveBeenCalledWith('message.download.failed') + }) + }) + + describe('dialog function', () => { + it('should preview image successfully', async () => { + const mockContainer = createMockContainer() + const mockSvg = createMockSvgElement() + mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg) + + mocks.ImagePreviewService.show.mockResolvedValue(undefined) + + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + await act(async () => { + await result.current.dialog() + }) + + expect(mocks.ImagePreviewService.show).toHaveBeenCalledWith(mockSvg, { format: 'svg' }) + }) + + it('should handle preview failure', async () => { + const mockContainer = createMockContainer() + const mockSvg = createMockSvgElement() + mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg) + + mocks.ImagePreviewService.show.mockRejectedValue(new Error('Preview failed')) + + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + await act(async () => { + await result.current.dialog() + }) + + expect(mockMessage.error).toHaveBeenCalledWith('message.dialog.failed') + }) + + it('should do nothing when no element is found', async () => { + const mockContainer = createMockContainer() + mockContainer.querySelector = vi.fn().mockReturnValue(null) + + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + await act(async () => { + await result.current.dialog() + }) + + expect(mocks.ImagePreviewService.show).not.toHaveBeenCalled() + }) + }) + + describe('event listener management', () => { + it('should attach/remove event listeners based on options', () => { + const mockContainer = createMockContainer() + + // 启用拖拽和滚轮缩放 + renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg', + enableDrag: true, + enableWheelZoom: true + } + ) + ) + + expect(mockContainer.addEventListener).toHaveBeenCalledWith('mousedown', expect.any(Function)) + expect(mockContainer.addEventListener).toHaveBeenCalledWith('wheel', expect.any(Function), { passive: true }) + + // 重置并测试禁用情况 + vi.clearAllMocks() + + renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg', + enableDrag: false, + enableWheelZoom: false + } + ) + ) + + expect(mockContainer.addEventListener).not.toHaveBeenCalledWith('mousedown', expect.any(Function)) + expect(mockContainer.addEventListener).not.toHaveBeenCalledWith('wheel', expect.any(Function)) + }) + }) + + describe('getCurrentTransform function', () => { + it('should return current scale and position', () => { + const mockContainer = createMockContainer() + const mockSvg = createMockSvgElement() + mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg) + + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + // 初始状态 + const initialTransform = result.current.getCurrentTransform() + expect(initialTransform).toEqual({ scale: 1, x: 0, y: 0 }) + + // 缩放后状态 + act(() => { + result.current.zoom(0.5) + }) + const zoomedTransform = result.current.getCurrentTransform() + expect(zoomedTransform.scale).toBe(1.5) + expect(zoomedTransform.x).toBe(0) + expect(zoomedTransform.y).toBe(0) + + // 平移后状态 + act(() => { + result.current.pan(10, 20) + }) + const pannedTransform = result.current.getCurrentTransform() + expect(pannedTransform.scale).toBe(1.5) + expect(pannedTransform.x).toBe(10) + expect(pannedTransform.y).toBe(20) + }) + + it('should get position from DOMMatrix when element has transform', () => { + const mockContainer = createMockContainer() + const mockSvg = createMockSvgElement() + mockSvg.style.transform = 'translate(30px, 40px) scale(2)' + mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg) + + const { result } = renderHook(() => + useImageTools( + { current: mockContainer }, + { + prefix: 'test', + imgSelector: 'svg' + } + ) + ) + + // 手动设置 transformRef 以匹配 DOM 状态 + act(() => { + result.current.pan(30, 40, true) + result.current.zoom(2, true) + }) + + const transform = result.current.getCurrentTransform() + expect(transform.scale).toBe(2) + expect(transform.x).toBe(30) + expect(transform.y).toBe(40) + }) + }) +}) diff --git a/src/renderer/src/components/ActionTools/__tests__/useToolManager.test.ts b/src/renderer/src/components/ActionTools/__tests__/useToolManager.test.ts new file mode 100644 index 0000000000..86ec67b760 --- /dev/null +++ b/src/renderer/src/components/ActionTools/__tests__/useToolManager.test.ts @@ -0,0 +1,215 @@ +import { ActionTool, useToolManager } from '@renderer/components/ActionTools' +import { act, renderHook } from '@testing-library/react' +import { useState } from 'react' +import { describe, expect, it } from 'vitest' + +// 创建测试工具数据 +const createTestTool = (overrides: Partial = {}): ActionTool => ({ + id: 'test-tool', + type: 'core', + order: 10, + icon: 'TestIcon', + tooltip: 'Test Tool', + ...overrides +}) + +describe('useToolManager', () => { + describe('registerTool', () => { + it('should register a new tool', () => { + const { result } = renderHook(() => { + const [tools, setTools] = useState([]) + const { registerTool } = useToolManager(setTools) + return { tools, registerTool } + }) + + const testTool = createTestTool() + + act(() => { + result.current.registerTool(testTool) + }) + + expect(result.current.tools).toHaveLength(1) + expect(result.current.tools[0]).toEqual(testTool) + }) + + it('should replace existing tool with same id', () => { + const { result } = renderHook(() => { + const [tools, setTools] = useState([]) + const { registerTool } = useToolManager(setTools) + return { tools, registerTool } + }) + + const originalTool = createTestTool({ tooltip: 'Original' }) + const updatedTool = createTestTool({ tooltip: 'Updated' }) + + act(() => { + result.current.registerTool(originalTool) + result.current.registerTool(updatedTool) + }) + + expect(result.current.tools).toHaveLength(1) + expect(result.current.tools[0]).toEqual(updatedTool) + }) + + it('should sort tools by order (descending)', () => { + const { result } = renderHook(() => { + const [tools, setTools] = useState([]) + const { registerTool } = useToolManager(setTools) + return { tools, registerTool } + }) + + const tool1 = createTestTool({ id: 'tool1', order: 10 }) + const tool2 = createTestTool({ id: 'tool2', order: 30 }) + const tool3 = createTestTool({ id: 'tool3', order: 20 }) + + act(() => { + result.current.registerTool(tool1) + result.current.registerTool(tool2) + result.current.registerTool(tool3) + }) + + // 应该按 order 降序排列 + expect(result.current.tools[0].id).toBe('tool2') // order: 30 + expect(result.current.tools[1].id).toBe('tool3') // order: 20 + expect(result.current.tools[2].id).toBe('tool1') // order: 10 + }) + + it('should handle tools with children', () => { + const { result } = renderHook(() => { + const [tools, setTools] = useState([]) + const { registerTool } = useToolManager(setTools) + return { tools, registerTool } + }) + + const childTool = createTestTool({ id: 'child-tool', order: 5 }) + const parentTool = createTestTool({ + id: 'parent-tool', + order: 15, + children: [childTool] + }) + + act(() => { + result.current.registerTool(parentTool) + }) + + expect(result.current.tools).toHaveLength(1) + expect(result.current.tools[0]).toEqual(parentTool) + expect(result.current.tools[0].children).toEqual([childTool]) + }) + + it('should not modify state if setTools is not provided', () => { + const { result } = renderHook(() => useToolManager(undefined)) + + // 不应该抛出错误 + expect(() => { + act(() => { + result.current.registerTool(createTestTool()) + }) + }).not.toThrow() + }) + }) + + describe('removeTool', () => { + it('should remove tool by id', () => { + const { result } = renderHook(() => { + const [tools, setTools] = useState([createTestTool()]) + const { registerTool, removeTool } = useToolManager(setTools) + return { tools, registerTool, removeTool } + }) + + expect(result.current.tools).toHaveLength(1) + + act(() => { + result.current.removeTool('test-tool') + }) + + expect(result.current.tools).toHaveLength(0) + }) + + it('should not affect other tools when removing one', () => { + const { result } = renderHook(() => { + const toolsData = [ + createTestTool({ id: 'tool1' }), + createTestTool({ id: 'tool2' }), + createTestTool({ id: 'tool3' }) + ] + const [tools, setTools] = useState(toolsData) + const { removeTool } = useToolManager(setTools) + return { tools, removeTool } + }) + + expect(result.current.tools).toHaveLength(3) + + act(() => { + result.current.removeTool('tool2') + }) + + expect(result.current.tools).toHaveLength(2) + expect(result.current.tools[0].id).toBe('tool1') + expect(result.current.tools[1].id).toBe('tool3') + }) + + it('should handle removing non-existent tool', () => { + const { result } = renderHook(() => { + const [tools, setTools] = useState([createTestTool()]) + const { removeTool } = useToolManager(setTools) + return { tools, removeTool } + }) + + expect(result.current.tools).toHaveLength(1) + + act(() => { + result.current.removeTool('non-existent-tool') + }) + + expect(result.current.tools).toHaveLength(1) // 应该没有变化 + }) + + it('should not modify state if setTools is not provided', () => { + const { result } = renderHook(() => useToolManager(undefined)) + + // 不应该抛出错误 + expect(() => { + act(() => { + result.current.removeTool('test-tool') + }) + }).not.toThrow() + }) + }) + + describe('integration', () => { + it('should handle register and remove operations together', () => { + const { result } = renderHook(() => { + const [tools, setTools] = useState([]) + const { registerTool, removeTool } = useToolManager(setTools) + return { tools, registerTool, removeTool } + }) + + const tool1 = createTestTool({ id: 'tool1' }) + const tool2 = createTestTool({ id: 'tool2' }) + + // 注册两个工具 + act(() => { + result.current.registerTool(tool1) + result.current.registerTool(tool2) + }) + + expect(result.current.tools).toHaveLength(2) + + // 移除一个工具 + act(() => { + result.current.removeTool('tool1') + }) + + expect(result.current.tools).toHaveLength(1) + expect(result.current.tools[0].id).toBe('tool2') + + // 再次注册被移除的工具 + act(() => { + result.current.registerTool(tool1) + }) + + expect(result.current.tools).toHaveLength(2) + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/constants.ts b/src/renderer/src/components/ActionTools/constants.ts similarity index 91% rename from src/renderer/src/components/CodeToolbar/constants.ts rename to src/renderer/src/components/ActionTools/constants.ts index 4eeec0fa15..c2b4966e5f 100644 --- a/src/renderer/src/components/CodeToolbar/constants.ts +++ b/src/renderer/src/components/ActionTools/constants.ts @@ -1,6 +1,6 @@ -import { CodeToolSpec } from './types' +import { ActionToolSpec } from './types' -export const TOOL_SPECS: Record = { +export const TOOL_SPECS: Record = { // Core tools copy: { id: 'copy', diff --git a/src/renderer/src/components/ActionTools/hooks/useImageTools.tsx b/src/renderer/src/components/ActionTools/hooks/useImageTools.tsx new file mode 100644 index 0000000000..e02d8846a9 --- /dev/null +++ b/src/renderer/src/components/ActionTools/hooks/useImageTools.tsx @@ -0,0 +1,292 @@ +import { loggerService } from '@logger' +import { useTheme } from '@renderer/context/ThemeProvider' +import { ImagePreviewService } from '@renderer/services/ImagePreviewService' +import { download as downloadFile } from '@renderer/utils/download' +import { svgToPngBlob, svgToSvgBlob } from '@renderer/utils/image' +import { RefObject, useCallback, useEffect, useRef } from 'react' +import { useTranslation } from 'react-i18next' + +const logger = loggerService.withContext('usePreviewToolHandlers') + +/** + * 使用图像处理工具的自定义Hook + * 提供图像缩放、复制和下载功能 + */ +export const useImageTools = ( + containerRef: RefObject, + options: { + prefix: string + imgSelector: string + enableDrag?: boolean + enableWheelZoom?: boolean + } +) => { + const transformRef = useRef({ scale: 1, x: 0, y: 0 }) // 管理变换状态 + const { imgSelector, prefix, enableDrag, enableWheelZoom } = options + const { t } = useTranslation() + const { theme } = useTheme() + + // 创建选择器函数 + const getImgElement = useCallback(() => { + if (!containerRef.current) return null + + // 优先尝试从 Shadow DOM 中查找 + const shadowRoot = containerRef.current.shadowRoot + if (shadowRoot) { + return shadowRoot.querySelector(imgSelector) as SVGElement | null + } + + // 降级到常规 DOM 查找 + return containerRef.current.querySelector(imgSelector) as SVGElement | null + }, [containerRef, imgSelector]) + + // 获取原始图像元素(移除所有变换) + const getCleanImgElement = useCallback((): SVGElement | null => { + const imgElement = getImgElement() + if (!imgElement) return null + + const clonedElement = imgElement.cloneNode(true) as SVGElement + clonedElement.style.transform = '' + clonedElement.style.transformOrigin = '' + return clonedElement + }, [getImgElement]) + + // 查询当前位置 + const getCurrentPosition = useCallback(() => { + const imgElement = getImgElement() + if (!imgElement) return transformRef.current + + const transform = imgElement.style.transform + if (!transform || transform === 'none') return transformRef.current + + // 使用CSS矩阵解析 + const matrix = new DOMMatrix(transform) + return { x: matrix.m41, y: matrix.m42 } + }, [getImgElement]) + + /** + * 平移缩放变换 + * @param element 要应用变换的元素 + * @param x X轴偏移量 + * @param y Y轴偏移量 + * @param scale 缩放比例 + */ + const applyTransform = useCallback((element: SVGElement | null, x: number, y: number, scale: number) => { + if (!element) return + element.style.transformOrigin = 'top left' + element.style.transform = `translate(${x}px, ${y}px) scale(${scale})` + }, []) + + /** + * 平移函数 - 按指定方向和距离移动图像 + * @param dx X轴偏移量(正数向右,负数向左) + * @param dy Y轴偏移量(正数向下,负数向上) + * @param absolute 是否为绝对位置(true)或相对偏移(false) + */ + const pan = useCallback( + (dx: number, dy: number, absolute = false) => { + const currentPos = getCurrentPosition() + const newX = absolute ? dx : currentPos.x + dx + const newY = absolute ? dy : currentPos.y + dy + + transformRef.current.x = newX + transformRef.current.y = newY + + const imgElement = getImgElement() + applyTransform(imgElement, newX, newY, transformRef.current.scale) + }, + [getCurrentPosition, getImgElement, applyTransform] + ) + + // 拖拽平移支持 + useEffect(() => { + if (!enableDrag || !containerRef.current) return + + const container = containerRef.current + const startPos = { x: 0, y: 0 } + + const handleMouseMove = (e: MouseEvent) => { + const dx = e.clientX - startPos.x + const dy = e.clientY - startPos.y + + // 直接使用 transformRef 中的初始偏移量进行计算 + const newX = transformRef.current.x + dx + const newY = transformRef.current.y + dy + + const imgElement = getImgElement() + // 实时应用变换,但不更新 ref,避免累积误差 + applyTransform(imgElement, newX, newY, transformRef.current.scale) + e.preventDefault() + } + + const handleMouseUp = (e: MouseEvent) => { + document.removeEventListener('mousemove', handleMouseMove) + document.removeEventListener('mouseup', handleMouseUp) + + container.style.cursor = 'default' + + // 拖拽结束后,计算最终位置并更新 ref + const dx = e.clientX - startPos.x + const dy = e.clientY - startPos.y + transformRef.current.x += dx + transformRef.current.y += dy + } + + const handleMouseDown = (e: MouseEvent) => { + if (e.button !== 0) return // 只响应左键 + + // 每次拖拽开始时,都以 ref 中当前的位置为基准 + const currentPos = getCurrentPosition() + transformRef.current.x = currentPos.x + transformRef.current.y = currentPos.y + + startPos.x = e.clientX + startPos.y = e.clientY + + container.style.cursor = 'grabbing' + e.preventDefault() + + document.addEventListener('mousemove', handleMouseMove) + document.addEventListener('mouseup', handleMouseUp) + } + + container.addEventListener('mousedown', handleMouseDown) + + return () => { + container.removeEventListener('mousedown', handleMouseDown) + // 清理以防万一,例如组件在拖拽过程中被卸载 + document.removeEventListener('mousemove', handleMouseMove) + document.removeEventListener('mouseup', handleMouseUp) + } + }, [containerRef, getImgElement, applyTransform, getCurrentPosition, enableDrag]) + + /** + * 缩放 + * @param delta 缩放增量(正值放大,负值缩小) + */ + const zoom = useCallback( + (delta: number, absolute = false) => { + const newScale = absolute + ? Math.max(0.1, Math.min(3, delta)) + : Math.max(0.1, Math.min(3, transformRef.current.scale + delta)) + + transformRef.current.scale = newScale + + const imgElement = getImgElement() + applyTransform(imgElement, transformRef.current.x, transformRef.current.y, newScale) + }, + [getImgElement, applyTransform] + ) + + // 滚轮缩放支持 + useEffect(() => { + if (!enableWheelZoom || !containerRef.current) return + + const container = containerRef.current + + const handleWheel = (e: WheelEvent) => { + if ((e.ctrlKey || e.metaKey) && e.target) { + // 确认事件发生在容器内部 + if (container.contains(e.target as Node)) { + const delta = e.deltaY < 0 ? 0.1 : -0.1 + zoom(delta) + } + } + } + + container.addEventListener('wheel', handleWheel, { passive: true }) + return () => container.removeEventListener('wheel', handleWheel) + }, [containerRef, zoom, enableWheelZoom]) + + /** + * 复制图像 + * + * 目前使用了清理变换后的图像,因此不适用于画布 + */ + const copy = useCallback(async () => { + try { + const imgElement = getCleanImgElement() + if (!imgElement) return + + const blob = await svgToPngBlob(imgElement) + await navigator.clipboard.write([new ClipboardItem({ 'image/png': blob })]) + window.message.success(t('message.copy.success')) + } catch (error) { + logger.error('Copy failed:', error as Error) + window.message.error(t('message.copy.failed')) + } + }, [getCleanImgElement, t]) + + /** + * 下载图像 + * + * 目前使用了清理变换后的图像,因此不适用于画布 + */ + const download = useCallback( + async (format: 'svg' | 'png') => { + try { + const imgElement = getCleanImgElement() + if (!imgElement) return + + const timestamp = Date.now() + + if (format === 'svg') { + const blob = svgToSvgBlob(imgElement) + const url = URL.createObjectURL(blob) + downloadFile(url, `${prefix}-${timestamp}.svg`) + URL.revokeObjectURL(url) + } else { + const blob = await svgToPngBlob(imgElement) + const pngUrl = URL.createObjectURL(blob) + downloadFile(pngUrl, `${prefix}-${timestamp}.png`) + URL.revokeObjectURL(pngUrl) + } + } catch (error) { + logger.error('Download failed:', error as Error) + window.message.error(t('message.download.failed')) + } + }, + [getCleanImgElement, prefix, t] + ) + + /** + * 预览 dialog + * + * 目前使用了清理变换后的图像,因此不适用于画布 + */ + const dialog = useCallback(async () => { + try { + const imgElement = getCleanImgElement() + if (!imgElement) return + + await ImagePreviewService.show(imgElement, { format: 'svg' }) + } catch (error) { + logger.error('Dialog preview failed:', error as Error) + window.message.error(t('message.dialog.failed')) + } + }, [getCleanImgElement, t]) + + // 获取当前变换状态 + const getCurrentTransform = useCallback(() => { + return { + scale: transformRef.current.scale, + x: transformRef.current.x, + y: transformRef.current.y + } + }, [transformRef]) + + // 切换主题时重置变换 + useEffect(() => { + pan(0, 0, true) + zoom(1, true) + }, [pan, zoom, theme]) + + return { + zoom, + pan, + copy, + download, + dialog, + getCurrentTransform + } +} diff --git a/src/renderer/src/components/CodeToolbar/hook.ts b/src/renderer/src/components/ActionTools/hooks/useToolManager.ts similarity index 76% rename from src/renderer/src/components/CodeToolbar/hook.ts rename to src/renderer/src/components/ActionTools/hooks/useToolManager.ts index 5b5d6b338f..ae73fcdb5d 100644 --- a/src/renderer/src/components/CodeToolbar/hook.ts +++ b/src/renderer/src/components/ActionTools/hooks/useToolManager.ts @@ -1,11 +1,11 @@ import { useCallback } from 'react' -import { CodeTool } from './types' +import { ActionTool, ToolRegisterProps } from '../types' -export const useCodeTool = (setTools?: (value: React.SetStateAction) => void) => { +export const useToolManager = (setTools?: ToolRegisterProps['setTools']) => { // 注册工具,如果已存在同ID工具则替换 const registerTool = useCallback( - (tool: CodeTool) => { + (tool: ActionTool) => { setTools?.((prev) => { const filtered = prev.filter((t) => t.id !== tool.id) return [...filtered, tool].sort((a, b) => b.order - a.order) diff --git a/src/renderer/src/components/ActionTools/index.ts b/src/renderer/src/components/ActionTools/index.ts new file mode 100644 index 0000000000..4c223a6613 --- /dev/null +++ b/src/renderer/src/components/ActionTools/index.ts @@ -0,0 +1,4 @@ +export * from './constants' +export * from './hooks/useImageTools' +export * from './hooks/useToolManager' +export * from './types' diff --git a/src/renderer/src/components/ActionTools/types.ts b/src/renderer/src/components/ActionTools/types.ts new file mode 100644 index 0000000000..db9855f74d --- /dev/null +++ b/src/renderer/src/components/ActionTools/types.ts @@ -0,0 +1,34 @@ +/** + * 动作工具基本信息 + */ +export interface ActionToolSpec { + id: string + type: 'core' | 'quick' + order: number +} + +/** + * 动作工具定义接口 + * @param id 唯一标识符 + * @param type 工具类型 + * @param order 显示顺序,越小越靠右 + * @param icon 按钮图标 + * @param tooltip 提示文本 + * @param visible 显示条件 + * @param onClick 点击动作 + * @param children 子工具(例如 more 下拉菜单) + */ +export interface ActionTool extends ActionToolSpec { + icon: React.ReactNode + tooltip?: string + visible?: () => boolean + onClick?: () => void + children?: Omit[] +} + +/** + * 子组件向父组件注册工具所需的 props + */ +export interface ToolRegisterProps { + setTools?: (value: React.SetStateAction) => void +} diff --git a/src/renderer/src/components/CodeBlockView/GraphvizPreview.tsx b/src/renderer/src/components/CodeBlockView/GraphvizPreview.tsx deleted file mode 100644 index 48b45bf875..0000000000 --- a/src/renderer/src/components/CodeBlockView/GraphvizPreview.tsx +++ /dev/null @@ -1,102 +0,0 @@ -import { usePreviewToolHandlers, usePreviewTools } from '@renderer/components/CodeToolbar' -import { LoadingIcon } from '@renderer/components/Icons' -import { AsyncInitializer } from '@renderer/utils/asyncInitializer' -import { Flex, Spin } from 'antd' -import { debounce } from 'lodash' -import React, { memo, startTransition, useCallback, useEffect, useMemo, useRef, useState } from 'react' -import styled from 'styled-components' - -import PreviewError from './PreviewError' -import { BasicPreviewProps } from './types' - -// 管理 viz 实例 -const vizInitializer = new AsyncInitializer(async () => { - const module = await import('@viz-js/viz') - return await module.instance() -}) - -/** 预览 Graphviz 图表 - * 通过防抖渲染提供比较统一的体验,减少闪烁。 - */ -const GraphvizPreview: React.FC = ({ children, setTools }) => { - const graphvizRef = useRef(null) - const [error, setError] = useState(null) - const [isLoading, setIsLoading] = useState(false) - - // 使用通用图像工具 - const { handleZoom, handleCopyImage, handleDownload } = usePreviewToolHandlers(graphvizRef, { - imgSelector: 'svg', - prefix: 'graphviz', - enableWheelZoom: true - }) - - // 使用工具栏 - usePreviewTools({ - setTools, - handleZoom, - handleCopyImage, - handleDownload - }) - - // 实际的渲染函数 - const renderGraphviz = useCallback(async (content: string) => { - if (!content || !graphvizRef.current) return - - try { - setIsLoading(true) - - const viz = await vizInitializer.get() - const svgElement = viz.renderSVGElement(content) - - // 清空容器并添加新的 SVG - graphvizRef.current.innerHTML = '' - graphvizRef.current.appendChild(svgElement) - - // 渲染成功,清除错误记录 - setError(null) - } catch (error) { - setError((error as Error).message || 'DOT syntax error or rendering failed') - } finally { - setIsLoading(false) - } - }, []) - - // debounce 渲染 - const debouncedRender = useMemo( - () => - debounce((content: string) => { - startTransition(() => renderGraphviz(content)) - }, 300), - [renderGraphviz] - ) - - // 触发渲染 - useEffect(() => { - if (children) { - setIsLoading(true) - debouncedRender(children) - } else { - debouncedRender.cancel() - setIsLoading(false) - } - - return () => { - debouncedRender.cancel() - } - }, [children, debouncedRender]) - - return ( - }> - - {error && {error}} - - - - ) -} - -const StyledGraphviz = styled.div` - overflow: auto; -` - -export default memo(GraphvizPreview) diff --git a/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx b/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx index 24a9749021..9fa4038459 100644 --- a/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx +++ b/src/renderer/src/components/CodeBlockView/HtmlArtifactsPopup.tsx @@ -22,45 +22,51 @@ const HtmlArtifactsPopup: React.FC = ({ open, title, ht const [currentHtml, setCurrentHtml] = useState(html) const [isFullscreen, setIsFullscreen] = useState(false) - // 预览刷新相关状态 + // Preview refresh related state const [previewHtml, setPreviewHtml] = useState(html) const intervalRef = useRef(null) const latestHtmlRef = useRef(html) + const currentPreviewHtmlRef = useRef(html) - // 当外部html更新时,同步更新内部状态 + // Sync internal state when external html updates useEffect(() => { setCurrentHtml(html) latestHtmlRef.current = html }, [html]) - // 当内部编辑的html更新时,更新引用 + // Update reference when internally edited html changes useEffect(() => { latestHtmlRef.current = currentHtml }, [currentHtml]) - // 2秒定时检查并刷新预览(仅在内容变化时) + // Update reference when preview content changes + useEffect(() => { + currentPreviewHtmlRef.current = previewHtml + }, [previewHtml]) + + // Check and refresh preview every 2 seconds (only when content changes) useEffect(() => { if (!open) return - // 立即设置初始预览内容 - setPreviewHtml(currentHtml) + // Set initial preview content immediately + setPreviewHtml(latestHtmlRef.current) - // 设置定时器,每2秒检查一次内容是否有变化 + // Set timer to check for content changes every 2 seconds intervalRef.current = setInterval(() => { - if (latestHtmlRef.current !== previewHtml) { + if (latestHtmlRef.current !== currentPreviewHtmlRef.current) { setPreviewHtml(latestHtmlRef.current) } }, 2000) - // 清理函数 + // Cleanup function return () => { if (intervalRef.current) { clearInterval(intervalRef.current) } } - }, [currentHtml, open, previewHtml]) + }, [open]) - // 全屏时防止 body 滚动 + // Prevent body scroll when fullscreen useEffect(() => { if (!open || !isFullscreen) return @@ -147,9 +153,10 @@ const HtmlArtifactsPopup: React.FC = ({ open, title, ht editable={true} onSave={setCurrentHtml} style={{ height: '100%' }} + expanded + unwrapped={false} options={{ - stream: false, - collapsible: false + stream: false }} /> @@ -159,7 +166,7 @@ const HtmlArtifactsPopup: React.FC = ({ open, title, ht {previewHtml.trim() ? ( = ({ open, title, ht ) } -// 简化的样式组件 const StyledModal = styled(Modal)<{ $isFullscreen?: boolean }>` ${(props) => props.$isFullscreen diff --git a/src/renderer/src/components/CodeBlockView/MermaidPreview.tsx b/src/renderer/src/components/CodeBlockView/MermaidPreview.tsx deleted file mode 100644 index b57c4a68a8..0000000000 --- a/src/renderer/src/components/CodeBlockView/MermaidPreview.tsx +++ /dev/null @@ -1,155 +0,0 @@ -import { nanoid } from '@reduxjs/toolkit' -import { usePreviewToolHandlers, usePreviewTools } from '@renderer/components/CodeToolbar' -import { LoadingIcon } from '@renderer/components/Icons' -import { useMermaid } from '@renderer/hooks/useMermaid' -import { Flex, Spin } from 'antd' -import { debounce } from 'lodash' -import React, { memo, startTransition, useCallback, useEffect, useMemo, useRef, useState } from 'react' -import styled from 'styled-components' - -import PreviewError from './PreviewError' -import { BasicPreviewProps } from './types' - -/** 预览 Mermaid 图表 - * 通过防抖渲染提供比较统一的体验,减少闪烁。 - * FIXME: 等将来容易判断代码块结束位置时再重构。 - */ -const MermaidPreview: React.FC = ({ children, setTools }) => { - const { mermaid, isLoading: isLoadingMermaid, error: mermaidError } = useMermaid() - const mermaidRef = useRef(null) - const diagramId = useRef(`mermaid-${nanoid(6)}`).current - const [error, setError] = useState(null) - const [isRendering, setIsRendering] = useState(false) - const [isVisible, setIsVisible] = useState(true) - - // 使用通用图像工具 - const { handleZoom, handleCopyImage, handleDownload } = usePreviewToolHandlers(mermaidRef, { - imgSelector: 'svg', - prefix: 'mermaid', - enableWheelZoom: true - }) - - // 使用工具栏 - usePreviewTools({ - setTools, - handleZoom, - handleCopyImage, - handleDownload - }) - - // 实际的渲染函数 - const renderMermaid = useCallback( - async (content: string) => { - if (!content || !mermaidRef.current) return - - try { - setIsRendering(true) - - // 验证语法,提前抛出异常 - await mermaid.parse(content) - - const { svg } = await mermaid.render(diagramId, content, mermaidRef.current) - - // 避免不可见时产生 undefined 和 NaN - const fixedSvg = svg.replace(/translate\(undefined,\s*NaN\)/g, 'translate(0, 0)') - mermaidRef.current.innerHTML = fixedSvg - - // 渲染成功,清除错误记录 - setError(null) - } catch (error) { - setError((error as Error).message) - } finally { - setIsRendering(false) - } - }, - [diagramId, mermaid] - ) - - // debounce 渲染 - const debouncedRender = useMemo( - () => - debounce((content: string) => { - startTransition(() => renderMermaid(content)) - }, 300), - [renderMermaid] - ) - - /** - * 监听可见性变化,用于触发重新渲染。 - * 这是为了解决 `MessageGroup` 组件的 `fold` 布局中被 `display: none` 隐藏的图标无法正确渲染的问题。 - * 监听时向上遍历到第一个有 `fold` className 的父节点为止(也就是目前的 `MessageWrapper`)。 - * FIXME: 将来 mermaid-js 修复此问题后可以移除这里的相关逻辑。 - */ - useEffect(() => { - if (!mermaidRef.current) return - - const checkVisibility = () => { - const element = mermaidRef.current - if (!element) return - - const currentlyVisible = element.offsetParent !== null - setIsVisible(currentlyVisible) - } - - // 初始检查 - checkVisibility() - - const observer = new MutationObserver(() => { - checkVisibility() - }) - - let targetElement = mermaidRef.current.parentElement - while (targetElement) { - observer.observe(targetElement, { - attributes: true, - attributeFilter: ['class', 'style'] - }) - - if (targetElement.className?.includes('fold')) { - break - } - - targetElement = targetElement.parentElement - } - - return () => { - observer.disconnect() - } - }, []) - - // 触发渲染 - useEffect(() => { - if (isLoadingMermaid) return - - if (mermaidRef.current?.offsetParent === null) return - - if (children) { - setIsRendering(true) - debouncedRender(children) - } else { - debouncedRender.cancel() - setIsRendering(false) - } - - return () => { - debouncedRender.cancel() - } - }, [children, isLoadingMermaid, debouncedRender, isVisible]) - - const isLoading = isLoadingMermaid || isRendering - - return ( - }> - - {(mermaidError || error) && {mermaidError || error}} - - - - ) -} - -const StyledMermaid = styled.div` - overflow: auto; -` - -export default memo(MermaidPreview) diff --git a/src/renderer/src/components/CodeBlockView/PlantUmlPreview.tsx b/src/renderer/src/components/CodeBlockView/PlantUmlPreview.tsx deleted file mode 100644 index 0916056039..0000000000 --- a/src/renderer/src/components/CodeBlockView/PlantUmlPreview.tsx +++ /dev/null @@ -1,192 +0,0 @@ -import { LoadingOutlined } from '@ant-design/icons' -import { usePreviewToolHandlers, usePreviewTools } from '@renderer/components/CodeToolbar' -import { Spin } from 'antd' -import pako from 'pako' -import React, { memo, useCallback, useRef, useState } from 'react' -import { useTranslation } from 'react-i18next' -import styled from 'styled-components' - -import { BasicPreviewProps } from './types' - -const PlantUMLServer = 'https://www.plantuml.com/plantuml' -function encode64(data: Uint8Array) { - let r = '' - for (let i = 0; i < data.length; i += 3) { - if (i + 2 === data.length) { - r += append3bytes(data[i], data[i + 1], 0) - } else if (i + 1 === data.length) { - r += append3bytes(data[i], 0, 0) - } else { - r += append3bytes(data[i], data[i + 1], data[i + 2]) - } - } - return r -} - -function encode6bit(b: number) { - if (b < 10) { - return String.fromCharCode(48 + b) - } - b -= 10 - if (b < 26) { - return String.fromCharCode(65 + b) - } - b -= 26 - if (b < 26) { - return String.fromCharCode(97 + b) - } - b -= 26 - if (b === 0) { - return '-' - } - if (b === 1) { - return '_' - } - return '?' -} - -function append3bytes(b1: number, b2: number, b3: number) { - const c1 = b1 >> 2 - const c2 = ((b1 & 0x3) << 4) | (b2 >> 4) - const c3 = ((b2 & 0xf) << 2) | (b3 >> 6) - const c4 = b3 & 0x3f - let r = '' - r += encode6bit(c1 & 0x3f) - r += encode6bit(c2 & 0x3f) - r += encode6bit(c3 & 0x3f) - r += encode6bit(c4 & 0x3f) - return r -} -/** - * https://plantuml.com/zh/code-javascript-synchronous - * To use PlantUML image generation, a text diagram description have to be : - 1. Encoded in UTF-8 - 2. Compressed using Deflate algorithm - 3. Reencoded in ASCII using a transformation _close_ to base64 - */ -function encodeDiagram(diagram: string): string { - const utf8text = new TextEncoder().encode(diagram) - const compressed = pako.deflateRaw(utf8text) - return encode64(compressed) -} - -async function downloadUrl(url: string, filename: string) { - const response = await fetch(url) - if (!response.ok) { - window.message.warning({ content: response.statusText, duration: 1.5 }) - return - } - const blob = await response.blob() - const link = document.createElement('a') - link.href = URL.createObjectURL(blob) - link.download = filename - document.body.appendChild(link) - link.click() - document.body.removeChild(link) - URL.revokeObjectURL(link.href) -} - -type PlantUMLServerImageProps = { - format: 'png' | 'svg' - diagram: string - onClick?: React.MouseEventHandler - className?: string -} - -function getPlantUMLImageUrl(format: 'png' | 'svg', diagram: string, isDark?: boolean) { - const encodedDiagram = encodeDiagram(diagram) - if (isDark) { - return `${PlantUMLServer}/d${format}/${encodedDiagram}` - } - return `${PlantUMLServer}/${format}/${encodedDiagram}` -} - -const PlantUMLServerImage: React.FC = ({ format, diagram, onClick, className }) => { - const [loading, setLoading] = useState(true) - // FIXME: 黑暗模式背景太黑了,目前让 PlantUML 和 SVG 一样保持白色背景 - const url = getPlantUMLImageUrl(format, diagram, false) - return ( - - - }> - { - setLoading(false) - }} - onError={(e) => { - setLoading(false) - const target = e.target as HTMLImageElement - target.style.opacity = '0.5' - target.style.filter = 'blur(2px)' - }} - /> - - - ) -} - -const PlantUmlPreview: React.FC = ({ children, setTools }) => { - const { t } = useTranslation() - const containerRef = useRef(null) - - const encodedDiagram = encodeDiagram(children) - - // 自定义 PlantUML 下载方法 - const customDownload = useCallback( - (format: 'svg' | 'png') => { - const timestamp = Date.now() - const url = `${PlantUMLServer}/${format}/${encodedDiagram}` - const filename = `plantuml-diagram-${timestamp}.${format}` - downloadUrl(url, filename).catch(() => { - window.message.error(t('code_block.download.failed.network')) - }) - }, - [encodedDiagram, t] - ) - - // 使用通用图像工具,提供自定义下载方法 - const { handleZoom, handleCopyImage } = usePreviewToolHandlers(containerRef, { - imgSelector: '.plantuml-preview img', - prefix: 'plantuml-diagram', - enableWheelZoom: true, - customDownloader: customDownload - }) - - // 使用工具栏 - usePreviewTools({ - setTools, - handleZoom, - handleCopyImage, - handleDownload: customDownload - }) - - return ( -
- -
- ) -} - -const StyledPlantUML = styled.div` - max-height: calc(80vh - 100px); - text-align: left; - overflow-y: auto; - background-color: white; - img { - max-width: 100%; - height: auto; - min-height: 100px; - transition: transform 0.2s ease; - } -` - -export default memo(PlantUmlPreview) diff --git a/src/renderer/src/components/CodeBlockView/PreviewError.tsx b/src/renderer/src/components/CodeBlockView/PreviewError.tsx deleted file mode 100644 index 1139dea7ff..0000000000 --- a/src/renderer/src/components/CodeBlockView/PreviewError.tsx +++ /dev/null @@ -1,14 +0,0 @@ -import { memo } from 'react' -import { styled } from 'styled-components' - -const PreviewError = styled.div` - overflow: auto; - padding: 16px; - color: #ff4d4f; - border: 1px solid #ff4d4f; - border-radius: 4px; - word-wrap: break-word; - white-space: pre-wrap; -` - -export default memo(PreviewError) diff --git a/src/renderer/src/components/CodeBlockView/StatusBar.tsx b/src/renderer/src/components/CodeBlockView/StatusBar.tsx index 651405863f..defd070ac8 100644 --- a/src/renderer/src/components/CodeBlockView/StatusBar.tsx +++ b/src/renderer/src/components/CodeBlockView/StatusBar.tsx @@ -18,6 +18,7 @@ const Container = styled(Flex)` gap: 8px; overflow-y: auto; text-wrap: wrap; + border-radius: 0 0 8px 8px; ` export default memo(StatusBar) diff --git a/src/renderer/src/components/CodeBlockView/SvgPreview.tsx b/src/renderer/src/components/CodeBlockView/SvgPreview.tsx deleted file mode 100644 index fe60101519..0000000000 --- a/src/renderer/src/components/CodeBlockView/SvgPreview.tsx +++ /dev/null @@ -1,61 +0,0 @@ -import { usePreviewToolHandlers, usePreviewTools } from '@renderer/components/CodeToolbar' -import { memo, useEffect, useRef } from 'react' - -import { BasicPreviewProps } from './types' - -/** - * 使用 Shadow DOM 渲染 SVG - */ -const SvgPreview: React.FC = ({ children, setTools }) => { - const svgContainerRef = useRef(null) - - useEffect(() => { - const container = svgContainerRef.current - if (!container) return - - const shadowRoot = container.shadowRoot || container.attachShadow({ mode: 'open' }) - - // 添加基础样式 - const style = document.createElement('style') - style.textContent = ` - :host { - padding: 1em; - background-color: white; - overflow: auto; - border: 0.5px solid var(--color-code-background); - border-top-left-radius: 0; - border-top-right-radius: 0; - display: block; - } - svg { - max-width: 100%; - height: auto; - } - ` - - // 清空并重新添加内容 - shadowRoot.innerHTML = '' - shadowRoot.appendChild(style) - - const svgContainer = document.createElement('div') - svgContainer.innerHTML = children - shadowRoot.appendChild(svgContainer) - }, [children]) - - // 使用通用图像工具 - const { handleCopyImage, handleDownload } = usePreviewToolHandlers(svgContainerRef, { - imgSelector: 'svg', - prefix: 'svg-image' - }) - - // 使用工具栏 - usePreviewTools({ - setTools, - handleCopyImage, - handleDownload - }) - - return
-} - -export default memo(SvgPreview) diff --git a/src/renderer/src/components/CodeBlockView/constants.ts b/src/renderer/src/components/CodeBlockView/constants.ts index fc6687d5f1..8d73d04e4a 100644 --- a/src/renderer/src/components/CodeBlockView/constants.ts +++ b/src/renderer/src/components/CodeBlockView/constants.ts @@ -1,7 +1,4 @@ -import GraphvizPreview from './GraphvizPreview' -import MermaidPreview from './MermaidPreview' -import PlantUmlPreview from './PlantUmlPreview' -import SvgPreview from './SvgPreview' +import { GraphvizPreview, MermaidPreview, PlantUmlPreview, SvgPreview } from '@renderer/components/Preview' /** * 特殊视图语言列表 diff --git a/src/renderer/src/components/CodeBlockView/types.ts b/src/renderer/src/components/CodeBlockView/types.ts index 5ec413658f..b1bb959458 100644 --- a/src/renderer/src/components/CodeBlockView/types.ts +++ b/src/renderer/src/components/CodeBlockView/types.ts @@ -1,13 +1,3 @@ -import { CodeTool } from '@renderer/components/CodeToolbar' - -/** - * 预览组件的基本 props - */ -export interface BasicPreviewProps { - children: string - setTools?: (value: React.SetStateAction) => void -} - /** * 视图模式 */ diff --git a/src/renderer/src/components/CodeBlockView/view.tsx b/src/renderer/src/components/CodeBlockView/view.tsx index 4a844b7b33..3a5975a901 100644 --- a/src/renderer/src/components/CodeBlockView/view.tsx +++ b/src/renderer/src/components/CodeBlockView/view.tsx @@ -1,19 +1,30 @@ import { loggerService } from '@logger' -import CodeEditor from '@renderer/components/CodeEditor' -import { CodeTool, CodeToolbar, TOOL_SPECS, useCodeTool } from '@renderer/components/CodeToolbar' -import { LoadingIcon } from '@renderer/components/Icons' +import { ActionTool } from '@renderer/components/ActionTools' +import CodeEditor, { CodeEditorHandles } from '@renderer/components/CodeEditor' +import { + CodeToolbar, + useCopyTool, + useDownloadTool, + useExpandTool, + useRunTool, + useSaveTool, + useSplitViewTool, + useViewSourceTool, + useWrapTool +} from '@renderer/components/CodeToolbar' +import CodeViewer from '@renderer/components/CodeViewer' +import ImageViewer from '@renderer/components/ImageViewer' +import { BasicPreviewHandles } from '@renderer/components/Preview' +import { MAX_COLLAPSED_CODE_HEIGHT } from '@renderer/config/constant' import { useSettings } from '@renderer/hooks/useSettings' import { pyodideService } from '@renderer/services/PyodideService' import { extractTitle } from '@renderer/utils/formats' -import { getExtensionByLanguage, isHtmlCode, isValidPlantUML } from '@renderer/utils/markdown' +import { getExtensionByLanguage, isHtmlCode } from '@renderer/utils/markdown' import dayjs from 'dayjs' -import { CirclePlay, CodeXml, Copy, Download, Eye, Square, SquarePen, SquareSplitHorizontal } from 'lucide-react' -import React, { memo, useCallback, useEffect, useMemo, useState } from 'react' +import React, { memo, startTransition, useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import styled from 'styled-components' +import styled, { css } from 'styled-components' -import ImageViewer from '../ImageViewer' -import CodePreview from './CodePreview' import { SPECIAL_VIEW_COMPONENTS, SPECIAL_VIEWS } from './constants' import HtmlArtifactsCard from './HtmlArtifactsCard' import StatusBar from './StatusBar' @@ -45,31 +56,83 @@ interface Props { */ export const CodeBlockView: React.FC = memo(({ children, language, onSave }) => { const { t } = useTranslation() - const { codeEditor, codeExecution } = useSettings() + const { codeEditor, codeExecution, codeImageTools, codeCollapsible, codeWrappable } = useSettings() + + const [viewState, setViewState] = useState({ + mode: 'special' as ViewMode, + previousMode: 'special' as ViewMode + }) + const { mode: viewMode } = viewState + + const setViewMode = useCallback((newMode: ViewMode) => { + setViewState((current) => ({ + mode: newMode, + // 当新模式不是 'split' 时才更新 + previousMode: newMode !== 'split' ? newMode : current.previousMode + })) + }, []) + + const toggleSplitView = useCallback(() => { + setViewState((current) => { + // 如果当前是 split 模式,恢复到上一个模式 + if (current.mode === 'split') { + return { ...current, mode: current.previousMode } + } + return { mode: 'split', previousMode: current.mode } + }) + }, []) - const [viewMode, setViewMode] = useState('special') const [isRunning, setIsRunning] = useState(false) const [executionResult, setExecutionResult] = useState<{ text: string; image?: string } | null>(null) - const [tools, setTools] = useState([]) - const { registerTool, removeTool } = useCodeTool(setTools) + const [tools, setTools] = useState([]) const isExecutable = useMemo(() => { return codeExecution.enabled && language === 'python' }, [codeExecution.enabled, language]) + const sourceViewRef = useRef(null) + const specialViewRef = useRef(null) + const hasSpecialView = useMemo(() => SPECIAL_VIEWS.includes(language), [language]) const isInSpecialView = useMemo(() => { return hasSpecialView && viewMode === 'special' }, [hasSpecialView, viewMode]) + const [expandOverride, setExpandOverride] = useState(!codeCollapsible) + const [unwrapOverride, setUnwrapOverride] = useState(!codeWrappable) + + // 重置用户操作 + useEffect(() => { + setExpandOverride(!codeCollapsible) + }, [codeCollapsible]) + + // 重置用户操作 + useEffect(() => { + setUnwrapOverride(!codeWrappable) + }, [codeWrappable]) + + const shouldExpand = useMemo(() => !codeCollapsible || expandOverride, [codeCollapsible, expandOverride]) + const shouldUnwrap = useMemo(() => !codeWrappable || unwrapOverride, [codeWrappable, unwrapOverride]) + + const [sourceScrollHeight, setSourceScrollHeight] = useState(0) + const expandable = useMemo(() => { + return codeCollapsible && sourceScrollHeight > MAX_COLLAPSED_CODE_HEIGHT + }, [codeCollapsible, sourceScrollHeight]) + + const handleHeightChange = useCallback((height: number) => { + startTransition(() => { + setSourceScrollHeight((prev) => (prev === height ? prev : height)) + }) + }, []) + const handleCopySource = useCallback(() => { navigator.clipboard.writeText(children) window.message.success({ content: t('code_block.copy.success'), key: 'copy-code' }) }, [children, t]) - const handleDownloadSource = useCallback(async () => { + const handleDownloadSource = useCallback(() => { let fileName = '' // 尝试提取 HTML 标题 @@ -82,7 +145,7 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave fileName = `${dayjs().format('YYYYMMDDHHmm')}` } - const ext = await getExtensionByLanguage(language) + const ext = getExtensionByLanguage(language) window.api.file.save(`${fileName}${ext}`, children) }, [children, language]) @@ -106,101 +169,103 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave }) }, [children, codeExecution.timeoutMinutes]) - useEffect(() => { - // 复制按钮 - registerTool({ - ...TOOL_SPECS.copy, - icon: , - tooltip: t('code_block.copy.source'), - onClick: handleCopySource - }) + const showPreviewTools = useMemo(() => { + return viewMode !== 'source' && hasSpecialView + }, [hasSpecialView, viewMode]) - // 下载按钮 - registerTool({ - ...TOOL_SPECS.download, - icon: , - tooltip: t('code_block.download.source'), - onClick: handleDownloadSource - }) - return () => { - removeTool(TOOL_SPECS.copy.id) - removeTool(TOOL_SPECS.download.id) - } - }, [handleCopySource, handleDownloadSource, registerTool, removeTool, t]) + // 复制按钮 + useCopyTool({ + showPreviewTools, + previewRef: specialViewRef, + onCopySource: handleCopySource, + setTools + }) - // 特殊视图的编辑按钮,在分屏模式下不可用 - useEffect(() => { - if (!hasSpecialView || viewMode === 'split') return + // 下载按钮 + useDownloadTool({ + showPreviewTools, + previewRef: specialViewRef, + onDownloadSource: handleDownloadSource, + setTools + }) - const viewSourceToolSpec = codeEditor.enabled ? TOOL_SPECS.edit : TOOL_SPECS['view-source'] + // 特殊视图的编辑/查看源码按钮,在分屏模式下不可用 + useViewSourceTool({ + enabled: hasSpecialView, + editable: codeEditor.enabled, + viewMode, + onViewModeChange: setViewMode, + setTools + }) - if (codeEditor.enabled) { - registerTool({ - ...viewSourceToolSpec, - icon: viewMode === 'source' ? : , - tooltip: viewMode === 'source' ? t('code_block.preview.label') : t('code_block.edit.label'), - onClick: () => setViewMode(viewMode === 'source' ? 'special' : 'source') - }) - } else { - registerTool({ - ...viewSourceToolSpec, - icon: viewMode === 'source' ? : , - tooltip: viewMode === 'source' ? t('code_block.preview.label') : t('code_block.preview.source'), - onClick: () => setViewMode(viewMode === 'source' ? 'special' : 'source') - }) - } - - return () => removeTool(viewSourceToolSpec.id) - }, [codeEditor.enabled, hasSpecialView, viewMode, registerTool, removeTool, t]) - - // 特殊视图的分屏按钮 - useEffect(() => { - if (!hasSpecialView) return - - registerTool({ - ...TOOL_SPECS['split-view'], - icon: viewMode === 'split' ? : , - tooltip: viewMode === 'split' ? t('code_block.split.restore') : t('code_block.split.label'), - onClick: () => setViewMode(viewMode === 'split' ? 'special' : 'split') - }) - - return () => removeTool(TOOL_SPECS['split-view'].id) - }, [hasSpecialView, viewMode, registerTool, removeTool, t]) + // 特殊视图存在时的分屏按钮 + useSplitViewTool({ + enabled: hasSpecialView, + viewMode, + onToggleSplitView: toggleSplitView, + setTools + }) // 运行按钮 - useEffect(() => { - if (!isExecutable) return + useRunTool({ + enabled: isExecutable, + isRunning, + onRun: handleRunScript, + setTools + }) - registerTool({ - ...TOOL_SPECS.run, - icon: isRunning ? : , - tooltip: t('code_block.run'), - onClick: () => !isRunning && handleRunScript() - }) + // 源代码视图的展开/折叠按钮 + useExpandTool({ + enabled: !isInSpecialView, + expanded: shouldExpand, + expandable, + toggle: useCallback(() => setExpandOverride((prev) => !prev), []), + setTools + }) - return () => isExecutable && removeTool(TOOL_SPECS.run.id) - }, [isExecutable, isRunning, handleRunScript, registerTool, removeTool, t]) + // 源代码视图的自动换行按钮 + useWrapTool({ + enabled: !isInSpecialView, + unwrapped: shouldUnwrap, + wrappable: codeWrappable, + toggle: useCallback(() => setUnwrapOverride((prev) => !prev), []), + setTools + }) + + // 代码编辑器的保存按钮 + useSaveTool({ + enabled: codeEditor.enabled && !isInSpecialView, + sourceViewRef, + setTools + }) // 源代码视图组件 - const sourceView = useMemo(() => { - if (codeEditor.enabled) { - return ( + const sourceView = useMemo( + () => + codeEditor.enabled ? ( - ) - } else { - return ( - + ) : ( + {children} - - ) - } - }, [children, codeEditor.enabled, language, onSave, setTools]) + + ), + [children, codeEditor.enabled, handleHeightChange, language, onSave, shouldExpand, shouldUnwrap] + ) // 特殊视图组件映射 const specialView = useMemo(() => { @@ -208,13 +273,12 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave if (!SpecialView) return null - // PlantUML 语法验证 - if (language === 'plantuml' && !isValidPlantUML(children)) { - return null - } - - return {children} - }, [children, language]) + return ( + + {children} + + ) + }, [children, codeImageTools, language]) const renderHeader = useMemo(() => { const langTag = '<' + language.toUpperCase() + '>' @@ -227,7 +291,7 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave const showSourceView = !specialView || viewMode !== 'special' return ( - + {showSpecialView && specialView} {showSourceView && sourceView} @@ -260,7 +324,7 @@ const CodeBlockWrapper = styled.div<{ $isInSpecialView: boolean }>` position: relative; width: 100%; /* FIXME: 最小宽度用于解决两个问题。 - * 一是 CodePreview 在气泡样式下的用户消息中无法撑开气泡, + * 一是 CodeViewer 在气泡样式下的用户消息中无法撑开气泡, * 二是 代码块内容过少时 toolbar 会和 title 重叠。 */ min-width: 45ch; @@ -295,9 +359,10 @@ const CodeHeader = styled.div<{ $isInSpecialView: boolean }>` border-top-right-radius: 8px; margin-top: ${(props) => (props.$isInSpecialView ? '6px' : '0')}; height: ${(props) => (props.$isInSpecialView ? '16px' : '34px')}; + background-color: ${(props) => (props.$isInSpecialView ? 'transparent' : 'var(--color-background-mute)')}; ` -const SplitViewWrapper = styled.div` +const SplitViewWrapper = styled.div<{ $viewMode?: ViewMode }>` display: flex; > * { @@ -306,7 +371,27 @@ const SplitViewWrapper = styled.div` } &:not(:has(+ [class*='Container'])) { - border-radius: 0 0 8px 8px; + // 特殊视图的 header 会隐藏,所以全都使用圆角 + border-radius: ${(props) => (props.$viewMode === 'special' ? '8px' : '0 0 8px 8px')}; overflow: hidden; } + + // 在 split 模式下添加中间分隔线 + ${(props) => + props.$viewMode === 'split' && + css` + position: relative; + + &:before { + content: ''; + position: absolute; + top: 0; + bottom: 0; + left: 50%; + width: 1px; + background-color: var(--color-background-mute); + transform: translateX(-50%); + z-index: 1; + } + `} ` diff --git a/src/renderer/src/components/CodeEditor/hooks.ts b/src/renderer/src/components/CodeEditor/hooks.ts index 53ea6f4a22..d49a703297 100644 --- a/src/renderer/src/components/CodeEditor/hooks.ts +++ b/src/renderer/src/components/CodeEditor/hooks.ts @@ -175,3 +175,26 @@ export function useBlurHandler({ onBlur }: UseBlurHandlerProps) { }) }, [onBlur]) } + +interface UseHeightListenerProps { + onHeightChange?: (scrollHeight: number) => void +} + +/** + * CodeMirror 扩展,用于监听编辑器高度变化 + * @param onHeightChange 高度变化时触发的回调函数 + * @returns 扩展或空数组 + */ +export function useHeightListener({ onHeightChange }: UseHeightListenerProps) { + return useMemo(() => { + if (!onHeightChange) { + return [] + } + + return EditorView.updateListener.of((update) => { + if (update.docChanged || update.heightChanged) { + onHeightChange(update.view.scrollDOM?.scrollHeight ?? 0) + } + }) + }, [onHeightChange]) +} diff --git a/src/renderer/src/components/CodeEditor/index.tsx b/src/renderer/src/components/CodeEditor/index.tsx index c36c7f7076..3ae87ad5dd 100644 --- a/src/renderer/src/components/CodeEditor/index.tsx +++ b/src/renderer/src/components/CodeEditor/index.tsx @@ -1,32 +1,29 @@ -import { CodeTool, TOOL_SPECS, useCodeTool } from '@renderer/components/CodeToolbar' +import { MAX_COLLAPSED_CODE_HEIGHT } from '@renderer/config/constant' import { useCodeStyle } from '@renderer/context/CodeStyleProvider' import { useSettings } from '@renderer/hooks/useSettings' import CodeMirror, { Annotation, BasicSetupOptions, EditorView, Extension } from '@uiw/react-codemirror' import diff from 'fast-diff' -import { - ChevronsDownUp, - ChevronsUpDown, - Save as SaveIcon, - Text as UnWrapIcon, - WrapText as WrapIcon -} from 'lucide-react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useCallback, useEffect, useImperativeHandle, useMemo, useRef } from 'react' import { memo } from 'react' -import { useTranslation } from 'react-i18next' -import { useBlurHandler, useLanguageExtensions, useSaveKeymap } from './hooks' +import { useBlurHandler, useHeightListener, useLanguageExtensions, useSaveKeymap } from './hooks' // 标记非用户编辑的变更 const External = Annotation.define() -interface Props { +export interface CodeEditorHandles { + save?: () => void +} + +interface CodeEditorProps { + ref?: React.RefObject value: string placeholder?: string | HTMLElement language: string onSave?: (newContent: string) => void onChange?: (newContent: string) => void onBlur?: (newContent: string) => void - setTools?: (value: React.SetStateAction) => void + onHeightChange?: (scrollHeight: number) => void height?: string minHeight?: string maxHeight?: string @@ -35,15 +32,16 @@ interface Props { options?: { stream?: boolean // 用于流式响应场景,默认 false lint?: boolean - collapsible?: boolean - wrappable?: boolean keymap?: boolean } & BasicSetupOptions /** 用于追加 extensions */ extensions?: Extension[] /** 用于覆写编辑器的样式,会直接传给 CodeMirror 的 style 属性 */ style?: React.CSSProperties + className?: string editable?: boolean + expanded?: boolean + unwrapped?: boolean } /** @@ -52,13 +50,14 @@ interface Props { * 目前必须和 CodeToolbar 配合使用。 */ const CodeEditor = ({ + ref, value, placeholder, language, onSave, onChange, onBlur, - setTools, + onHeightChange, height, minHeight, maxHeight, @@ -66,17 +65,12 @@ const CodeEditor = ({ options, extensions, style, - editable = true -}: Props) => { - const { - fontSize: _fontSize, - codeShowLineNumbers: _lineNumbers, - codeCollapsible: _collapsible, - codeWrappable: _wrappable, - codeEditor - } = useSettings() - const collapsible = useMemo(() => options?.collapsible ?? _collapsible, [options?.collapsible, _collapsible]) - const wrappable = useMemo(() => options?.wrappable ?? _wrappable, [options?.wrappable, _wrappable]) + className, + editable = true, + expanded = true, + unwrapped = false +}: CodeEditorProps) => { + const { fontSize: _fontSize, codeShowLineNumbers: _lineNumbers, codeEditor } = useSettings() const enableKeymap = useMemo(() => options?.keymap ?? codeEditor.keymap, [options?.keymap, codeEditor.keymap]) // 合并 codeEditor 和 options 的 basicSetup,options 优先 @@ -91,63 +85,16 @@ const CodeEditor = ({ const customFontSize = useMemo(() => fontSize ?? `${_fontSize - 1}px`, [fontSize, _fontSize]) const { activeCmTheme } = useCodeStyle() - const [isExpanded, setIsExpanded] = useState(!collapsible) - const [isUnwrapped, setIsUnwrapped] = useState(!wrappable) const initialContent = useRef(options?.stream ? (value ?? '').trimEnd() : (value ?? '')) - const [editorReady, setEditorReady] = useState(false) const editorViewRef = useRef(null) - const { t } = useTranslation() const langExtensions = useLanguageExtensions(language, options?.lint) - const { registerTool, removeTool } = useCodeTool(setTools) - - // 展开/折叠工具 - useEffect(() => { - registerTool({ - ...TOOL_SPECS.expand, - icon: isExpanded ? : , - tooltip: isExpanded ? t('code_block.collapse') : t('code_block.expand'), - visible: () => { - const scrollHeight = editorViewRef?.current?.scrollDOM?.scrollHeight - return collapsible && (scrollHeight ?? 0) > 350 - }, - onClick: () => setIsExpanded((prev) => !prev) - }) - - return () => removeTool(TOOL_SPECS.expand.id) - }, [collapsible, isExpanded, registerTool, removeTool, t, editorReady]) - - // 自动换行工具 - useEffect(() => { - registerTool({ - ...TOOL_SPECS.wrap, - icon: isUnwrapped ? : , - tooltip: isUnwrapped ? t('code_block.wrap.on') : t('code_block.wrap.off'), - visible: () => wrappable, - onClick: () => setIsUnwrapped((prev) => !prev) - }) - - return () => removeTool(TOOL_SPECS.wrap.id) - }, [wrappable, isUnwrapped, registerTool, removeTool, t]) - const handleSave = useCallback(() => { const currentDoc = editorViewRef.current?.state.doc.toString() ?? '' onSave?.(currentDoc) }, [onSave]) - // 保存按钮 - useEffect(() => { - registerTool({ - ...TOOL_SPECS.save, - icon: , - tooltip: t('code_block.edit.save.label'), - onClick: handleSave - }) - - return () => removeTool(TOOL_SPECS.save.id) - }, [handleSave, registerTool, removeTool, t]) - // 流式响应过程中计算 changes 来更新 EditorView // 无法处理用户在流式响应过程中编辑代码的情况(应该也不必处理) useEffect(() => { @@ -166,26 +113,24 @@ const CodeEditor = ({ } }, [options?.stream, value]) - useEffect(() => { - setIsExpanded(!collapsible) - }, [collapsible]) - - useEffect(() => { - setIsUnwrapped(!wrappable) - }, [wrappable]) - const saveKeymapExtension = useSaveKeymap({ onSave, enabled: enableKeymap }) const blurExtension = useBlurHandler({ onBlur }) + const heightListenerExtension = useHeightListener({ onHeightChange }) const customExtensions = useMemo(() => { return [ ...(extensions ?? []), ...langExtensions, - ...(isUnwrapped ? [] : [EditorView.lineWrapping]), + ...(unwrapped ? [] : [EditorView.lineWrapping]), saveKeymapExtension, - blurExtension + blurExtension, + heightListenerExtension ].flat() - }, [extensions, langExtensions, isUnwrapped, saveKeymapExtension, blurExtension]) + }, [extensions, langExtensions, unwrapped, saveKeymapExtension, blurExtension, heightListenerExtension]) + + useImperativeHandle(ref, () => ({ + save: handleSave + })) return ( { editorViewRef.current = view - setEditorReady(true) + onHeightChange?.(view.scrollDOM?.scrollHeight ?? 0) }} onChange={(value, viewUpdate) => { if (onChange && viewUpdate.docChanged) onChange(value) @@ -230,6 +175,7 @@ const CodeEditor = ({ borderRadius: 'inherit', ...style }} + className={`code-editor ${className ?? ''}`} /> ) } diff --git a/src/renderer/src/components/CodeToolbar/__tests__/CodeToolButton.test.tsx b/src/renderer/src/components/CodeToolbar/__tests__/CodeToolButton.test.tsx new file mode 100644 index 0000000000..045d242158 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/CodeToolButton.test.tsx @@ -0,0 +1,164 @@ +import { ActionTool } from '@renderer/components/ActionTools' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import CodeToolButton from '../button' + +// Mock Antd components +const mocks = vi.hoisted(() => ({ + Tooltip: vi.fn(({ children, title }) => ( +
+ {children} +
+ )), + Dropdown: vi.fn(({ children, menu }) => ( +
+ {children} +
+ )) +})) + +vi.mock('antd', () => ({ + Tooltip: mocks.Tooltip, + Dropdown: mocks.Dropdown +})) + +// Mock ToolWrapper +vi.mock('../styles', () => ({ + ToolWrapper: ({ children, onClick }: { children: React.ReactNode; onClick?: () => void }) => ( + + ) +})) + +// Helper function to create mock tools +const createMockTool = (overrides: Partial = {}): ActionTool => ({ + id: 'test-tool', + type: 'core', + order: 10, + icon: Test Icon, + tooltip: 'Test Tool', + onClick: vi.fn(), + ...overrides +}) + +const createMockChildTool = (id: string, tooltip: string): Omit => ({ + id, + type: 'quick', + order: 10, + icon: {tooltip} Icon, + tooltip, + onClick: vi.fn() +}) + +describe('CodeToolButton', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering modes', () => { + it('should render as simple button when no children', () => { + const tool = createMockTool() + render() + + // Should render button with tooltip + expect(screen.getByTestId('tooltip')).toBeInTheDocument() + expect(screen.getByTestId('tool-wrapper')).toBeInTheDocument() + expect(screen.getByTestId('test-icon')).toBeInTheDocument() + + // Should not render dropdown + expect(screen.queryByTestId('dropdown')).not.toBeInTheDocument() + }) + + it('should render as simple button when children array is empty', () => { + const tool = createMockTool({ children: [] }) + render() + + expect(screen.queryByTestId('dropdown')).not.toBeInTheDocument() + expect(screen.getByTestId('tooltip')).toBeInTheDocument() + }) + + it('should render as dropdown when has children', () => { + const children = [createMockChildTool('child1', 'Child 1')] + const tool = createMockTool({ children }) + render() + + // Should render dropdown containing the main button + expect(screen.getByTestId('dropdown')).toBeInTheDocument() + expect(screen.getByTestId('tooltip')).toBeInTheDocument() + expect(screen.getByTestId('tool-wrapper')).toBeInTheDocument() + }) + }) + + describe('user interactions', () => { + it('should trigger onClick when simple button is clicked', () => { + const mockOnClick = vi.fn() + const tool = createMockTool({ onClick: mockOnClick }) + render() + + fireEvent.click(screen.getByTestId('tool-wrapper')) + + expect(mockOnClick).toHaveBeenCalledTimes(1) + }) + + it('should handle missing onClick gracefully', () => { + const tool = createMockTool({ onClick: undefined }) + render() + + expect(() => { + fireEvent.click(screen.getByTestId('tool-wrapper')) + }).not.toThrow() + }) + }) + + describe('dropdown functionality', () => { + it('should configure dropdown with correct menu structure', () => { + const mockOnClick1 = vi.fn() + const mockOnClick2 = vi.fn() + const children = [createMockChildTool('child1', 'Child 1'), createMockChildTool('child2', 'Child 2')] + children[0].onClick = mockOnClick1 + children[1].onClick = mockOnClick2 + + const tool = createMockTool({ children }) + render() + + // Verify dropdown was called with correct menu structure + expect(mocks.Dropdown).toHaveBeenCalled() + const dropdownProps = mocks.Dropdown.mock.calls[0][0] + + expect(dropdownProps.menu.items).toHaveLength(2) + expect(dropdownProps.menu.items[0].key).toBe('child1') + expect(dropdownProps.menu.items[0].label).toBe('Child 1') + expect(dropdownProps.menu.items[0].onClick).toBe(mockOnClick1) + expect(dropdownProps.trigger).toEqual(['click']) + }) + }) + + describe('accessibility', () => { + it('should provide accessible button element with tooltip', () => { + const tool = createMockTool({ tooltip: 'Accessible Tool' }) + render() + + const button = screen.getByTestId('tool-wrapper') + expect(button.tagName).toBe('BUTTON') + expect(screen.getByTestId('tooltip')).toHaveAttribute('data-title', 'Accessible Tool') + }) + }) + + describe('error handling', () => { + it('should render without crashing for minimal tool configuration', () => { + const minimalTool: ActionTool = { + id: 'minimal', + type: 'core', + order: 1, + icon: null, + tooltip: '' + } + + expect(() => { + render() + }).not.toThrow() + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/__tests__/CodeToolbar.test.tsx b/src/renderer/src/components/CodeToolbar/__tests__/CodeToolbar.test.tsx new file mode 100644 index 0000000000..5c38de461b --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/CodeToolbar.test.tsx @@ -0,0 +1,262 @@ +import { ActionTool } from '@renderer/components/ActionTools' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import CodeToolbar from '../toolbar' + +// Test constants +const MORE_BUTTON_TOOLTIP = 'code_block.more' + +// Mock components +const mocks = vi.hoisted(() => ({ + CodeToolButton: vi.fn(({ tool }) => ( +
+ {tool.icon} +
+ )), + Tooltip: vi.fn(({ children, title }) => ( +
+ {children} +
+ )), + HStack: vi.fn(({ children, className }) => ( +
+ {children} +
+ )), + ToolWrapper: vi.fn(({ children, onClick, className }) => ( +
+ {children} +
+ )), + EllipsisVertical: vi.fn(() =>
), + useTranslation: vi.fn(() => ({ + t: vi.fn((key: string) => key) + })) +})) + +vi.mock('../button', () => ({ + default: mocks.CodeToolButton +})) + +vi.mock('antd', () => ({ + Tooltip: mocks.Tooltip +})) + +vi.mock('@renderer/components/Layout', () => ({ + HStack: mocks.HStack +})) + +vi.mock('./styles', () => ({ + ToolWrapper: mocks.ToolWrapper +})) + +vi.mock('lucide-react', () => ({ + EllipsisVertical: mocks.EllipsisVertical +})) + +vi.mock('react-i18next', () => ({ + useTranslation: mocks.useTranslation +})) + +// Helper function to create mock tools +const createMockTool = (overrides: Partial = {}): ActionTool => ({ + id: 'test-tool', + type: 'core', + order: 1, + icon:
Icon
, + tooltip: 'Test Tool', + onClick: vi.fn(), + ...overrides +}) + +// Common test data +const createMixedTools = () => [ + createMockTool({ id: 'quick1', type: 'quick' }), + createMockTool({ id: 'quick2', type: 'quick' }), + createMockTool({ id: 'core1', type: 'core' }) +] + +const createCoreOnlyTools = () => [ + createMockTool({ id: 'core1', type: 'core' }), + createMockTool({ id: 'core2', type: 'core' }) +] + +// Helper function to click more button +const clickMoreButton = () => { + const tooltip = screen.getByTestId('tooltip') + fireEvent.click(tooltip.firstChild as Element) +} + +describe('CodeToolbar', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('basic rendering', () => { + it('should match snapshot with mixed tools', () => { + const { container } = render() + expect(container).toMatchSnapshot() + }) + + it('should match snapshot with only core tools', () => { + const { container } = render() + expect(container).toMatchSnapshot() + }) + }) + + describe('empty state', () => { + it('should render nothing when no tools provided', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('should render nothing when all tools are not visible', () => { + const tools = [ + createMockTool({ id: 'tool1', visible: () => false }), + createMockTool({ id: 'tool2', visible: () => false }) + ] + const { container } = render() + expect(container.firstChild).toBeNull() + }) + }) + + describe('tool visibility filtering', () => { + it('should only render visible tools', () => { + const tools = [ + createMockTool({ id: 'visible-tool', visible: () => true }), + createMockTool({ id: 'hidden-tool', visible: () => false }), + createMockTool({ id: 'no-visible-prop' }) // Should be visible by default + ] + render() + + expect(screen.getByTestId('tool-button-visible-tool')).toBeInTheDocument() + expect(screen.getByTestId('tool-button-no-visible-prop')).toBeInTheDocument() + expect(screen.queryByTestId('tool-button-hidden-tool')).not.toBeInTheDocument() + }) + + it('should show tools without visible function by default', () => { + const tools = [createMockTool({ id: 'default-visible' })] + render() + + expect(screen.getByTestId('tool-button-default-visible')).toBeInTheDocument() + }) + }) + + describe('tool type grouping and quick tools behavior', () => { + it('should separate core and quick tools - show quick tools when expanded', () => { + const tools = [ + createMockTool({ id: 'core1', type: 'core' }), + createMockTool({ id: 'quick1', type: 'quick' }), + createMockTool({ id: 'core2', type: 'core' }), + createMockTool({ id: 'quick2', type: 'quick' }) + ] + render() + + // Initial state: core tools visible, quick tools hidden + expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument() + expect(screen.getByTestId('tool-button-core2')).toBeInTheDocument() + expect(screen.queryByTestId('tool-button-quick1')).not.toBeInTheDocument() + expect(screen.queryByTestId('tool-button-quick2')).not.toBeInTheDocument() + + // After clicking more button, quick tools should be visible + clickMoreButton() + + expect(screen.getByTestId('tool-button-quick1')).toBeInTheDocument() + expect(screen.getByTestId('tool-button-quick2')).toBeInTheDocument() + }) + + it('should render only core tools when no quick tools exist', () => { + render() + + expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument() + expect(screen.getByTestId('tool-button-core2')).toBeInTheDocument() + expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument() // No more button + }) + + it('should show single quick tool directly without more button', () => { + const tools = [createMockTool({ id: 'quick1', type: 'quick' }), createMockTool({ id: 'core1', type: 'core' })] + render() + + expect(screen.getByTestId('tool-button-quick1')).toBeInTheDocument() + expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument() + expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument() // No more button + }) + + it('should show more button when multiple quick tools exist', () => { + render() + + // Initially quick tools should be hidden + expect(screen.queryByTestId('tool-button-quick1')).not.toBeInTheDocument() + expect(screen.queryByTestId('tool-button-quick2')).not.toBeInTheDocument() + expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument() + expect(screen.getByTestId('tooltip')).toBeInTheDocument() // More button exists + }) + + it('should toggle quick tools visibility when more button is clicked', () => { + render() + + // Initial state: quick tools hidden + expect(screen.queryByTestId('tool-button-quick1')).not.toBeInTheDocument() + expect(screen.queryByTestId('tool-button-quick2')).not.toBeInTheDocument() + + // Click more button: quick tools visible + clickMoreButton() + expect(screen.getByTestId('tool-button-quick1')).toBeInTheDocument() + expect(screen.getByTestId('tool-button-quick2')).toBeInTheDocument() + + // Click more button again: quick tools hidden + clickMoreButton() + expect(screen.queryByTestId('tool-button-quick1')).not.toBeInTheDocument() + expect(screen.queryByTestId('tool-button-quick2')).not.toBeInTheDocument() + }) + + it('should apply active class to more button when quick tools are shown', () => { + const tools = [createMockTool({ id: 'quick1', type: 'quick' }), createMockTool({ id: 'quick2', type: 'quick' })] + render() + + const tooltip = screen.getByTestId('tooltip') + const moreButton = tooltip.firstChild as Element + + // Initial state: no active class + expect(moreButton).not.toHaveClass('active') + + // After click: has active class + fireEvent.click(moreButton) + expect(moreButton).toHaveClass('active') + + // After second click: no active class + fireEvent.click(moreButton) + expect(moreButton).not.toHaveClass('active') + }) + + it('should display correct tooltip and icon for more button', () => { + render() + + const tooltip = screen.getByTestId('tooltip') + expect(tooltip).toHaveAttribute('data-title', MORE_BUTTON_TOOLTIP) + + expect(screen.getByTestId('ellipsis-icon')).toBeInTheDocument() + expect(screen.getByTestId('ellipsis-icon')).toHaveClass('tool-icon') + }) + + it('should render core tools regardless of quick tools state', () => { + const tools = [ + createMockTool({ id: 'quick1', type: 'quick' }), + createMockTool({ id: 'quick2', type: 'quick' }), + createMockTool({ id: 'core1', type: 'core' }), + createMockTool({ id: 'core2', type: 'core' }) + ] + render() + + // Core tools always visible + expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument() + expect(screen.getByTestId('tool-button-core2')).toBeInTheDocument() + + // After clicking more button, core tools still visible + clickMoreButton() + expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument() + expect(screen.getByTestId('tool-button-core2')).toBeInTheDocument() + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/__tests__/__snapshots__/CodeToolbar.test.tsx.snap b/src/renderer/src/components/CodeToolbar/__tests__/__snapshots__/CodeToolbar.test.tsx.snap new file mode 100644 index 0000000000..c2b4028e32 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/__snapshots__/CodeToolbar.test.tsx.snap @@ -0,0 +1,129 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`CodeToolbar > basic rendering > should match snapshot with mixed tools 1`] = ` +.c2 { + display: flex; + align-items: center; + justify-content: center; + width: 24px; + height: 24px; + border-radius: 4px; + cursor: pointer; + user-select: none; + transition: all 0.2s ease; + color: var(--color-text-3); +} + +.c2:hover { + background-color: var(--color-background-soft); +} + +.c2:hover .tool-icon { + color: var(--color-text-1); +} + +.c2.active { + color: var(--color-primary); +} + +.c2.active .tool-icon { + color: var(--color-primary); +} + +.c2 .tool-icon { + width: 14px; + height: 14px; + color: var(--color-text-3); +} + +.c0 { + position: sticky; + top: 28px; + z-index: 10; +} + +.c1 { + position: absolute; + align-items: center; + bottom: 0.3rem; + right: 0.5rem; + height: 24px; + gap: 4px; +} + +
+
+
+
+
+
+
+
+
+
+ Icon +
+
+
+
+
+`; + +exports[`CodeToolbar > basic rendering > should match snapshot with only core tools 1`] = ` +.c0 { + position: sticky; + top: 28px; + z-index: 10; +} + +.c1 { + position: absolute; + align-items: center; + bottom: 0.3rem; + right: 0.5rem; + height: 24px; + gap: 4px; +} + +
+
+
+
+
+ Icon +
+
+
+
+
+`; diff --git a/src/renderer/src/components/CodeToolbar/__tests__/useCopyTool.test.tsx b/src/renderer/src/components/CodeToolbar/__tests__/useCopyTool.test.tsx new file mode 100644 index 0000000000..2b39950eab --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/useCopyTool.test.tsx @@ -0,0 +1,251 @@ +import { useCopyTool } from '@renderer/components/CodeToolbar/hooks/useCopyTool' +import { BasicPreviewHandles } from '@renderer/components/Preview' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies +const mocks = vi.hoisted(() => ({ + i18n: { + t: vi.fn((key: string) => key) + }, + useTemporaryValue: vi.fn(), + useToolManager: vi.fn(), + TOOL_SPECS: { + copy: { + id: 'copy', + type: 'core', + order: 11 + }, + 'copy-image': { + id: 'copy-image', + type: 'quick', + order: 30 + } + } +})) + +vi.mock('lucide-react', () => ({ + Check: () =>
, + Image: () =>
+})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: mocks.i18n.t + }) +})) + +vi.mock('@renderer/components/Icons', () => ({ + CopyIcon: () =>
+})) + +vi.mock('@renderer/components/ActionTools', () => ({ + TOOL_SPECS: mocks.TOOL_SPECS, + useToolManager: mocks.useToolManager +})) + +vi.mock('@renderer/hooks/useTemporaryValue', () => ({ + useTemporaryValue: mocks.useTemporaryValue +})) + +// Mock useToolManager +const mockRegisterTool = vi.fn() +const mockRemoveTool = vi.fn() +mocks.useToolManager.mockImplementation(() => ({ + registerTool: mockRegisterTool, + removeTool: mockRemoveTool +})) + +// Mock useTemporaryValue setters +const mockSetCopiedTemporarily = vi.fn() +const mockSetCopiedImageTemporarily = vi.fn() + +describe('useCopyTool', () => { + beforeEach(() => { + vi.clearAllMocks() + // Reset mocks for each test to ensure isolation + mocks.useTemporaryValue + .mockImplementationOnce(() => [false, mockSetCopiedTemporarily]) + .mockImplementationOnce(() => [false, mockSetCopiedImageTemporarily]) + }) + + // Helper function to create mock props + const createMockProps = (overrides: Partial[0]> = {}) => ({ + showPreviewTools: false, + previewRef: { current: null }, + onCopySource: vi.fn(), + setTools: vi.fn(), + ...overrides + }) + + const createMockPreviewHandles = (): BasicPreviewHandles => ({ + pan: vi.fn(), + zoom: vi.fn(), + copy: vi.fn(), + download: vi.fn() + }) + + describe('tool registration', () => { + it('should register only the copy-source tool when showPreviewTools is false', () => { + const props = createMockProps({ showPreviewTools: false }) + renderHook(() => useCopyTool(props)) + + expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools) + expect(mockRegisterTool).toHaveBeenCalledTimes(1) + expect(mockRegisterTool).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'copy', + tooltip: 'code_block.copy.source' + }) + ) + }) + + it('should register only the copy-source tool when previewRef is null', () => { + const props = createMockProps({ showPreviewTools: true, previewRef: { current: null } }) + renderHook(() => useCopyTool(props)) + + expect(mockRegisterTool).toHaveBeenCalledTimes(1) + expect(mockRegisterTool).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'copy' + }) + ) + }) + + it('should register both copy-source and copy-image tools when preview is available', () => { + const props = createMockProps({ + showPreviewTools: true, + previewRef: { current: createMockPreviewHandles() } + }) + + renderHook(() => useCopyTool(props)) + + expect(mockRegisterTool).toHaveBeenCalledTimes(2) + + // Check first tool: copy source + expect(mockRegisterTool).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'copy', + tooltip: 'code_block.copy.source', + onClick: expect.any(Function) + }) + ) + + // Check second tool: copy image + expect(mockRegisterTool).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'copy-image', + tooltip: 'preview.copy.image', + onClick: expect.any(Function) + }) + ) + }) + }) + + describe('copy functionality', () => { + it('should execute copy source behavior when copy-source tool is clicked', () => { + const mockOnCopySource = vi.fn() + const props = createMockProps({ onCopySource: mockOnCopySource }) + renderHook(() => useCopyTool(props)) + + const copySourceTool = mockRegisterTool.mock.calls[0][0] + act(() => { + copySourceTool.onClick() + }) + + expect(mockOnCopySource).toHaveBeenCalledTimes(1) + expect(mockSetCopiedTemporarily).toHaveBeenCalledWith(true) + }) + + it('should execute copy image behavior when copy-image tool is clicked', () => { + const mockPreviewHandles = createMockPreviewHandles() + const props = createMockProps({ + showPreviewTools: true, + previewRef: { current: mockPreviewHandles } + }) + + renderHook(() => useCopyTool(props)) + + // The copy-image tool is the second one registered + const copyImageTool = mockRegisterTool.mock.calls[1][0] + act(() => { + copyImageTool.onClick() + }) + + expect(mockPreviewHandles.copy).toHaveBeenCalledTimes(1) + expect(mockSetCopiedImageTemporarily).toHaveBeenCalledWith(true) + }) + }) + + describe('cleanup', () => { + it('should remove both tools on unmount when both are registered', () => { + const props = createMockProps({ + showPreviewTools: true, + previewRef: { current: createMockPreviewHandles() } + }) + const { unmount } = renderHook(() => useCopyTool(props)) + + unmount() + + expect(mockRemoveTool).toHaveBeenCalledTimes(2) + expect(mockRemoveTool).toHaveBeenCalledWith('copy') + expect(mockRemoveTool).toHaveBeenCalledWith('copy-image') + }) + + it('should attempt to remove both tools on unmount even if only one is registered', () => { + const props = createMockProps({ showPreviewTools: false }) + const { unmount } = renderHook(() => useCopyTool(props)) + + unmount() + + // The cleanup function is static and always tries to remove both + expect(mockRemoveTool).toHaveBeenCalledTimes(2) + expect(mockRemoveTool).toHaveBeenCalledWith('copy') + expect(mockRemoveTool).toHaveBeenCalledWith('copy-image') + }) + }) + + describe('edge cases', () => { + it('should handle copy source failure gracefully', () => { + const mockOnCopySource = vi.fn().mockImplementation(() => { + throw new Error('Copy failed') + }) + const props = createMockProps({ onCopySource: mockOnCopySource }) + renderHook(() => useCopyTool(props)) + + const copySourceTool = mockRegisterTool.mock.calls[0][0] + + expect(() => { + act(() => { + copySourceTool.onClick() + }) + }).toThrow('Copy failed') + + expect(mockOnCopySource).toHaveBeenCalledTimes(1) + expect(mockSetCopiedTemporarily).toHaveBeenCalledWith(false) + }) + + it('should handle copy image failure gracefully', () => { + const mockPreviewHandles = createMockPreviewHandles() + mockPreviewHandles.copy = vi.fn().mockImplementation(() => { + throw new Error('Image copy failed') + }) + const props = createMockProps({ + showPreviewTools: true, + previewRef: { current: mockPreviewHandles } + }) + renderHook(() => useCopyTool(props)) + + const copyImageTool = mockRegisterTool.mock.calls[1][0] + + expect(() => { + act(() => { + copyImageTool.onClick() + }) + }).toThrow('Image copy failed') + + expect(mockPreviewHandles.copy).toHaveBeenCalledTimes(1) + expect(mockSetCopiedImageTemporarily).toHaveBeenCalledWith(false) + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/__tests__/useDownloadTool.test.tsx b/src/renderer/src/components/CodeToolbar/__tests__/useDownloadTool.test.tsx new file mode 100644 index 0000000000..0181dfc5fe --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/useDownloadTool.test.tsx @@ -0,0 +1,348 @@ +import { useDownloadTool } from '@renderer/components/CodeToolbar/hooks/useDownloadTool' +import { BasicPreviewHandles } from '@renderer/components/Preview' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies +const mocks = vi.hoisted(() => ({ + i18n: { + t: vi.fn((key: string) => key) + }, + useToolManager: vi.fn(), + TOOL_SPECS: { + download: { + id: 'download', + type: 'core', + order: 10 + }, + 'download-svg': { + id: 'download-svg', + type: 'quick', + order: 31 + }, + 'download-png': { + id: 'download-png', + type: 'quick', + order: 32 + } + } +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: mocks.i18n.t + }) +})) + +vi.mock('@renderer/components/Icons', () => ({ + FilePngIcon: () =>
, + FileSvgIcon: () =>
+})) + +vi.mock('@renderer/components/ActionTools', () => ({ + TOOL_SPECS: mocks.TOOL_SPECS, + useToolManager: mocks.useToolManager +})) + +// Mock useToolManager +const mockRegisterTool = vi.fn() +const mockRemoveTool = vi.fn() +mocks.useToolManager.mockImplementation(() => ({ + registerTool: mockRegisterTool, + removeTool: mockRemoveTool +})) + +describe('useDownloadTool', () => { + beforeEach(() => { + vi.clearAllMocks() + // Note: mock implementations are already set in vi.hoisted() above + }) + + // Helper function to create mock props + const createMockProps = (overrides: Partial[0]> = {}) => { + const defaultProps = { + showPreviewTools: false, + previewRef: { current: null }, + onDownloadSource: vi.fn(), + setTools: vi.fn() + } + + return { ...defaultProps, ...overrides } + } + + // Helper function to create mock preview handles + const createMockPreviewHandles = (): BasicPreviewHandles => ({ + pan: vi.fn(), + zoom: vi.fn(), + copy: vi.fn(), + download: vi.fn() + }) + + // Helper function for tool registration assertions + const expectToolRegistration = (times: number, toolConfig?: object) => { + expect(mockRegisterTool).toHaveBeenCalledTimes(times) + if (times > 0 && toolConfig) { + expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig)) + } + } + + const expectNoChildren = () => { + const registeredTool = mockRegisterTool.mock.calls[0][0] + expect(registeredTool).not.toHaveProperty('children') + } + + describe('tool registration', () => { + it('should register single download tool when showPreviewTools is false', () => { + const props = createMockProps({ showPreviewTools: false }) + renderHook(() => useDownloadTool(props)) + + expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools) + expectToolRegistration(1, { + id: 'download', + type: 'core', + order: 10, + tooltip: 'code_block.download.source', + onClick: expect.any(Function), + icon: expect.any(Object) + }) + expectNoChildren() + }) + + it('should register single download tool when showPreviewTools is true but previewRef.current is null', () => { + const props = createMockProps({ showPreviewTools: true, previewRef: { current: null } }) + renderHook(() => useDownloadTool(props)) + + expectToolRegistration(1, { + id: 'download', + type: 'core', + order: 10, + tooltip: 'code_block.download.source', // When previewRef.current is null, showPreviewTools is false + onClick: expect.any(Function), + icon: expect.any(Object) + }) + expectNoChildren() + }) + + it('should register download tool with children when showPreviewTools is true and previewRef.current is not null', () => { + const mockPreviewHandles = createMockPreviewHandles() + const props = createMockProps({ + showPreviewTools: true, + previewRef: { current: mockPreviewHandles } + }) + + renderHook(() => useDownloadTool(props)) + + expectToolRegistration(1, { + id: 'download', + type: 'core', + order: 10, + tooltip: undefined, + icon: expect.any(Object), + children: expect.arrayContaining([ + expect.objectContaining({ + id: 'download', + type: 'core', + order: 10, + tooltip: 'code_block.download.source', + onClick: expect.any(Function), + icon: expect.any(Object) + }), + expect.objectContaining({ + id: 'download-svg', + type: 'quick', + order: 31, + tooltip: 'code_block.download.svg', + onClick: expect.any(Function), + icon: expect.any(Object) + }), + expect.objectContaining({ + id: 'download-png', + type: 'quick', + order: 32, + tooltip: 'code_block.download.png', + onClick: expect.any(Function), + icon: expect.any(Object) + }) + ]) + }) + }) + }) + + describe('download functionality', () => { + it('should execute download source behavior when tool is activated', () => { + const mockOnDownloadSource = vi.fn() + const props = createMockProps({ onDownloadSource: mockOnDownloadSource }) + renderHook(() => useDownloadTool(props)) + + // Get the onClick handler from the registered tool + const registeredTool = mockRegisterTool.mock.calls[0][0] + act(() => { + registeredTool.onClick() + }) + + expect(mockOnDownloadSource).toHaveBeenCalledTimes(1) + }) + + it('should execute download SVG behavior when SVG download tool is activated', () => { + const mockPreviewHandles = createMockPreviewHandles() + const props = createMockProps({ + showPreviewTools: true, + previewRef: { current: mockPreviewHandles } + }) + + renderHook(() => useDownloadTool(props)) + + // Get the download-svg child tool + const registeredTool = mockRegisterTool.mock.calls[0][0] + const downloadSvgTool = registeredTool.children?.find((child: any) => child.tooltip === 'code_block.download.svg') + + expect(downloadSvgTool).toBeDefined() + + act(() => { + downloadSvgTool.onClick() + }) + + expect(mockPreviewHandles.download).toHaveBeenCalledTimes(1) + expect(mockPreviewHandles.download).toHaveBeenCalledWith('svg') + }) + + it('should execute download PNG behavior when PNG download tool is activated', () => { + const mockPreviewHandles = createMockPreviewHandles() + const props = createMockProps({ + showPreviewTools: true, + previewRef: { current: mockPreviewHandles } + }) + + renderHook(() => useDownloadTool(props)) + + // Get the download-png child tool + const registeredTool = mockRegisterTool.mock.calls[0][0] + const downloadPngTool = registeredTool.children?.find((child: any) => child.tooltip === 'code_block.download.png') + + expect(downloadPngTool).toBeDefined() + + act(() => { + downloadPngTool.onClick() + }) + + expect(mockPreviewHandles.download).toHaveBeenCalledTimes(1) + expect(mockPreviewHandles.download).toHaveBeenCalledWith('png') + }) + + it('should execute download source behavior from child tool', () => { + const mockOnDownloadSource = vi.fn() + const props = createMockProps({ + showPreviewTools: true, + onDownloadSource: mockOnDownloadSource, + previewRef: { current: createMockPreviewHandles() } + }) + + renderHook(() => useDownloadTool(props)) + + // Get the download source child tool + const registeredTool = mockRegisterTool.mock.calls[0][0] + const downloadSourceTool = registeredTool.children?.find( + (child: any) => child.tooltip === 'code_block.download.source' + ) + + expect(downloadSourceTool).toBeDefined() + + act(() => { + downloadSourceTool.onClick() + }) + + expect(mockOnDownloadSource).toHaveBeenCalledTimes(1) + }) + }) + + describe('cleanup', () => { + it('should remove tool on unmount', () => { + const props = createMockProps() + const { unmount } = renderHook(() => useDownloadTool(props)) + + unmount() + + expect(mockRemoveTool).toHaveBeenCalledWith('download') + }) + }) + + describe('edge cases', () => { + it('should handle missing setTools gracefully', () => { + const props = createMockProps({ setTools: undefined }) + + expect(() => { + renderHook(() => useDownloadTool(props)) + }).not.toThrow() + + // Should still call useToolManager (but won't actually register) + expect(mocks.useToolManager).toHaveBeenCalledWith(undefined) + }) + + it('should handle missing previewRef.current gracefully', () => { + const props = createMockProps({ + showPreviewTools: true, + previewRef: { current: null } + }) + + expect(() => { + renderHook(() => useDownloadTool(props)) + }).not.toThrow() + + // Should register single tool without children + expectToolRegistration(1) + const registeredTool = mockRegisterTool.mock.calls[0][0] + expect(registeredTool).not.toHaveProperty('children') + }) + + it('should handle download source operation failures gracefully', () => { + const mockOnDownloadSource = vi.fn().mockImplementation(() => { + throw new Error('Download failed') + }) + + const props = createMockProps({ onDownloadSource: mockOnDownloadSource }) + renderHook(() => useDownloadTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + + // Errors should be propagated up + expect(() => { + act(() => { + registeredTool.onClick() + }) + }).toThrow('Download failed') + + // Callback should still be called + expect(mockOnDownloadSource).toHaveBeenCalledTimes(1) + }) + + it('should handle download image operation failures gracefully', () => { + const mockPreviewHandles = createMockPreviewHandles() + mockPreviewHandles.download = vi.fn().mockImplementation(() => { + throw new Error('Image download failed') + }) + + const props = createMockProps({ + showPreviewTools: true, + previewRef: { current: mockPreviewHandles } + }) + + renderHook(() => useDownloadTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + const downloadSvgTool = registeredTool.children?.find((child: any) => child.tooltip === 'code_block.download.svg') + + expect(downloadSvgTool).toBeDefined() + + // Errors should be propagated up + expect(() => { + act(() => { + downloadSvgTool.onClick() + }) + }).toThrow('Image download failed') + + // Callback should still be called + expect(mockPreviewHandles.download).toHaveBeenCalledTimes(1) + expect(mockPreviewHandles.download).toHaveBeenCalledWith('svg') + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/__tests__/useExpandTool.test.tsx b/src/renderer/src/components/CodeToolbar/__tests__/useExpandTool.test.tsx new file mode 100644 index 0000000000..2b539f9673 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/useExpandTool.test.tsx @@ -0,0 +1,190 @@ +import { useExpandTool } from '@renderer/components/CodeToolbar/hooks/useExpandTool' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies +const mocks = vi.hoisted(() => ({ + i18n: { + t: vi.fn((key: string) => key) + }, + useToolManager: vi.fn(), + TOOL_SPECS: { + expand: { + id: 'expand', + type: 'quick', + order: 12 + } + } +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: mocks.i18n.t + }) +})) + +vi.mock('@renderer/components/ActionTools', () => ({ + TOOL_SPECS: mocks.TOOL_SPECS, + useToolManager: mocks.useToolManager +})) + +// Mock useToolManager +const mockRegisterTool = vi.fn() +const mockRemoveTool = vi.fn() +mocks.useToolManager.mockImplementation(() => ({ + registerTool: mockRegisterTool, + removeTool: mockRemoveTool +})) + +vi.mock('lucide-react', () => ({ + ChevronsDownUp: () =>
, + ChevronsUpDown: () =>
+})) + +describe('useExpandTool', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Helper function to create mock props + const createMockProps = (overrides: Partial[0]> = {}) => { + const defaultProps = { + enabled: true, + expanded: false, + expandable: true, + toggle: vi.fn(), + setTools: vi.fn() + } + + return { ...defaultProps, ...overrides } + } + + // Helper function for tool registration assertions + const expectToolRegistration = (times: number, toolConfig?: object) => { + expect(mockRegisterTool).toHaveBeenCalledTimes(times) + if (times > 0 && toolConfig) { + expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig)) + } + } + + describe('tool registration', () => { + it('should register expand tool when enabled', () => { + const props = createMockProps({ enabled: true }) + renderHook(() => useExpandTool(props)) + + expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools) + expectToolRegistration(1, { + id: 'expand', + type: 'quick', + order: 12, + tooltip: 'code_block.expand', + onClick: expect.any(Function), + visible: expect.any(Function) + }) + }) + + it('should not register tool when disabled', () => { + const props = createMockProps({ enabled: false }) + renderHook(() => useExpandTool(props)) + + expect(mockRegisterTool).not.toHaveBeenCalled() + }) + + it('should re-register tool when expanded changes', () => { + const props = createMockProps({ expanded: false }) + const { rerender } = renderHook((hookProps) => useExpandTool(hookProps), { + initialProps: props + }) + + expect(mockRegisterTool).toHaveBeenCalledTimes(1) + const firstCall = mockRegisterTool.mock.calls[0][0] + expect(firstCall.tooltip).toBe('code_block.expand') + + // Change expanded to true and rerender + const newProps = { ...props, expanded: true } + rerender(newProps) + + expect(mockRegisterTool).toHaveBeenCalledTimes(2) + const secondCall = mockRegisterTool.mock.calls[1][0] + expect(secondCall.tooltip).toBe('code_block.collapse') + }) + }) + + describe('visibility behavior', () => { + it('should be visible when expandable is true', () => { + const props = createMockProps({ expandable: true }) + renderHook(() => useExpandTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + expect(registeredTool.visible()).toBe(true) + }) + + it('should not be visible when expandable is false', () => { + const props = createMockProps({ expandable: false }) + renderHook(() => useExpandTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + expect(registeredTool.visible()).toBe(false) + }) + + it('should not be visible when expandable is undefined', () => { + const props = createMockProps({ expandable: undefined }) + renderHook(() => useExpandTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + expect(registeredTool.visible()).toBe(false) + }) + }) + + describe('toggle functionality', () => { + it('should execute toggle function when tool is clicked', () => { + const mockToggle = vi.fn() + const props = createMockProps({ toggle: mockToggle }) + renderHook(() => useExpandTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + act(() => { + registeredTool.onClick() + }) + + expect(mockToggle).toHaveBeenCalledTimes(1) + }) + }) + + describe('cleanup', () => { + it('should remove tool on unmount', () => { + const props = createMockProps() + const { unmount } = renderHook(() => useExpandTool(props)) + + unmount() + + expect(mockRemoveTool).toHaveBeenCalledWith('expand') + }) + }) + + describe('edge cases', () => { + it('should handle missing setTools gracefully', () => { + const props = createMockProps({ setTools: undefined }) + + expect(() => { + renderHook(() => useExpandTool(props)) + }).not.toThrow() + + // Should still call useToolManager (but won't actually register) + expect(mocks.useToolManager).toHaveBeenCalledWith(undefined) + }) + + it('should not break when toggle is undefined', () => { + const props = createMockProps({ toggle: undefined }) + renderHook(() => useExpandTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + + expect(() => { + act(() => { + registeredTool.onClick() + }) + }).not.toThrow() + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/__tests__/useRunTool.test.tsx b/src/renderer/src/components/CodeToolbar/__tests__/useRunTool.test.tsx new file mode 100644 index 0000000000..99b8460405 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/useRunTool.test.tsx @@ -0,0 +1,165 @@ +import { useRunTool } from '@renderer/components/CodeToolbar/hooks/useRunTool' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies +const mocks = vi.hoisted(() => ({ + i18n: { + t: vi.fn((key: string) => key) + }, + useToolManager: vi.fn(), + TOOL_SPECS: { + run: { + id: 'run', + type: 'quick', + order: 11 + } + } +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: mocks.i18n.t + }) +})) + +vi.mock('lucide-react', () => ({ + CirclePlay: () =>
CirclePlay
+})) + +vi.mock('@renderer/components/Icons', () => ({ + LoadingIcon: () =>
Loading
+})) + +vi.mock('@renderer/components/ActionTools', () => ({ + TOOL_SPECS: mocks.TOOL_SPECS, + useToolManager: mocks.useToolManager +})) + +const mockRegisterTool = vi.fn() +const mockRemoveTool = vi.fn() +mocks.useToolManager.mockImplementation(() => ({ + registerTool: mockRegisterTool, + removeTool: mockRemoveTool +})) + +describe('useRunTool', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + const createMockProps = (overrides: Partial[0]> = {}) => { + const defaultProps = { + enabled: true, + isRunning: false, + onRun: vi.fn(), + setTools: vi.fn() + } + + return { ...defaultProps, ...overrides } + } + + const expectToolRegistration = (times: number, toolConfig?: object) => { + expect(mockRegisterTool).toHaveBeenCalledTimes(times) + if (times > 0 && toolConfig) { + expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig)) + } + } + + describe('tool registration', () => { + it('should not register tool when disabled', () => { + const props = createMockProps({ enabled: false }) + renderHook(() => useRunTool(props)) + + expect(mockRegisterTool).not.toHaveBeenCalled() + }) + + it('should register run tool when enabled', () => { + const props = createMockProps({ enabled: true }) + renderHook(() => useRunTool(props)) + + expectToolRegistration(1, { + id: 'run', + type: 'quick', + order: 11, + tooltip: 'code_block.run' + }) + }) + + it('should re-register tool when isRunning changes', () => { + const props = createMockProps({ isRunning: false }) + const { rerender } = renderHook((hookProps) => useRunTool(hookProps), { + initialProps: props + }) + + expect(mockRegisterTool).toHaveBeenCalledTimes(1) + + const newProps = { ...props, isRunning: true } + rerender(newProps) + + expect(mockRegisterTool).toHaveBeenCalledTimes(2) + }) + }) + + describe('run functionality', () => { + it('should execute onRun when tool is clicked and not running', () => { + const mockOnRun = vi.fn() + const props = createMockProps({ onRun: mockOnRun, isRunning: false }) + renderHook(() => useRunTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + act(() => { + registeredTool.onClick() + }) + + expect(mockOnRun).toHaveBeenCalledTimes(1) + }) + + it('should not execute onRun when tool is clicked and already running', () => { + const mockOnRun = vi.fn() + const props = createMockProps({ onRun: mockOnRun, isRunning: true }) + renderHook(() => useRunTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + act(() => { + registeredTool.onClick() + }) + + expect(mockOnRun).not.toHaveBeenCalled() + }) + }) + + describe('cleanup', () => { + it('should remove tool on unmount', () => { + const props = createMockProps() + const { unmount } = renderHook(() => useRunTool(props)) + + unmount() + + expect(mockRemoveTool).toHaveBeenCalledWith('run') + }) + }) + + describe('edge cases', () => { + it('should handle missing setTools gracefully', () => { + const props = createMockProps({ setTools: undefined }) + + expect(() => { + renderHook(() => useRunTool(props)) + }).not.toThrow() + }) + + it('should not break when onRun is undefined', () => { + const props = createMockProps({ onRun: undefined }) + renderHook(() => useRunTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + + expect(() => { + act(() => { + registeredTool.onClick() + }) + }).not.toThrow() + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/__tests__/useSaveTool.test.tsx b/src/renderer/src/components/CodeToolbar/__tests__/useSaveTool.test.tsx new file mode 100644 index 0000000000..69a394e1d3 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/useSaveTool.test.tsx @@ -0,0 +1,193 @@ +import { useSaveTool } from '@renderer/components/CodeToolbar/hooks/useSaveTool' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies +const mocks = vi.hoisted(() => ({ + i18n: { + t: vi.fn((key: string) => key) + }, + useToolManager: vi.fn(), + useTemporaryValue: vi.fn(), + TOOL_SPECS: { + save: { + id: 'save', + type: 'core', + order: 14 + } + } +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: mocks.i18n.t + }) +})) + +vi.mock('@renderer/components/ActionTools', () => ({ + TOOL_SPECS: mocks.TOOL_SPECS, + useToolManager: mocks.useToolManager +})) + +// Mock useTemporaryValue +const mockSetTemporaryValue = vi.fn() +mocks.useTemporaryValue.mockImplementation(() => [false, mockSetTemporaryValue]) + +vi.mock('@renderer/hooks/useTemporaryValue', () => ({ + useTemporaryValue: mocks.useTemporaryValue +})) + +// Mock useToolManager +const mockRegisterTool = vi.fn() +const mockRemoveTool = vi.fn() +mocks.useToolManager.mockImplementation(() => ({ + registerTool: mockRegisterTool, + removeTool: mockRemoveTool +})) + +vi.mock('lucide-react', () => ({ + Check: () =>
, + SaveIcon: () =>
+})) + +describe('useSaveTool', () => { + beforeEach(() => { + vi.clearAllMocks() + // Reset to default values + mocks.useTemporaryValue.mockImplementation(() => [false, mockSetTemporaryValue]) + }) + + // Helper function to create mock props + const createMockProps = (overrides: Partial[0]> = {}) => { + const defaultProps = { + enabled: true, + sourceViewRef: { current: null }, + setTools: vi.fn() + } + + return { ...defaultProps, ...overrides } + } + + // Helper function for tool registration assertions + const expectToolRegistration = (times: number, toolConfig?: object) => { + expect(mockRegisterTool).toHaveBeenCalledTimes(times) + if (times > 0 && toolConfig) { + expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig)) + } + } + + describe('tool registration', () => { + it('should register save tool when enabled', () => { + const props = createMockProps({ enabled: true }) + renderHook(() => useSaveTool(props)) + + expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools) + expectToolRegistration(1, { + id: 'save', + type: 'core', + order: 14, + tooltip: 'code_block.edit.save.label', + onClick: expect.any(Function) + }) + }) + + it('should not register tool when disabled', () => { + const props = createMockProps({ enabled: false }) + renderHook(() => useSaveTool(props)) + + expect(mockRegisterTool).not.toHaveBeenCalled() + }) + + it('should re-register tool when saved state changes', () => { + // Initially not saved + mocks.useTemporaryValue.mockImplementation(() => [false, mockSetTemporaryValue]) + const props = createMockProps() + const { rerender } = renderHook(() => useSaveTool(props)) + + expect(mockRegisterTool).toHaveBeenCalledTimes(1) + + // Change to saved state and rerender + mocks.useTemporaryValue.mockImplementation(() => [true, mockSetTemporaryValue]) + rerender() + + expect(mockRegisterTool).toHaveBeenCalledTimes(2) + }) + }) + + describe('save functionality', () => { + it('should execute save behavior when tool is clicked', () => { + const mockSave = vi.fn() + const mockEditorHandles = { save: mockSave } + const props = createMockProps({ + sourceViewRef: { current: mockEditorHandles } + }) + renderHook(() => useSaveTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + act(() => { + registeredTool.onClick() + }) + + expect(mockSave).toHaveBeenCalledTimes(1) + expect(mockSetTemporaryValue).toHaveBeenCalledWith(true) + }) + + it('should handle when sourceViewRef.current is null', () => { + const props = createMockProps({ + sourceViewRef: { current: null } + }) + renderHook(() => useSaveTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + + expect(() => { + act(() => { + registeredTool.onClick() + }) + }).not.toThrow() + + expect(mockSetTemporaryValue).toHaveBeenCalledWith(true) + }) + + it('should handle when sourceViewRef.current.save is undefined', () => { + const props = createMockProps({ + sourceViewRef: { current: {} } + }) + renderHook(() => useSaveTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + + expect(() => { + act(() => { + registeredTool.onClick() + }) + }).not.toThrow() + + expect(mockSetTemporaryValue).toHaveBeenCalledWith(true) + }) + }) + + describe('cleanup', () => { + it('should remove tool on unmount', () => { + const props = createMockProps() + const { unmount } = renderHook(() => useSaveTool(props)) + + unmount() + + expect(mockRemoveTool).toHaveBeenCalledWith('save') + }) + }) + + describe('edge cases', () => { + it('should handle missing setTools gracefully', () => { + const props = createMockProps({ setTools: undefined }) + + expect(() => { + renderHook(() => useSaveTool(props)) + }).not.toThrow() + + // Should still call useToolManager (but won't actually register) + expect(mocks.useToolManager).toHaveBeenCalledWith(undefined) + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/__tests__/useSplitViewTool.test.tsx b/src/renderer/src/components/CodeToolbar/__tests__/useSplitViewTool.test.tsx new file mode 100644 index 0000000000..fbe52bbb35 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/useSplitViewTool.test.tsx @@ -0,0 +1,180 @@ +import { ViewMode } from '@renderer/components/CodeBlockView/types' +import { useSplitViewTool } from '@renderer/components/CodeToolbar/hooks/useSplitViewTool' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies +const mocks = vi.hoisted(() => ({ + i18n: { + t: vi.fn((key: string) => key) + }, + useToolManager: vi.fn(), + TOOL_SPECS: { + 'split-view': { + id: 'split-view', + type: 'quick', + order: 10 + } + } +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: mocks.i18n.t + }) +})) + +vi.mock('@renderer/components/ActionTools', () => ({ + TOOL_SPECS: mocks.TOOL_SPECS, + useToolManager: mocks.useToolManager +})) + +// Mock useToolManager +const mockRegisterTool = vi.fn() +const mockRemoveTool = vi.fn() +mocks.useToolManager.mockImplementation(() => ({ + registerTool: mockRegisterTool, + removeTool: mockRemoveTool +})) + +describe('useSplitViewTool', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Helper function to create mock props + const createMockProps = (overrides: Partial[0]> = {}) => { + const defaultProps = { + enabled: true, + viewMode: 'special' as ViewMode, + onToggleSplitView: vi.fn(), + setTools: vi.fn() + } + + return { ...defaultProps, ...overrides } + } + + // Helper function for tool registration assertions + const expectToolRegistration = (times: number, toolConfig?: object) => { + expect(mockRegisterTool).toHaveBeenCalledTimes(times) + if (times > 0 && toolConfig) { + expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig)) + } + } + + describe('tool registration', () => { + it('should not register tool when disabled', () => { + const props = createMockProps({ enabled: false }) + renderHook(() => useSplitViewTool(props)) + + expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools) + expect(mockRegisterTool).not.toHaveBeenCalled() + }) + + it('should register split view tool when enabled', () => { + const props = createMockProps({ enabled: true }) + renderHook(() => useSplitViewTool(props)) + + expectToolRegistration(1, { + id: 'split-view', + type: 'quick', + order: 10, + tooltip: 'code_block.split.label', + onClick: expect.any(Function), + icon: expect.any(Object) + }) + }) + + it('should show different tooltip when in split mode', () => { + const props = createMockProps({ viewMode: 'split' }) + renderHook(() => useSplitViewTool(props)) + + expectToolRegistration(1, { + tooltip: 'code_block.split.restore' + }) + }) + + it('should show different tooltip when not in split mode', () => { + const props = createMockProps({ viewMode: 'special' }) + renderHook(() => useSplitViewTool(props)) + + expectToolRegistration(1, { + tooltip: 'code_block.split.label' + }) + }) + + it('should re-register tool when viewMode changes', () => { + const props = createMockProps({ viewMode: 'special' }) + const { rerender } = renderHook((hookProps) => useSplitViewTool(hookProps), { + initialProps: props + }) + + expect(mockRegisterTool).toHaveBeenCalledTimes(1) + + // Change viewMode and rerender + const newProps = { ...props, viewMode: 'split' as ViewMode } + rerender(newProps) + + // Should register tool again with updated state + expect(mockRegisterTool).toHaveBeenCalledTimes(2) + + // Verify the new registration has correct tooltip + const secondRegistration = mockRegisterTool.mock.calls[1][0] + expect(secondRegistration.tooltip).toBe('code_block.split.restore') + }) + }) + + describe('view mode switching', () => { + it('should call onToggleSplitView when tool is clicked', () => { + const mockOnToggleSplitView = vi.fn() + const props = createMockProps({ + onToggleSplitView: mockOnToggleSplitView + }) + renderHook(() => useSplitViewTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + act(() => { + registeredTool.onClick() + }) + + expect(mockOnToggleSplitView).toHaveBeenCalledTimes(1) + }) + }) + + describe('cleanup', () => { + it('should remove tool on unmount', () => { + const props = createMockProps() + const { unmount } = renderHook(() => useSplitViewTool(props)) + + unmount() + + expect(mockRemoveTool).toHaveBeenCalledWith('split-view') + }) + }) + + describe('edge cases', () => { + it('should handle missing setTools gracefully', () => { + const props = createMockProps({ setTools: undefined }) + + expect(() => { + renderHook(() => useSplitViewTool(props)) + }).not.toThrow() + + // Should still call useToolManager (but won't actually register) + expect(mocks.useToolManager).toHaveBeenCalledWith(undefined) + }) + + it('should not break when onToggleSplitView is undefined', () => { + const props = createMockProps({ onToggleSplitView: undefined }) + renderHook(() => useSplitViewTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + + expect(() => { + act(() => { + registeredTool.onClick() + }) + }).not.toThrow() + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/__tests__/useViewSourceTool.test.tsx b/src/renderer/src/components/CodeToolbar/__tests__/useViewSourceTool.test.tsx new file mode 100644 index 0000000000..9bac34c57a --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/useViewSourceTool.test.tsx @@ -0,0 +1,226 @@ +import { ViewMode } from '@renderer/components/CodeBlockView/types' +import { useViewSourceTool } from '@renderer/components/CodeToolbar/hooks/useViewSourceTool' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies +const mocks = vi.hoisted(() => ({ + i18n: { + t: vi.fn((key: string) => key) + }, + useToolManager: vi.fn(), + TOOL_SPECS: { + edit: { + id: 'edit', + type: 'core', + order: 12 + }, + 'view-source': { + id: 'view-source', + type: 'core', + order: 12 + } + } +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: mocks.i18n.t + }) +})) + +vi.mock('@renderer/components/ActionTools', () => ({ + TOOL_SPECS: mocks.TOOL_SPECS, + useToolManager: mocks.useToolManager +})) + +const mockRegisterTool = vi.fn() +const mockRemoveTool = vi.fn() +mocks.useToolManager.mockImplementation(() => ({ + registerTool: mockRegisterTool, + removeTool: mockRemoveTool +})) + +describe('useViewSourceTool', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + const createMockProps = (overrides: Partial[0]> = {}) => { + const defaultProps = { + enabled: true, + editable: false, + viewMode: 'special' as ViewMode, + onViewModeChange: vi.fn(), + setTools: vi.fn() + } + + return { ...defaultProps, ...overrides } + } + + const expectToolRegistration = (times: number, toolConfig?: object) => { + expect(mockRegisterTool).toHaveBeenCalledTimes(times) + if (times > 0 && toolConfig) { + expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig)) + } + } + + describe('tool registration', () => { + it('should not register tool when disabled', () => { + const props = createMockProps({ enabled: false }) + renderHook(() => useViewSourceTool(props)) + + expect(mockRegisterTool).not.toHaveBeenCalled() + }) + + it('should not register tool when in split mode', () => { + const props = createMockProps({ viewMode: 'split' }) + renderHook(() => useViewSourceTool(props)) + + expect(mockRegisterTool).not.toHaveBeenCalled() + }) + + it('should register view-source tool when not editable', () => { + const props = createMockProps({ editable: false }) + renderHook(() => useViewSourceTool(props)) + + expectToolRegistration(1, { + id: 'view-source', + type: 'core', + order: 12 + }) + }) + + it('should register edit tool when editable', () => { + const props = createMockProps({ editable: true }) + renderHook(() => useViewSourceTool(props)) + + expectToolRegistration(1, { + id: 'edit', + type: 'core', + order: 12 + }) + }) + + it('should re-register tool when editable changes', () => { + const props = createMockProps({ editable: false }) + const { rerender } = renderHook((hookProps) => useViewSourceTool(hookProps), { + initialProps: props + }) + + expect(mockRegisterTool).toHaveBeenCalledTimes(1) + + const newProps = { ...props, editable: true } + rerender(newProps) + + expect(mockRegisterTool).toHaveBeenCalledTimes(2) + expect(mockRemoveTool).toHaveBeenCalledWith('view-source') + }) + }) + + describe('tooltip variations', () => { + it('should show correct tooltips for edit mode', () => { + const props = createMockProps({ editable: true, viewMode: 'source' }) + renderHook(() => useViewSourceTool(props)) + + expectToolRegistration(1, { + tooltip: 'preview.label' + }) + + vi.clearAllMocks() + + const propsSpecial = createMockProps({ editable: true, viewMode: 'special' }) + renderHook(() => useViewSourceTool(propsSpecial)) + + expectToolRegistration(1, { + tooltip: 'code_block.edit.label' + }) + }) + + it('should show correct tooltips for view-source mode', () => { + const props = createMockProps({ editable: false, viewMode: 'source' }) + renderHook(() => useViewSourceTool(props)) + + expectToolRegistration(1, { + tooltip: 'preview.label' + }) + + vi.clearAllMocks() + + const propsSpecial = createMockProps({ editable: false, viewMode: 'special' }) + renderHook(() => useViewSourceTool(propsSpecial)) + + expectToolRegistration(1, { + tooltip: 'preview.source' + }) + }) + }) + + describe('view mode switching', () => { + it('should switch from special to source when tool is clicked', () => { + const mockOnViewModeChange = vi.fn() + const props = createMockProps({ + viewMode: 'special', + onViewModeChange: mockOnViewModeChange + }) + renderHook(() => useViewSourceTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + act(() => { + registeredTool.onClick() + }) + + expect(mockOnViewModeChange).toHaveBeenCalledWith('source') + }) + + it('should switch from source to special when tool is clicked', () => { + const mockOnViewModeChange = vi.fn() + const props = createMockProps({ + viewMode: 'source', + onViewModeChange: mockOnViewModeChange + }) + renderHook(() => useViewSourceTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + act(() => { + registeredTool.onClick() + }) + + expect(mockOnViewModeChange).toHaveBeenCalledWith('special') + }) + }) + + describe('cleanup', () => { + it('should remove tool on unmount', () => { + const props = createMockProps() + const { unmount } = renderHook(() => useViewSourceTool(props)) + + unmount() + + expect(mockRemoveTool).toHaveBeenCalledWith('view-source') + }) + }) + + describe('edge cases', () => { + it('should handle missing setTools gracefully', () => { + const props = createMockProps({ setTools: undefined }) + + expect(() => { + renderHook(() => useViewSourceTool(props)) + }).not.toThrow() + }) + + it('should not break when onViewModeChange is undefined', () => { + const props = createMockProps({ onViewModeChange: undefined }) + renderHook(() => useViewSourceTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + + expect(() => { + act(() => { + registeredTool.onClick() + }) + }).not.toThrow() + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/__tests__/useWrapTool.test.tsx b/src/renderer/src/components/CodeToolbar/__tests__/useWrapTool.test.tsx new file mode 100644 index 0000000000..ca601cd37f --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/__tests__/useWrapTool.test.tsx @@ -0,0 +1,190 @@ +import { useWrapTool } from '@renderer/components/CodeToolbar/hooks/useWrapTool' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies +const mocks = vi.hoisted(() => ({ + i18n: { + t: vi.fn((key: string) => key) + }, + useToolManager: vi.fn(), + TOOL_SPECS: { + wrap: { + id: 'wrap', + type: 'quick', + order: 13 + } + } +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: mocks.i18n.t + }) +})) + +vi.mock('@renderer/components/ActionTools', () => ({ + TOOL_SPECS: mocks.TOOL_SPECS, + useToolManager: mocks.useToolManager +})) + +// Mock useToolManager +const mockRegisterTool = vi.fn() +const mockRemoveTool = vi.fn() +mocks.useToolManager.mockImplementation(() => ({ + registerTool: mockRegisterTool, + removeTool: mockRemoveTool +})) + +vi.mock('lucide-react', () => ({ + Text: () =>
, + WrapText: () =>
+})) + +describe('useWrapTool', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Helper function to create mock props + const createMockProps = (overrides: Partial[0]> = {}) => { + const defaultProps = { + enabled: true, + unwrapped: false, + wrappable: true, + toggle: vi.fn(), + setTools: vi.fn() + } + + return { ...defaultProps, ...overrides } + } + + // Helper function for tool registration assertions + const expectToolRegistration = (times: number, toolConfig?: object) => { + expect(mockRegisterTool).toHaveBeenCalledTimes(times) + if (times > 0 && toolConfig) { + expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig)) + } + } + + describe('tool registration', () => { + it('should register wrap tool when enabled', () => { + const props = createMockProps({ enabled: true }) + renderHook(() => useWrapTool(props)) + + expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools) + expectToolRegistration(1, { + id: 'wrap', + type: 'quick', + order: 13, + tooltip: 'code_block.wrap.off', + onClick: expect.any(Function), + visible: expect.any(Function) + }) + }) + + it('should not register tool when disabled', () => { + const props = createMockProps({ enabled: false }) + renderHook(() => useWrapTool(props)) + + expect(mockRegisterTool).not.toHaveBeenCalled() + }) + + it('should re-register tool when unwrapped changes', () => { + const props = createMockProps({ unwrapped: false }) + const { rerender } = renderHook((hookProps) => useWrapTool(hookProps), { + initialProps: props + }) + + expect(mockRegisterTool).toHaveBeenCalledTimes(1) + const firstCall = mockRegisterTool.mock.calls[0][0] + expect(firstCall.tooltip).toBe('code_block.wrap.off') + + // Change unwrapped to true and rerender + const newProps = { ...props, unwrapped: true } + rerender(newProps) + + expect(mockRegisterTool).toHaveBeenCalledTimes(2) + const secondCall = mockRegisterTool.mock.calls[1][0] + expect(secondCall.tooltip).toBe('code_block.wrap.on') + }) + }) + + describe('visibility behavior', () => { + it('should be visible when wrappable is true', () => { + const props = createMockProps({ wrappable: true }) + renderHook(() => useWrapTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + expect(registeredTool.visible()).toBe(true) + }) + + it('should not be visible when wrappable is false', () => { + const props = createMockProps({ wrappable: false }) + renderHook(() => useWrapTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + expect(registeredTool.visible()).toBe(false) + }) + + it('should not be visible when wrappable is undefined', () => { + const props = createMockProps({ wrappable: undefined }) + renderHook(() => useWrapTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + expect(registeredTool.visible()).toBe(false) + }) + }) + + describe('toggle functionality', () => { + it('should execute toggle function when tool is clicked', () => { + const mockToggle = vi.fn() + const props = createMockProps({ toggle: mockToggle }) + renderHook(() => useWrapTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + act(() => { + registeredTool.onClick() + }) + + expect(mockToggle).toHaveBeenCalledTimes(1) + }) + }) + + describe('cleanup', () => { + it('should remove tool on unmount', () => { + const props = createMockProps() + const { unmount } = renderHook(() => useWrapTool(props)) + + unmount() + + expect(mockRemoveTool).toHaveBeenCalledWith('wrap') + }) + }) + + describe('edge cases', () => { + it('should handle missing setTools gracefully', () => { + const props = createMockProps({ setTools: undefined }) + + expect(() => { + renderHook(() => useWrapTool(props)) + }).not.toThrow() + + // Should still call useToolManager (but won't actually register) + expect(mocks.useToolManager).toHaveBeenCalledWith(undefined) + }) + + it('should not break when toggle is undefined', () => { + const props = createMockProps({ toggle: undefined }) + renderHook(() => useWrapTool(props)) + + const registeredTool = mockRegisterTool.mock.calls[0][0] + + expect(() => { + act(() => { + registeredTool.onClick() + }) + }).not.toThrow() + }) + }) +}) diff --git a/src/renderer/src/components/CodeToolbar/button.tsx b/src/renderer/src/components/CodeToolbar/button.tsx new file mode 100644 index 0000000000..1488752726 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/button.tsx @@ -0,0 +1,41 @@ +import { ActionTool } from '@renderer/components/ActionTools' +import { Dropdown, Tooltip } from 'antd' +import { memo, useMemo } from 'react' + +import { ToolWrapper } from './styles' + +interface CodeToolButtonProps { + tool: ActionTool +} + +const CodeToolButton = ({ tool }: CodeToolButtonProps) => { + const mainTool = useMemo( + () => ( + + {tool.icon} + + ), + [tool] + ) + + if (tool.children?.length && tool.children.length > 0) { + return ( + ({ + key: child.id, + label: child.tooltip, + icon: child.icon, + onClick: child.onClick + })) + }} + trigger={['click']}> + {mainTool} + + ) + } + + return mainTool +} + +export default memo(CodeToolButton) diff --git a/src/renderer/src/components/CodeToolbar/hooks/index.ts b/src/renderer/src/components/CodeToolbar/hooks/index.ts new file mode 100644 index 0000000000..bd35bf2681 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/hooks/index.ts @@ -0,0 +1,8 @@ +export * from './useCopyTool' +export * from './useDownloadTool' +export * from './useExpandTool' +export * from './useRunTool' +export * from './useSaveTool' +export * from './useSplitViewTool' +export * from './useViewSourceTool' +export * from './useWrapTool' diff --git a/src/renderer/src/components/CodeToolbar/hooks/useCopyTool.tsx b/src/renderer/src/components/CodeToolbar/hooks/useCopyTool.tsx new file mode 100644 index 0000000000..ea928df4fd --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/hooks/useCopyTool.tsx @@ -0,0 +1,89 @@ +import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools' +import { CopyIcon } from '@renderer/components/Icons' +import { BasicPreviewHandles } from '@renderer/components/Preview' +import { useTemporaryValue } from '@renderer/hooks/useTemporaryValue' +import { Check, Image } from 'lucide-react' +import { useCallback, useEffect } from 'react' +import { useTranslation } from 'react-i18next' + +interface UseCopyToolProps { + showPreviewTools?: boolean + previewRef: React.RefObject + onCopySource: () => void + setTools: React.Dispatch> +} + +export const useCopyTool = ({ showPreviewTools, previewRef, onCopySource, setTools }: UseCopyToolProps) => { + const [copied, setCopiedTemporarily] = useTemporaryValue(false) + const [copiedImage, setCopiedImageTemporarily] = useTemporaryValue(false) + const { t } = useTranslation() + const { registerTool, removeTool } = useToolManager(setTools) + + const handleCopySource = useCallback(() => { + try { + onCopySource() + setCopiedTemporarily(true) + } catch (error) { + setCopiedTemporarily(false) + throw error + } + }, [onCopySource, setCopiedTemporarily]) + + const handleCopyImage = useCallback(() => { + try { + previewRef.current?.copy() + setCopiedImageTemporarily(true) + } catch (error) { + setCopiedImageTemporarily(false) + throw error + } + }, [previewRef, setCopiedImageTemporarily]) + + useEffect(() => { + const includePreviewTools = showPreviewTools && previewRef.current !== null + + const baseTool = { + ...TOOL_SPECS.copy, + icon: copied ? ( + + ) : ( + + ), + tooltip: t('code_block.copy.source'), + onClick: handleCopySource + } + + const copyImageTool = { + ...TOOL_SPECS['copy-image'], + icon: copiedImage ? ( + + ) : ( + + ), + tooltip: t('preview.copy.image'), + onClick: handleCopyImage + } + + registerTool(baseTool) + + if (includePreviewTools) { + registerTool(copyImageTool) + } + + return () => { + removeTool(TOOL_SPECS.copy.id) + removeTool(TOOL_SPECS['copy-image'].id) + } + }, [ + onCopySource, + registerTool, + removeTool, + t, + copied, + copiedImage, + handleCopySource, + handleCopyImage, + showPreviewTools, + previewRef + ]) +} diff --git a/src/renderer/src/components/CodeToolbar/hooks/useDownloadTool.tsx b/src/renderer/src/components/CodeToolbar/hooks/useDownloadTool.tsx new file mode 100644 index 0000000000..397c51c921 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/hooks/useDownloadTool.tsx @@ -0,0 +1,61 @@ +import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools' +import { FilePngIcon, FileSvgIcon } from '@renderer/components/Icons' +import { BasicPreviewHandles } from '@renderer/components/Preview' +import { Download, FileCode } from 'lucide-react' +import { useEffect } from 'react' +import { useTranslation } from 'react-i18next' + +interface UseDownloadToolProps { + showPreviewTools?: boolean + previewRef: React.RefObject + onDownloadSource: () => void + setTools: React.Dispatch> +} + +export const useDownloadTool = ({ showPreviewTools, previewRef, onDownloadSource, setTools }: UseDownloadToolProps) => { + const { t } = useTranslation() + const { registerTool, removeTool } = useToolManager(setTools) + + useEffect(() => { + const includePreviewTools = showPreviewTools && previewRef.current !== null + + const baseTool = { + ...TOOL_SPECS.download, + icon: , + tooltip: includePreviewTools ? undefined : t('code_block.download.source') + } + + if (includePreviewTools) { + registerTool({ + ...baseTool, + children: [ + { + ...TOOL_SPECS.download, + icon: , + tooltip: t('code_block.download.source'), + onClick: onDownloadSource + }, + { + ...TOOL_SPECS['download-svg'], + icon: , + tooltip: t('code_block.download.svg'), + onClick: () => previewRef.current?.download('svg') + }, + { + ...TOOL_SPECS['download-png'], + icon: , + tooltip: t('code_block.download.png'), + onClick: () => previewRef.current?.download('png') + } + ] + }) + } else { + registerTool({ + ...baseTool, + onClick: onDownloadSource + }) + } + + return () => removeTool(TOOL_SPECS.download.id) + }, [onDownloadSource, registerTool, removeTool, t, showPreviewTools, previewRef]) +} diff --git a/src/renderer/src/components/CodeToolbar/hooks/useExpandTool.tsx b/src/renderer/src/components/CodeToolbar/hooks/useExpandTool.tsx new file mode 100644 index 0000000000..6428a9c543 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/hooks/useExpandTool.tsx @@ -0,0 +1,35 @@ +import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools' +import { ChevronsDownUp, ChevronsUpDown } from 'lucide-react' +import { useCallback, useEffect } from 'react' +import { useTranslation } from 'react-i18next' + +interface UseExpandToolProps { + enabled?: boolean + expanded?: boolean + expandable?: boolean + toggle: () => void + setTools: React.Dispatch> +} + +export const useExpandTool = ({ enabled, expanded, expandable, toggle, setTools }: UseExpandToolProps) => { + const { t } = useTranslation() + const { registerTool, removeTool } = useToolManager(setTools) + + const handleToggle = useCallback(() => { + toggle?.() + }, [toggle]) + + useEffect(() => { + if (enabled) { + registerTool({ + ...TOOL_SPECS.expand, + icon: expanded ? : , + tooltip: expanded ? t('code_block.collapse') : t('code_block.expand'), + visible: () => expandable ?? false, + onClick: handleToggle + }) + } + + return () => removeTool(TOOL_SPECS.expand.id) + }, [enabled, expandable, expanded, handleToggle, registerTool, removeTool, t]) +} diff --git a/src/renderer/src/components/CodeToolbar/hooks/useRunTool.tsx b/src/renderer/src/components/CodeToolbar/hooks/useRunTool.tsx new file mode 100644 index 0000000000..4c46681a4d --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/hooks/useRunTool.tsx @@ -0,0 +1,30 @@ +import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools' +import { LoadingIcon } from '@renderer/components/Icons' +import { CirclePlay } from 'lucide-react' +import { useEffect } from 'react' +import { useTranslation } from 'react-i18next' + +interface UseRunToolProps { + enabled: boolean + isRunning: boolean + onRun: () => void + setTools: React.Dispatch> +} + +export const useRunTool = ({ enabled, isRunning, onRun, setTools }: UseRunToolProps) => { + const { t } = useTranslation() + const { registerTool, removeTool } = useToolManager(setTools) + + useEffect(() => { + if (!enabled) return + + registerTool({ + ...TOOL_SPECS.run, + icon: isRunning ? : , + tooltip: t('code_block.run'), + onClick: () => !isRunning && onRun?.() + }) + + return () => removeTool(TOOL_SPECS.run.id) + }, [enabled, isRunning, onRun, registerTool, removeTool, t]) +} diff --git a/src/renderer/src/components/CodeToolbar/hooks/useSaveTool.tsx b/src/renderer/src/components/CodeToolbar/hooks/useSaveTool.tsx new file mode 100644 index 0000000000..c847b6ca90 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/hooks/useSaveTool.tsx @@ -0,0 +1,40 @@ +import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools' +import { CodeEditorHandles } from '@renderer/components/CodeEditor' +import { useTemporaryValue } from '@renderer/hooks/useTemporaryValue' +import { Check, SaveIcon } from 'lucide-react' +import { useCallback, useEffect } from 'react' +import { useTranslation } from 'react-i18next' + +interface UseSaveToolProps { + enabled?: boolean + sourceViewRef: React.RefObject + setTools: React.Dispatch> +} + +export const useSaveTool = ({ enabled, sourceViewRef, setTools }: UseSaveToolProps) => { + const [saved, setSavedTemporarily] = useTemporaryValue(false) + const { t } = useTranslation() + const { registerTool, removeTool } = useToolManager(setTools) + + const handleSave = useCallback(() => { + sourceViewRef.current?.save?.() + setSavedTemporarily(true) + }, [sourceViewRef, setSavedTemporarily]) + + useEffect(() => { + if (enabled) { + registerTool({ + ...TOOL_SPECS.save, + icon: saved ? ( + + ) : ( + + ), + tooltip: t('code_block.edit.save.label'), + onClick: handleSave + }) + } + + return () => removeTool(TOOL_SPECS.save.id) + }, [enabled, handleSave, registerTool, removeTool, saved, t]) +} diff --git a/src/renderer/src/components/CodeToolbar/hooks/useSplitViewTool.tsx b/src/renderer/src/components/CodeToolbar/hooks/useSplitViewTool.tsx new file mode 100644 index 0000000000..63367d692f --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/hooks/useSplitViewTool.tsx @@ -0,0 +1,34 @@ +import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools' +import { ViewMode } from '@renderer/components/CodeBlockView/types' +import { Square, SquareSplitHorizontal } from 'lucide-react' +import { useCallback, useEffect } from 'react' +import { useTranslation } from 'react-i18next' + +interface UseSplitViewToolProps { + enabled: boolean + viewMode: ViewMode + onToggleSplitView: () => void + setTools: React.Dispatch> +} + +export const useSplitViewTool = ({ enabled, viewMode, onToggleSplitView, setTools }: UseSplitViewToolProps) => { + const { t } = useTranslation() + const { registerTool, removeTool } = useToolManager(setTools) + + const handleToggleSplitView = useCallback(() => { + onToggleSplitView?.() + }, [onToggleSplitView]) + + useEffect(() => { + if (!enabled) return + + registerTool({ + ...TOOL_SPECS['split-view'], + icon: viewMode === 'split' ? : , + tooltip: viewMode === 'split' ? t('code_block.split.restore') : t('code_block.split.label'), + onClick: handleToggleSplitView + }) + + return () => removeTool(TOOL_SPECS['split-view'].id) + }, [enabled, viewMode, registerTool, removeTool, t, handleToggleSplitView]) +} diff --git a/src/renderer/src/components/CodeToolbar/hooks/useViewSourceTool.tsx b/src/renderer/src/components/CodeToolbar/hooks/useViewSourceTool.tsx new file mode 100644 index 0000000000..a3a6da0152 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/hooks/useViewSourceTool.tsx @@ -0,0 +1,53 @@ +import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools' +import { ViewMode } from '@renderer/components/CodeBlockView/types' +import { CodeXml, Eye, SquarePen } from 'lucide-react' +import { useCallback, useEffect } from 'react' +import { useTranslation } from 'react-i18next' + +interface UseViewSourceToolProps { + enabled: boolean + editable: boolean + viewMode: ViewMode + onViewModeChange: (mode: ViewMode) => void + setTools: React.Dispatch> +} + +export const useViewSourceTool = ({ + enabled, + editable, + viewMode, + onViewModeChange, + setTools +}: UseViewSourceToolProps) => { + const { t } = useTranslation() + const { registerTool, removeTool } = useToolManager(setTools) + + const handleToggleViewMode = useCallback(() => { + const newMode = viewMode === 'source' ? 'special' : 'source' + onViewModeChange?.(newMode) + }, [viewMode, onViewModeChange]) + + useEffect(() => { + if (!enabled || viewMode === 'split') return + + const toolSpec = editable ? TOOL_SPECS.edit : TOOL_SPECS['view-source'] + + if (editable) { + registerTool({ + ...toolSpec, + icon: viewMode === 'source' ? : , + tooltip: viewMode === 'source' ? t('preview.label') : t('code_block.edit.label'), + onClick: handleToggleViewMode + }) + } else { + registerTool({ + ...toolSpec, + icon: viewMode === 'source' ? : , + tooltip: viewMode === 'source' ? t('preview.label') : t('preview.source'), + onClick: handleToggleViewMode + }) + } + + return () => removeTool(toolSpec.id) + }, [enabled, editable, viewMode, registerTool, removeTool, t, handleToggleViewMode]) +} diff --git a/src/renderer/src/components/CodeToolbar/hooks/useWrapTool.tsx b/src/renderer/src/components/CodeToolbar/hooks/useWrapTool.tsx new file mode 100644 index 0000000000..c0354e78fd --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/hooks/useWrapTool.tsx @@ -0,0 +1,35 @@ +import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools' +import { Text as UnWrapIcon, WrapText as WrapIcon } from 'lucide-react' +import { useCallback, useEffect } from 'react' +import { useTranslation } from 'react-i18next' + +interface UseWrapToolProps { + enabled?: boolean + unwrapped?: boolean + wrappable?: boolean + toggle: () => void + setTools: React.Dispatch> +} + +export const useWrapTool = ({ enabled, unwrapped, wrappable, toggle, setTools }: UseWrapToolProps) => { + const { t } = useTranslation() + const { registerTool, removeTool } = useToolManager(setTools) + + const handleToggle = useCallback(() => { + toggle?.() + }, [toggle]) + + useEffect(() => { + if (enabled) { + registerTool({ + ...TOOL_SPECS.wrap, + icon: unwrapped ? : , + tooltip: unwrapped ? t('code_block.wrap.on') : t('code_block.wrap.off'), + visible: () => wrappable ?? false, + onClick: handleToggle + }) + } + + return () => removeTool(TOOL_SPECS.wrap.id) + }, [enabled, handleToggle, registerTool, removeTool, t, unwrapped, wrappable]) +} diff --git a/src/renderer/src/components/CodeToolbar/index.ts b/src/renderer/src/components/CodeToolbar/index.ts index 96434b97e9..f672a4fa42 100644 --- a/src/renderer/src/components/CodeToolbar/index.ts +++ b/src/renderer/src/components/CodeToolbar/index.ts @@ -1,5 +1,3 @@ -export * from './constants' -export * from './hook' -export * from './toolbar' -export * from './types' -export * from './usePreviewTools' +export { default as CodeToolButton } from './button' +export * from './hooks' +export { default as CodeToolbar } from './toolbar' diff --git a/src/renderer/src/components/CodeToolbar/styles.ts b/src/renderer/src/components/CodeToolbar/styles.ts new file mode 100644 index 0000000000..8db4211c80 --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/styles.ts @@ -0,0 +1,35 @@ +import styled from 'styled-components' + +export const ToolWrapper = styled.div` + display: flex; + align-items: center; + justify-content: center; + width: 24px; + height: 24px; + border-radius: 4px; + cursor: pointer; + user-select: none; + transition: all 0.2s ease; + color: var(--color-text-3); + + &:hover { + background-color: var(--color-background-soft); + .tool-icon { + color: var(--color-text-1); + } + } + + &.active { + color: var(--color-primary); + .tool-icon { + color: var(--color-primary); + } + } + + /* For Lucide icons */ + .tool-icon { + width: 14px; + height: 14px; + color: var(--color-text-3); + } +` diff --git a/src/renderer/src/components/CodeToolbar/toolbar.tsx b/src/renderer/src/components/CodeToolbar/toolbar.tsx index cd615afcb4..7b17a6f0e8 100644 --- a/src/renderer/src/components/CodeToolbar/toolbar.tsx +++ b/src/renderer/src/components/CodeToolbar/toolbar.tsx @@ -1,25 +1,15 @@ +import { ActionTool } from '@renderer/components/ActionTools' import { HStack } from '@renderer/components/Layout' import { Tooltip } from 'antd' import { EllipsisVertical } from 'lucide-react' -import React, { memo, useMemo, useState } from 'react' +import { memo, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import styled from 'styled-components' -import { CodeTool } from './types' +import CodeToolButton from './button' +import { ToolWrapper } from './styles' -interface CodeToolButtonProps { - tool: CodeTool -} - -const CodeToolButton: React.FC = memo(({ tool }) => { - return ( - - tool.onClick()}>{tool.icon} - - ) -}) - -export const CodeToolbar: React.FC<{ tools: CodeTool[] }> = memo(({ tools }) => { +const CodeToolbar = ({ tools }: { tools: ActionTool[] }) => { const [showQuickTools, setShowQuickTools] = useState(false) const { t } = useTranslation() @@ -51,7 +41,7 @@ export const CodeToolbar: React.FC<{ tools: CodeTool[] }> = memo(({ tools }) => {quickTools.length > 1 && ( setShowQuickTools(!showQuickTools)} className={showQuickTools ? 'active' : ''}> - + )} @@ -63,7 +53,7 @@ export const CodeToolbar: React.FC<{ tools: CodeTool[] }> = memo(({ tools }) => ) -}) +} const StickyWrapper = styled.div` position: sticky; @@ -80,36 +70,4 @@ const ToolbarWrapper = styled(HStack)` gap: 4px; ` -const ToolWrapper = styled.div` - display: flex; - align-items: center; - justify-content: center; - width: 24px; - height: 24px; - border-radius: 4px; - cursor: pointer; - user-select: none; - transition: all 0.2s ease; - color: var(--color-text-3); - - &:hover { - background-color: var(--color-background-soft); - .icon { - color: var(--color-text-1); - } - } - - &.active { - color: var(--color-primary); - .icon { - color: var(--color-primary); - } - } - - /* For Lucide icons */ - .icon { - width: 14px; - height: 14px; - color: var(--color-text-3); - } -` +export default memo(CodeToolbar) diff --git a/src/renderer/src/components/CodeToolbar/types.ts b/src/renderer/src/components/CodeToolbar/types.ts deleted file mode 100644 index d1181650fe..0000000000 --- a/src/renderer/src/components/CodeToolbar/types.ts +++ /dev/null @@ -1,25 +0,0 @@ -/** - * 代码块工具基本信息 - */ -export interface CodeToolSpec { - id: string - type: 'core' | 'quick' - order: number -} - -/** - * 代码块工具定义接口 - * @param id 唯一标识符 - * @param type 工具类型 - * @param icon 按钮图标 - * @param tooltip 提示文本 - * @param condition 显示条件 - * @param onClick 点击动作 - * @param order 显示顺序,越小越靠右 - */ -export interface CodeTool extends CodeToolSpec { - icon: React.ReactNode - tooltip: string - visible?: () => boolean - onClick: () => void -} diff --git a/src/renderer/src/components/CodeToolbar/usePreviewTools.tsx b/src/renderer/src/components/CodeToolbar/usePreviewTools.tsx deleted file mode 100644 index 8914862e43..0000000000 --- a/src/renderer/src/components/CodeToolbar/usePreviewTools.tsx +++ /dev/null @@ -1,363 +0,0 @@ -import { loggerService } from '@logger' -import { download } from '@renderer/utils/download' -import { FileImage, ZoomIn, ZoomOut } from 'lucide-react' -import { RefObject, useCallback, useEffect, useRef, useState } from 'react' -import { useTranslation } from 'react-i18next' - -import { DownloadPngIcon, DownloadSvgIcon } from '../Icons/DownloadIcons' -import { TOOL_SPECS } from './constants' -import { useCodeTool } from './hook' -import { CodeTool } from './types' - -const logger = loggerService.withContext('usePreviewToolHandlers') - -// 预编译正则表达式用于查询位置 -const TRANSFORM_REGEX = /translate\((-?\d+\.?\d*)px,\s*(-?\d+\.?\d*)px\)/ - -/** - * 使用图像处理工具的自定义Hook - * 提供图像缩放、复制和下载功能 - */ -export const usePreviewToolHandlers = ( - containerRef: RefObject, - options: { - prefix: string - imgSelector: string - enableWheelZoom?: boolean - customDownloader?: (format: 'svg' | 'png') => void - } -) => { - const transformRef = useRef({ scale: 1, x: 0, y: 0 }) // 管理变换状态 - const [renderTrigger, setRenderTrigger] = useState(0) // 仅用于触发组件重渲染的状态 - const { imgSelector, prefix, customDownloader, enableWheelZoom } = options - const { t } = useTranslation() - - // 创建选择器函数 - const getImgElement = useCallback(() => { - if (!containerRef.current) return null - - // 优先尝试从 Shadow DOM 中查找 - const shadowRoot = containerRef.current.shadowRoot - if (shadowRoot) { - return shadowRoot.querySelector(imgSelector) as SVGElement | null - } - - // 降级到常规 DOM 查找 - return containerRef.current.querySelector(imgSelector) as SVGElement | null - }, [containerRef, imgSelector]) - - // 查询当前位置 - const getCurrentPosition = useCallback(() => { - const imgElement = getImgElement() - if (!imgElement) return { x: transformRef.current.x, y: transformRef.current.y } - - const transform = imgElement.style.transform - if (!transform || transform === 'none') return { x: transformRef.current.x, y: transformRef.current.y } - - const match = transform.match(TRANSFORM_REGEX) - if (match && match.length >= 3) { - return { - x: parseFloat(match[1]), - y: parseFloat(match[2]) - } - } - - return { x: transformRef.current.x, y: transformRef.current.y } - }, [getImgElement]) - - // 平移缩放变换 - const applyTransform = useCallback((element: SVGElement | null, x: number, y: number, scale: number) => { - if (!element) return - element.style.transformOrigin = 'top left' - element.style.transform = `translate(${x}px, ${y}px) scale(${scale})` - }, []) - - // 拖拽平移支持 - useEffect(() => { - const container = containerRef.current - if (!container) return - - let isDragging = false - const startPos = { x: 0, y: 0 } - const startOffset = { x: 0, y: 0 } - - const onMouseDown = (e: MouseEvent) => { - if (e.button !== 0) return // 只响应左键 - - // 更新当前实际位置 - const position = getCurrentPosition() - transformRef.current.x = position.x - transformRef.current.y = position.y - - isDragging = true - startPos.x = e.clientX - startPos.y = e.clientY - startOffset.x = position.x - startOffset.y = position.y - - container.style.cursor = 'grabbing' - e.preventDefault() - } - - const onMouseMove = (e: MouseEvent) => { - if (!isDragging) return - - const dx = e.clientX - startPos.x - const dy = e.clientY - startPos.y - const newX = startOffset.x + dx - const newY = startOffset.y + dy - - const imgElement = getImgElement() - applyTransform(imgElement, newX, newY, transformRef.current.scale) - - e.preventDefault() - } - - const stopDrag = () => { - if (!isDragging) return - - // 更新位置但不立即触发状态变更 - const position = getCurrentPosition() - transformRef.current.x = position.x - transformRef.current.y = position.y - - // 只触发一次渲染以保持组件状态同步 - setRenderTrigger((prev) => prev + 1) - - isDragging = false - container.style.cursor = 'default' - } - - // 绑定到document以确保拖拽可以在鼠标离开容器后继续 - container.addEventListener('mousedown', onMouseDown) - document.addEventListener('mousemove', onMouseMove) - document.addEventListener('mouseup', stopDrag) - - return () => { - container.removeEventListener('mousedown', onMouseDown) - document.removeEventListener('mousemove', onMouseMove) - document.removeEventListener('mouseup', stopDrag) - } - }, [containerRef, getCurrentPosition, getImgElement, applyTransform]) - - // 缩放处理函数 - const handleZoom = useCallback( - (delta: number) => { - const newScale = Math.max(0.1, Math.min(3, transformRef.current.scale + delta)) - transformRef.current.scale = newScale - - const imgElement = getImgElement() - applyTransform(imgElement, transformRef.current.x, transformRef.current.y, newScale) - - // 触发重渲染以保持组件状态同步 - setRenderTrigger((prev) => prev + 1) - }, - [getImgElement, applyTransform] - ) - - // 滚轮缩放支持 - useEffect(() => { - if (!enableWheelZoom || !containerRef.current) return - - const container = containerRef.current - - const handleWheel = (e: WheelEvent) => { - if ((e.ctrlKey || e.metaKey) && e.target) { - // 确认事件发生在容器内部 - if (container.contains(e.target as Node)) { - const delta = e.deltaY < 0 ? 0.1 : -0.1 - handleZoom(delta) - } - } - } - - container.addEventListener('wheel', handleWheel, { passive: true }) - return () => container.removeEventListener('wheel', handleWheel) - }, [containerRef, handleZoom, enableWheelZoom]) - - // 复制图像处理函数 - const handleCopyImage = useCallback(async () => { - try { - const imgElement = getImgElement() - if (!imgElement) return - - const canvas = document.createElement('canvas') - const ctx = canvas.getContext('2d') - const img = new Image() - img.crossOrigin = 'anonymous' - - const viewBox = imgElement.getAttribute('viewBox')?.split(' ').map(Number) || [] - const width = viewBox[2] || imgElement.clientWidth || imgElement.getBoundingClientRect().width - const height = viewBox[3] || imgElement.clientHeight || imgElement.getBoundingClientRect().height - - const svgData = new XMLSerializer().serializeToString(imgElement) - const svgBase64 = `data:image/svg+xml;base64,${btoa(unescape(encodeURIComponent(svgData)))}` - - img.onload = async () => { - const scale = 3 - canvas.width = width * scale - canvas.height = height * scale - - if (ctx) { - ctx.scale(scale, scale) - ctx.drawImage(img, 0, 0, width, height) - const blob = await new Promise((resolve) => canvas.toBlob((b) => resolve(b!), 'image/png')) - await navigator.clipboard.write([new ClipboardItem({ 'image/png': blob })]) - window.message.success(t('message.copy.success')) - } - } - img.src = svgBase64 - } catch (error) { - logger.error('Copy failed:', error as Error) - window.message.error(t('message.copy.failed')) - } - }, [getImgElement, t]) - - // 下载处理函数 - const handleDownload = useCallback( - (format: 'svg' | 'png') => { - // 如果有自定义下载器,使用自定义实现 - if (customDownloader) { - customDownloader(format) - return - } - - try { - const imgElement = getImgElement() - if (!imgElement) return - - const timestamp = Date.now() - - if (format === 'svg') { - const svgData = new XMLSerializer().serializeToString(imgElement) - const blob = new Blob([svgData], { type: 'image/svg+xml' }) - const url = URL.createObjectURL(blob) - download(url, `${prefix}-${timestamp}.svg`) - URL.revokeObjectURL(url) - } else if (format === 'png') { - const canvas = document.createElement('canvas') - const ctx = canvas.getContext('2d') - const img = new Image() - img.crossOrigin = 'anonymous' - - const viewBox = imgElement.getAttribute('viewBox')?.split(' ').map(Number) || [] - const width = viewBox[2] || imgElement.clientWidth || imgElement.getBoundingClientRect().width - const height = viewBox[3] || imgElement.clientHeight || imgElement.getBoundingClientRect().height - - const svgData = new XMLSerializer().serializeToString(imgElement) - const svgBase64 = `data:image/svg+xml;base64,${btoa(unescape(encodeURIComponent(svgData)))}` - - img.onload = () => { - const scale = 3 - canvas.width = width * scale - canvas.height = height * scale - - if (ctx) { - ctx.scale(scale, scale) - ctx.drawImage(img, 0, 0, width, height) - } - - canvas.toBlob((blob) => { - if (blob) { - const pngUrl = URL.createObjectURL(blob) - download(pngUrl, `${prefix}-${timestamp}.png`) - URL.revokeObjectURL(pngUrl) - } - }, 'image/png') - } - img.src = svgBase64 - } - } catch (error) { - logger.error('Download failed:', error as Error) - } - }, - [getImgElement, prefix, customDownloader] - ) - - return { - scale: transformRef.current.scale, - handleZoom, - handleCopyImage, - handleDownload, - renderTrigger // 导出渲染触发器,万一要用 - } -} - -export interface PreviewToolsOptions { - setTools?: (value: React.SetStateAction) => void - handleZoom?: (delta: number) => void - handleCopyImage?: () => Promise - handleDownload?: (format: 'svg' | 'png') => void -} - -/** - * 提供预览组件通用工具栏功能的自定义Hook - */ -export const usePreviewTools = ({ setTools, handleZoom, handleCopyImage, handleDownload }: PreviewToolsOptions) => { - const { t } = useTranslation() - const { registerTool, removeTool } = useCodeTool(setTools) - - useEffect(() => { - // 根据提供的功能有选择性地注册工具 - if (handleZoom) { - // 放大工具 - registerTool({ - ...TOOL_SPECS['zoom-in'], - icon: , - tooltip: t('code_block.preview.zoom_in'), - onClick: () => handleZoom(0.1) - }) - - // 缩小工具 - registerTool({ - ...TOOL_SPECS['zoom-out'], - icon: , - tooltip: t('code_block.preview.zoom_out'), - onClick: () => handleZoom(-0.1) - }) - } - - if (handleCopyImage) { - // 复制图片工具 - registerTool({ - ...TOOL_SPECS['copy-image'], - icon: , - tooltip: t('code_block.preview.copy.image'), - onClick: handleCopyImage - }) - } - - if (handleDownload) { - // 下载 SVG 工具 - registerTool({ - ...TOOL_SPECS['download-svg'], - icon: , - tooltip: t('code_block.download.svg'), - onClick: () => handleDownload('svg') - }) - - // 下载 PNG 工具 - registerTool({ - ...TOOL_SPECS['download-png'], - icon: , - tooltip: t('code_block.download.png'), - onClick: () => handleDownload('png') - }) - } - - // 清理函数 - return () => { - if (handleZoom) { - removeTool(TOOL_SPECS['zoom-in'].id) - removeTool(TOOL_SPECS['zoom-out'].id) - } - if (handleCopyImage) { - removeTool(TOOL_SPECS['copy-image'].id) - } - if (handleDownload) { - removeTool(TOOL_SPECS['download-svg'].id) - removeTool(TOOL_SPECS['download-png'].id) - } - } - }, [handleCopyImage, handleDownload, handleZoom, registerTool, removeTool, t]) -} diff --git a/src/renderer/src/components/CodeBlockView/CodePreview.tsx b/src/renderer/src/components/CodeViewer.tsx similarity index 69% rename from src/renderer/src/components/CodeBlockView/CodePreview.tsx rename to src/renderer/src/components/CodeViewer.tsx index 9e08dab5ae..c73063e73d 100644 --- a/src/renderer/src/components/CodeBlockView/CodePreview.tsx +++ b/src/renderer/src/components/CodeViewer.tsx @@ -1,4 +1,4 @@ -import { TOOL_SPECS, useCodeTool } from '@renderer/components/CodeToolbar' +import { MAX_COLLAPSED_CODE_HEIGHT } from '@renderer/config/constant' import { useCodeStyle } from '@renderer/context/CodeStyleProvider' import { useCodeHighlight } from '@renderer/hooks/useCodeHighlight' import { useSettings } from '@renderer/hooks/useSettings' @@ -6,82 +6,34 @@ import { uuid } from '@renderer/utils' import { getReactStyleFromToken } from '@renderer/utils/shiki' import { useVirtualizer } from '@tanstack/react-virtual' import { debounce } from 'lodash' -import { ChevronsDownUp, ChevronsUpDown, Text as UnWrapIcon, WrapText as WrapIcon } from 'lucide-react' -import React, { memo, useCallback, useEffect, useLayoutEffect, useMemo, useRef, useState } from 'react' -import { useTranslation } from 'react-i18next' +import React, { memo, useCallback, useEffect, useLayoutEffect, useMemo, useRef } from 'react' import { ThemedToken } from 'shiki/core' import styled from 'styled-components' -import { BasicPreviewProps } from './types' - -interface CodePreviewProps extends BasicPreviewProps { +interface CodeViewerProps { language: string + children: string + expanded?: boolean + unwrapped?: boolean + onHeightChange?: (scrollHeight: number) => void + className?: string } -const MAX_COLLAPSE_HEIGHT = 350 - /** * Shiki 流式代码高亮组件 * - 通过 shiki tokenizer 处理流式响应,高性能 * - 使用虚拟滚动和按需高亮,改善页面内有大量长代码块时的响应 * - 并发安全 */ -const CodePreview = ({ children, language, setTools }: CodePreviewProps) => { - const { codeShowLineNumbers, fontSize, codeCollapsible, codeWrappable } = useSettings() +const CodeViewer = ({ children, language, expanded, unwrapped, onHeightChange, className }: CodeViewerProps) => { + const { codeShowLineNumbers, fontSize } = useSettings() const { getShikiPreProperties, isShikiThemeDark } = useCodeStyle() - const [expandOverride, setExpandOverride] = useState(!codeCollapsible) - const [unwrapOverride, setUnwrapOverride] = useState(!codeWrappable) const shikiThemeRef = useRef(null) const scrollerRef = useRef(null) const callerId = useRef(`${Date.now()}-${uuid()}`).current const rawLines = useMemo(() => (typeof children === 'string' ? children.trimEnd().split('\n') : []), [children]) - const { t } = useTranslation() - const { registerTool, removeTool } = useCodeTool(setTools) - - // 展开/折叠工具 - useEffect(() => { - registerTool({ - ...TOOL_SPECS.expand, - icon: expandOverride ? : , - tooltip: expandOverride ? t('code_block.collapse') : t('code_block.expand'), - visible: () => { - const scrollHeight = scrollerRef.current?.scrollHeight - return codeCollapsible && (scrollHeight ?? 0) > MAX_COLLAPSE_HEIGHT - }, - onClick: () => setExpandOverride((prev) => !prev) - }) - - return () => removeTool(TOOL_SPECS.expand.id) - }, [codeCollapsible, expandOverride, registerTool, removeTool, t]) - - // 自动换行工具 - useEffect(() => { - registerTool({ - ...TOOL_SPECS.wrap, - icon: unwrapOverride ? : , - tooltip: unwrapOverride ? t('code_block.wrap.on') : t('code_block.wrap.off'), - visible: () => codeWrappable, - onClick: () => setUnwrapOverride((prev) => !prev) - }) - - return () => removeTool(TOOL_SPECS.wrap.id) - }, [codeWrappable, unwrapOverride, registerTool, removeTool, t]) - - // 重置用户操作(可以考虑移除,保持用户操作结果) - useEffect(() => { - setExpandOverride(!codeCollapsible) - }, [codeCollapsible]) - - // 重置用户操作(可以考虑移除,保持用户操作结果) - useEffect(() => { - setUnwrapOverride(!codeWrappable) - }, [codeWrappable]) - - const shouldCollapse = useMemo(() => codeCollapsible && !expandOverride, [codeCollapsible, expandOverride]) - const shouldWrap = useMemo(() => codeWrappable && !unwrapOverride, [codeWrappable, unwrapOverride]) - // 计算行号数字位数 const gutterDigits = useMemo( () => (codeShowLineNumbers ? Math.max(rawLines.length.toString().length, 1) : 0), @@ -90,10 +42,12 @@ const CodePreview = ({ children, language, setTools }: CodePreviewProps) => { // 设置 pre 标签属性 useLayoutEffect(() => { + let mounted = true getShikiPreProperties(language).then((properties) => { + if (!mounted) return const shikiTheme = shikiThemeRef.current if (shikiTheme) { - shikiTheme.className = `${properties.class || 'shiki'}` + shikiTheme.className = `${properties.class || 'shiki'} code-viewer ${className ?? ''}` // 滚动条适应 shiki 主题变化而非应用主题 shikiTheme.classList.add(isShikiThemeDark ? 'shiki-dark' : 'shiki-light') @@ -103,7 +57,10 @@ const CodePreview = ({ children, language, setTools }: CodePreviewProps) => { shikiTheme.tabIndex = properties.tabindex } }) - }, [language, getShikiPreProperties, isShikiThemeDark]) + return () => { + mounted = false + } + }, [language, getShikiPreProperties, isShikiThemeDark, className]) // Virtualizer 配置 const getScrollElement = useCallback(() => scrollerRef.current, []) @@ -140,19 +97,25 @@ const CodePreview = ({ children, language, setTools }: CodePreviewProps) => { } }, [virtualItems, debouncedHighlightLines]) + // Report scrollHeight when it might change + useLayoutEffect(() => { + onHeightChange?.(scrollerRef.current?.scrollHeight ?? 0) + }, [rawLines.length, onHeightChange]) + return (
{ width: '100%', transform: `translateY(${virtualItems[0]?.start ?? 0}px)` }}> - {virtualizer.getVirtualItems().map((virtualItem) => ( + {virtualItems.map((virtualItem) => (
{ ) } -CodePreview.displayName = 'CodePreview' +CodeViewer.displayName = 'CodeViewer' const plainTokenStyle = { color: 'inherit', @@ -259,20 +222,24 @@ VirtualizedRow.displayName = 'VirtualizedRow' const ScrollContainer = styled.div<{ $wrap?: boolean + $expanded?: boolean $lineHeight?: number }>` display: block; overflow-x: auto; position: relative; border-radius: inherit; - padding: 0.5em 1em; + /* padding right 下沉到 line-content 中 */ + padding: 0.5em 0 0.5em 1em; .line { display: flex; align-items: flex-start; width: 100%; line-height: ${(props) => props.$lineHeight}px; - contain: content; + /* contain 优化 wrap 时滚动性能,will-change 优化 unwrap 时滚动性能 */ + contain: ${(props) => (props.$wrap ? 'content' : 'none')}; + will-change: ${(props) => (!props.$wrap && !props.$expanded ? 'transform' : 'auto')}; .line-number { width: var(--gutter-width, 1.2ch); @@ -288,6 +255,7 @@ const ScrollContainer = styled.div<{ .line-content { flex: 1; + padding-right: 1em; * { white-space: ${(props) => (props.$wrap ? 'pre-wrap' : 'pre')}; overflow-wrap: ${(props) => (props.$wrap ? 'break-word' : 'normal')}; @@ -296,4 +264,4 @@ const ScrollContainer = styled.div<{ } ` -export default memo(CodePreview) +export default memo(CodeViewer) diff --git a/src/renderer/src/components/CustomCollapse.tsx b/src/renderer/src/components/CustomCollapse.tsx index d41a9ffd60..8362d8a479 100644 --- a/src/renderer/src/components/CustomCollapse.tsx +++ b/src/renderer/src/components/CustomCollapse.tsx @@ -78,7 +78,7 @@ const CustomCollapse: FC = ({ style={collapseStyle} defaultActiveKey={defaultActiveKey} activeKey={activeKey} - destroyInactivePanel={destroyInactivePanel} + destroyOnHidden={destroyInactivePanel} collapsible={collapsible} onChange={(keys) => { setActiveKeys(keys) diff --git a/src/renderer/src/components/__tests__/DraggableList.test.tsx b/src/renderer/src/components/DraggableList/__tests__/DraggableList.test.tsx similarity index 99% rename from src/renderer/src/components/__tests__/DraggableList.test.tsx rename to src/renderer/src/components/DraggableList/__tests__/DraggableList.test.tsx index 4878fd4838..a570f58bcf 100644 --- a/src/renderer/src/components/__tests__/DraggableList.test.tsx +++ b/src/renderer/src/components/DraggableList/__tests__/DraggableList.test.tsx @@ -3,7 +3,7 @@ import { render, screen } from '@testing-library/react' import { describe, expect, it, vi } from 'vitest' -import { DraggableList } from '../DraggableList' +import { DraggableList } from '../' // mock @hello-pangea/dnd 组件 vi.mock('@hello-pangea/dnd', () => { diff --git a/src/renderer/src/components/__tests__/DraggableVirtualList.test.tsx b/src/renderer/src/components/DraggableList/__tests__/DraggableVirtualList.test.tsx similarity index 98% rename from src/renderer/src/components/__tests__/DraggableVirtualList.test.tsx rename to src/renderer/src/components/DraggableList/__tests__/DraggableVirtualList.test.tsx index b82181ef42..74a7a414ee 100644 --- a/src/renderer/src/components/__tests__/DraggableVirtualList.test.tsx +++ b/src/renderer/src/components/DraggableList/__tests__/DraggableVirtualList.test.tsx @@ -3,7 +3,7 @@ import { render, screen } from '@testing-library/react' import { describe, expect, it, vi } from 'vitest' -import DraggableVirtualList from '../DraggableList/virtual-list' +import { DraggableVirtualList } from '../' // Mock 依赖项 vi.mock('@hello-pangea/dnd', () => ({ diff --git a/src/renderer/src/components/__tests__/__snapshots__/DraggableList.test.tsx.snap b/src/renderer/src/components/DraggableList/__tests__/__snapshots__/DraggableList.test.tsx.snap similarity index 100% rename from src/renderer/src/components/__tests__/__snapshots__/DraggableList.test.tsx.snap rename to src/renderer/src/components/DraggableList/__tests__/__snapshots__/DraggableList.test.tsx.snap diff --git a/src/renderer/src/components/__tests__/__snapshots__/DraggableVirtualList.test.tsx.snap b/src/renderer/src/components/DraggableList/__tests__/__snapshots__/DraggableVirtualList.test.tsx.snap similarity index 100% rename from src/renderer/src/components/__tests__/__snapshots__/DraggableVirtualList.test.tsx.snap rename to src/renderer/src/components/DraggableList/__tests__/__snapshots__/DraggableVirtualList.test.tsx.snap diff --git a/src/renderer/src/components/DraggableList/__tests__/useDraggableReorder.test.ts b/src/renderer/src/components/DraggableList/__tests__/useDraggableReorder.test.ts new file mode 100644 index 0000000000..f2d2fe837f --- /dev/null +++ b/src/renderer/src/components/DraggableList/__tests__/useDraggableReorder.test.ts @@ -0,0 +1,151 @@ +import { DropResult } from '@hello-pangea/dnd' +import { act, renderHook } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' + +import { useDraggableReorder } from '../useDraggableReorder' + +// 辅助函数和模拟数据 +const createMockItem = (id: number) => ({ id: `item-${id}`, name: `Item ${id}` }) +const mockOriginalList = [createMockItem(1), createMockItem(2), createMockItem(3), createMockItem(4), createMockItem(5)] + +/** + * 创建一个符合 DropResult 类型的模拟对象。 + * @param sourceIndex - 拖拽源的视图索引 + * @param destIndex - 拖拽目标的视图索引 + * @param draggableId - 被拖拽项的唯一 ID,应与其 itemKey 对应 + */ +const createMockDropResult = (sourceIndex: number, destIndex: number | null, draggableId: string): DropResult => ({ + reason: 'DROP', + source: { index: sourceIndex, droppableId: 'droppable' }, + destination: destIndex !== null ? { index: destIndex, droppableId: 'droppable' } : null, + combine: null, + mode: 'FLUID', + draggableId, + type: 'DEFAULT' +}) + +describe('useDraggableReorder', () => { + describe('reorder', () => { + it('should correctly reorder the list when it is not filtered', () => { + const onUpdate = vi.fn() + const { result } = renderHook(() => + useDraggableReorder({ + originalList: mockOriginalList, + filteredList: mockOriginalList, // 列表未过滤 + onUpdate, + idKey: 'id' + }) + ) + + // 模拟将第一项 (视图索引 0, 原始索引 0) 拖到第三项的位置 (视图索引 2) + // 在未过滤列表中,itemKey(0) 返回 0 + const dropResult = createMockDropResult(0, 2, '0') + + act(() => { + result.current.onDragEnd(dropResult) + }) + + expect(onUpdate).toHaveBeenCalledTimes(1) + const newList = onUpdate.mock.calls[0][0] + // 原始: [1, 2, 3, 4, 5] -> 拖拽后预期: [2, 3, 1, 4, 5] + expect(newList.map((i) => i.id)).toEqual(['item-2', 'item-3', 'item-1', 'item-4', 'item-5']) + }) + + it('should correctly reorder the original list when the list is filtered', () => { + const onUpdate = vi.fn() + // 过滤后只剩下奇数项: [item-1, item-3, item-5] + const filteredList = [mockOriginalList[0], mockOriginalList[2], mockOriginalList[4]] + + const { result } = renderHook(() => + useDraggableReorder({ + originalList: mockOriginalList, + filteredList, + onUpdate, + idKey: 'id' + }) + ) + + // 在过滤后的列表中,将最后一项 'item-5' (视图索引 2) 拖到第一项 'item-1' (视图索引 0) 的位置 + // 'item-5' 的原始索引是 4, 所以 itemKey(2) 返回 4 + const dropResult = createMockDropResult(2, 0, '4') + + act(() => { + result.current.onDragEnd(dropResult) + }) + + expect(onUpdate).toHaveBeenCalledTimes(1) + const newList = onUpdate.mock.calls[0][0] + // 原始: [1, 2, 3, 4, 5] + // 拖拽后预期: 'item-5' 移动到 'item-1' 的位置 -> [5, 1, 2, 3, 4] + expect(newList.map((i) => i.id)).toEqual(['item-5', 'item-1', 'item-2', 'item-3', 'item-4']) + }) + }) + + describe('onUpdate', () => { + it('should not call onUpdate if destination is null', () => { + const onUpdate = vi.fn() + const { result } = renderHook(() => + useDraggableReorder({ + originalList: mockOriginalList, + filteredList: mockOriginalList, + onUpdate, + idKey: 'id' + }) + ) + + // 模拟拖拽到列表外 + const dropResult = createMockDropResult(0, null, '0') + + act(() => { + result.current.onDragEnd(dropResult) + }) + + expect(onUpdate).not.toHaveBeenCalled() + }) + + it('should not call onUpdate if source and destination are the same', () => { + const onUpdate = vi.fn() + const { result } = renderHook(() => + useDraggableReorder({ + originalList: mockOriginalList, + filteredList: mockOriginalList, + onUpdate, + idKey: 'id' + }) + ) + + // 模拟拖拽后放回原位 + const dropResult = createMockDropResult(1, 1, '1') + + act(() => { + result.current.onDragEnd(dropResult) + }) + + expect(onUpdate).not.toHaveBeenCalled() + }) + }) + + describe('itemKey', () => { + it('should return the correct original index from a filtered list index', () => { + const onUpdate = vi.fn() + // 过滤后只剩下奇数项: [item-1, item-3, item-5] + const filteredList = [mockOriginalList[0], mockOriginalList[2], mockOriginalList[4]] + + const { result } = renderHook(() => + useDraggableReorder({ + originalList: mockOriginalList, + filteredList, + onUpdate, + idKey: 'id' + }) + ) + + // 视图索引 0 -> 'item-1' -> 原始索引 0 + expect(result.current.itemKey(0)).toBe(0) + // 视图索引 1 -> 'item-3' -> 原始索引 2 + expect(result.current.itemKey(1)).toBe(2) + // 视图索引 2 -> 'item-5' -> 原始索引 4 + expect(result.current.itemKey(2)).toBe(4) + }) + }) +}) diff --git a/src/renderer/src/components/DraggableList/index.tsx b/src/renderer/src/components/DraggableList/index.tsx index de98dd00d5..642b12bfd7 100644 --- a/src/renderer/src/components/DraggableList/index.tsx +++ b/src/renderer/src/components/DraggableList/index.tsx @@ -1,2 +1,3 @@ export { default as DraggableList } from './list' +export { useDraggableReorder } from './useDraggableReorder' export { default as DraggableVirtualList } from './virtual-list' diff --git a/src/renderer/src/components/DraggableList/useDraggableReorder.ts b/src/renderer/src/components/DraggableList/useDraggableReorder.ts new file mode 100644 index 0000000000..59a04788a6 --- /dev/null +++ b/src/renderer/src/components/DraggableList/useDraggableReorder.ts @@ -0,0 +1,70 @@ +import { DropResult } from '@hello-pangea/dnd' +import { Key, useCallback, useMemo } from 'react' + +interface UseDraggableReorderParams { + /** 原始的、完整的数据列表 */ + originalList: T[] + /** 当前在界面上渲染的、可能被过滤的列表 */ + filteredList: T[] + /** 用于更新原始列表状态的函数 */ + onUpdate: (newList: T[]) => void + /** 用于从列表项中获取唯一ID的属性名或函数 */ + idKey: keyof T | ((item: T) => Key) +} + +/** + * 增强拖拽排序能力,处理“过滤后列表”与“原始列表”的索引映射问题。 + * + * @template T 列表项的类型 + * @param params - { originalList, filteredList, onUpdate, idKey } + * @returns 返回可以直接传递给 DraggableVirtualList 的 props: { onDragEnd, itemKey } + */ +export function useDraggableReorder({ originalList, filteredList, onUpdate, idKey }: UseDraggableReorderParams) { + const getId = useCallback((item: T) => (typeof idKey === 'function' ? idKey(item) : (item[idKey] as Key)), [idKey]) + + // 创建从 item ID 到其在 *原始列表* 中索引的映射 + const itemIndexMap = useMemo(() => { + const map = new Map() + originalList.forEach((item, index) => { + map.set(getId(item), index) + }) + return map + }, [originalList, getId]) + + // 创建一个函数,将 *过滤后列表* 的视图索引转换为 *原始列表* 的数据索引 + const getItemKey = useCallback( + (index: number): Key => { + const item = filteredList[index] + // 如果找不到item,返回视图索引兜底 + if (!item) return index + + const originalIndex = itemIndexMap.get(getId(item)) + return originalIndex ?? index + }, + [filteredList, itemIndexMap, getId] + ) + + // 创建 onDragEnd 回调,封装了所有重排逻辑 + const onDragEnd = useCallback( + (result: DropResult) => { + if (!result.destination) return + + // 使用 getItemKey 将视图索引转换为数据索引 + const sourceOriginalIndex = getItemKey(result.source.index) as number + const destOriginalIndex = getItemKey(result.destination.index) as number + + if (sourceOriginalIndex === destOriginalIndex) return + + // 操作原始列表的副本 + const newList = [...originalList] + const [movedItem] = newList.splice(sourceOriginalIndex, 1) + newList.splice(destOriginalIndex, 0, movedItem) + + // 调用外部更新函数 + onUpdate(newList) + }, + [originalList, onUpdate, getItemKey] + ) + + return { onDragEnd, itemKey: getItemKey } +} diff --git a/src/renderer/src/components/DraggableList/virtual-list.tsx b/src/renderer/src/components/DraggableList/virtual-list.tsx index c8d868f1ba..b8020aa051 100644 --- a/src/renderer/src/components/DraggableList/virtual-list.tsx +++ b/src/renderer/src/components/DraggableList/virtual-list.tsx @@ -22,7 +22,7 @@ import { type Key, memo, useCallback, useRef } from 'react' * @property {React.CSSProperties} [itemStyle] 元素内容区域的附加样式 * @property {React.CSSProperties} [itemContainerStyle] 元素拖拽容器的附加样式 * @property {Partial} [droppableProps] 透传给 Droppable 的额外配置 - * @property {(list: T[]) => void} onUpdate 拖拽排序完成后的回调,返回新的列表顺序 + * @property {(list: T[]) => void} [onUpdate] 拖拽排序完成后的回调,返回新的列表顺序(可被 useDraggableReorder 替代) * @property {OnDragStartResponder} [onDragStart] 开始拖拽时的回调 * @property {OnDragEndResponder} [onDragEnd] 结束拖拽时的回调 * @property {T[]} list 渲染的数据源 @@ -39,7 +39,7 @@ interface DraggableVirtualListProps { itemStyle?: React.CSSProperties itemContainerStyle?: React.CSSProperties droppableProps?: Partial - onUpdate: (list: T[]) => void + onUpdate?: (list: T[]) => void onDragStart?: OnDragStartResponder onDragEnd?: OnDragEndResponder list: T[] @@ -48,6 +48,7 @@ interface DraggableVirtualListProps { overscan?: number header?: React.ReactNode children: (item: T, index: number) => React.ReactNode + disabled?: boolean } /** @@ -73,11 +74,12 @@ function DraggableVirtualList({ estimateSize: _estimateSize, overscan = 5, header, - children + children, + disabled }: DraggableVirtualListProps): React.ReactElement { const _onDragEnd = (result: DropResult, provided: ResponderProvided) => { onDragEnd?.(result, provided) - if (result.destination) { + if (onUpdate && result.destination) { const sourceIndex = result.source.index const destIndex = result.destination.index const reorderAgents = droppableReorder(list, sourceIndex, destIndex) @@ -157,6 +159,7 @@ function DraggableVirtualList({ itemContainerStyle={itemContainerStyle} virtualizer={virtualizer} children={children} + disabled={disabled} /> ))}
@@ -172,53 +175,56 @@ function DraggableVirtualList({ /** * 渲染单个可拖拽的虚拟列表项,高度为动态测量 */ -const VirtualRow = memo(({ virtualItem, list, children, itemStyle, itemContainerStyle, virtualizer }: any) => { - const item = list[virtualItem.index] - const draggableId = String(virtualItem.key) - return ( - - {(provided) => { - const setDragRefs = (el: HTMLElement | null) => { - provided.innerRef(el) - virtualizer.measureElement(el) - } +const VirtualRow = memo( + ({ virtualItem, list, children, itemStyle, itemContainerStyle, virtualizer, disabled }: any) => { + const item = list[virtualItem.index] + const draggableId = String(virtualItem.key) + return ( + + {(provided) => { + const setDragRefs = (el: HTMLElement | null) => { + provided.innerRef(el) + virtualizer.measureElement(el) + } - const dndStyle = provided.draggableProps.style - const virtualizerTransform = `translateY(${virtualItem.start}px)` + const dndStyle = provided.draggableProps.style + const virtualizerTransform = `translateY(${virtualItem.start}px)` - // dnd 的 transform 负责拖拽时的位移和让位动画, - // virtualizer 的 translateY 负责将项定位到虚拟列表的正确位置, - // 它们拼接起来可以同时实现拖拽视觉效果和虚拟化定位。 - const combinedTransform = dndStyle?.transform - ? `${dndStyle.transform} ${virtualizerTransform}` - : virtualizerTransform + // dnd 的 transform 负责拖拽时的位移和让位动画, + // virtualizer 的 translateY 负责将项定位到虚拟列表的正确位置, + // 它们拼接起来可以同时实现拖拽视觉效果和虚拟化定位。 + const combinedTransform = dndStyle?.transform + ? `${dndStyle.transform} ${virtualizerTransform}` + : virtualizerTransform - return ( -
-
- {item && children(item, virtualItem.index)} + return ( +
+
+ {item && children(item, virtualItem.index)} +
-
- ) - }} - - ) -}) + ) + }} + + ) + } +) export default DraggableVirtualList diff --git a/src/renderer/src/components/Icons/DownloadIcons.tsx b/src/renderer/src/components/Icons/DownloadIcons.tsx deleted file mode 100644 index 55c6f00f1a..0000000000 --- a/src/renderer/src/components/Icons/DownloadIcons.tsx +++ /dev/null @@ -1,68 +0,0 @@ -import { SVGProps } from 'react' - -// 基础下载图标 -export const DownloadIcon = (props: SVGProps) => ( - - - - - -) - -// 带有文件类型的下载图标基础组件 -const DownloadTypeIconBase = ({ type, ...props }: SVGProps & { type: string }) => ( - - - {type} - - - - - -) - -// JPG 文件下载图标 -export const DownloadJpgIcon = (props: SVGProps) => - -// PNG 文件下载图标 -export const DownloadPngIcon = (props: SVGProps) => - -// SVG 文件下载图标 -export const DownloadSvgIcon = (props: SVGProps) => diff --git a/src/renderer/src/components/Icons/FileIcons.tsx b/src/renderer/src/components/Icons/FileIcons.tsx new file mode 100644 index 0000000000..386c823ef4 --- /dev/null +++ b/src/renderer/src/components/Icons/FileIcons.tsx @@ -0,0 +1,70 @@ +import { CSSProperties, SVGProps } from 'react' + +interface BaseFileIconProps extends SVGProps { + size?: string + text?: string +} + +const textStyle: CSSProperties = { + fontStyle: 'italic', + fontSize: '7.70985px', + lineHeight: 0.8, + fontFamily: "'Times New Roman'", + textAlign: 'center', + writingMode: 'horizontal-tb', + direction: 'ltr', + textAnchor: 'middle', + fill: 'none', + stroke: '#000000', + strokeWidth: '0.289119', + strokeLinejoin: 'round', + strokeDasharray: 'none' +} + +const tspanStyle: CSSProperties = { + fontStyle: 'normal', + fontVariant: 'normal', + fontWeight: 'normal', + fontStretch: 'condensed', + fontSize: '7.70985px', + lineHeight: 0.8, + fontFamily: 'Arial', + fill: '#000000', + fillOpacity: 1, + strokeWidth: '0.289119', + strokeDasharray: 'none' +} + +const BaseFileIcon = ({ size = '1.1em', text = 'SVG', ...props }: BaseFileIconProps) => ( + + + + + + + {text} + + + +) + +export const FileSvgIcon = (props: Omit) => +export const FilePngIcon = (props: Omit) => diff --git a/src/renderer/src/components/Icons/index.ts b/src/renderer/src/components/Icons/index.ts index cc6f4c2b60..94714e73cc 100644 --- a/src/renderer/src/components/Icons/index.ts +++ b/src/renderer/src/components/Icons/index.ts @@ -1,8 +1,8 @@ export { default as CopyIcon } from './CopyIcon' export { default as DeleteIcon } from './DeleteIcon' -export * from './DownloadIcons' export { default as EditIcon } from './EditIcon' export { default as FallbackFavicon } from './FallbackFavicon' +export * from './FileIcons' export { default as MinAppIcon } from './MinAppIcon' export * from './NutstoreIcons' export { default as OcrIcon } from './OcrIcon' diff --git a/src/renderer/src/components/ImageViewer.tsx b/src/renderer/src/components/ImageViewer.tsx index ddb28a4d52..bdb891a074 100644 --- a/src/renderer/src/components/ImageViewer.tsx +++ b/src/renderer/src/components/ImageViewer.tsx @@ -86,7 +86,7 @@ const ImageViewer: React.FC = ({ src, style, ...props }) => { }, { key: 'copy-image', - label: t('code_block.preview.copy.image'), + label: t('preview.copy.image'), icon: , onClick: () => handleCopyImage(src) } @@ -101,6 +101,7 @@ const ImageViewer: React.FC = ({ src, style, ...props }) => { {...props} preview={{ mask: typeof props.preview === 'object' ? props.preview.mask : false, + ...(typeof props.preview === 'object' ? props.preview : {}), toolbarRender: ( _, { diff --git a/src/renderer/src/components/ModelIdWithTags.tsx b/src/renderer/src/components/ModelIdWithTags.tsx index 4b8ca86123..bf902ae1c4 100644 --- a/src/renderer/src/components/ModelIdWithTags.tsx +++ b/src/renderer/src/components/ModelIdWithTags.tsx @@ -26,7 +26,7 @@ const ModelIdWithTags = ({ maxWidth: '500px' } }} - destroyTooltipOnHide + destroyOnHidden title={ {model.id} diff --git a/src/renderer/src/components/ModelTagsWithLabel.tsx b/src/renderer/src/components/ModelTagsWithLabel.tsx index 86a04dd454..3da6ccfc8d 100644 --- a/src/renderer/src/components/ModelTagsWithLabel.tsx +++ b/src/renderer/src/components/ModelTagsWithLabel.tsx @@ -1,4 +1,3 @@ -import { EyeOutlined, GlobalOutlined, ToolOutlined } from '@ant-design/icons' import { isEmbeddingModel, isFunctionCallingModel, @@ -14,7 +13,15 @@ import { FC, memo, useLayoutEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import styled from 'styled-components' -import CustomTag from './CustomTag' +import CustomTag from './Tags/CustomTag' +import { + EmbeddingTag, + ReasoningTag, + RerankerTag, + ToolsCallingTag, + VisionTag, + WebSearchTag +} from './Tags/ModelCapabilities' interface ModelTagsProps { model: Model @@ -70,45 +77,17 @@ const ModelTagsWithLabel: FC = ({ return ( - {isVisionModel(model) && ( - } - tooltip={showTooltip ? t('models.type.vision') : undefined}> - {shouldShowLabel ? t('models.type.vision') : ''} - - )} - {isWebSearchModel(model) && ( - } - tooltip={showTooltip ? t('models.type.websearch') : undefined}> - {shouldShowLabel ? t('models.type.websearch') : ''} - - )} + {isVisionModel(model) && } + {isWebSearchModel(model) && } {showReasoning && isReasoningModel(model) && ( - } - tooltip={showTooltip ? t('models.type.reasoning') : undefined}> - {shouldShowLabel ? t('models.type.reasoning') : ''} - + )} {showToolsCalling && isFunctionCallingModel(model) && ( - } - tooltip={showTooltip ? t('models.type.function_calling') : undefined}> - {shouldShowLabel ? t('models.type.function_calling') : ''} - + )} - {isEmbeddingModel(model) && } + {isEmbeddingModel(model) && } {showFree && isFreeModel(model) && } - {isRerankModel(model) && } + {isRerankModel(model) && } ) } diff --git a/src/renderer/src/components/OAuth/OAuthButton.tsx b/src/renderer/src/components/OAuth/OAuthButton.tsx index ad0b806bf8..16faefe42c 100644 --- a/src/renderer/src/components/OAuth/OAuthButton.tsx +++ b/src/renderer/src/components/OAuth/OAuthButton.tsx @@ -1,6 +1,12 @@ import { getProviderLabel } from '@renderer/i18n/label' import { Provider } from '@renderer/types' -import { oauthWithAihubmix, oauthWithPPIO, oauthWithSiliconFlow, oauthWithTokenFlux } from '@renderer/utils/oauth' +import { + oauthWith302AI, + oauthWithAihubmix, + oauthWithPPIO, + oauthWithSiliconFlow, + oauthWithTokenFlux +} from '@renderer/utils/oauth' import { Button, ButtonProps } from 'antd' import { FC } from 'react' import { useTranslation } from 'react-i18next' @@ -36,6 +42,10 @@ const OAuthButton: FC = ({ provider, onSuccess, ...buttonProps }) => { if (provider.id === 'tokenflux') { oauthWithTokenFlux() } + + if (provider.id === '302ai') { + oauthWith302AI(handleSuccess) + } } return ( diff --git a/src/renderer/src/components/Popups/SaveToKnowledgePopup.tsx b/src/renderer/src/components/Popups/SaveToKnowledgePopup.tsx index 3997135f72..ae4792a36c 100644 --- a/src/renderer/src/components/Popups/SaveToKnowledgePopup.tsx +++ b/src/renderer/src/components/Popups/SaveToKnowledgePopup.tsx @@ -1,14 +1,18 @@ import { loggerService } from '@logger' -import CustomTag from '@renderer/components/CustomTag' +import CustomTag from '@renderer/components/Tags/CustomTag' import { TopView } from '@renderer/components/TopView' import { useKnowledge, useKnowledgeBases } from '@renderer/hooks/useKnowledge' +import { Topic } from '@renderer/types' import { Message } from '@renderer/types/newMessage' import { analyzeMessageContent, + analyzeTopicContent, CONTENT_TYPES, ContentType, MessageContentStats, - processMessageContent + processMessageContent, + processTopicContent, + TopicContentStats } from '@renderer/utils/knowledge' import { Flex, Form, Modal, Select, Tooltip, Typography } from 'antd' import { Check, CircleHelp } from 'lucide-react' @@ -20,11 +24,12 @@ const logger = loggerService.withContext('SaveToKnowledgePopup') const { Text } = Typography -// 内容类型配置 +// Base Content Type Config const CONTENT_TYPE_CONFIG = { [CONTENT_TYPES.TEXT]: { label: 'chat.save.knowledge.content.maintext.title', - description: 'chat.save.knowledge.content.maintext.description' + description: 'chat.save.knowledge.content.maintext.description', + topicDescription: 'chat.save.topic.knowledge.content.maintext.description' }, [CONTENT_TYPES.CODE]: { label: 'chat.save.knowledge.content.code.title', @@ -62,16 +67,20 @@ const TAG_COLORS = { UNSELECTED: '#8c8c8c' } as const +type ContentStats = MessageContentStats | TopicContentStats + interface ContentTypeOption { type: ContentType - label: string count: number enabled: boolean - description?: string + label: string + description: string } +type ContentSource = { type: 'message'; data: Message } | { type: 'topic'; data: Topic } + interface ShowParams { - message: Message + source: ContentSource title?: string } @@ -84,35 +93,73 @@ interface Props extends ShowParams { resolve: (data: SaveResult | null) => void } -const PopupContainer: React.FC = ({ message, title, resolve }) => { +const PopupContainer: React.FC = ({ source, title, resolve }) => { const [open, setOpen] = useState(true) const [loading, setLoading] = useState(false) + const [analysisLoading, setAnalysisLoading] = useState(true) const [selectedBaseId, setSelectedBaseId] = useState() const [selectedTypes, setSelectedTypes] = useState([]) const [hasInitialized, setHasInitialized] = useState(false) + const [contentStats, setContentStats] = useState(null) const { bases } = useKnowledgeBases() const { addNote, addFiles } = useKnowledge(selectedBaseId || '') const { t } = useTranslation() - // 分析消息内容统计 - const contentStats = useMemo(() => analyzeMessageContent(message), [message]) + const isTopicMode = source?.type === 'topic' - // 生成内容类型选项(只显示有内容的类型) + // 异步分析内容统计 + useEffect(() => { + const analyze = async () => { + setAnalysisLoading(true) + setContentStats(null) + try { + const stats = isTopicMode + ? await analyzeTopicContent(source?.data as Topic) + : analyzeMessageContent(source?.data as Message) + setContentStats(stats) + } catch (error) { + logger.error('analyze content failed:', error as Error) + setContentStats({ + text: 0, + code: 0, + thinking: 0, + images: 0, + files: 0, + tools: 0, + citations: 0, + translations: 0, + errors: 0, + ...(isTopicMode && { messages: 0 }) + }) + } finally { + setAnalysisLoading(false) + } + } + analyze() + }, [source, isTopicMode]) + + // 生成内容类型选项 const contentTypeOptions: ContentTypeOption[] = useMemo(() => { + if (!contentStats) return [] + return Object.entries(CONTENT_TYPE_CONFIG) .map(([type, config]) => { const contentType = type as ContentType - const count = contentStats[contentType as keyof MessageContentStats] || 0 + const count = contentStats[contentType as keyof ContentStats] || 0 + const descriptionKey = + isTopicMode && 'topicDescription' in config && config.topicDescription + ? config.topicDescription + : config.description return { type: contentType, count, enabled: count > 0, label: t(config.label), - description: t(config.description) + description: t(descriptionKey) } }) - .filter((option) => option.enabled) // 只显示有内容的类型 - }, [contentStats, t]) + .filter((option) => option.enabled) + }, [contentStats, t, isTopicMode]) // 知识库选项 const knowledgeBaseOptions = useMemo( @@ -120,12 +167,12 @@ const PopupContainer: React.FC = ({ message, title, resolve }) => { bases.map((base) => ({ label: base.name, value: base.id, - disabled: !base.version // 如果知识库没有配置好就禁用 + disabled: !base.version })), [bases] ) - // 合并状态计算 + // 表单状态 const formState = useMemo(() => { const hasValidBase = selectedBaseId && bases.find((base) => base.id === selectedBaseId)?.version const hasContent = contentTypeOptions.length > 0 @@ -142,7 +189,7 @@ const PopupContainer: React.FC = ({ message, title, resolve }) => { } }, [selectedBaseId, bases, contentTypeOptions, selectedTypes]) - // 默认选择第一个可用的知识库 + // 默认选择第一个可用知识库 useEffect(() => { if (!selectedBaseId) { const firstAvailableBase = bases.find((base) => base.version) @@ -152,49 +199,51 @@ const PopupContainer: React.FC = ({ message, title, resolve }) => { } }, [bases, selectedBaseId]) - // 默认选择所有可用的内容类型(仅在初始化时) + // 默认选择所有可用内容类型 useEffect(() => { if (!hasInitialized && contentTypeOptions.length > 0) { - const availableTypes = contentTypeOptions.map((option) => option.type) - setSelectedTypes(availableTypes) + setSelectedTypes(contentTypeOptions.map((option) => option.type)) setHasInitialized(true) } }, [contentTypeOptions, hasInitialized]) - // 计算UI状态 + // UI状态 const uiState = useMemo(() => { + if (analysisLoading) { + return { type: 'loading', message: t('chat.save.topic.knowledge.loading') } + } if (!formState.hasContent) { - return { type: 'empty', message: t('chat.save.knowledge.empty.no_content') } + return { + type: 'empty', + message: t(isTopicMode ? 'chat.save.topic.knowledge.empty.no_content' : 'chat.save.knowledge.empty.no_content') + } } if (bases.length === 0) { return { type: 'empty', message: t('chat.save.knowledge.empty.no_knowledge_base') } } return { type: 'form' } - }, [formState.hasContent, bases.length, t]) + }, [analysisLoading, formState.hasContent, bases.length, t, isTopicMode]) - // 处理内容类型选择切换 const handleContentTypeToggle = (type: ContentType) => { setSelectedTypes((prev) => (prev.includes(type) ? prev.filter((t) => t !== type) : [...prev, type])) } const onOk = async () => { - if (!formState.canSubmit) { - return - } + if (!formState.canSubmit) return setLoading(true) let savedCount = 0 try { - const result = processMessageContent(message, selectedTypes) + const result = isTopicMode + ? await processTopicContent(source?.data as Topic, selectedTypes) + : processMessageContent(source?.data as Message, selectedTypes) - // 保存文本内容 if (result.text.trim() && selectedTypes.some((type) => type !== CONTENT_TYPES.FILE)) { await addNote(result.text) savedCount++ } - // 保存文件 if (result.files.length > 0 && selectedTypes.includes(CONTENT_TYPES.FILE)) { addFiles(result.files) savedCount += result.files.length @@ -204,27 +253,22 @@ const PopupContainer: React.FC = ({ message, title, resolve }) => { resolve({ success: true, savedCount }) } catch (error) { logger.error('save failed:', error as Error) - window.message.error(t('chat.save.knowledge.error.save_failed')) + window.message.error( + t(isTopicMode ? 'chat.save.topic.knowledge.error.save_failed' : 'chat.save.knowledge.error.save_failed') + ) setLoading(false) } } - const onCancel = () => { - setOpen(false) - } + const onCancel = () => setOpen(false) + const onClose = () => resolve(null) - const onClose = () => { - resolve(null) - } - - // 渲染空状态 const renderEmptyState = () => ( {uiState.message} ) - // 渲染表单内容 const renderFormContent = () => ( <>
@@ -241,7 +285,10 @@ const PopupContainer: React.FC = ({ message, title, resolve }) => { /> - + {contentTypeOptions.map((option) => ( = ({ message, title, resolve }) => { - {formState.selectedCount > 0 && ( - + + {formState.selectedCount > 0 && ( - {t('chat.save.knowledge.select.content.tip', { count: formState.selectedCount })} + {t( + isTopicMode + ? 'chat.save.topic.knowledge.select.content.selected_tip' + : 'chat.save.knowledge.select.content.tip', + { + count: formState.selectedCount, + ...(isTopicMode && { messages: (contentStats as TopicContentStats)?.messages || 0 }) + } + )} - - )} - - {formState.hasNoSelection && ( - + )} + {formState.hasNoSelection && ( {t('chat.save.knowledge.error.no_content_selected')} - - )} + )} + {!formState.hasNoSelection && formState.selectedCount === 0 && ( + +   + + )} + ) return ( = ({ message, title, resolve }) => { width={500} okText={t('common.save')} cancelText={t('common.cancel')} - okButtonProps={{ - loading, - disabled: !formState.canSubmit - }}> - {uiState.type === 'empty' ? renderEmptyState() : renderFormContent()} + okButtonProps={{ loading, disabled: !formState.canSubmit || analysisLoading }}> + {uiState.type === 'form' ? renderFormContent() : renderEmptyState()} ) } @@ -327,11 +381,22 @@ export default class SaveToKnowledgePopup { ) }) } + + static showForMessage(message: Message, title?: string): Promise { + return this.show({ source: { type: 'message', data: message }, title }) + } + + static showForTopic(topic: Topic, title?: string): Promise { + return this.show({ source: { type: 'topic', data: topic }, title }) + } } const EmptyContainer = styled.div` + display: flex; + justify-content: center; + align-items: center; + min-height: 100px; text-align: center; - padding: 40px 20px; ` const ContentTypeItem = styled(Flex)` @@ -352,4 +417,7 @@ const InfoContainer = styled.div` padding: 12px; border-radius: 6px; margin-top: 16px; + min-height: 40px; /* To avoid layout shift */ + display: flex; + align-items: center; ` diff --git a/src/renderer/src/components/Popups/TextEditPopup.tsx b/src/renderer/src/components/Popups/TextEditPopup.tsx index d5cd04a1c3..403c218dc9 100644 --- a/src/renderer/src/components/Popups/TextEditPopup.tsx +++ b/src/renderer/src/components/Popups/TextEditPopup.tsx @@ -1,9 +1,7 @@ import { LoadingOutlined } from '@ant-design/icons' import { loggerService } from '@logger' -import { useDefaultModel } from '@renderer/hooks/useAssistant' import { useSettings } from '@renderer/hooks/useSettings' -import { fetchTranslate } from '@renderer/services/ApiService' -import { getDefaultTranslateAssistant } from '@renderer/services/AssistantService' +import { translateText } from '@renderer/services/TranslateService' import { getLanguageByLangcode } from '@renderer/utils/translate' import { Modal, ModalProps } from 'antd' import TextArea from 'antd/es/input/TextArea' @@ -43,7 +41,6 @@ const PopupContainer: React.FC = ({ const [textValue, setTextValue] = useState(text) const [isTranslating, setIsTranslating] = useState(false) const textareaRef = useRef(null) - const { translateModel } = useDefaultModel() const { targetLanguage, showTranslateConfirm } = useSettings() const isMounted = useRef(true) @@ -103,21 +100,12 @@ const PopupContainer: React.FC = ({ if (!confirmed) return } - if (!translateModel) { - window.message.error({ - content: t('translate.error.not_configured'), - key: 'translate-message' - }) - return - } - if (isMounted.current) { setIsTranslating(true) } try { - const assistant = getDefaultTranslateAssistant(getLanguageByLangcode(targetLanguage), textValue) - const translatedText = await fetchTranslate({ content: textValue, assistant }) + const translatedText = await translateText(textValue, getLanguageByLangcode(targetLanguage)) if (isMounted.current) { setTextValue(translatedText) } diff --git a/src/renderer/src/components/Preview/GraphvizPreview.tsx b/src/renderer/src/components/Preview/GraphvizPreview.tsx new file mode 100644 index 0000000000..c3c5c641a2 --- /dev/null +++ b/src/renderer/src/components/Preview/GraphvizPreview.tsx @@ -0,0 +1,56 @@ +import { AsyncInitializer } from '@renderer/utils/asyncInitializer' +import React, { memo, useCallback } from 'react' +import styled from 'styled-components' + +import { useDebouncedRender } from './hooks/useDebouncedRender' +import ImagePreviewLayout from './ImagePreviewLayout' +import { BasicPreviewHandles, BasicPreviewProps } from './types' +import { renderSvgInShadowHost } from './utils' + +// 管理 viz 实例 +const vizInitializer = new AsyncInitializer(async () => { + const module = await import('@viz-js/viz') + return await module.instance() +}) + +/** 预览 Graphviz 图表 + * 使用 usePreviewRenderer hook 大幅简化组件逻辑 + */ +const GraphvizPreview = ({ + children, + enableToolbar = false, + ref +}: BasicPreviewProps & { ref?: React.RefObject }) => { + // 定义渲染函数 + const renderGraphviz = useCallback(async (content: string, container: HTMLDivElement) => { + const viz = await vizInitializer.get() + const svg = viz.renderString(content, { format: 'svg' }) + renderSvgInShadowHost(svg, container) + }, []) + + // 使用预览渲染器 hook + const { containerRef, error, isLoading } = useDebouncedRender(children, renderGraphviz, { + debounceDelay: 300 + }) + + return ( + + + + ) +} + +const StyledGraphviz = styled.div` + overflow: auto; + position: relative; + width: 100%; + height: 100%; +` + +export default memo(GraphvizPreview) diff --git a/src/renderer/src/components/Preview/ImagePreviewLayout.tsx b/src/renderer/src/components/Preview/ImagePreviewLayout.tsx new file mode 100644 index 0000000000..cff446e250 --- /dev/null +++ b/src/renderer/src/components/Preview/ImagePreviewLayout.tsx @@ -0,0 +1,60 @@ +import { useImageTools } from '@renderer/components/ActionTools/hooks/useImageTools' +import { LoadingIcon } from '@renderer/components/Icons' +import { Spin } from 'antd' +import { memo, useImperativeHandle } from 'react' + +import ImageToolbar from './ImageToolbar' +import { PreviewContainer, PreviewError } from './styles' +import { BasicPreviewHandles } from './types' + +interface ImagePreviewLayoutProps { + children: React.ReactNode + ref?: React.RefObject + imageRef: React.RefObject + source: string + loading?: boolean + error?: string | null + enableToolbar?: boolean + className?: string +} + +const ImagePreviewLayout = ({ + children, + ref, + imageRef, + source, + loading, + error, + enableToolbar, + className +}: ImagePreviewLayoutProps) => { + // 使用通用图像工具 + const { pan, zoom, copy, download, dialog } = useImageTools(imageRef, { + imgSelector: 'svg', + prefix: source ?? 'svg', + enableDrag: true, + enableWheelZoom: true + }) + + useImperativeHandle(ref, () => { + return { + pan, + zoom, + copy, + download, + dialog + } + }) + + return ( + }> + + {error && {error}} + {children} + {!error && enableToolbar && } + + + ) +} + +export default memo(ImagePreviewLayout) diff --git a/src/renderer/src/components/Preview/ImageToolButton.tsx b/src/renderer/src/components/Preview/ImageToolButton.tsx new file mode 100644 index 0000000000..e14ae8fee0 --- /dev/null +++ b/src/renderer/src/components/Preview/ImageToolButton.tsx @@ -0,0 +1,18 @@ +import { Button, Tooltip } from 'antd' +import { memo } from 'react' + +interface ImageToolButtonProps { + tooltip: string + icon: React.ReactNode + onClick: () => void +} + +const ImageToolButton = ({ tooltip, icon, onClick }: ImageToolButtonProps) => { + return ( + + + )), + Tooltip: vi.fn(({ children, title }) =>
{children}
) +})) + +describe('ImageToolButton', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + const defaultProps = { + tooltip: 'Test tooltip', + icon: Icon, + onClick: vi.fn() + } + + it('should match snapshot', () => { + const { asFragment } = render() + expect(asFragment()).toMatchSnapshot() + }) +}) diff --git a/src/renderer/src/components/Preview/__tests__/ImageToolbar.test.tsx b/src/renderer/src/components/Preview/__tests__/ImageToolbar.test.tsx new file mode 100644 index 0000000000..a64076e3a4 --- /dev/null +++ b/src/renderer/src/components/Preview/__tests__/ImageToolbar.test.tsx @@ -0,0 +1,96 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import ImageToolbar from '../ImageToolbar' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key + }) +})) + +// Mock ImageToolButton +vi.mock('../ImageToolButton', () => ({ + default: vi.fn(({ tooltip, onClick, icon }) => ( + + )) +})) + +// Mock lucide-react icons +vi.mock('lucide-react', () => ({ + ChevronUp: () => , + ChevronDown: () => , + ChevronLeft: () => , + ChevronRight: () => , + ZoomIn: () => +, + ZoomOut: () => -, + Scan: () => +})) + +vi.mock('@renderer/components/Icons', () => ({ + ResetIcon: () => +})) + +// Mock utils +vi.mock('@renderer/utils', () => ({ + classNames: (...args: any[]) => args.filter(Boolean).join(' ') +})) + +describe('ImageToolbar', () => { + const mockPan = vi.fn() + const mockZoom = vi.fn() + const mockOpenDialog = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should match snapshot', () => { + const { asFragment } = render() + expect(asFragment()).toMatchSnapshot() + }) + + it('calls onPan with correct values when pan buttons are clicked', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'preview.pan_up' })) + expect(mockPan).toHaveBeenCalledWith(0, -20) + + fireEvent.click(screen.getByRole('button', { name: 'preview.pan_down' })) + expect(mockPan).toHaveBeenCalledWith(0, 20) + + fireEvent.click(screen.getByRole('button', { name: 'preview.pan_left' })) + expect(mockPan).toHaveBeenCalledWith(-20, 0) + + fireEvent.click(screen.getByRole('button', { name: 'preview.pan_right' })) + expect(mockPan).toHaveBeenCalledWith(20, 0) + }) + + it('calls onZoom with correct values when zoom buttons are clicked', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'preview.zoom_in' })) + expect(mockZoom).toHaveBeenCalledWith(0.1) + + fireEvent.click(screen.getByRole('button', { name: 'preview.zoom_out' })) + expect(mockZoom).toHaveBeenCalledWith(-0.1) + }) + + it('calls onReset with correct values when reset button is clicked', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'preview.reset' })) + expect(mockPan).toHaveBeenCalledWith(0, 0, true) + expect(mockZoom).toHaveBeenCalledWith(1, true) + }) + + it('calls onOpenDialog when dialog button is clicked', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'preview.dialog' })) + expect(mockOpenDialog).toHaveBeenCalled() + }) +}) diff --git a/src/renderer/src/components/Preview/__tests__/MermaidPreview.test.tsx b/src/renderer/src/components/Preview/__tests__/MermaidPreview.test.tsx new file mode 100644 index 0000000000..17ada0668c --- /dev/null +++ b/src/renderer/src/components/Preview/__tests__/MermaidPreview.test.tsx @@ -0,0 +1,259 @@ +import { render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest' + +import { MermaidPreview } from '..' + +const mocks = vi.hoisted(() => ({ + useMermaid: vi.fn(), + useDebouncedRender: vi.fn(), + ImagePreviewLayout: vi.fn(({ children, loading, error, enableToolbar, source }) => ( +
+ {enableToolbar &&
Toolbar
} + {loading &&
Loading...
} + {error &&
{error}
} +
{children}
+
+ )) +})) + +// Mock hooks +vi.mock('@renderer/hooks/useMermaid', () => ({ + useMermaid: () => mocks.useMermaid() +})) + +vi.mock('@renderer/components/Preview/ImagePreviewLayout', () => ({ + default: mocks.ImagePreviewLayout +})) + +vi.mock('@renderer/components/Preview/hooks/useDebouncedRender', () => ({ + useDebouncedRender: mocks.useDebouncedRender +})) + +// Mock nanoid +vi.mock('@reduxjs/toolkit', () => ({ + nanoid: () => 'test-id-123456' +})) + +describe('MermaidPreview', () => { + const mermaidCode = 'graph TD\nA-->B' + const mockContainerRef = { current: document.createElement('div') } + + const mockMermaid = { + parse: vi.fn(), + render: vi.fn() + } + + // Helper function to create mock useDebouncedRender return value + const createMockHookReturn = (overrides = {}) => ({ + containerRef: mockContainerRef, + error: null, + isLoading: false, + triggerRender: vi.fn(), + cancelRender: vi.fn(), + clearError: vi.fn(), + setLoading: vi.fn(), + ...overrides + }) + + beforeEach(() => { + // Setup default mocks + mocks.useMermaid.mockReturnValue({ + mermaid: mockMermaid, + isLoading: false, + error: null + }) + + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn()) + + mockMermaid.parse.mockResolvedValue(true) + mockMermaid.render.mockResolvedValue({ + svg: 'test diagram' + }) + + // Mock MutationObserver + global.MutationObserver = vi.fn().mockImplementation(() => ({ + observe: vi.fn(), + disconnect: vi.fn(), + takeRecords: vi.fn() + })) + }) + + afterEach(() => { + vi.clearAllMocks() + vi.restoreAllMocks() + }) + + describe('basic rendering', () => { + it('should match snapshot', () => { + const { container } = render({mermaidCode}) + expect(container).toMatchSnapshot() + }) + + it('should handle valid mermaid content', () => { + render({mermaidCode}) + + expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument() + expect(mocks.useDebouncedRender).toHaveBeenCalledWith( + mermaidCode, + expect.any(Function), + expect.objectContaining({ + debounceDelay: 300, + shouldRender: expect.any(Function) + }) + ) + }) + + it('should handle empty content', () => { + render({''}) + + expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument() + expect(mocks.useDebouncedRender).toHaveBeenCalledWith('', expect.any(Function), expect.any(Object)) + }) + }) + + describe('loading state', () => { + it('should show loading when useMermaid is loading', () => { + mocks.useMermaid.mockReturnValue({ + mermaid: mockMermaid, + isLoading: true, + error: null + }) + + render({mermaidCode}) + + expect(screen.getByTestId('loading')).toBeInTheDocument() + }) + + it('should show loading when useDebouncedRender is loading', () => { + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: true })) + + render({mermaidCode}) + + expect(screen.getByTestId('loading')).toBeInTheDocument() + }) + + it('should not show loading when both are not loading', () => { + render({mermaidCode}) + + expect(screen.queryByTestId('loading')).not.toBeInTheDocument() + }) + }) + + describe('error handling', () => { + it('should show error from useMermaid', () => { + const mermaidError = 'Mermaid initialization failed' + mocks.useMermaid.mockReturnValue({ + mermaid: mockMermaid, + isLoading: false, + error: mermaidError + }) + + render({mermaidCode}) + + const errorElement = screen.getByTestId('error') + expect(errorElement).toBeInTheDocument() + expect(errorElement).toHaveTextContent(mermaidError) + }) + + it('should show error from useDebouncedRender', () => { + const renderError = 'Diagram rendering failed' + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: renderError })) + + render({mermaidCode}) + + const errorElement = screen.getByTestId('error') + expect(errorElement).toBeInTheDocument() + expect(errorElement).toHaveTextContent(renderError) + }) + + it('should prioritize useMermaid error over render error', () => { + const mermaidError = 'Mermaid initialization failed' + const renderError = 'Diagram rendering failed' + + mocks.useMermaid.mockReturnValue({ + mermaid: mockMermaid, + isLoading: false, + error: mermaidError + }) + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: renderError })) + + render({mermaidCode}) + + const errorElement = screen.getByTestId('error') + expect(errorElement).toHaveTextContent(mermaidError) + }) + }) + + describe('ref forwarding', () => { + it('should forward ref to ImagePreviewLayout', () => { + const ref = { current: null } + render({mermaidCode}) + + expect(mocks.ImagePreviewLayout).toHaveBeenCalledWith(expect.objectContaining({ ref }), undefined) + }) + }) + + describe('visibility detection', () => { + it('should observe parent elements up to fold className', () => { + // Create a DOM structure that simulates MessageGroup fold layout + const foldContainer = document.createElement('div') + foldContainer.className = 'fold selected' + + const messageWrapper = document.createElement('div') + messageWrapper.className = 'message-wrapper' + + const codeBlock = document.createElement('div') + codeBlock.className = 'code-block' + + foldContainer.appendChild(messageWrapper) + messageWrapper.appendChild(codeBlock) + document.body.appendChild(foldContainer) + + try { + render({mermaidCode}, { + container: codeBlock + }) + + const observerInstance = (global.MutationObserver as Mock).mock.results[0]?.value + expect(observerInstance.observe).toHaveBeenCalled() + } finally { + // Cleanup + document.body.removeChild(foldContainer) + } + }) + + it('should handle visibility changes and trigger re-render', () => { + const mockTriggerRender = vi.fn() + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ triggerRender: mockTriggerRender })) + + const { container } = render({mermaidCode}) + + // Get the MutationObserver callback + const observerCallback = (global.MutationObserver as Mock).mock.calls[0][0] + + // Mock the container element to be initially hidden + const mermaidElement = container.querySelector('.mermaid') + Object.defineProperty(mermaidElement, 'offsetParent', { + get: () => null, // Hidden + configurable: true + }) + + // Simulate MutationObserver detecting visibility change + observerCallback([]) + + // Now make it visible + Object.defineProperty(mermaidElement, 'offsetParent', { + get: () => document.body, // Visible + configurable: true + }) + + // Simulate another MutationObserver callback for visibility change + observerCallback([]) + + // The visibility change should have been detected and component should be ready to re-render + // We verify the component structure is correct for potential re-rendering + expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument() + expect(mermaidElement).toBeInTheDocument() + }) + }) +}) diff --git a/src/renderer/src/components/Preview/__tests__/PlantUmlPreview.test.tsx b/src/renderer/src/components/Preview/__tests__/PlantUmlPreview.test.tsx new file mode 100644 index 0000000000..08447046ec --- /dev/null +++ b/src/renderer/src/components/Preview/__tests__/PlantUmlPreview.test.tsx @@ -0,0 +1,169 @@ +import PlantUmlPreview from '@renderer/components/Preview/PlantUmlPreview' +import { render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// Use vi.hoisted to manage mocks +const mocks = vi.hoisted(() => ({ + ImagePreviewLayout: vi.fn(({ children, loading, error, enableToolbar, source }) => ( +
+ {enableToolbar &&
Toolbar
} + {loading &&
Loading...
} + {error &&
{error}
} +
{children}
+
+ )), + renderSvgInShadowHost: vi.fn(), + useDebouncedRender: vi.fn(), + logger: { + warn: vi.fn() + } +})) + +vi.mock('@renderer/components/Preview/ImagePreviewLayout', () => ({ + default: mocks.ImagePreviewLayout +})) + +vi.mock('@renderer/components/Preview/utils', () => ({ + renderSvgInShadowHost: mocks.renderSvgInShadowHost +})) + +vi.mock('@renderer/components/Preview/hooks/useDebouncedRender', () => ({ + useDebouncedRender: mocks.useDebouncedRender +})) + +describe('PlantUmlPreview', () => { + const diagram = '@startuml\nA -> B\n@enduml' + const mockContainerRef = { current: document.createElement('div') } + + // Helper function to create mock useDebouncedRender return value + const createMockHookReturn = (overrides = {}) => ({ + containerRef: mockContainerRef, + error: null, + isLoading: false, + triggerRender: vi.fn(), + cancelRender: vi.fn(), + clearError: vi.fn(), + setLoading: vi.fn(), + ...overrides + }) + + beforeEach(() => { + // Setup default successful state + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn()) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('basic rendering', () => { + it('should match snapshot', () => { + const { container } = render({diagram}) + expect(container).toMatchSnapshot() + }) + + it('should handle valid plantuml diagram', () => { + render({diagram}) + + // Component should render without throwing + expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument() + expect(mocks.useDebouncedRender).toHaveBeenCalledWith( + diagram, + expect.any(Function), + expect.objectContaining({ debounceDelay: 300 }) + ) + }) + + it('should handle empty content', () => { + render({''}) + + // Component should render without throwing + expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument() + expect(mocks.useDebouncedRender).toHaveBeenCalledWith('', expect.any(Function), expect.any(Object)) + }) + }) + + describe('loading state', () => { + it('should show loading indicator when rendering', () => { + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: true })) + + render({diagram}) + + expect(screen.getByTestId('loading')).toBeInTheDocument() + }) + + it('should not show loading indicator when not rendering', () => { + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: false })) + + render({diagram}) + + expect(screen.queryByTestId('loading')).not.toBeInTheDocument() + }) + }) + + describe('error handling', () => { + it('should show network error message', () => { + const networkError = 'Network Error: Unable to connect to PlantUML server. Please check your network connection.' + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: networkError })) + + render({diagram}) + + const errorElement = screen.getByTestId('error') + expect(errorElement).toBeInTheDocument() + expect(errorElement).toHaveTextContent(networkError) + }) + + it('should show syntax error message for invalid diagram', () => { + const syntaxError = + 'Diagram rendering failed (400): This is likely due to a syntax error in the diagram. Please check your code.' + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: syntaxError })) + + render({diagram}) + + const errorElement = screen.getByTestId('error') + expect(errorElement).toBeInTheDocument() + expect(errorElement).toHaveTextContent(syntaxError) + }) + + it('should show server error message', () => { + const serverError = + 'Diagram rendering failed (503): The PlantUML server is temporarily unavailable. Please try again later.' + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: serverError })) + + render({diagram}) + + const errorElement = screen.getByTestId('error') + expect(errorElement).toBeInTheDocument() + expect(errorElement).toHaveTextContent(serverError) + }) + + it('should show generic error message for other errors', () => { + const genericError = "Diagram rendering failed, server returned: 418 I'm a teapot" + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: genericError })) + + render({diagram}) + + const errorElement = screen.getByTestId('error') + expect(errorElement).toBeInTheDocument() + expect(errorElement).toHaveTextContent(genericError) + }) + + it('should not show error when rendering is successful', () => { + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: null })) + + render({diagram}) + + expect(screen.queryByTestId('error')).not.toBeInTheDocument() + }) + }) + + describe('ref forwarding', () => { + it('should forward ref to ImagePreviewLayout', () => { + const ref = { current: null } + render({diagram}) + + // The ref should be passed to ImagePreviewLayout + expect(mocks.ImagePreviewLayout).toHaveBeenCalledWith(expect.objectContaining({ ref }), undefined) + }) + }) +}) diff --git a/src/renderer/src/components/Preview/__tests__/SvgPreview.test.tsx b/src/renderer/src/components/Preview/__tests__/SvgPreview.test.tsx new file mode 100644 index 0000000000..e4a32d7c1c --- /dev/null +++ b/src/renderer/src/components/Preview/__tests__/SvgPreview.test.tsx @@ -0,0 +1,149 @@ +import SvgPreview from '@renderer/components/Preview/SvgPreview' +import { render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// Use vi.hoisted to manage mocks +const mocks = vi.hoisted(() => ({ + ImagePreviewLayout: vi.fn(({ children, loading, error, enableToolbar, source }) => ( +
+ {enableToolbar &&
Toolbar
} + {loading &&
Loading...
} + {error &&
{error}
} +
{children}
+
+ )), + renderSvgInShadowHost: vi.fn(), + useDebouncedRender: vi.fn() +})) + +vi.mock('@renderer/components/Preview/ImagePreviewLayout', () => ({ + default: mocks.ImagePreviewLayout +})) + +vi.mock('@renderer/components/Preview/utils', () => ({ + renderSvgInShadowHost: mocks.renderSvgInShadowHost +})) + +vi.mock('@renderer/components/Preview/hooks/useDebouncedRender', () => ({ + useDebouncedRender: mocks.useDebouncedRender +})) + +describe('SvgPreview', () => { + const svgContent = '' + const mockContainerRef = { current: document.createElement('div') } + + // Helper function to create mock useDebouncedRender return value + const createMockHookReturn = (overrides = {}) => ({ + containerRef: mockContainerRef, + error: null, + isLoading: false, + triggerRender: vi.fn(), + cancelRender: vi.fn(), + clearError: vi.fn(), + setLoading: vi.fn(), + ...overrides + }) + + beforeEach(() => { + // Setup default successful state + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn()) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('basic rendering', () => { + it('should match snapshot', () => { + const { container } = render({svgContent}) + expect(container).toMatchSnapshot() + }) + + it('should handle valid svg content', () => { + render({svgContent}) + + // Component should render without throwing + expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument() + expect(mocks.useDebouncedRender).toHaveBeenCalledWith( + svgContent, + expect.any(Function), + expect.objectContaining({ debounceDelay: 300 }) + ) + }) + + it('should handle empty content', () => { + render({''}) + + // Component should render without throwing + expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument() + expect(mocks.useDebouncedRender).toHaveBeenCalledWith('', expect.any(Function), expect.any(Object)) + }) + }) + + describe('loading state', () => { + it('should show loading indicator when rendering', () => { + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: true })) + + render({svgContent}) + + expect(screen.getByTestId('loading')).toBeInTheDocument() + }) + + it('should not show loading indicator when not rendering', () => { + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: false })) + + render({svgContent}) + + expect(screen.queryByTestId('loading')).not.toBeInTheDocument() + }) + }) + + describe('error handling', () => { + it('should show error message when rendering fails', () => { + const errorMessage = 'Invalid SVG content' + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: errorMessage })) + + render({svgContent}) + + const errorElement = screen.getByTestId('error') + expect(errorElement).toBeInTheDocument() + expect(errorElement).toHaveTextContent(errorMessage) + }) + + it('should not show error when rendering is successful', () => { + mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: null })) + + render({svgContent}) + + expect(screen.queryByTestId('error')).not.toBeInTheDocument() + }) + }) + + describe('custom styling', () => { + it('should use custom className when provided', () => { + render({svgContent}) + + const content = screen.getByTestId('preview-content') + const svgContainer = content.querySelector('.custom-svg-class') + expect(svgContainer).toBeInTheDocument() + }) + + it('should use default className when not provided', () => { + render({svgContent}) + + const content = screen.getByTestId('preview-content') + const svgContainer = content.querySelector('.svg-preview.special-preview') + expect(svgContainer).toBeInTheDocument() + }) + }) + + describe('ref forwarding', () => { + it('should forward ref to ImagePreviewLayout', () => { + const ref = { current: null } + render({svgContent}) + + // The ref should be passed to ImagePreviewLayout + expect(mocks.ImagePreviewLayout).toHaveBeenCalledWith(expect.objectContaining({ ref }), undefined) + }) + }) +}) diff --git a/src/renderer/src/components/Preview/__tests__/__snapshots__/GraphvizPreview.test.tsx.snap b/src/renderer/src/components/Preview/__tests__/__snapshots__/GraphvizPreview.test.tsx.snap new file mode 100644 index 0000000000..923f35b9e7 --- /dev/null +++ b/src/renderer/src/components/Preview/__tests__/__snapshots__/GraphvizPreview.test.tsx.snap @@ -0,0 +1,30 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`GraphvizPreview > basic rendering > should match snapshot 1`] = ` +.c0 { + overflow: auto; + position: relative; + width: 100%; + height: 100%; +} + +
+
+
+ Toolbar +
+
+
+
+
+
+`; diff --git a/src/renderer/src/components/Preview/__tests__/__snapshots__/ImagePreviewLayout.test.tsx.snap b/src/renderer/src/components/Preview/__tests__/__snapshots__/ImagePreviewLayout.test.tsx.snap new file mode 100644 index 0000000000..13e847fc31 --- /dev/null +++ b/src/renderer/src/components/Preview/__tests__/__snapshots__/ImagePreviewLayout.test.tsx.snap @@ -0,0 +1,18 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`ImagePreviewLayout > should match snapshot 1`] = ` +
+
+
+
+ Test Content +
+
+
+
+`; diff --git a/src/renderer/src/components/Preview/__tests__/__snapshots__/ImageToolButton.test.tsx.snap b/src/renderer/src/components/Preview/__tests__/__snapshots__/ImageToolButton.test.tsx.snap new file mode 100644 index 0000000000..0f155f1ff1 --- /dev/null +++ b/src/renderer/src/components/Preview/__tests__/__snapshots__/ImageToolButton.test.tsx.snap @@ -0,0 +1,18 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`ImageToolButton > should match snapshot 1`] = ` + +
+
+
+`; diff --git a/src/renderer/src/components/Preview/__tests__/__snapshots__/ImageToolbar.test.tsx.snap b/src/renderer/src/components/Preview/__tests__/__snapshots__/ImageToolbar.test.tsx.snap new file mode 100644 index 0000000000..b697f06aab --- /dev/null +++ b/src/renderer/src/components/Preview/__tests__/__snapshots__/ImageToolbar.test.tsx.snap @@ -0,0 +1,141 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`ImageToolbar > should match snapshot 1`] = ` + + .c0 { + display: flex; + flex-direction: column; + align-items: center; + position: absolute; + gap: 4px; + right: 1em; + bottom: 1em; + z-index: 5; +} + +.c0 .ant-btn { + line-height: 0; +} + +.c1 { + display: flex; + justify-content: center; + gap: 4px; + width: 100%; +} + +.c2 { + flex: 1; +} + +