Merge branch 'main' of https://github.com/CherryHQ/cherry-studio into wip/refactor/databases

This commit is contained in:
fullex 2025-08-09 09:44:28 +08:00
commit 92cd012037
210 changed files with 11778 additions and 4298 deletions

View File

@ -1,279 +0,0 @@
diff --git a/client.js b/client.js
index 33b4ff6309d5f29187dab4e285d07dac20340bab..8f568637ee9e4677585931fb0284c8165a933f69 100644
--- a/client.js
+++ b/client.js
@@ -433,7 +433,7 @@ class OpenAI {
'User-Agent': this.getUserAgent(),
'X-Stainless-Retry-Count': String(retryCount),
...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}),
- ...(0, detect_platform_1.getPlatformHeaders)(),
+ // ...(0, detect_platform_1.getPlatformHeaders)(),
'OpenAI-Organization': this.organization,
'OpenAI-Project': this.project,
},
diff --git a/client.mjs b/client.mjs
index c34c18213073540ebb296ea540b1d1ad39527906..1ce1a98256d7e90e26ca963582f235b23e996e73 100644
--- a/client.mjs
+++ b/client.mjs
@@ -430,7 +430,7 @@ export class OpenAI {
'User-Agent': this.getUserAgent(),
'X-Stainless-Retry-Count': String(retryCount),
...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}),
- ...getPlatformHeaders(),
+ // ...getPlatformHeaders(),
'OpenAI-Organization': this.organization,
'OpenAI-Project': this.project,
},
diff --git a/core/error.js b/core/error.js
index a12d9d9ccd242050161adeb0f82e1b98d9e78e20..fe3a5462480558bc426deea147f864f12b36f9bd 100644
--- a/core/error.js
+++ b/core/error.js
@@ -40,7 +40,7 @@ class APIError extends OpenAIError {
if (!status || !headers) {
return new APIConnectionError({ message, cause: (0, errors_1.castToError)(errorResponse) });
}
- const error = errorResponse?.['error'];
+ const error = errorResponse?.['error'] || errorResponse;
if (status === 400) {
return new BadRequestError(status, error, message, headers);
}
diff --git a/core/error.mjs b/core/error.mjs
index 83cefbaffeb8c657536347322d8de9516af479a2..63334b7972ec04882aa4a0800c1ead5982345045 100644
--- a/core/error.mjs
+++ b/core/error.mjs
@@ -36,7 +36,7 @@ export class APIError extends OpenAIError {
if (!status || !headers) {
return new APIConnectionError({ message, cause: castToError(errorResponse) });
}
- const error = errorResponse?.['error'];
+ const error = errorResponse?.['error'] || errorResponse;
if (status === 400) {
return new BadRequestError(status, error, message, headers);
}
diff --git a/resources/embeddings.js b/resources/embeddings.js
index 2404264d4ba0204322548945ebb7eab3bea82173..8f1bc45cc45e0797d50989d96b51147b90ae6790 100644
--- a/resources/embeddings.js
+++ b/resources/embeddings.js
@@ -5,52 +5,64 @@ exports.Embeddings = void 0;
const resource_1 = require("../core/resource.js");
const utils_1 = require("../internal/utils.js");
class Embeddings extends resource_1.APIResource {
- /**
- * Creates an embedding vector representing the input text.
- *
- * @example
- * ```ts
- * const createEmbeddingResponse =
- * await client.embeddings.create({
- * input: 'The quick brown fox jumped over the lazy dog',
- * model: 'text-embedding-3-small',
- * });
- * ```
- */
- create(body, options) {
- const hasUserProvidedEncodingFormat = !!body.encoding_format;
- // No encoding_format specified, defaulting to base64 for performance reasons
- // See https://github.com/openai/openai-node/pull/1312
- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
- if (hasUserProvidedEncodingFormat) {
- (0, utils_1.loggerFor)(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format);
- }
- const response = this._client.post('/embeddings', {
- body: {
- ...body,
- encoding_format: encoding_format,
- },
- ...options,
- });
- // if the user specified an encoding_format, return the response as-is
- if (hasUserProvidedEncodingFormat) {
- return response;
- }
- // in this stage, we are sure the user did not specify an encoding_format
- // and we defaulted to base64 for performance reasons
- // we are sure then that the response is base64 encoded, let's decode it
- // the returned result will be a float32 array since this is OpenAI API's default encoding
- (0, utils_1.loggerFor)(this._client).debug('embeddings/decoding base64 embeddings from base64');
- return response._thenUnwrap((response) => {
- if (response && response.data) {
- response.data.forEach((embeddingBase64Obj) => {
- const embeddingBase64Str = embeddingBase64Obj.embedding;
- embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)(embeddingBase64Str);
- });
- }
- return response;
- });
- }
+ /**
+ * Creates an embedding vector representing the input text.
+ *
+ * @example
+ * ```ts
+ * const createEmbeddingResponse =
+ * await client.embeddings.create({
+ * input: 'The quick brown fox jumped over the lazy dog',
+ * model: 'text-embedding-3-small',
+ * });
+ * ```
+ */
+ create(body, options) {
+ const hasUserProvidedEncodingFormat = !!body.encoding_format;
+ // No encoding_format specified, defaulting to base64 for performance reasons
+ // See https://github.com/openai/openai-node/pull/1312
+ let encoding_format = hasUserProvidedEncodingFormat
+ ? body.encoding_format
+ : "base64";
+ if (body.model.includes("jina")) {
+ encoding_format = undefined;
+ }
+ if (hasUserProvidedEncodingFormat) {
+ (0, utils_1.loggerFor)(this._client).debug(
+ "embeddings/user defined encoding_format:",
+ body.encoding_format
+ );
+ }
+ const response = this._client.post("/embeddings", {
+ body: {
+ ...body,
+ encoding_format: encoding_format,
+ },
+ ...options,
+ });
+ // if the user specified an encoding_format, return the response as-is
+ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) {
+ return response;
+ }
+ // in this stage, we are sure the user did not specify an encoding_format
+ // and we defaulted to base64 for performance reasons
+ // we are sure then that the response is base64 encoded, let's decode it
+ // the returned result will be a float32 array since this is OpenAI API's default encoding
+ (0, utils_1.loggerFor)(this._client).debug(
+ "embeddings/decoding base64 embeddings from base64"
+ );
+ return response._thenUnwrap((response) => {
+ if (response && response.data && typeof response.data[0]?.embedding === 'string') {
+ response.data.forEach((embeddingBase64Obj) => {
+ const embeddingBase64Str = embeddingBase64Obj.embedding;
+ embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)(
+ embeddingBase64Str
+ );
+ });
+ }
+ return response;
+ });
+ }
}
exports.Embeddings = Embeddings;
//# sourceMappingURL=embeddings.js.map
diff --git a/resources/embeddings.mjs b/resources/embeddings.mjs
index 19dcaef578c194a89759c4360073cfd4f7dd2cbf..0284e9cc615c900eff508eb595f7360a74bd9200 100644
--- a/resources/embeddings.mjs
+++ b/resources/embeddings.mjs
@@ -2,51 +2,61 @@
import { APIResource } from "../core/resource.mjs";
import { loggerFor, toFloat32Array } from "../internal/utils.mjs";
export class Embeddings extends APIResource {
- /**
- * Creates an embedding vector representing the input text.
- *
- * @example
- * ```ts
- * const createEmbeddingResponse =
- * await client.embeddings.create({
- * input: 'The quick brown fox jumped over the lazy dog',
- * model: 'text-embedding-3-small',
- * });
- * ```
- */
- create(body, options) {
- const hasUserProvidedEncodingFormat = !!body.encoding_format;
- // No encoding_format specified, defaulting to base64 for performance reasons
- // See https://github.com/openai/openai-node/pull/1312
- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
- if (hasUserProvidedEncodingFormat) {
- loggerFor(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format);
- }
- const response = this._client.post('/embeddings', {
- body: {
- ...body,
- encoding_format: encoding_format,
- },
- ...options,
- });
- // if the user specified an encoding_format, return the response as-is
- if (hasUserProvidedEncodingFormat) {
- return response;
- }
- // in this stage, we are sure the user did not specify an encoding_format
- // and we defaulted to base64 for performance reasons
- // we are sure then that the response is base64 encoded, let's decode it
- // the returned result will be a float32 array since this is OpenAI API's default encoding
- loggerFor(this._client).debug('embeddings/decoding base64 embeddings from base64');
- return response._thenUnwrap((response) => {
- if (response && response.data) {
- response.data.forEach((embeddingBase64Obj) => {
- const embeddingBase64Str = embeddingBase64Obj.embedding;
- embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str);
- });
- }
- return response;
- });
- }
+ /**
+ * Creates an embedding vector representing the input text.
+ *
+ * @example
+ * ```ts
+ * const createEmbeddingResponse =
+ * await client.embeddings.create({
+ * input: 'The quick brown fox jumped over the lazy dog',
+ * model: 'text-embedding-3-small',
+ * });
+ * ```
+ */
+ create(body, options) {
+ const hasUserProvidedEncodingFormat = !!body.encoding_format;
+ // No encoding_format specified, defaulting to base64 for performance reasons
+ // See https://github.com/openai/openai-node/pull/1312
+ let encoding_format = hasUserProvidedEncodingFormat
+ ? body.encoding_format
+ : "base64";
+ if (body.model.includes("jina")) {
+ encoding_format = undefined;
+ }
+ if (hasUserProvidedEncodingFormat) {
+ loggerFor(this._client).debug(
+ "embeddings/user defined encoding_format:",
+ body.encoding_format
+ );
+ }
+ const response = this._client.post("/embeddings", {
+ body: {
+ ...body,
+ encoding_format: encoding_format,
+ },
+ ...options,
+ });
+ // if the user specified an encoding_format, return the response as-is
+ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) {
+ return response;
+ }
+ // in this stage, we are sure the user did not specify an encoding_format
+ // and we defaulted to base64 for performance reasons
+ // we are sure then that the response is base64 encoded, let's decode it
+ // the returned result will be a float32 array since this is OpenAI API's default encoding
+ loggerFor(this._client).debug(
+ "embeddings/decoding base64 embeddings from base64"
+ );
+ return response._thenUnwrap((response) => {
+ if (response && response.data && typeof response.data[0]?.embedding === 'string') {
+ response.data.forEach((embeddingBase64Obj) => {
+ const embeddingBase64Str = embeddingBase64Obj.embedding;
+ embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str);
+ });
+ }
+ return response;
+ });
+ }
}
//# sourceMappingURL=embeddings.mjs.map

View File

@ -0,0 +1,344 @@
diff --git a/client.js b/client.js
index 22cc08d77ce849842a28f684c20dd5738152efa4..0c20f96405edbe7724b87517115fa2a61934b343 100644
--- a/client.js
+++ b/client.js
@@ -444,7 +444,7 @@ class OpenAI {
'User-Agent': this.getUserAgent(),
'X-Stainless-Retry-Count': String(retryCount),
...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}),
- ...(0, detect_platform_1.getPlatformHeaders)(),
+ // ...(0, detect_platform_1.getPlatformHeaders)(),
'OpenAI-Organization': this.organization,
'OpenAI-Project': this.project,
},
diff --git a/client.mjs b/client.mjs
index 7f1af99fb30d2cae03eea6687b53e6c7828faceb..fd66373a5eff31a5846084387a3fd97956c9ad48 100644
--- a/client.mjs
+++ b/client.mjs
@@ -1,43 +1,41 @@
// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
var _OpenAI_instances, _a, _OpenAI_encoder, _OpenAI_baseURLOverridden;
-import { __classPrivateFieldGet, __classPrivateFieldSet } from "./internal/tslib.mjs";
-import { uuid4 } from "./internal/utils/uuid.mjs";
-import { validatePositiveInteger, isAbsoluteURL, safeJSON } from "./internal/utils/values.mjs";
-import { sleep } from "./internal/utils/sleep.mjs";
-import { castToError, isAbortError } from "./internal/errors.mjs";
-import { getPlatformHeaders } from "./internal/detect-platform.mjs";
-import * as Shims from "./internal/shims.mjs";
-import * as Opts from "./internal/request-options.mjs";
-import * as qs from "./internal/qs/index.mjs";
-import { VERSION } from "./version.mjs";
+import { APIPromise } from "./core/api-promise.mjs";
import * as Errors from "./core/error.mjs";
import * as Pagination from "./core/pagination.mjs";
import * as Uploads from "./core/uploads.mjs";
-import * as API from "./resources/index.mjs";
-import { APIPromise } from "./core/api-promise.mjs";
-import { Batches, } from "./resources/batches.mjs";
-import { Completions, } from "./resources/completions.mjs";
-import { Embeddings, } from "./resources/embeddings.mjs";
-import { Files, } from "./resources/files.mjs";
-import { Images, } from "./resources/images.mjs";
-import { Models } from "./resources/models.mjs";
-import { Moderations, } from "./resources/moderations.mjs";
-import { Webhooks } from "./resources/webhooks.mjs";
+import { isRunningInBrowser } from "./internal/detect-platform.mjs";
+import { castToError, isAbortError } from "./internal/errors.mjs";
+import { buildHeaders } from "./internal/headers.mjs";
+import * as qs from "./internal/qs/index.mjs";
+import * as Opts from "./internal/request-options.mjs";
+import * as Shims from "./internal/shims.mjs";
+import { __classPrivateFieldGet, __classPrivateFieldSet } from "./internal/tslib.mjs";
+import { readEnv } from "./internal/utils/env.mjs";
+import { formatRequestDetails, loggerFor, parseLogLevel, } from "./internal/utils/log.mjs";
+import { sleep } from "./internal/utils/sleep.mjs";
+import { uuid4 } from "./internal/utils/uuid.mjs";
+import { isAbsoluteURL, isEmptyObj, safeJSON, validatePositiveInteger } from "./internal/utils/values.mjs";
import { Audio } from "./resources/audio/audio.mjs";
+import { Batches, } from "./resources/batches.mjs";
import { Beta } from "./resources/beta/beta.mjs";
import { Chat } from "./resources/chat/chat.mjs";
+import { Completions, } from "./resources/completions.mjs";
import { Containers, } from "./resources/containers/containers.mjs";
+import { Embeddings, } from "./resources/embeddings.mjs";
import { Evals, } from "./resources/evals/evals.mjs";
+import { Files, } from "./resources/files.mjs";
import { FineTuning } from "./resources/fine-tuning/fine-tuning.mjs";
import { Graders } from "./resources/graders/graders.mjs";
+import { Images, } from "./resources/images.mjs";
+import * as API from "./resources/index.mjs";
+import { Models } from "./resources/models.mjs";
+import { Moderations, } from "./resources/moderations.mjs";
import { Responses } from "./resources/responses/responses.mjs";
import { Uploads as UploadsAPIUploads, } from "./resources/uploads/uploads.mjs";
import { VectorStores, } from "./resources/vector-stores/vector-stores.mjs";
-import { isRunningInBrowser } from "./internal/detect-platform.mjs";
-import { buildHeaders } from "./internal/headers.mjs";
-import { readEnv } from "./internal/utils/env.mjs";
-import { formatRequestDetails, loggerFor, parseLogLevel, } from "./internal/utils/log.mjs";
-import { isEmptyObj } from "./internal/utils/values.mjs";
+import { Webhooks } from "./resources/webhooks.mjs";
+import { VERSION } from "./version.mjs";
/**
* API Client for interfacing with the OpenAI API.
*/
@@ -441,7 +439,7 @@ export class OpenAI {
'User-Agent': this.getUserAgent(),
'X-Stainless-Retry-Count': String(retryCount),
...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}),
- ...getPlatformHeaders(),
+ // ...getPlatformHeaders(),
'OpenAI-Organization': this.organization,
'OpenAI-Project': this.project,
},
diff --git a/core/error.js b/core/error.js
index c302cc356f0f24b50c3f5a0aa3ea0b79ae1e9a8d..164ee2ee31cd7eea8f70139e25d140b763e91d36 100644
--- a/core/error.js
+++ b/core/error.js
@@ -40,7 +40,7 @@ class APIError extends OpenAIError {
if (!status || !headers) {
return new APIConnectionError({ message, cause: (0, errors_1.castToError)(errorResponse) });
}
- const error = errorResponse?.['error'];
+ const error = errorResponse?.['error'] || errorResponse;
if (status === 400) {
return new BadRequestError(status, error, message, headers);
}
diff --git a/core/error.mjs b/core/error.mjs
index 75f5b0c328cc4894478f3490a00dbf6abd96fc12..269f46f96e9fad1f7a1649a3810562abc7fae37f 100644
--- a/core/error.mjs
+++ b/core/error.mjs
@@ -36,7 +36,7 @@ export class APIError extends OpenAIError {
if (!status || !headers) {
return new APIConnectionError({ message, cause: castToError(errorResponse) });
}
- const error = errorResponse?.['error'];
+ const error = errorResponse?.['error'] || errorResponse;
if (status === 400) {
return new BadRequestError(status, error, message, headers);
}
diff --git a/resources/embeddings.js b/resources/embeddings.js
index 2404264d4ba0204322548945ebb7eab3bea82173..93b9e286f62101b5aa7532e96ddba61f682ece3f 100644
--- a/resources/embeddings.js
+++ b/resources/embeddings.js
@@ -5,52 +5,64 @@ exports.Embeddings = void 0;
const resource_1 = require("../core/resource.js");
const utils_1 = require("../internal/utils.js");
class Embeddings extends resource_1.APIResource {
- /**
- * Creates an embedding vector representing the input text.
- *
- * @example
- * ```ts
- * const createEmbeddingResponse =
- * await client.embeddings.create({
- * input: 'The quick brown fox jumped over the lazy dog',
- * model: 'text-embedding-3-small',
- * });
- * ```
- */
- create(body, options) {
- const hasUserProvidedEncodingFormat = !!body.encoding_format;
- // No encoding_format specified, defaulting to base64 for performance reasons
- // See https://github.com/openai/openai-node/pull/1312
- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
- if (hasUserProvidedEncodingFormat) {
- (0, utils_1.loggerFor)(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format);
- }
- const response = this._client.post('/embeddings', {
- body: {
- ...body,
- encoding_format: encoding_format,
- },
- ...options,
- });
- // if the user specified an encoding_format, return the response as-is
- if (hasUserProvidedEncodingFormat) {
- return response;
- }
- // in this stage, we are sure the user did not specify an encoding_format
- // and we defaulted to base64 for performance reasons
- // we are sure then that the response is base64 encoded, let's decode it
- // the returned result will be a float32 array since this is OpenAI API's default encoding
- (0, utils_1.loggerFor)(this._client).debug('embeddings/decoding base64 embeddings from base64');
- return response._thenUnwrap((response) => {
- if (response && response.data) {
- response.data.forEach((embeddingBase64Obj) => {
- const embeddingBase64Str = embeddingBase64Obj.embedding;
- embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)(embeddingBase64Str);
- });
- }
- return response;
- });
+ /**
+ * Creates an embedding vector representing the input text.
+ *
+ * @example
+ * ```ts
+ * const createEmbeddingResponse =
+ * await client.embeddings.create({
+ * input: 'The quick brown fox jumped over the lazy dog',
+ * model: 'text-embedding-3-small',
+ * });
+ * ```
+ */
+ create(body, options) {
+ const hasUserProvidedEncodingFormat = !!body.encoding_format;
+ // No encoding_format specified, defaulting to base64 for performance reasons
+ // See https://github.com/openai/openai-node/pull/1312
+ let encoding_format = hasUserProvidedEncodingFormat
+ ? body.encoding_format
+ : "base64";
+ if (body.model.includes("jina")) {
+ encoding_format = undefined;
+ }
+ if (hasUserProvidedEncodingFormat) {
+ (0, utils_1.loggerFor)(this._client).debug(
+ "embeddings/user defined encoding_format:",
+ body.encoding_format
+ );
}
+ const response = this._client.post("/embeddings", {
+ body: {
+ ...body,
+ encoding_format: encoding_format,
+ },
+ ...options,
+ });
+ // if the user specified an encoding_format, return the response as-is
+ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) {
+ return response;
+ }
+ // in this stage, we are sure the user did not specify an encoding_format
+ // and we defaulted to base64 for performance reasons
+ // we are sure then that the response is base64 encoded, let's decode it
+ // the returned result will be a float32 array since this is OpenAI API's default encoding
+ (0, utils_1.loggerFor)(this._client).debug(
+ "embeddings/decoding base64 embeddings from base64"
+ );
+ return response._thenUnwrap((response) => {
+ if (response && response.data && typeof response.data[0]?.embedding === 'string') {
+ response.data.forEach((embeddingBase64Obj) => {
+ const embeddingBase64Str = embeddingBase64Obj.embedding;
+ embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)(
+ embeddingBase64Str
+ );
+ });
+ }
+ return response;
+ });
+ }
}
exports.Embeddings = Embeddings;
//# sourceMappingURL=embeddings.js.map
diff --git a/resources/embeddings.mjs b/resources/embeddings.mjs
index 19dcaef578c194a89759c4360073cfd4f7dd2cbf..42c903fadb03c707356a983603ff09e4152ecf11 100644
--- a/resources/embeddings.mjs
+++ b/resources/embeddings.mjs
@@ -2,51 +2,61 @@
import { APIResource } from "../core/resource.mjs";
import { loggerFor, toFloat32Array } from "../internal/utils.mjs";
export class Embeddings extends APIResource {
- /**
- * Creates an embedding vector representing the input text.
- *
- * @example
- * ```ts
- * const createEmbeddingResponse =
- * await client.embeddings.create({
- * input: 'The quick brown fox jumped over the lazy dog',
- * model: 'text-embedding-3-small',
- * });
- * ```
- */
- create(body, options) {
- const hasUserProvidedEncodingFormat = !!body.encoding_format;
- // No encoding_format specified, defaulting to base64 for performance reasons
- // See https://github.com/openai/openai-node/pull/1312
- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
- if (hasUserProvidedEncodingFormat) {
- loggerFor(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format);
- }
- const response = this._client.post('/embeddings', {
- body: {
- ...body,
- encoding_format: encoding_format,
- },
- ...options,
- });
- // if the user specified an encoding_format, return the response as-is
- if (hasUserProvidedEncodingFormat) {
- return response;
- }
- // in this stage, we are sure the user did not specify an encoding_format
- // and we defaulted to base64 for performance reasons
- // we are sure then that the response is base64 encoded, let's decode it
- // the returned result will be a float32 array since this is OpenAI API's default encoding
- loggerFor(this._client).debug('embeddings/decoding base64 embeddings from base64');
- return response._thenUnwrap((response) => {
- if (response && response.data) {
- response.data.forEach((embeddingBase64Obj) => {
- const embeddingBase64Str = embeddingBase64Obj.embedding;
- embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str);
- });
- }
- return response;
- });
+ /**
+ * Creates an embedding vector representing the input text.
+ *
+ * @example
+ * ```ts
+ * const createEmbeddingResponse =
+ * await client.embeddings.create({
+ * input: 'The quick brown fox jumped over the lazy dog',
+ * model: 'text-embedding-3-small',
+ * });
+ * ```
+ */
+ create(body, options) {
+ const hasUserProvidedEncodingFormat = !!body.encoding_format;
+ // No encoding_format specified, defaulting to base64 for performance reasons
+ // See https://github.com/openai/openai-node/pull/1312
+ let encoding_format = hasUserProvidedEncodingFormat
+ ? body.encoding_format
+ : "base64";
+ if (body.model.includes("jina")) {
+ encoding_format = undefined;
+ }
+ if (hasUserProvidedEncodingFormat) {
+ loggerFor(this._client).debug(
+ "embeddings/user defined encoding_format:",
+ body.encoding_format
+ );
}
+ const response = this._client.post("/embeddings", {
+ body: {
+ ...body,
+ encoding_format: encoding_format,
+ },
+ ...options,
+ });
+ // if the user specified an encoding_format, return the response as-is
+ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) {
+ return response;
+ }
+ // in this stage, we are sure the user did not specify an encoding_format
+ // and we defaulted to base64 for performance reasons
+ // we are sure then that the response is base64 encoded, let's decode it
+ // the returned result will be a float32 array since this is OpenAI API's default encoding
+ loggerFor(this._client).debug(
+ "embeddings/decoding base64 embeddings from base64"
+ );
+ return response._thenUnwrap((response) => {
+ if (response && response.data && typeof response.data[0]?.embedding === 'string') {
+ response.data.forEach((embeddingBase64Obj) => {
+ const embeddingBase64Str = embeddingBase64Obj.embedding;
+ embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str);
+ });
+ }
+ return response;
+ });
+ }
}
//# sourceMappingURL=embeddings.mjs.map

View File

@ -0,0 +1,23 @@
diff --git a/dist/index.js b/dist/index.js
index b54962b2d332c1a3affadbdb37d39fdf90ab9f82..7906b4ea3bf9dffe60d74c279e9cfe885489c9f9 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -36,12 +36,12 @@ async function getWindowsSystemProxy() {
const proxies = Object.fromEntries(proxyConfigString
.split(';')
.map((proxyPair) => proxyPair.split('=')));
- const proxyUrl = proxies['https']
- ? `https://${proxies['https']}`
- : proxies['http']
- ? `http://${proxies['http']}`
- : proxies['socks']
- ? `socks://${proxies['socks']}`
+ const proxyUrl = proxies['http']
+ ? `http://${proxies['http']}`
+ : proxies['socks']
+ ? `socks://${proxies['socks']}`
+ : proxies['https']
+ ? `https://${proxies['https']}`
: undefined;
if (!proxyUrl) {
throw new Error(`Could not get usable proxy URL from ${proxyConfigString}`);

View File

@ -0,0 +1,180 @@
# CodeBlockView Component Structure
## Overview
CodeBlockView is the core component in Cherry Studio for displaying and manipulating code blocks. It supports multiple view modes and visual previews for special languages, providing rich interactive tools.
## Component Structure
```mermaid
graph TD
A[CodeBlockView] --> B[CodeToolbar]
A --> C[SourceView]
A --> D[SpecialView]
A --> E[StatusBar]
B --> F[CodeToolButton]
C --> G[CodeEditor / CodeViewer]
D --> H[MermaidPreview]
D --> I[PlantUmlPreview]
D --> J[SvgPreview]
D --> K[GraphvizPreview]
F --> L[useCopyTool]
F --> M[useDownloadTool]
F --> N[useViewSourceTool]
F --> O[useSplitViewTool]
F --> P[useRunTool]
F --> Q[useExpandTool]
F --> R[useWrapTool]
F --> S[useSaveTool]
```
## Core Concepts
### View Types
- **preview**: Preview view, where non-source code is displayed as special views
- **edit**: Edit view
### View Modes
- **source**: Source code view mode
- **special**: Special view mode (Mermaid, PlantUML, SVG)
- **split**: Split view mode (source code and special view displayed side by side)
### Special View Languages
- mermaid
- plantuml
- svg
- dot
- graphviz
## Component Details
### CodeBlockView Main Component
Main responsibilities:
1. Managing view mode state
2. Coordinating the display of source code view and special view
3. Managing toolbar tools
4. Handling code execution state
### Subcomponents
#### CodeToolbar
- Toolbar displayed at the top-right corner of the code block
- Contains core and quick tools
- Dynamically displays relevant tools based on context
#### CodeEditor/CodeViewer Source View
- Editable code editor or read-only code viewer
- Uses either component based on settings
- Supports syntax highlighting for multiple programming languages
#### Special View Components
- **MermaidPreview**: Mermaid diagram preview
- **PlantUmlPreview**: PlantUML diagram preview
- **SvgPreview**: SVG image preview
- **GraphvizPreview**: Graphviz diagram preview
All special view components share a common architecture for consistent user experience and functionality. For detailed information about these components and their implementation, see [Image Preview Components Documentation](./ImagePreview-en.md).
#### StatusBar
- Displays Python code execution results
- Can show both text and image results
## Tool System
CodeBlockView uses a hook-based tool system:
```mermaid
graph TD
A[CodeBlockView] --> B[useCopyTool]
A --> C[useDownloadTool]
A --> D[useViewSourceTool]
A --> E[useSplitViewTool]
A --> F[useRunTool]
A --> G[useExpandTool]
A --> H[useWrapTool]
A --> I[useSaveTool]
B --> J[ToolManager]
C --> J
D --> J
E --> J
F --> J
G --> J
H --> J
I --> J
J --> K[CodeToolbar]
```
Each tool hook is responsible for registering specific function tool buttons to the tool manager, which then passes these tools to the CodeToolbar component for rendering.
### Tool Types
- **core**: Core tools, always displayed in the toolbar
- **quick**: Quick tools, displayed in a dropdown menu when there are more than one
### Tool List
1. **Copy**: Copy code or image
2. **Download**: Download code or image
3. **View Source**: Switch between special view and source code view
4. **Split View**: Toggle split view mode
5. **Run**: Run Python code
6. **Expand/Collapse**: Control code block expansion/collapse
7. **Wrap**: Control automatic line wrapping
8. **Save**: Save edited code
## State Management
CodeBlockView manages the following states through React hooks:
1. **viewMode**: Current view mode ('source' | 'special' | 'split')
2. **isRunning**: Python code execution status
3. **executionResult**: Python code execution result
4. **tools**: Toolbar tool list
5. **expandOverride/unwrapOverride**: User override settings for expand/wrap
6. **sourceScrollHeight**: Source code view scroll height
## Interaction Flow
```mermaid
sequenceDiagram
participant U as User
participant CB as CodeBlockView
participant CT as CodeToolbar
participant SV as SpecialView
participant SE as SourceEditor
U->>CB: View code block
CB->>CB: Initialize state
CB->>CT: Register tools
CB->>SV: Render special view (if applicable)
CB->>SE: Render source view
U->>CT: Click tool button
CT->>CB: Trigger tool callback
CB->>CB: Update state
CB->>CT: Re-register tools (if needed)
```
## Special Handling
### HTML Code Blocks
HTML code blocks are specially handled using the HtmlArtifactsCard component.
### Python Code Execution
Supports executing Python code and displaying results using Pyodide to run Python code in the browser.

View File

@ -0,0 +1,180 @@
# CodeBlockView 组件结构说明
## 概述
CodeBlockView 是 Cherry Studio 中用于显示和操作代码块的核心组件。它支持多种视图模式和特殊语言的可视化预览,提供丰富的交互工具。
## 组件结构
```mermaid
graph TD
A[CodeBlockView] --> B[CodeToolbar]
A --> C[SourceView]
A --> D[SpecialView]
A --> E[StatusBar]
B --> F[CodeToolButton]
C --> G[CodeEditor / CodeViewer]
D --> H[MermaidPreview]
D --> I[PlantUmlPreview]
D --> J[SvgPreview]
D --> K[GraphvizPreview]
F --> L[useCopyTool]
F --> M[useDownloadTool]
F --> N[useViewSourceTool]
F --> O[useSplitViewTool]
F --> P[useRunTool]
F --> Q[useExpandTool]
F --> R[useWrapTool]
F --> S[useSaveTool]
```
## 核心概念
### 视图类型
- **preview**: 预览视图,非源代码的是特殊视图
- **edit**: 编辑视图
### 视图模式
- **source**: 源代码视图模式
- **special**: 特殊视图模式Mermaid、PlantUML、SVG
- **split**: 分屏模式(源代码和特殊视图并排显示)
### 特殊视图语言
- mermaid
- plantuml
- svg
- dot
- graphviz
## 组件详细说明
### CodeBlockView 主组件
主要负责:
1. 管理视图模式状态
2. 协调源代码视图和特殊视图的显示
3. 管理工具栏工具
4. 处理代码执行状态
### 子组件
#### CodeToolbar 工具栏
- 显示在代码块右上角的工具栏
- 包含核心(core)和快捷(quick)两类工具
- 根据上下文动态显示相关工具
#### CodeEditor/CodeViewer 源代码视图
- 可编辑的代码编辑器或只读的代码查看器
- 根据设置决定使用哪个组件
- 支持多种编程语言高亮
#### 特殊视图组件
- **MermaidPreview**: Mermaid 图表预览
- **PlantUmlPreview**: PlantUML 图表预览
- **SvgPreview**: SVG 图像预览
- **GraphvizPreview**: Graphviz 图表预览
所有特殊视图组件共享通用架构,以确保一致的用户体验和功能。有关这些组件及其实现的详细信息,请参阅 [图像预览组件文档](./ImagePreview-zh.md)。
#### StatusBar 状态栏
- 显示 Python 代码执行结果
- 可显示文本和图像结果
## 工具系统
CodeBlockView 使用基于 hooks 的工具系统:
```mermaid
graph TD
A[CodeBlockView] --> B[useCopyTool]
A --> C[useDownloadTool]
A --> D[useViewSourceTool]
A --> E[useSplitViewTool]
A --> F[useRunTool]
A --> G[useExpandTool]
A --> H[useWrapTool]
A --> I[useSaveTool]
B --> J[ToolManager]
C --> J
D --> J
E --> J
F --> J
G --> J
H --> J
I --> J
J --> K[CodeToolbar]
```
每个工具 hook 负责注册特定功能的工具按钮到工具管理器,工具管理器再将这些工具传递给 CodeToolbar 组件进行渲染。
### 工具类型
- **core**: 核心工具,始终显示在工具栏
- **quick**: 快捷工具当数量大于1时通过下拉菜单显示
### 工具列表
1. **复制(copy)**: 复制代码或图像
2. **下载(download)**: 下载代码或图像
3. **查看源码(view-source)**: 在特殊视图和源码视图间切换
4. **分屏(split-view)**: 切换分屏模式
5. **运行(run)**: 运行 Python 代码
6. **展开/折叠(expand)**: 控制代码块的展开/折叠
7. **换行(wrap)**: 控制代码的自动换行
8. **保存(save)**: 保存编辑的代码
## 状态管理
CodeBlockView 通过 React hooks 管理以下状态:
1. **viewMode**: 当前视图模式 ('source' | 'special' | 'split')
2. **isRunning**: Python 代码执行状态
3. **executionResult**: Python 代码执行结果
4. **tools**: 工具栏工具列表
5. **expandOverride/unwrapOverride**: 用户展开/换行的覆盖设置
6. **sourceScrollHeight**: 源代码视图滚动高度
## 交互流程
```mermaid
sequenceDiagram
participant U as User
participant CB as CodeBlockView
participant CT as CodeToolbar
participant SV as SpecialView
participant SE as SourceEditor
U->>CB: 查看代码块
CB->>CB: 初始化状态
CB->>CT: 注册工具
CB->>SV: 渲染特殊视图(如果适用)
CB->>SE: 渲染源码视图
U->>CT: 点击工具按钮
CT->>CB: 触发工具回调
CB->>CB: 更新状态
CB->>CT: 重新注册工具(如果需要)
```
## 特殊处理
### HTML 代码块
HTML 代码块会被特殊处理,使用 HtmlArtifactsCard 组件显示。
### Python 代码执行
支持执行 Python 代码并显示结果,使用 Pyodide 在浏览器中运行 Python 代码。

View File

@ -0,0 +1,195 @@
# Image Preview Components
## Overview
Image Preview Components are a set of specialized components in Cherry Studio for rendering and displaying various diagram and image formats. They provide a consistent user experience across different preview types with shared functionality for loading states, error handling, and interactive controls.
## Supported Formats
- **Mermaid**: Interactive diagrams and flowcharts
- **PlantUML**: UML diagrams and system architecture
- **SVG**: Scalable vector graphics
- **Graphviz/DOT**: Graph visualization and network diagrams
## Architecture
```mermaid
graph TD
A[MermaidPreview] --> D[ImagePreviewLayout]
B[PlantUmlPreview] --> D
C[SvgPreview] --> D
E[GraphvizPreview] --> D
D --> F[ImageToolbar]
D --> G[useDebouncedRender]
F --> H[Pan Controls]
F --> I[Zoom Controls]
F --> J[Reset Function]
F --> K[Dialog Control]
G --> L[Debounced Rendering]
G --> M[Error Handling]
G --> N[Loading State]
G --> O[Dependency Management]
```
## Core Components
### ImagePreviewLayout
A common layout wrapper that provides the foundation for all image preview components.
**Features:**
- **Loading State Management**: Shows loading spinner during rendering
- **Error Display**: Displays error messages when rendering fails
- **Toolbar Integration**: Conditionally renders ImageToolbar when enabled
- **Container Management**: Wraps preview content with consistent styling
- **Responsive Design**: Adapts to different container sizes
**Props:**
- `children`: The preview content to be displayed
- `loading`: Boolean indicating if content is being rendered
- `error`: Error message to display if rendering fails
- `enableToolbar`: Whether to show the interactive toolbar
- `imageRef`: Reference to the container element for image manipulation
### ImageToolbar
Interactive toolbar component providing image manipulation controls.
**Features:**
- **Pan Controls**: 4-directional pan buttons (up, down, left, right)
- **Zoom Controls**: Zoom in/out functionality with configurable increments
- **Reset Function**: Restore original pan and zoom state
- **Dialog Control**: Open preview in expanded dialog view
- **Accessible Design**: Full keyboard navigation and screen reader support
**Layout:**
- 3x3 grid layout positioned at bottom-right of preview
- Responsive button sizing
- Tooltip support for all controls
### useDebouncedRender Hook
A specialized React hook for managing preview rendering with performance optimizations.
**Features:**
- **Debounced Rendering**: Prevents excessive re-renders during rapid content changes (default 300ms delay)
- **Automatic Dependency Management**: Handles dependencies for render and condition functions
- **Error Handling**: Catches and manages rendering errors with detailed error messages
- **Loading State**: Tracks rendering progress with automatic state updates
- **Conditional Rendering**: Supports pre-render condition checks
- **Manual Controls**: Provides trigger, cancel, and state management functions
**API:**
```typescript
const { containerRef, error, isLoading, triggerRender, cancelRender, clearError, setLoading } = useDebouncedRender(
value,
renderFunction,
options
)
```
**Options:**
- `debounceDelay`: Customize debounce timing
- `shouldRender`: Function for conditional rendering logic
## Component Implementations
### MermaidPreview
Renders Mermaid diagrams with special handling for visibility detection.
**Special Features:**
- Syntax validation before rendering
- Visibility detection to handle collapsed containers
- SVG coordinate fixing for edge cases
- Integration with mermaid.js library
### PlantUmlPreview
Renders PlantUML diagrams using the online PlantUML server.
**Special Features:**
- Network error handling and retry logic
- Diagram encoding using deflate compression
- Support for light/dark themes
- Server status monitoring
### SvgPreview
Renders SVG content using Shadow DOM for isolation.
**Special Features:**
- Shadow DOM rendering for style isolation
- Direct SVG content injection
- Minimal processing overhead
- Cross-browser compatibility
### GraphvizPreview
Renders Graphviz/DOT diagrams using the viz.js library.
**Special Features:**
- Client-side rendering with viz.js
- Lazy loading of viz.js library
- SVG element generation
- Memory-efficient processing
## Shared Functionality
### Error Handling
All preview components provide consistent error handling:
- Network errors (connection failures)
- Syntax errors (invalid diagram code)
- Server errors (external service failures)
- Rendering errors (library failures)
### Loading States
Standardized loading indicators across all components:
- Spinner animation during processing
- Progress feedback for long operations
- Smooth transitions between states
### Interactive Controls
Common interaction patterns:
- Pan and zoom functionality
- Reset to original view
- Full-screen dialog mode
- Keyboard accessibility
### Performance Optimizations
- Debounced rendering to prevent excessive updates
- Lazy loading of heavy libraries
- Memory management for large diagrams
- Efficient re-rendering strategies
## Integration with CodeBlockView
Image Preview Components integrate seamlessly with CodeBlockView:
- Automatic format detection based on language tags
- Consistent toolbar integration
- Shared state management
- Responsive layout adaptation
For more information about the overall CodeBlockView architecture, see [CodeBlockView Documentation](./CodeBlockView-en.md).

View File

@ -0,0 +1,195 @@
# 图像预览组件
## 概述
图像预览组件是 Cherry Studio 中用于渲染和显示各种图表和图像格式的专用组件集合。它们为不同预览类型提供一致的用户体验,具有共享的加载状态、错误处理和交互控制功能。
## 支持格式
- **Mermaid**: 交互式图表和流程图
- **PlantUML**: UML 图表和系统架构
- **SVG**: 可缩放矢量图形
- **Graphviz/DOT**: 图形可视化和网络图表
## 架构
```mermaid
graph TD
A[MermaidPreview] --> D[ImagePreviewLayout]
B[PlantUmlPreview] --> D
C[SvgPreview] --> D
E[GraphvizPreview] --> D
D --> F[ImageToolbar]
D --> G[useDebouncedRender]
F --> H[平移控制]
F --> I[缩放控制]
F --> J[重置功能]
F --> K[对话框控制]
G --> L[防抖渲染]
G --> M[错误处理]
G --> N[加载状态]
G --> O[依赖管理]
```
## 核心组件
### ImagePreviewLayout 图像预览布局
为所有图像预览组件提供基础的通用布局包装器。
**功能特性:**
- **加载状态管理**: 在渲染期间显示加载动画
- **错误显示**: 渲染失败时显示错误信息
- **工具栏集成**: 启用时有条件地渲染 ImageToolbar
- **容器管理**: 使用一致的样式包装预览内容
- **响应式设计**: 适应不同的容器尺寸
**属性:**
- `children`: 要显示的预览内容
- `loading`: 指示内容是否正在渲染的布尔值
- `error`: 渲染失败时显示的错误信息
- `enableToolbar`: 是否显示交互式工具栏
- `imageRef`: 用于图像操作的容器元素引用
### ImageToolbar 图像工具栏
提供图像操作控制的交互式工具栏组件。
**功能特性:**
- **平移控制**: 4方向平移按钮上、下、左、右
- **缩放控制**: 放大/缩小功能,支持可配置的增量
- **重置功能**: 恢复原始平移和缩放状态
- **对话框控制**: 在展开对话框中打开预览
- **无障碍设计**: 完整的键盘导航和屏幕阅读器支持
**布局:**
- 3x3 网格布局,位于预览右下角
- 响应式按钮尺寸
- 所有控件的工具提示支持
### useDebouncedRender Hook 防抖渲染钩子
用于管理预览渲染的专用 React Hook具有性能优化功能。
**功能特性:**
- **防抖渲染**: 防止内容快速变化时的过度重新渲染(默认 300ms 延迟)
- **自动依赖管理**: 处理渲染和条件函数的依赖项
- **错误处理**: 捕获和管理渲染错误,提供详细的错误信息
- **加载状态**: 跟踪渲染进度并自动更新状态
- **条件渲染**: 支持预渲染条件检查
- **手动控制**: 提供触发、取消和状态管理功能
**API:**
```typescript
const { containerRef, error, isLoading, triggerRender, cancelRender, clearError, setLoading } = useDebouncedRender(
value,
renderFunction,
options
)
```
**选项:**
- `debounceDelay`: 自定义防抖时间
- `shouldRender`: 条件渲染逻辑函数
## 组件实现
### MermaidPreview Mermaid 预览
渲染 Mermaid 图表,具有可见性检测的特殊处理。
**特殊功能:**
- 渲染前语法验证
- 可见性检测以处理折叠的容器
- 边缘情况的 SVG 坐标修复
- 与 mermaid.js 库集成
### PlantUmlPreview PlantUML 预览
使用在线 PlantUML 服务器渲染 PlantUML 图表。
**特殊功能:**
- 网络错误处理和重试逻辑
- 使用 deflate 压缩的图表编码
- 支持明/暗主题
- 服务器状态监控
### SvgPreview SVG 预览
使用 Shadow DOM 隔离渲染 SVG 内容。
**特殊功能:**
- Shadow DOM 渲染实现样式隔离
- 直接 SVG 内容注入
- 最小化处理开销
- 跨浏览器兼容性
### GraphvizPreview Graphviz 预览
使用 viz.js 库渲染 Graphviz/DOT 图表。
**特殊功能:**
- 使用 viz.js 进行客户端渲染
- viz.js 库的懒加载
- SVG 元素生成
- 内存高效处理
## 共享功能
### 错误处理
所有预览组件提供一致的错误处理:
- 网络错误(连接失败)
- 语法错误(无效的图表代码)
- 服务器错误(外部服务失败)
- 渲染错误(库失败)
### 加载状态
所有组件的标准化加载指示器:
- 处理期间的动画
- 长时间操作的进度反馈
- 状态间的平滑过渡
### 交互控制
通用交互模式:
- 平移和缩放功能
- 重置到原始视图
- 全屏对话框模式
- 键盘无障碍访问
### 性能优化
- 防抖渲染以防止过度更新
- 重型库的懒加载
- 大型图表的内存管理
- 高效的重新渲染策略
## 与 CodeBlockView 的集成
图像预览组件与 CodeBlockView 无缝集成:
- 基于语言标签的自动格式检测
- 一致的工具栏集成
- 共享状态管理
- 响应式布局适应
有关整体 CodeBlockView 架构的更多信息,请参阅 [CodeBlockView 文档](./CodeBlockView-zh.md)。

View File

@ -50,6 +50,7 @@ files:
- '!node_modules/rollup-plugin-visualizer'
- '!node_modules/js-tiktoken'
- '!node_modules/@tavily/core/node_modules/js-tiktoken'
- '!node_modules/pdf-parse/lib/pdf.js/{v1.9.426,v1.10.88,v2.0.550}'
- '!node_modules/mammoth/{mammoth.browser.js,mammoth.browser.min.js}'
- '!node_modules/selection-hook/prebuilds/**/*' # we rebuild .node, don't use prebuilds
- '!node_modules/selection-hook/node_modules' # we don't need what in the node_modules dir
@ -117,18 +118,4 @@ afterSign: scripts/notarize.js
artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
新增服务商AWS Bedrock
富文本编辑器支持:提升提示词编辑体验,支持更丰富的格式调整
拖拽输入优化:支持从其他软件直接拖拽文本至输入框,简化内容输入流程
参数调节增强:新增 Top-P 和 Temperature 开关设置,提供更灵活的模型调控选项
翻译任务后台执行:翻译任务支持后台运行,提升多任务处理效率
新模型支持:新增 Qwen-MT、Qwen3235BA22Bthinking 和 sonar-deep-research 模型,扩展推理能力
推理稳定性提升:修复部分模型思考内容无法输出的问题,确保推理结果完整
Mistral 模型修复:解决 Mistral 模型无法使用的问题,恢复其推理功能
备份目录优化:支持相对路径输入,提升备份配置灵活性
数据导出调整:新增引用内容导出开关,提供更精细的导出控制
文本流完整性:修复文本流末尾文字丢失问题,确保输出内容完整
内存泄漏修复:优化代码逻辑,解决内存泄漏问题,提升运行稳定性
嵌入模型简化:降低嵌入模型配置复杂度,提高易用性
MCP Tool 长时间运行:增强 MCP 工具的稳定性,支持长时间任务执行
设置页面优化:优化设置页面布局,提升用户体验
稳定性改进和错误修复

View File

@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.5.4",
"version": "1.5.5",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@ -219,7 +219,7 @@
"motion": "^12.10.5",
"notion-helper": "^1.3.22",
"npx-scope-finder": "^1.2.0",
"openai": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
"openai": "patch:openai@npm%3A5.12.0#~/.yarn/patches/openai-npm-5.12.0-a06a6369b2.patch",
"p-queue": "^8.1.0",
"pdf-lib": "^1.17.1",
"playwright": "^1.52.0",
@ -273,20 +273,22 @@
"zod": "^3.25.74"
},
"resolutions": {
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",
"@langchain/openai@npm:>=0.1.0 <0.4.0": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",
"openai@npm:^4.77.0": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
"openai@npm:^4.77.0": "patch:openai@npm%3A5.12.0#~/.yarn/patches/openai-npm-5.12.0-a06a6369b2.patch",
"pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch",
"app-builder-lib@npm:26.0.13": "patch:app-builder-lib@npm%3A26.0.13#~/.yarn/patches/app-builder-lib-npm-26.0.13-a064c9e1d0.patch",
"openai@npm:^4.87.3": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
"openai@npm:^4.87.3": "patch:openai@npm%3A5.12.0#~/.yarn/patches/openai-npm-5.12.0-a06a6369b2.patch",
"app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch",
"@langchain/core@npm:^0.3.26": "patch:@langchain/core@npm%3A0.3.44#~/.yarn/patches/@langchain-core-npm-0.3.44-41d5c3cb0a.patch",
"node-abi": "4.12.0",
"undici": "6.21.2",
"vite": "npm:rolldown-vite@latest",
"atomically@npm:^1.7.0": "patch:atomically@npm%3A1.7.0#~/.yarn/patches/atomically-npm-1.7.0-e742e5293b.patch",
"file-stream-rotator@npm:^0.6.1": "patch:file-stream-rotator@npm%3A0.6.1#~/.yarn/patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch"
"file-stream-rotator@npm:^0.6.1": "patch:file-stream-rotator@npm%3A0.6.1#~/.yarn/patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch",
"windows-system-proxy@npm:^1.0.0": "patch:windows-system-proxy@npm%3A1.0.0#~/.yarn/patches/windows-system-proxy-npm-1.0.0-ff2a828eec.patch"
},
"packageManager": "yarn@4.9.1",
"lint-staged": {

View File

@ -94,17 +94,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
let proxyConfig: ProxyConfig
if (proxy === 'system') {
// system proxy will use the system filter by themselves
proxyConfig = { mode: 'system' }
} else if (proxy) {
proxyConfig = { mode: 'fixed_servers', proxyRules: proxy }
proxyConfig = { mode: 'fixed_servers', proxyRules: proxy, proxyBypassRules: bypassRules }
} else {
proxyConfig = { mode: 'direct' }
}
if (bypassRules) {
proxyConfig.proxyBypassRules = bypassRules
}
await proxyManager.configureProxy(proxyConfig)
})

View File

@ -1,5 +1,4 @@
import { loggerService } from '@logger'
import { defaultByPassRules } from '@shared/config/constant'
import axios from 'axios'
import { app, ProxyConfig, session } from 'electron'
import { socksDispatcher } from 'fetch-socks'
@ -10,9 +9,13 @@ import { ProxyAgent } from 'proxy-agent'
import { Dispatcher, EnvHttpProxyAgent, getGlobalDispatcher, setGlobalDispatcher } from 'undici'
const logger = loggerService.withContext('ProxyManager')
let byPassRules = defaultByPassRules.split(',')
let byPassRules: string[] = []
const isByPass = (hostname: string) => {
if (byPassRules.length === 0) {
return false
}
return byPassRules.includes(hostname)
}
@ -98,7 +101,7 @@ export class ProxyManager {
await this.configureProxy({
mode: 'system',
proxyRules: currentProxy?.proxyUrl.toLowerCase(),
proxyBypassRules: this.config.proxyBypassRules
proxyBypassRules: undefined
})
}, 1000 * 60)
}
@ -131,7 +134,7 @@ export class ProxyManager {
this.monitorSystemProxy()
}
byPassRules = config.proxyBypassRules?.split(',') || defaultByPassRules.split(',')
byPassRules = config.proxyBypassRules?.split(',') || []
this.setGlobalProxy(this.config)
} catch (error) {
logger.error('Failed to config proxy:', error as Error)

View File

@ -252,7 +252,9 @@ export class WindowService {
'https://cloud.siliconflow.cn/expensebill',
'https://aihubmix.com/token',
'https://aihubmix.com/topup',
'https://aihubmix.com/statistics'
'https://aihubmix.com/statistics',
'https://dash.302.ai/sso/login',
'https://dash.302.ai/charge'
]
if (oauthProviderUrls.some((link) => url.startsWith(link))) {

View File

@ -3,25 +3,28 @@ import {
isFunctionCallingModel,
isNotSupportTemperatureAndTopP,
isOpenAIModel,
isSupportedFlexServiceTier
isSupportFlexServiceTierModel
} from '@renderer/config/models'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import { isSupportServiceTierProviders } from '@renderer/config/providers'
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import { SettingsState } from '@renderer/store/settings'
import {
Assistant,
FileTypes,
GenerateImageParams,
GroqServiceTiers,
isGroqServiceTier,
isOpenAIServiceTier,
KnowledgeReference,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
MemoryItem,
Model,
OpenAIServiceTier,
OpenAIServiceTiers,
Provider,
SystemProviderIds,
ToolCallResponse,
WebSearchProviderResponse,
WebSearchResponse
@ -201,29 +204,37 @@ export abstract class BaseApiClient<
return assistantSettings?.enableTopP ? assistantSettings?.topP : undefined
}
// NOTE: 这个也许可以迁移到OpenAIBaseClient
protected getServiceTier(model: Model) {
if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') {
const serviceTierSetting = this.provider.serviceTier
if (!isSupportServiceTierProviders(this.provider) || !isOpenAIModel(model) || !serviceTierSetting) {
return undefined
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
let serviceTier = 'auto' as OpenAIServiceTier
if (openAI && openAI?.serviceTier === 'flex') {
if (isSupportedFlexServiceTier(model)) {
serviceTier = 'flex'
} else {
serviceTier = 'auto'
// 处理不同供应商需要 fallback 到默认值的情况
if (this.provider.id === SystemProviderIds.groq) {
if (
!isGroqServiceTier(serviceTierSetting) ||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
} else {
serviceTier = openAI.serviceTier
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
if (
!isOpenAIServiceTier(serviceTierSetting) ||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
}
return serviceTier
return serviceTierSetting
}
protected getTimeout(model: Model) {
if (isSupportedFlexServiceTier(model)) {
if (isSupportFlexServiceTierModel(model)) {
return 15 * 1000 * 60
}
return defaultTimeout

View File

@ -11,7 +11,6 @@ import {
import {
ContentBlock,
ContentBlockParam,
MessageCreateParams,
MessageCreateParamsBase,
RedactedThinkingBlockParam,
ServerToolUseBlockParam,
@ -70,6 +69,7 @@ import {
mcpToolsToAnthropicTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { t } from 'i18next'
import { BaseApiClient } from '../BaseApiClient'
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
@ -494,22 +494,14 @@ export class AnthropicAPIClient extends BaseApiClient<
system: systemMessage ? [systemMessage] : undefined,
thinking: this.getBudgetToken(assistant, model),
tools: tools.length > 0 ? tools : undefined,
stream: streamOutput,
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
}
const finalParams: MessageCreateParams = streamOutput
? {
...commonParams,
stream: true
}
: {
...commonParams,
stream: false
}
const timeout = this.getTimeout(model)
return { payload: finalParams, messages: sdkMessages, metadata: { timeout } }
return { payload: commonParams, messages: sdkMessages, metadata: { timeout } }
}
}
}
@ -520,6 +512,14 @@ export class AnthropicAPIClient extends BaseApiClient<
const toolCalls: Record<number, ToolUseBlock> = {}
return {
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
if (typeof rawChunk === 'string') {
try {
rawChunk = JSON.parse(rawChunk)
} catch (error) {
logger.error('invalid chunk', { rawChunk, error })
throw new Error(t('error.chat.chunk.non_json'))
}
}
switch (rawChunk.type) {
case 'message': {
let i = 0

View File

@ -42,6 +42,7 @@ import {
mcpToolsToAwsBedrockTools
} from '@renderer/utils/mcp-tools'
import { findImageBlocks } from '@renderer/utils/messageUtils/find'
import { t } from 'i18next'
import { BaseApiClient } from '../BaseApiClient'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
@ -417,7 +418,10 @@ export class AwsBedrockAPIClient extends BaseApiClient<
temperature: this.getTemperature(assistant, model),
topP: this.getTopP(assistant, model),
stream: streamOutput !== false,
tools: tools.length > 0 ? tools : undefined
tools: tools.length > 0 ? tools : undefined,
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
}
const timeout = this.getTimeout(model)
@ -436,6 +440,15 @@ export class AwsBedrockAPIClient extends BaseApiClient<
async transform(rawChunk: AwsBedrockSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
logger.silly('Processing AWS Bedrock chunk:', rawChunk)
if (typeof rawChunk === 'string') {
try {
rawChunk = JSON.parse(rawChunk)
} catch (error) {
logger.error('invalid chunk', { rawChunk, error })
throw new Error(t('error.chat.chunk.non_json'))
}
}
// 处理消息开始事件
if (rawChunk.messageStart) {
controller.enqueue({

View File

@ -60,6 +60,7 @@ import {
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { defaultTimeout, MB } from '@shared/config/constant'
import { t } from 'i18next'
import { BaseApiClient } from '../BaseApiClient'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
@ -531,6 +532,7 @@ export class GeminiAPIClient extends BaseApiClient<
...(enableGenerateImage ? this.getGenerateImageParameter() : {}),
...this.getBudgetToken(assistant, model),
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
}
@ -557,6 +559,14 @@ export class GeminiAPIClient extends BaseApiClient<
return () => ({
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
logger.silly('chunk', chunk)
if (typeof chunk === 'string') {
try {
chunk = JSON.parse(chunk)
} catch (error) {
logger.error('invalid chunk', { chunk, error })
throw new Error(t('error.chat.chunk.non_json'))
}
}
if (chunk.candidates && chunk.candidates.length > 0) {
for (const candidate of chunk.candidates) {
if (candidate.content) {

View File

@ -4,10 +4,11 @@ import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
getOpenAIWebSearchParams,
getThinkModelType,
isDoubaoThinkingAutoModel,
isGrokReasoningModel,
isNotSupportSystemMessageModel,
isQwen3235BA22BThinkingModel,
isQwenAlwaysThinkModel,
isQwenMTModel,
isQwenReasoningModel,
isReasoningModel,
@ -20,12 +21,13 @@ import {
isSupportedThinkingTokenModel,
isSupportedThinkingTokenQwenModel,
isSupportedThinkingTokenZhipuModel,
isVisionModel
isVisionModel,
MODEL_SUPPORTED_REASONING_EFFORT
} from '@renderer/config/models'
import {
isSupportArrayContentProvider,
isSupportDeveloperRoleProvider,
isSupportQwen3EnableThinkingProvider,
isSupportEnableThinkingProvider,
isSupportStreamOptionsProvider
} from '@renderer/config/providers'
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
@ -39,6 +41,7 @@ import {
MCPTool,
MCPToolResponse,
Model,
OpenAIServiceTier,
Provider,
ToolCallResponse,
TranslateAssistant,
@ -63,6 +66,7 @@ import {
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { t } from 'i18next'
import OpenAI, { AzureOpenAI } from 'openai'
import { ChatCompletionContentPart, ChatCompletionContentPartRefusal, ChatCompletionTool } from 'openai/resources'
@ -146,10 +150,11 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
}
return { reasoning: { enabled: false, exclude: true } }
}
if (isSupportedThinkingTokenQwenModel(model) || isSupportedThinkingTokenHunyuanModel(model)) {
if (isQwen3235BA22BThinkingModel(model)) {
return {}
}
if (
isSupportEnableThinkingProvider(this.provider) &&
(isSupportedThinkingTokenQwenModel(model) || isSupportedThinkingTokenHunyuanModel(model))
) {
return { enable_thinking: false }
}
@ -178,6 +183,8 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
return {}
}
// reasoningEffort有效的情况
const effortRatio = EFFORT_RATIO[reasoningEffort]
const budgetTokens = Math.floor(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min!
@ -195,9 +202,10 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
}
// Qwen models
if (isSupportedThinkingTokenQwenModel(model)) {
if (isQwenReasoningModel(model)) {
const thinkConfig = {
enable_thinking: isQwen3235BA22BThinkingModel(model) ? undefined : true,
enable_thinking:
isQwenAlwaysThinkModel(model) || !isSupportEnableThinkingProvider(this.provider) ? undefined : true,
thinking_budget: budgetTokens
}
if (this.provider.id === 'dashscope') {
@ -210,7 +218,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
}
// Hunyuan models
if (isSupportedThinkingTokenHunyuanModel(model)) {
if (isSupportedThinkingTokenHunyuanModel(model) && isSupportEnableThinkingProvider(this.provider)) {
return {
enable_thinking: true
}
@ -218,8 +226,18 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
// Grok models/Perplexity models/OpenAI models
if (isSupportedReasoningEffortModel(model)) {
return {
reasoning_effort: reasoningEffort
// 检查模型是否支持所选选项
const modelType = getThinkModelType(model)
const supportedOptions = MODEL_SUPPORTED_REASONING_EFFORT[modelType]
if (supportedOptions.includes(reasoningEffort)) {
return {
reasoning_effort: reasoningEffort
}
} else {
// 如果不支持fallback到第一个支持的值
return {
reasoning_effort: supportedOptions[0]
}
}
}
@ -530,7 +548,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
if (
lastUserMsg &&
isSupportedThinkingTokenQwenModel(model) &&
!isSupportQwen3EnableThinkingProvider(this.provider)
!isSupportEnableThinkingProvider(this.provider)
) {
const postsuffix = '/no_think'
const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true
@ -550,7 +568,11 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
reqMessages = processReqMessages(model, reqMessages)
// 5. 创建通用参数
const commonParams = {
// Create the appropriate parameters object based on whether streaming is enabled
// Note: Some providers like Mistral don't support stream_options
const shouldIncludeStreamOptions = streamOutput && isSupportStreamOptionsProvider(this.provider)
const commonParams: OpenAISdkParams = {
model: model.id,
messages:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
@ -560,35 +582,24 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
top_p: this.getTopP(assistant, model),
max_tokens: maxTokens,
tools: tools.length > 0 ? tools : undefined,
service_tier: this.getServiceTier(model),
stream: streamOutput,
...(shouldIncludeStreamOptions ? { stream_options: { include_usage: true } } : {}),
// groq 有不同的 service tier 配置,不符合 openai 接口类型
service_tier: this.getServiceTier(model) as OpenAIServiceTier,
...this.getProviderSpecificParameters(assistant, model),
...this.getReasoningEffort(assistant, model),
...getOpenAIWebSearchParams(model, enableWebSearch),
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}),
// OpenRouter usage tracking
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
...(isQwenMTModel(model) ? extra_body : {})
...(isQwenMTModel(model) ? extra_body : {}),
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
}
// Create the appropriate parameters object based on whether streaming is enabled
// Note: Some providers like Mistral don't support stream_options
const shouldIncludeStreamOptions = streamOutput && isSupportStreamOptionsProvider(this.provider)
const sdkParams: OpenAISdkParams = streamOutput
? {
...commonParams,
stream: true,
...(shouldIncludeStreamOptions ? { stream_options: { include_usage: true } } : {})
}
: {
...commonParams,
stream: false
}
const timeout = this.getTimeout(model)
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
return { payload: commonParams, messages: reqMessages, metadata: { timeout } }
}
}
}
@ -758,6 +769,15 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
return
}
if (typeof chunk === 'string') {
try {
chunk = JSON.parse(chunk)
} catch (error) {
logger.error('invalid chunk', { chunk, error })
throw new Error(t('error.chat.chunk.non_json'))
}
}
// 处理chunk
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
for (const choice of chunk.choices) {

View File

@ -99,8 +99,12 @@ export abstract class OpenAIBaseClient<
override async listModels(): Promise<OpenAI.Models.Model[]> {
try {
const sdk = await this.getSdkInstance()
const response = await sdk.models.list()
if (this.provider.id === 'github') {
// GitHub Models 其 models 和 chat completions 两个接口的 baseUrl 不一样
const baseUrl = 'https://models.github.ai/catalog/'
const newSdk = sdk.withOptions({ baseURL: baseUrl })
const response = await newSdk.models.list()
// @ts-ignore key is not typed
return response?.body
.map((model) => ({
@ -111,6 +115,7 @@ export abstract class OpenAIBaseClient<
}))
.filter(isSupportedModel)
}
const response = await sdk.models.list()
if (this.provider.id === 'together') {
// @ts-ignore key is not typed
return response?.body.map((model) => ({

View File

@ -1,3 +1,4 @@
import { loggerService } from '@logger'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
import {
@ -6,7 +7,7 @@ import {
isSupportedReasoningEffortOpenAIModel,
isVisionModel
} from '@renderer/config/models'
import { isSupportDeveloperRoleProvider } from '@renderer/config/providers'
import { isSupportDeveloperRoleProvider, isSupportStreamOptionsProvider } from '@renderer/config/providers'
import { estimateTextTokens } from '@renderer/services/TokenService'
import {
FileMetadata,
@ -15,6 +16,7 @@ import {
MCPTool,
MCPToolResponse,
Model,
OpenAIServiceTier,
Provider,
ToolCallResponse,
WebSearchSource
@ -38,6 +40,7 @@ import {
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { MB } from '@shared/config/constant'
import { t } from 'i18next'
import { isEmpty } from 'lodash'
import OpenAI, { AzureOpenAI } from 'openai'
import { ResponseInput } from 'openai/resources/responses/responses'
@ -46,6 +49,7 @@ import { RequestTransformer, ResponseChunkTransformer } from '../types'
import { OpenAIAPIClient } from './OpenAIApiClient'
import { OpenAIBaseClient } from './OpenAIBaseClient'
const logger = loggerService.withContext('OpenAIResponseAPIClient')
export class OpenAIResponseAPIClient extends OpenAIBaseClient<
OpenAI,
OpenAIResponseSdkParams,
@ -338,8 +342,8 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
}
public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] {
if (typeof sdkPayload.input === 'string') {
return [{ role: 'user', content: sdkPayload.input }]
if (!sdkPayload.input || typeof sdkPayload.input === 'string') {
return [{ role: 'user', content: sdkPayload.input ?? '' }]
}
return sdkPayload.input
}
@ -437,7 +441,10 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
}
tools = tools.concat(extraTools)
const commonParams = {
const shouldIncludeStreamOptions = streamOutput && isSupportStreamOptionsProvider(this.provider)
const commonParams: OpenAIResponseSdkParams = {
model: model.id,
input:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
@ -447,23 +454,17 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
top_p: this.getTopP(assistant, model),
max_output_tokens: maxTokens,
stream: streamOutput,
...(shouldIncludeStreamOptions ? { stream_options: { include_usage: true } } : {}),
tools: !isEmpty(tools) ? tools : undefined,
service_tier: this.getServiceTier(model),
// groq 有不同的 service tier 配置,不符合 openai 接口类型
service_tier: this.getServiceTier(model) as OpenAIServiceTier,
...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning),
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
}
const sdkParams: OpenAIResponseSdkParams = streamOutput
? {
...commonParams,
stream: true
}
: {
...commonParams,
stream: false
}
const timeout = this.getTimeout(model)
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
return { payload: commonParams, messages: reqMessages, metadata: { timeout } }
}
}
}
@ -477,6 +478,14 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
let isFirstTextChunk = true
return () => ({
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
if (typeof chunk === 'string') {
try {
chunk = JSON.parse(chunk)
} catch (error) {
logger.error('invalid chunk', { chunk, error })
throw new Error(t('error.chat.chunk.non_json'))
}
}
// 处理chunk
if ('output' in chunk) {
if (ctx._internal?.toolProcessingState) {

View File

@ -123,7 +123,10 @@ export default class AiProvider {
}
const middlewares = builder.build()
logger.silly('middlewares', middlewares)
logger.silly(
'middlewares',
middlewares.map((m) => m.name)
)
// 3. Create the wrapped SDK method with middlewares
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)

View File

@ -85,9 +85,15 @@ const FinalChunkConsumerMiddleware: CompletionsMiddleware =
logger.warn(`Received undefined chunk before stream was done.`)
}
}
} catch (error) {
} catch (error: any) {
logger.error(`Error consuming stream:`, error as Error)
throw error
// FIXME: 临时解决方案。该中间件的异常无法被 ErrorHandlerMiddleware捕获。
if (params.onError) {
params.onError(error)
}
if (params.shouldThrow) {
throw error
}
} finally {
if (params.onChunk && !isRecursiveCall) {
params.onChunk({

View File

@ -0,0 +1,555 @@
import { useImageTools } from '@renderer/components/ActionTools'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock dependencies
const mocks = vi.hoisted(() => ({
i18n: {
t: (key: string) => key
},
svgToPngBlob: vi.fn(),
svgToSvgBlob: vi.fn(),
download: vi.fn(),
ImagePreviewService: {
show: vi.fn()
}
}))
vi.mock('@renderer/utils/image', () => ({
svgToPngBlob: mocks.svgToPngBlob,
svgToSvgBlob: mocks.svgToSvgBlob
}))
vi.mock('@renderer/utils/download', () => ({
download: mocks.download
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: mocks.i18n.t
})
}))
vi.mock('@renderer/services/ImagePreviewService', () => ({
ImagePreviewService: mocks.ImagePreviewService
}))
vi.mock('@renderer/context/ThemeProvider', () => ({
useTheme: () => ({
theme: 'light'
})
}))
// Mock navigator.clipboard
const mockWrite = vi.fn()
// Mock window.message
const mockMessage = {
success: vi.fn(),
error: vi.fn()
}
// Mock ClipboardItem
class MockClipboardItem {
constructor(items: any) {
return items
}
}
// Mock URL
const mockCreateObjectURL = vi.fn(() => 'blob:test-url')
const mockRevokeObjectURL = vi.fn()
describe('useImageTools', () => {
beforeEach(() => {
// Setup global mocks
Object.defineProperty(global.navigator, 'clipboard', {
value: { write: mockWrite },
writable: true
})
Object.defineProperty(global.window, 'message', {
value: mockMessage,
writable: true
})
// Mock ClipboardItem
global.ClipboardItem = MockClipboardItem as any
// Mock URL
global.URL = {
createObjectURL: mockCreateObjectURL,
revokeObjectURL: mockRevokeObjectURL
} as any
// Mock DOMMatrix
global.DOMMatrix = class DOMMatrix {
m41 = 0
m42 = 0
a = 1
d = 1
constructor(transform?: string) {
if (transform) {
// 简单解析 translate(x, y)
const translateMatch = transform.match(/translate\(([^,]+),\s*([^)]+)\)/)
if (translateMatch) {
this.m41 = parseFloat(translateMatch[1])
this.m42 = parseFloat(translateMatch[2])
}
// 解析 scale(s)
const scaleMatch = transform.match(/scale\(([^)]+)\)/)
if (scaleMatch) {
const scaleValue = parseFloat(scaleMatch[1])
this.a = scaleValue
this.d = scaleValue
}
}
}
static fromMatrix() {
return new DOMMatrix()
}
} as any
vi.clearAllMocks()
})
// 创建模拟的 DOM 环境
const createMockContainer = () => {
const mockContainer = {
addEventListener: vi.fn(),
removeEventListener: vi.fn(),
contains: vi.fn().mockReturnValue(true),
style: {
cursor: ''
},
querySelector: vi.fn(),
shadowRoot: null
} as unknown as HTMLDivElement
return mockContainer
}
const createMockSvgElement = () => {
const mockSvg = {
style: {
transform: '',
transformOrigin: ''
},
cloneNode: vi.fn().mockReturnThis()
} as unknown as SVGElement
return mockSvg
}
describe('initialization', () => {
it('should initialize with default scale', () => {
const mockContainer = createMockContainer()
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
const transform = result.current.getCurrentTransform()
expect(transform.scale).toBe(1)
})
})
describe('pan function', () => {
it('should pan with relative and absolute coordinates', () => {
const mockContainer = createMockContainer()
const mockSvg = createMockSvgElement()
mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg)
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
// 相对坐标平移
act(() => {
result.current.pan(10, 20)
})
expect(mockSvg.style.transform).toContain('translate(10px, 20px)')
// 绝对坐标平移
act(() => {
result.current.pan(50, 60, true)
})
expect(mockSvg.style.transform).toContain('translate(50px, 60px)')
})
})
describe('zoom function', () => {
it('should zoom in/out and set absolute zoom level', () => {
const mockContainer = createMockContainer()
const mockSvg = createMockSvgElement()
mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg)
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
// 放大
act(() => {
result.current.zoom(0.5)
})
expect(result.current.getCurrentTransform().scale).toBe(1.5)
expect(mockSvg.style.transform).toContain('scale(1.5)')
// 缩小
act(() => {
result.current.zoom(-0.3)
})
expect(result.current.getCurrentTransform().scale).toBe(1.2)
expect(mockSvg.style.transform).toContain('scale(1.2)')
// 设置绝对缩放级别
act(() => {
result.current.zoom(2.5, true)
})
expect(result.current.getCurrentTransform().scale).toBe(2.5)
})
it('should constrain zoom between 0.1 and 3', () => {
const mockContainer = createMockContainer()
const mockSvg = createMockSvgElement()
mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg)
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
// 尝试过度缩小
act(() => {
result.current.zoom(-10)
})
expect(result.current.getCurrentTransform().scale).toBe(0.1)
// 尝试过度放大
act(() => {
result.current.zoom(10)
})
expect(result.current.getCurrentTransform().scale).toBe(3)
})
})
describe('copy and download functions', () => {
it('should copy image to clipboard successfully', async () => {
const mockContainer = createMockContainer()
const mockSvg = createMockSvgElement()
mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg)
// Mock svgToPngBlob to return a blob
const mockBlob = new Blob(['test'], { type: 'image/png' })
mocks.svgToPngBlob.mockResolvedValue(mockBlob)
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
await act(async () => {
await result.current.copy()
})
expect(mocks.svgToPngBlob).toHaveBeenCalledWith(mockSvg)
expect(mockWrite).toHaveBeenCalled()
expect(mockMessage.success).toHaveBeenCalledWith('message.copy.success')
})
it('should download image as PNG and SVG', async () => {
const mockContainer = createMockContainer()
const mockSvg = createMockSvgElement()
mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg)
// Mock svgToPngBlob to return a blob
const pngBlob = new Blob(['test'], { type: 'image/png' })
mocks.svgToPngBlob.mockResolvedValue(pngBlob)
// Mock svgToSvgBlob to return a blob
const svgBlob = new Blob(['<svg></svg>'], { type: 'image/svg+xml' })
mocks.svgToSvgBlob.mockReturnValue(svgBlob)
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
// 下载 PNG
await act(async () => {
await result.current.download('png')
})
expect(mocks.svgToPngBlob).toHaveBeenCalledWith(mockSvg)
// 下载 SVG
await act(async () => {
await result.current.download('svg')
})
expect(mocks.svgToSvgBlob).toHaveBeenCalledWith(mockSvg)
// 验证通用的下载流程
expect(mockCreateObjectURL).toHaveBeenCalledTimes(2)
expect(mocks.download).toHaveBeenCalledTimes(2)
expect(mockRevokeObjectURL).toHaveBeenCalledTimes(2)
})
it('should handle copy/download failures and missing elements', async () => {
const mockContainer = createMockContainer()
const mockSvg = createMockSvgElement()
// 测试无元素情况
mockContainer.querySelector = vi.fn().mockReturnValue(null)
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
// 复制无元素
await act(async () => {
await result.current.copy()
})
expect(mocks.svgToPngBlob).not.toHaveBeenCalled()
// 下载无元素
await act(async () => {
await result.current.download('png')
})
expect(mocks.svgToPngBlob).not.toHaveBeenCalled()
// 测试失败情况
mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg)
mocks.svgToPngBlob.mockRejectedValue(new Error('Conversion failed'))
// 复制失败
await act(async () => {
await result.current.copy()
})
expect(mockMessage.error).toHaveBeenCalledWith('message.copy.failed')
// 下载失败
await act(async () => {
await result.current.download('png')
})
expect(mockMessage.error).toHaveBeenCalledWith('message.download.failed')
})
})
describe('dialog function', () => {
it('should preview image successfully', async () => {
const mockContainer = createMockContainer()
const mockSvg = createMockSvgElement()
mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg)
mocks.ImagePreviewService.show.mockResolvedValue(undefined)
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
await act(async () => {
await result.current.dialog()
})
expect(mocks.ImagePreviewService.show).toHaveBeenCalledWith(mockSvg, { format: 'svg' })
})
it('should handle preview failure', async () => {
const mockContainer = createMockContainer()
const mockSvg = createMockSvgElement()
mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg)
mocks.ImagePreviewService.show.mockRejectedValue(new Error('Preview failed'))
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
await act(async () => {
await result.current.dialog()
})
expect(mockMessage.error).toHaveBeenCalledWith('message.dialog.failed')
})
it('should do nothing when no element is found', async () => {
const mockContainer = createMockContainer()
mockContainer.querySelector = vi.fn().mockReturnValue(null)
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
await act(async () => {
await result.current.dialog()
})
expect(mocks.ImagePreviewService.show).not.toHaveBeenCalled()
})
})
describe('event listener management', () => {
it('should attach/remove event listeners based on options', () => {
const mockContainer = createMockContainer()
// 启用拖拽和滚轮缩放
renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg',
enableDrag: true,
enableWheelZoom: true
}
)
)
expect(mockContainer.addEventListener).toHaveBeenCalledWith('mousedown', expect.any(Function))
expect(mockContainer.addEventListener).toHaveBeenCalledWith('wheel', expect.any(Function), { passive: true })
// 重置并测试禁用情况
vi.clearAllMocks()
renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg',
enableDrag: false,
enableWheelZoom: false
}
)
)
expect(mockContainer.addEventListener).not.toHaveBeenCalledWith('mousedown', expect.any(Function))
expect(mockContainer.addEventListener).not.toHaveBeenCalledWith('wheel', expect.any(Function))
})
})
describe('getCurrentTransform function', () => {
it('should return current scale and position', () => {
const mockContainer = createMockContainer()
const mockSvg = createMockSvgElement()
mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg)
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
// 初始状态
const initialTransform = result.current.getCurrentTransform()
expect(initialTransform).toEqual({ scale: 1, x: 0, y: 0 })
// 缩放后状态
act(() => {
result.current.zoom(0.5)
})
const zoomedTransform = result.current.getCurrentTransform()
expect(zoomedTransform.scale).toBe(1.5)
expect(zoomedTransform.x).toBe(0)
expect(zoomedTransform.y).toBe(0)
// 平移后状态
act(() => {
result.current.pan(10, 20)
})
const pannedTransform = result.current.getCurrentTransform()
expect(pannedTransform.scale).toBe(1.5)
expect(pannedTransform.x).toBe(10)
expect(pannedTransform.y).toBe(20)
})
it('should get position from DOMMatrix when element has transform', () => {
const mockContainer = createMockContainer()
const mockSvg = createMockSvgElement()
mockSvg.style.transform = 'translate(30px, 40px) scale(2)'
mockContainer.querySelector = vi.fn().mockReturnValue(mockSvg)
const { result } = renderHook(() =>
useImageTools(
{ current: mockContainer },
{
prefix: 'test',
imgSelector: 'svg'
}
)
)
// 手动设置 transformRef 以匹配 DOM 状态
act(() => {
result.current.pan(30, 40, true)
result.current.zoom(2, true)
})
const transform = result.current.getCurrentTransform()
expect(transform.scale).toBe(2)
expect(transform.x).toBe(30)
expect(transform.y).toBe(40)
})
})
})

View File

@ -0,0 +1,215 @@
import { ActionTool, useToolManager } from '@renderer/components/ActionTools'
import { act, renderHook } from '@testing-library/react'
import { useState } from 'react'
import { describe, expect, it } from 'vitest'
// 创建测试工具数据
const createTestTool = (overrides: Partial<ActionTool> = {}): ActionTool => ({
id: 'test-tool',
type: 'core',
order: 10,
icon: 'TestIcon',
tooltip: 'Test Tool',
...overrides
})
describe('useToolManager', () => {
describe('registerTool', () => {
it('should register a new tool', () => {
const { result } = renderHook(() => {
const [tools, setTools] = useState<ActionTool[]>([])
const { registerTool } = useToolManager(setTools)
return { tools, registerTool }
})
const testTool = createTestTool()
act(() => {
result.current.registerTool(testTool)
})
expect(result.current.tools).toHaveLength(1)
expect(result.current.tools[0]).toEqual(testTool)
})
it('should replace existing tool with same id', () => {
const { result } = renderHook(() => {
const [tools, setTools] = useState<ActionTool[]>([])
const { registerTool } = useToolManager(setTools)
return { tools, registerTool }
})
const originalTool = createTestTool({ tooltip: 'Original' })
const updatedTool = createTestTool({ tooltip: 'Updated' })
act(() => {
result.current.registerTool(originalTool)
result.current.registerTool(updatedTool)
})
expect(result.current.tools).toHaveLength(1)
expect(result.current.tools[0]).toEqual(updatedTool)
})
it('should sort tools by order (descending)', () => {
const { result } = renderHook(() => {
const [tools, setTools] = useState<ActionTool[]>([])
const { registerTool } = useToolManager(setTools)
return { tools, registerTool }
})
const tool1 = createTestTool({ id: 'tool1', order: 10 })
const tool2 = createTestTool({ id: 'tool2', order: 30 })
const tool3 = createTestTool({ id: 'tool3', order: 20 })
act(() => {
result.current.registerTool(tool1)
result.current.registerTool(tool2)
result.current.registerTool(tool3)
})
// 应该按 order 降序排列
expect(result.current.tools[0].id).toBe('tool2') // order: 30
expect(result.current.tools[1].id).toBe('tool3') // order: 20
expect(result.current.tools[2].id).toBe('tool1') // order: 10
})
it('should handle tools with children', () => {
const { result } = renderHook(() => {
const [tools, setTools] = useState<ActionTool[]>([])
const { registerTool } = useToolManager(setTools)
return { tools, registerTool }
})
const childTool = createTestTool({ id: 'child-tool', order: 5 })
const parentTool = createTestTool({
id: 'parent-tool',
order: 15,
children: [childTool]
})
act(() => {
result.current.registerTool(parentTool)
})
expect(result.current.tools).toHaveLength(1)
expect(result.current.tools[0]).toEqual(parentTool)
expect(result.current.tools[0].children).toEqual([childTool])
})
it('should not modify state if setTools is not provided', () => {
const { result } = renderHook(() => useToolManager(undefined))
// 不应该抛出错误
expect(() => {
act(() => {
result.current.registerTool(createTestTool())
})
}).not.toThrow()
})
})
describe('removeTool', () => {
it('should remove tool by id', () => {
const { result } = renderHook(() => {
const [tools, setTools] = useState<ActionTool[]>([createTestTool()])
const { registerTool, removeTool } = useToolManager(setTools)
return { tools, registerTool, removeTool }
})
expect(result.current.tools).toHaveLength(1)
act(() => {
result.current.removeTool('test-tool')
})
expect(result.current.tools).toHaveLength(0)
})
it('should not affect other tools when removing one', () => {
const { result } = renderHook(() => {
const toolsData = [
createTestTool({ id: 'tool1' }),
createTestTool({ id: 'tool2' }),
createTestTool({ id: 'tool3' })
]
const [tools, setTools] = useState<ActionTool[]>(toolsData)
const { removeTool } = useToolManager(setTools)
return { tools, removeTool }
})
expect(result.current.tools).toHaveLength(3)
act(() => {
result.current.removeTool('tool2')
})
expect(result.current.tools).toHaveLength(2)
expect(result.current.tools[0].id).toBe('tool1')
expect(result.current.tools[1].id).toBe('tool3')
})
it('should handle removing non-existent tool', () => {
const { result } = renderHook(() => {
const [tools, setTools] = useState<ActionTool[]>([createTestTool()])
const { removeTool } = useToolManager(setTools)
return { tools, removeTool }
})
expect(result.current.tools).toHaveLength(1)
act(() => {
result.current.removeTool('non-existent-tool')
})
expect(result.current.tools).toHaveLength(1) // 应该没有变化
})
it('should not modify state if setTools is not provided', () => {
const { result } = renderHook(() => useToolManager(undefined))
// 不应该抛出错误
expect(() => {
act(() => {
result.current.removeTool('test-tool')
})
}).not.toThrow()
})
})
describe('integration', () => {
it('should handle register and remove operations together', () => {
const { result } = renderHook(() => {
const [tools, setTools] = useState<ActionTool[]>([])
const { registerTool, removeTool } = useToolManager(setTools)
return { tools, registerTool, removeTool }
})
const tool1 = createTestTool({ id: 'tool1' })
const tool2 = createTestTool({ id: 'tool2' })
// 注册两个工具
act(() => {
result.current.registerTool(tool1)
result.current.registerTool(tool2)
})
expect(result.current.tools).toHaveLength(2)
// 移除一个工具
act(() => {
result.current.removeTool('tool1')
})
expect(result.current.tools).toHaveLength(1)
expect(result.current.tools[0].id).toBe('tool2')
// 再次注册被移除的工具
act(() => {
result.current.registerTool(tool1)
})
expect(result.current.tools).toHaveLength(2)
})
})
})

View File

@ -1,6 +1,6 @@
import { CodeToolSpec } from './types'
import { ActionToolSpec } from './types'
export const TOOL_SPECS: Record<string, CodeToolSpec> = {
export const TOOL_SPECS: Record<string, ActionToolSpec> = {
// Core tools
copy: {
id: 'copy',

View File

@ -0,0 +1,292 @@
import { loggerService } from '@logger'
import { useTheme } from '@renderer/context/ThemeProvider'
import { ImagePreviewService } from '@renderer/services/ImagePreviewService'
import { download as downloadFile } from '@renderer/utils/download'
import { svgToPngBlob, svgToSvgBlob } from '@renderer/utils/image'
import { RefObject, useCallback, useEffect, useRef } from 'react'
import { useTranslation } from 'react-i18next'
const logger = loggerService.withContext('usePreviewToolHandlers')
/**
* 使Hook
*
*/
export const useImageTools = (
containerRef: RefObject<HTMLDivElement | null>,
options: {
prefix: string
imgSelector: string
enableDrag?: boolean
enableWheelZoom?: boolean
}
) => {
const transformRef = useRef({ scale: 1, x: 0, y: 0 }) // 管理变换状态
const { imgSelector, prefix, enableDrag, enableWheelZoom } = options
const { t } = useTranslation()
const { theme } = useTheme()
// 创建选择器函数
const getImgElement = useCallback(() => {
if (!containerRef.current) return null
// 优先尝试从 Shadow DOM 中查找
const shadowRoot = containerRef.current.shadowRoot
if (shadowRoot) {
return shadowRoot.querySelector(imgSelector) as SVGElement | null
}
// 降级到常规 DOM 查找
return containerRef.current.querySelector(imgSelector) as SVGElement | null
}, [containerRef, imgSelector])
// 获取原始图像元素(移除所有变换)
const getCleanImgElement = useCallback((): SVGElement | null => {
const imgElement = getImgElement()
if (!imgElement) return null
const clonedElement = imgElement.cloneNode(true) as SVGElement
clonedElement.style.transform = ''
clonedElement.style.transformOrigin = ''
return clonedElement
}, [getImgElement])
// 查询当前位置
const getCurrentPosition = useCallback(() => {
const imgElement = getImgElement()
if (!imgElement) return transformRef.current
const transform = imgElement.style.transform
if (!transform || transform === 'none') return transformRef.current
// 使用CSS矩阵解析
const matrix = new DOMMatrix(transform)
return { x: matrix.m41, y: matrix.m42 }
}, [getImgElement])
/**
*
* @param element
* @param x X轴偏移量
* @param y Y轴偏移量
* @param scale
*/
const applyTransform = useCallback((element: SVGElement | null, x: number, y: number, scale: number) => {
if (!element) return
element.style.transformOrigin = 'top left'
element.style.transform = `translate(${x}px, ${y}px) scale(${scale})`
}, [])
/**
* -
* @param dx X轴偏移量
* @param dy Y轴偏移量
* @param absolute truefalse
*/
const pan = useCallback(
(dx: number, dy: number, absolute = false) => {
const currentPos = getCurrentPosition()
const newX = absolute ? dx : currentPos.x + dx
const newY = absolute ? dy : currentPos.y + dy
transformRef.current.x = newX
transformRef.current.y = newY
const imgElement = getImgElement()
applyTransform(imgElement, newX, newY, transformRef.current.scale)
},
[getCurrentPosition, getImgElement, applyTransform]
)
// 拖拽平移支持
useEffect(() => {
if (!enableDrag || !containerRef.current) return
const container = containerRef.current
const startPos = { x: 0, y: 0 }
const handleMouseMove = (e: MouseEvent) => {
const dx = e.clientX - startPos.x
const dy = e.clientY - startPos.y
// 直接使用 transformRef 中的初始偏移量进行计算
const newX = transformRef.current.x + dx
const newY = transformRef.current.y + dy
const imgElement = getImgElement()
// 实时应用变换,但不更新 ref避免累积误差
applyTransform(imgElement, newX, newY, transformRef.current.scale)
e.preventDefault()
}
const handleMouseUp = (e: MouseEvent) => {
document.removeEventListener('mousemove', handleMouseMove)
document.removeEventListener('mouseup', handleMouseUp)
container.style.cursor = 'default'
// 拖拽结束后,计算最终位置并更新 ref
const dx = e.clientX - startPos.x
const dy = e.clientY - startPos.y
transformRef.current.x += dx
transformRef.current.y += dy
}
const handleMouseDown = (e: MouseEvent) => {
if (e.button !== 0) return // 只响应左键
// 每次拖拽开始时,都以 ref 中当前的位置为基准
const currentPos = getCurrentPosition()
transformRef.current.x = currentPos.x
transformRef.current.y = currentPos.y
startPos.x = e.clientX
startPos.y = e.clientY
container.style.cursor = 'grabbing'
e.preventDefault()
document.addEventListener('mousemove', handleMouseMove)
document.addEventListener('mouseup', handleMouseUp)
}
container.addEventListener('mousedown', handleMouseDown)
return () => {
container.removeEventListener('mousedown', handleMouseDown)
// 清理以防万一,例如组件在拖拽过程中被卸载
document.removeEventListener('mousemove', handleMouseMove)
document.removeEventListener('mouseup', handleMouseUp)
}
}, [containerRef, getImgElement, applyTransform, getCurrentPosition, enableDrag])
/**
*
* @param delta
*/
const zoom = useCallback(
(delta: number, absolute = false) => {
const newScale = absolute
? Math.max(0.1, Math.min(3, delta))
: Math.max(0.1, Math.min(3, transformRef.current.scale + delta))
transformRef.current.scale = newScale
const imgElement = getImgElement()
applyTransform(imgElement, transformRef.current.x, transformRef.current.y, newScale)
},
[getImgElement, applyTransform]
)
// 滚轮缩放支持
useEffect(() => {
if (!enableWheelZoom || !containerRef.current) return
const container = containerRef.current
const handleWheel = (e: WheelEvent) => {
if ((e.ctrlKey || e.metaKey) && e.target) {
// 确认事件发生在容器内部
if (container.contains(e.target as Node)) {
const delta = e.deltaY < 0 ? 0.1 : -0.1
zoom(delta)
}
}
}
container.addEventListener('wheel', handleWheel, { passive: true })
return () => container.removeEventListener('wheel', handleWheel)
}, [containerRef, zoom, enableWheelZoom])
/**
*
*
* 使
*/
const copy = useCallback(async () => {
try {
const imgElement = getCleanImgElement()
if (!imgElement) return
const blob = await svgToPngBlob(imgElement)
await navigator.clipboard.write([new ClipboardItem({ 'image/png': blob })])
window.message.success(t('message.copy.success'))
} catch (error) {
logger.error('Copy failed:', error as Error)
window.message.error(t('message.copy.failed'))
}
}, [getCleanImgElement, t])
/**
*
*
* 使
*/
const download = useCallback(
async (format: 'svg' | 'png') => {
try {
const imgElement = getCleanImgElement()
if (!imgElement) return
const timestamp = Date.now()
if (format === 'svg') {
const blob = svgToSvgBlob(imgElement)
const url = URL.createObjectURL(blob)
downloadFile(url, `${prefix}-${timestamp}.svg`)
URL.revokeObjectURL(url)
} else {
const blob = await svgToPngBlob(imgElement)
const pngUrl = URL.createObjectURL(blob)
downloadFile(pngUrl, `${prefix}-${timestamp}.png`)
URL.revokeObjectURL(pngUrl)
}
} catch (error) {
logger.error('Download failed:', error as Error)
window.message.error(t('message.download.failed'))
}
},
[getCleanImgElement, prefix, t]
)
/**
* dialog
*
* 使
*/
const dialog = useCallback(async () => {
try {
const imgElement = getCleanImgElement()
if (!imgElement) return
await ImagePreviewService.show(imgElement, { format: 'svg' })
} catch (error) {
logger.error('Dialog preview failed:', error as Error)
window.message.error(t('message.dialog.failed'))
}
}, [getCleanImgElement, t])
// 获取当前变换状态
const getCurrentTransform = useCallback(() => {
return {
scale: transformRef.current.scale,
x: transformRef.current.x,
y: transformRef.current.y
}
}, [transformRef])
// 切换主题时重置变换
useEffect(() => {
pan(0, 0, true)
zoom(1, true)
}, [pan, zoom, theme])
return {
zoom,
pan,
copy,
download,
dialog,
getCurrentTransform
}
}

View File

@ -1,11 +1,11 @@
import { useCallback } from 'react'
import { CodeTool } from './types'
import { ActionTool, ToolRegisterProps } from '../types'
export const useCodeTool = (setTools?: (value: React.SetStateAction<CodeTool[]>) => void) => {
export const useToolManager = (setTools?: ToolRegisterProps['setTools']) => {
// 注册工具如果已存在同ID工具则替换
const registerTool = useCallback(
(tool: CodeTool) => {
(tool: ActionTool) => {
setTools?.((prev) => {
const filtered = prev.filter((t) => t.id !== tool.id)
return [...filtered, tool].sort((a, b) => b.order - a.order)

View File

@ -0,0 +1,4 @@
export * from './constants'
export * from './hooks/useImageTools'
export * from './hooks/useToolManager'
export * from './types'

View File

@ -0,0 +1,34 @@
/**
*
*/
export interface ActionToolSpec {
id: string
type: 'core' | 'quick'
order: number
}
/**
*
* @param id
* @param type
* @param order
* @param icon
* @param tooltip
* @param visible
* @param onClick
* @param children more
*/
export interface ActionTool extends ActionToolSpec {
icon: React.ReactNode
tooltip?: string
visible?: () => boolean
onClick?: () => void
children?: Omit<ActionTool, 'children'>[]
}
/**
* props
*/
export interface ToolRegisterProps {
setTools?: (value: React.SetStateAction<ActionTool[]>) => void
}

View File

@ -1,102 +0,0 @@
import { usePreviewToolHandlers, usePreviewTools } from '@renderer/components/CodeToolbar'
import { LoadingIcon } from '@renderer/components/Icons'
import { AsyncInitializer } from '@renderer/utils/asyncInitializer'
import { Flex, Spin } from 'antd'
import { debounce } from 'lodash'
import React, { memo, startTransition, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import styled from 'styled-components'
import PreviewError from './PreviewError'
import { BasicPreviewProps } from './types'
// 管理 viz 实例
const vizInitializer = new AsyncInitializer(async () => {
const module = await import('@viz-js/viz')
return await module.instance()
})
/** Graphviz
*
*/
const GraphvizPreview: React.FC<BasicPreviewProps> = ({ children, setTools }) => {
const graphvizRef = useRef<HTMLDivElement>(null)
const [error, setError] = useState<string | null>(null)
const [isLoading, setIsLoading] = useState(false)
// 使用通用图像工具
const { handleZoom, handleCopyImage, handleDownload } = usePreviewToolHandlers(graphvizRef, {
imgSelector: 'svg',
prefix: 'graphviz',
enableWheelZoom: true
})
// 使用工具栏
usePreviewTools({
setTools,
handleZoom,
handleCopyImage,
handleDownload
})
// 实际的渲染函数
const renderGraphviz = useCallback(async (content: string) => {
if (!content || !graphvizRef.current) return
try {
setIsLoading(true)
const viz = await vizInitializer.get()
const svgElement = viz.renderSVGElement(content)
// 清空容器并添加新的 SVG
graphvizRef.current.innerHTML = ''
graphvizRef.current.appendChild(svgElement)
// 渲染成功,清除错误记录
setError(null)
} catch (error) {
setError((error as Error).message || 'DOT syntax error or rendering failed')
} finally {
setIsLoading(false)
}
}, [])
// debounce 渲染
const debouncedRender = useMemo(
() =>
debounce((content: string) => {
startTransition(() => renderGraphviz(content))
}, 300),
[renderGraphviz]
)
// 触发渲染
useEffect(() => {
if (children) {
setIsLoading(true)
debouncedRender(children)
} else {
debouncedRender.cancel()
setIsLoading(false)
}
return () => {
debouncedRender.cancel()
}
}, [children, debouncedRender])
return (
<Spin spinning={isLoading} indicator={<LoadingIcon color="var(--color-text-2)" />}>
<Flex vertical style={{ minHeight: isLoading ? '2rem' : 'auto' }}>
{error && <PreviewError>{error}</PreviewError>}
<StyledGraphviz ref={graphvizRef} className="graphviz special-preview" />
</Flex>
</Spin>
)
}
const StyledGraphviz = styled.div`
overflow: auto;
`
export default memo(GraphvizPreview)

View File

@ -22,45 +22,51 @@ const HtmlArtifactsPopup: React.FC<HtmlArtifactsPopupProps> = ({ open, title, ht
const [currentHtml, setCurrentHtml] = useState(html)
const [isFullscreen, setIsFullscreen] = useState(false)
// 预览刷新相关状态
// Preview refresh related state
const [previewHtml, setPreviewHtml] = useState(html)
const intervalRef = useRef<NodeJS.Timeout | null>(null)
const latestHtmlRef = useRef(html)
const currentPreviewHtmlRef = useRef(html)
// 当外部html更新时同步更新内部状态
// Sync internal state when external html updates
useEffect(() => {
setCurrentHtml(html)
latestHtmlRef.current = html
}, [html])
// 当内部编辑的html更新时更新引用
// Update reference when internally edited html changes
useEffect(() => {
latestHtmlRef.current = currentHtml
}, [currentHtml])
// 2秒定时检查并刷新预览仅在内容变化时
// Update reference when preview content changes
useEffect(() => {
currentPreviewHtmlRef.current = previewHtml
}, [previewHtml])
// Check and refresh preview every 2 seconds (only when content changes)
useEffect(() => {
if (!open) return
// 立即设置初始预览内容
setPreviewHtml(currentHtml)
// Set initial preview content immediately
setPreviewHtml(latestHtmlRef.current)
// 设置定时器每2秒检查一次内容是否有变化
// Set timer to check for content changes every 2 seconds
intervalRef.current = setInterval(() => {
if (latestHtmlRef.current !== previewHtml) {
if (latestHtmlRef.current !== currentPreviewHtmlRef.current) {
setPreviewHtml(latestHtmlRef.current)
}
}, 2000)
// 清理函数
// Cleanup function
return () => {
if (intervalRef.current) {
clearInterval(intervalRef.current)
}
}
}, [currentHtml, open, previewHtml])
}, [open])
// 全屏时防止 body 滚动
// Prevent body scroll when fullscreen
useEffect(() => {
if (!open || !isFullscreen) return
@ -147,9 +153,10 @@ const HtmlArtifactsPopup: React.FC<HtmlArtifactsPopupProps> = ({ open, title, ht
editable={true}
onSave={setCurrentHtml}
style={{ height: '100%' }}
expanded
unwrapped={false}
options={{
stream: false,
collapsible: false
stream: false
}}
/>
</CodeSection>
@ -159,7 +166,7 @@ const HtmlArtifactsPopup: React.FC<HtmlArtifactsPopupProps> = ({ open, title, ht
<PreviewSection>
{previewHtml.trim() ? (
<PreviewFrame
key={previewHtml} // 强制重新创建iframe当预览内容变化时
key={previewHtml} // Force recreate iframe when preview content changes
srcDoc={previewHtml}
title="HTML Preview"
sandbox="allow-scripts allow-same-origin allow-forms"
@ -176,7 +183,6 @@ const HtmlArtifactsPopup: React.FC<HtmlArtifactsPopupProps> = ({ open, title, ht
)
}
// 简化的样式组件
const StyledModal = styled(Modal)<{ $isFullscreen?: boolean }>`
${(props) =>
props.$isFullscreen

View File

@ -1,155 +0,0 @@
import { nanoid } from '@reduxjs/toolkit'
import { usePreviewToolHandlers, usePreviewTools } from '@renderer/components/CodeToolbar'
import { LoadingIcon } from '@renderer/components/Icons'
import { useMermaid } from '@renderer/hooks/useMermaid'
import { Flex, Spin } from 'antd'
import { debounce } from 'lodash'
import React, { memo, startTransition, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import styled from 'styled-components'
import PreviewError from './PreviewError'
import { BasicPreviewProps } from './types'
/** Mermaid
*
* FIXME: 等将来容易判断代码块结束位置时再重构
*/
const MermaidPreview: React.FC<BasicPreviewProps> = ({ children, setTools }) => {
const { mermaid, isLoading: isLoadingMermaid, error: mermaidError } = useMermaid()
const mermaidRef = useRef<HTMLDivElement>(null)
const diagramId = useRef<string>(`mermaid-${nanoid(6)}`).current
const [error, setError] = useState<string | null>(null)
const [isRendering, setIsRendering] = useState(false)
const [isVisible, setIsVisible] = useState(true)
// 使用通用图像工具
const { handleZoom, handleCopyImage, handleDownload } = usePreviewToolHandlers(mermaidRef, {
imgSelector: 'svg',
prefix: 'mermaid',
enableWheelZoom: true
})
// 使用工具栏
usePreviewTools({
setTools,
handleZoom,
handleCopyImage,
handleDownload
})
// 实际的渲染函数
const renderMermaid = useCallback(
async (content: string) => {
if (!content || !mermaidRef.current) return
try {
setIsRendering(true)
// 验证语法,提前抛出异常
await mermaid.parse(content)
const { svg } = await mermaid.render(diagramId, content, mermaidRef.current)
// 避免不可见时产生 undefined 和 NaN
const fixedSvg = svg.replace(/translate\(undefined,\s*NaN\)/g, 'translate(0, 0)')
mermaidRef.current.innerHTML = fixedSvg
// 渲染成功,清除错误记录
setError(null)
} catch (error) {
setError((error as Error).message)
} finally {
setIsRendering(false)
}
},
[diagramId, mermaid]
)
// debounce 渲染
const debouncedRender = useMemo(
() =>
debounce((content: string) => {
startTransition(() => renderMermaid(content))
}, 300),
[renderMermaid]
)
/**
*
* `MessageGroup` `fold` `display: none`
* `fold` className `MessageWrapper`
* FIXME: 将来 mermaid-js
*/
useEffect(() => {
if (!mermaidRef.current) return
const checkVisibility = () => {
const element = mermaidRef.current
if (!element) return
const currentlyVisible = element.offsetParent !== null
setIsVisible(currentlyVisible)
}
// 初始检查
checkVisibility()
const observer = new MutationObserver(() => {
checkVisibility()
})
let targetElement = mermaidRef.current.parentElement
while (targetElement) {
observer.observe(targetElement, {
attributes: true,
attributeFilter: ['class', 'style']
})
if (targetElement.className?.includes('fold')) {
break
}
targetElement = targetElement.parentElement
}
return () => {
observer.disconnect()
}
}, [])
// 触发渲染
useEffect(() => {
if (isLoadingMermaid) return
if (mermaidRef.current?.offsetParent === null) return
if (children) {
setIsRendering(true)
debouncedRender(children)
} else {
debouncedRender.cancel()
setIsRendering(false)
}
return () => {
debouncedRender.cancel()
}
}, [children, isLoadingMermaid, debouncedRender, isVisible])
const isLoading = isLoadingMermaid || isRendering
return (
<Spin spinning={isLoading} indicator={<LoadingIcon color="var(--color-text-2)" />}>
<Flex vertical style={{ minHeight: isLoading ? '2rem' : 'auto' }}>
{(mermaidError || error) && <PreviewError>{mermaidError || error}</PreviewError>}
<StyledMermaid ref={mermaidRef} className="mermaid special-preview" />
</Flex>
</Spin>
)
}
const StyledMermaid = styled.div`
overflow: auto;
`
export default memo(MermaidPreview)

View File

@ -1,192 +0,0 @@
import { LoadingOutlined } from '@ant-design/icons'
import { usePreviewToolHandlers, usePreviewTools } from '@renderer/components/CodeToolbar'
import { Spin } from 'antd'
import pako from 'pako'
import React, { memo, useCallback, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import { BasicPreviewProps } from './types'
const PlantUMLServer = 'https://www.plantuml.com/plantuml'
function encode64(data: Uint8Array) {
let r = ''
for (let i = 0; i < data.length; i += 3) {
if (i + 2 === data.length) {
r += append3bytes(data[i], data[i + 1], 0)
} else if (i + 1 === data.length) {
r += append3bytes(data[i], 0, 0)
} else {
r += append3bytes(data[i], data[i + 1], data[i + 2])
}
}
return r
}
function encode6bit(b: number) {
if (b < 10) {
return String.fromCharCode(48 + b)
}
b -= 10
if (b < 26) {
return String.fromCharCode(65 + b)
}
b -= 26
if (b < 26) {
return String.fromCharCode(97 + b)
}
b -= 26
if (b === 0) {
return '-'
}
if (b === 1) {
return '_'
}
return '?'
}
function append3bytes(b1: number, b2: number, b3: number) {
const c1 = b1 >> 2
const c2 = ((b1 & 0x3) << 4) | (b2 >> 4)
const c3 = ((b2 & 0xf) << 2) | (b3 >> 6)
const c4 = b3 & 0x3f
let r = ''
r += encode6bit(c1 & 0x3f)
r += encode6bit(c2 & 0x3f)
r += encode6bit(c3 & 0x3f)
r += encode6bit(c4 & 0x3f)
return r
}
/**
* https://plantuml.com/zh/code-javascript-synchronous
* To use PlantUML image generation, a text diagram description have to be :
1. Encoded in UTF-8
2. Compressed using Deflate algorithm
3. Reencoded in ASCII using a transformation _close_ to base64
*/
function encodeDiagram(diagram: string): string {
const utf8text = new TextEncoder().encode(diagram)
const compressed = pako.deflateRaw(utf8text)
return encode64(compressed)
}
async function downloadUrl(url: string, filename: string) {
const response = await fetch(url)
if (!response.ok) {
window.message.warning({ content: response.statusText, duration: 1.5 })
return
}
const blob = await response.blob()
const link = document.createElement('a')
link.href = URL.createObjectURL(blob)
link.download = filename
document.body.appendChild(link)
link.click()
document.body.removeChild(link)
URL.revokeObjectURL(link.href)
}
type PlantUMLServerImageProps = {
format: 'png' | 'svg'
diagram: string
onClick?: React.MouseEventHandler<HTMLDivElement>
className?: string
}
function getPlantUMLImageUrl(format: 'png' | 'svg', diagram: string, isDark?: boolean) {
const encodedDiagram = encodeDiagram(diagram)
if (isDark) {
return `${PlantUMLServer}/d${format}/${encodedDiagram}`
}
return `${PlantUMLServer}/${format}/${encodedDiagram}`
}
const PlantUMLServerImage: React.FC<PlantUMLServerImageProps> = ({ format, diagram, onClick, className }) => {
const [loading, setLoading] = useState(true)
// FIXME: 黑暗模式背景太黑了,目前让 PlantUML 和 SVG 一样保持白色背景
const url = getPlantUMLImageUrl(format, diagram, false)
return (
<StyledPlantUML onClick={onClick} className={className}>
<Spin
spinning={loading}
indicator={
<LoadingOutlined
spin
style={{
fontSize: 32
}}
/>
}>
<img
src={url}
onLoad={() => {
setLoading(false)
}}
onError={(e) => {
setLoading(false)
const target = e.target as HTMLImageElement
target.style.opacity = '0.5'
target.style.filter = 'blur(2px)'
}}
/>
</Spin>
</StyledPlantUML>
)
}
const PlantUmlPreview: React.FC<BasicPreviewProps> = ({ children, setTools }) => {
const { t } = useTranslation()
const containerRef = useRef<HTMLDivElement>(null)
const encodedDiagram = encodeDiagram(children)
// 自定义 PlantUML 下载方法
const customDownload = useCallback(
(format: 'svg' | 'png') => {
const timestamp = Date.now()
const url = `${PlantUMLServer}/${format}/${encodedDiagram}`
const filename = `plantuml-diagram-${timestamp}.${format}`
downloadUrl(url, filename).catch(() => {
window.message.error(t('code_block.download.failed.network'))
})
},
[encodedDiagram, t]
)
// 使用通用图像工具,提供自定义下载方法
const { handleZoom, handleCopyImage } = usePreviewToolHandlers(containerRef, {
imgSelector: '.plantuml-preview img',
prefix: 'plantuml-diagram',
enableWheelZoom: true,
customDownloader: customDownload
})
// 使用工具栏
usePreviewTools({
setTools,
handleZoom,
handleCopyImage,
handleDownload: customDownload
})
return (
<div ref={containerRef}>
<PlantUMLServerImage format="svg" diagram={children} className="plantuml-preview special-preview" />
</div>
)
}
const StyledPlantUML = styled.div`
max-height: calc(80vh - 100px);
text-align: left;
overflow-y: auto;
background-color: white;
img {
max-width: 100%;
height: auto;
min-height: 100px;
transition: transform 0.2s ease;
}
`
export default memo(PlantUmlPreview)

View File

@ -1,14 +0,0 @@
import { memo } from 'react'
import { styled } from 'styled-components'
const PreviewError = styled.div`
overflow: auto;
padding: 16px;
color: #ff4d4f;
border: 1px solid #ff4d4f;
border-radius: 4px;
word-wrap: break-word;
white-space: pre-wrap;
`
export default memo(PreviewError)

View File

@ -18,6 +18,7 @@ const Container = styled(Flex)`
gap: 8px;
overflow-y: auto;
text-wrap: wrap;
border-radius: 0 0 8px 8px;
`
export default memo(StatusBar)

View File

@ -1,61 +0,0 @@
import { usePreviewToolHandlers, usePreviewTools } from '@renderer/components/CodeToolbar'
import { memo, useEffect, useRef } from 'react'
import { BasicPreviewProps } from './types'
/**
* 使 Shadow DOM SVG
*/
const SvgPreview: React.FC<BasicPreviewProps> = ({ children, setTools }) => {
const svgContainerRef = useRef<HTMLDivElement>(null)
useEffect(() => {
const container = svgContainerRef.current
if (!container) return
const shadowRoot = container.shadowRoot || container.attachShadow({ mode: 'open' })
// 添加基础样式
const style = document.createElement('style')
style.textContent = `
:host {
padding: 1em;
background-color: white;
overflow: auto;
border: 0.5px solid var(--color-code-background);
border-top-left-radius: 0;
border-top-right-radius: 0;
display: block;
}
svg {
max-width: 100%;
height: auto;
}
`
// 清空并重新添加内容
shadowRoot.innerHTML = ''
shadowRoot.appendChild(style)
const svgContainer = document.createElement('div')
svgContainer.innerHTML = children
shadowRoot.appendChild(svgContainer)
}, [children])
// 使用通用图像工具
const { handleCopyImage, handleDownload } = usePreviewToolHandlers(svgContainerRef, {
imgSelector: 'svg',
prefix: 'svg-image'
})
// 使用工具栏
usePreviewTools({
setTools,
handleCopyImage,
handleDownload
})
return <div ref={svgContainerRef} className="svg-preview special-preview" />
}
export default memo(SvgPreview)

View File

@ -1,7 +1,4 @@
import GraphvizPreview from './GraphvizPreview'
import MermaidPreview from './MermaidPreview'
import PlantUmlPreview from './PlantUmlPreview'
import SvgPreview from './SvgPreview'
import { GraphvizPreview, MermaidPreview, PlantUmlPreview, SvgPreview } from '@renderer/components/Preview'
/**
*

View File

@ -1,13 +1,3 @@
import { CodeTool } from '@renderer/components/CodeToolbar'
/**
* props
*/
export interface BasicPreviewProps {
children: string
setTools?: (value: React.SetStateAction<CodeTool[]>) => void
}
/**
*
*/

View File

@ -1,19 +1,30 @@
import { loggerService } from '@logger'
import CodeEditor from '@renderer/components/CodeEditor'
import { CodeTool, CodeToolbar, TOOL_SPECS, useCodeTool } from '@renderer/components/CodeToolbar'
import { LoadingIcon } from '@renderer/components/Icons'
import { ActionTool } from '@renderer/components/ActionTools'
import CodeEditor, { CodeEditorHandles } from '@renderer/components/CodeEditor'
import {
CodeToolbar,
useCopyTool,
useDownloadTool,
useExpandTool,
useRunTool,
useSaveTool,
useSplitViewTool,
useViewSourceTool,
useWrapTool
} from '@renderer/components/CodeToolbar'
import CodeViewer from '@renderer/components/CodeViewer'
import ImageViewer from '@renderer/components/ImageViewer'
import { BasicPreviewHandles } from '@renderer/components/Preview'
import { MAX_COLLAPSED_CODE_HEIGHT } from '@renderer/config/constant'
import { useSettings } from '@renderer/hooks/useSettings'
import { pyodideService } from '@renderer/services/PyodideService'
import { extractTitle } from '@renderer/utils/formats'
import { getExtensionByLanguage, isHtmlCode, isValidPlantUML } from '@renderer/utils/markdown'
import { getExtensionByLanguage, isHtmlCode } from '@renderer/utils/markdown'
import dayjs from 'dayjs'
import { CirclePlay, CodeXml, Copy, Download, Eye, Square, SquarePen, SquareSplitHorizontal } from 'lucide-react'
import React, { memo, useCallback, useEffect, useMemo, useState } from 'react'
import React, { memo, startTransition, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import styled, { css } from 'styled-components'
import ImageViewer from '../ImageViewer'
import CodePreview from './CodePreview'
import { SPECIAL_VIEW_COMPONENTS, SPECIAL_VIEWS } from './constants'
import HtmlArtifactsCard from './HtmlArtifactsCard'
import StatusBar from './StatusBar'
@ -45,31 +56,83 @@ interface Props {
*/
export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave }) => {
const { t } = useTranslation()
const { codeEditor, codeExecution } = useSettings()
const { codeEditor, codeExecution, codeImageTools, codeCollapsible, codeWrappable } = useSettings()
const [viewState, setViewState] = useState({
mode: 'special' as ViewMode,
previousMode: 'special' as ViewMode
})
const { mode: viewMode } = viewState
const setViewMode = useCallback((newMode: ViewMode) => {
setViewState((current) => ({
mode: newMode,
// 当新模式不是 'split' 时才更新
previousMode: newMode !== 'split' ? newMode : current.previousMode
}))
}, [])
const toggleSplitView = useCallback(() => {
setViewState((current) => {
// 如果当前是 split 模式,恢复到上一个模式
if (current.mode === 'split') {
return { ...current, mode: current.previousMode }
}
return { mode: 'split', previousMode: current.mode }
})
}, [])
const [viewMode, setViewMode] = useState<ViewMode>('special')
const [isRunning, setIsRunning] = useState(false)
const [executionResult, setExecutionResult] = useState<{ text: string; image?: string } | null>(null)
const [tools, setTools] = useState<CodeTool[]>([])
const { registerTool, removeTool } = useCodeTool(setTools)
const [tools, setTools] = useState<ActionTool[]>([])
const isExecutable = useMemo(() => {
return codeExecution.enabled && language === 'python'
}, [codeExecution.enabled, language])
const sourceViewRef = useRef<CodeEditorHandles>(null)
const specialViewRef = useRef<BasicPreviewHandles>(null)
const hasSpecialView = useMemo(() => SPECIAL_VIEWS.includes(language), [language])
const isInSpecialView = useMemo(() => {
return hasSpecialView && viewMode === 'special'
}, [hasSpecialView, viewMode])
const [expandOverride, setExpandOverride] = useState(!codeCollapsible)
const [unwrapOverride, setUnwrapOverride] = useState(!codeWrappable)
// 重置用户操作
useEffect(() => {
setExpandOverride(!codeCollapsible)
}, [codeCollapsible])
// 重置用户操作
useEffect(() => {
setUnwrapOverride(!codeWrappable)
}, [codeWrappable])
const shouldExpand = useMemo(() => !codeCollapsible || expandOverride, [codeCollapsible, expandOverride])
const shouldUnwrap = useMemo(() => !codeWrappable || unwrapOverride, [codeWrappable, unwrapOverride])
const [sourceScrollHeight, setSourceScrollHeight] = useState(0)
const expandable = useMemo(() => {
return codeCollapsible && sourceScrollHeight > MAX_COLLAPSED_CODE_HEIGHT
}, [codeCollapsible, sourceScrollHeight])
const handleHeightChange = useCallback((height: number) => {
startTransition(() => {
setSourceScrollHeight((prev) => (prev === height ? prev : height))
})
}, [])
const handleCopySource = useCallback(() => {
navigator.clipboard.writeText(children)
window.message.success({ content: t('code_block.copy.success'), key: 'copy-code' })
}, [children, t])
const handleDownloadSource = useCallback(async () => {
const handleDownloadSource = useCallback(() => {
let fileName = ''
// 尝试提取 HTML 标题
@ -82,7 +145,7 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
fileName = `${dayjs().format('YYYYMMDDHHmm')}`
}
const ext = await getExtensionByLanguage(language)
const ext = getExtensionByLanguage(language)
window.api.file.save(`${fileName}${ext}`, children)
}, [children, language])
@ -106,101 +169,103 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
})
}, [children, codeExecution.timeoutMinutes])
useEffect(() => {
// 复制按钮
registerTool({
...TOOL_SPECS.copy,
icon: <Copy className="icon" />,
tooltip: t('code_block.copy.source'),
onClick: handleCopySource
})
const showPreviewTools = useMemo(() => {
return viewMode !== 'source' && hasSpecialView
}, [hasSpecialView, viewMode])
// 下载按钮
registerTool({
...TOOL_SPECS.download,
icon: <Download className="icon" />,
tooltip: t('code_block.download.source'),
onClick: handleDownloadSource
})
return () => {
removeTool(TOOL_SPECS.copy.id)
removeTool(TOOL_SPECS.download.id)
}
}, [handleCopySource, handleDownloadSource, registerTool, removeTool, t])
// 复制按钮
useCopyTool({
showPreviewTools,
previewRef: specialViewRef,
onCopySource: handleCopySource,
setTools
})
// 特殊视图的编辑按钮,在分屏模式下不可用
useEffect(() => {
if (!hasSpecialView || viewMode === 'split') return
// 下载按钮
useDownloadTool({
showPreviewTools,
previewRef: specialViewRef,
onDownloadSource: handleDownloadSource,
setTools
})
const viewSourceToolSpec = codeEditor.enabled ? TOOL_SPECS.edit : TOOL_SPECS['view-source']
// 特殊视图的编辑/查看源码按钮,在分屏模式下不可用
useViewSourceTool({
enabled: hasSpecialView,
editable: codeEditor.enabled,
viewMode,
onViewModeChange: setViewMode,
setTools
})
if (codeEditor.enabled) {
registerTool({
...viewSourceToolSpec,
icon: viewMode === 'source' ? <Eye className="icon" /> : <SquarePen className="icon" />,
tooltip: viewMode === 'source' ? t('code_block.preview.label') : t('code_block.edit.label'),
onClick: () => setViewMode(viewMode === 'source' ? 'special' : 'source')
})
} else {
registerTool({
...viewSourceToolSpec,
icon: viewMode === 'source' ? <Eye className="icon" /> : <CodeXml className="icon" />,
tooltip: viewMode === 'source' ? t('code_block.preview.label') : t('code_block.preview.source'),
onClick: () => setViewMode(viewMode === 'source' ? 'special' : 'source')
})
}
return () => removeTool(viewSourceToolSpec.id)
}, [codeEditor.enabled, hasSpecialView, viewMode, registerTool, removeTool, t])
// 特殊视图的分屏按钮
useEffect(() => {
if (!hasSpecialView) return
registerTool({
...TOOL_SPECS['split-view'],
icon: viewMode === 'split' ? <Square className="icon" /> : <SquareSplitHorizontal className="icon" />,
tooltip: viewMode === 'split' ? t('code_block.split.restore') : t('code_block.split.label'),
onClick: () => setViewMode(viewMode === 'split' ? 'special' : 'split')
})
return () => removeTool(TOOL_SPECS['split-view'].id)
}, [hasSpecialView, viewMode, registerTool, removeTool, t])
// 特殊视图存在时的分屏按钮
useSplitViewTool({
enabled: hasSpecialView,
viewMode,
onToggleSplitView: toggleSplitView,
setTools
})
// 运行按钮
useEffect(() => {
if (!isExecutable) return
useRunTool({
enabled: isExecutable,
isRunning,
onRun: handleRunScript,
setTools
})
registerTool({
...TOOL_SPECS.run,
icon: isRunning ? <LoadingIcon /> : <CirclePlay className="icon" />,
tooltip: t('code_block.run'),
onClick: () => !isRunning && handleRunScript()
})
// 源代码视图的展开/折叠按钮
useExpandTool({
enabled: !isInSpecialView,
expanded: shouldExpand,
expandable,
toggle: useCallback(() => setExpandOverride((prev) => !prev), []),
setTools
})
return () => isExecutable && removeTool(TOOL_SPECS.run.id)
}, [isExecutable, isRunning, handleRunScript, registerTool, removeTool, t])
// 源代码视图的自动换行按钮
useWrapTool({
enabled: !isInSpecialView,
unwrapped: shouldUnwrap,
wrappable: codeWrappable,
toggle: useCallback(() => setUnwrapOverride((prev) => !prev), []),
setTools
})
// 代码编辑器的保存按钮
useSaveTool({
enabled: codeEditor.enabled && !isInSpecialView,
sourceViewRef,
setTools
})
// 源代码视图组件
const sourceView = useMemo(() => {
if (codeEditor.enabled) {
return (
const sourceView = useMemo(
() =>
codeEditor.enabled ? (
<CodeEditor
className="source-view"
ref={sourceViewRef}
value={children}
language={language}
onSave={onSave}
onHeightChange={handleHeightChange}
options={{ stream: true }}
setTools={setTools}
expanded={shouldExpand}
unwrapped={shouldUnwrap}
/>
)
} else {
return (
<CodePreview language={language} setTools={setTools}>
) : (
<CodeViewer
className="source-view"
language={language}
expanded={shouldExpand}
unwrapped={shouldUnwrap}
onHeightChange={handleHeightChange}>
{children}
</CodePreview>
)
}
}, [children, codeEditor.enabled, language, onSave, setTools])
</CodeViewer>
),
[children, codeEditor.enabled, handleHeightChange, language, onSave, shouldExpand, shouldUnwrap]
)
// 特殊视图组件映射
const specialView = useMemo(() => {
@ -208,13 +273,12 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
if (!SpecialView) return null
// PlantUML 语法验证
if (language === 'plantuml' && !isValidPlantUML(children)) {
return null
}
return <SpecialView setTools={setTools}>{children}</SpecialView>
}, [children, language])
return (
<SpecialView ref={specialViewRef} enableToolbar={codeImageTools}>
{children}
</SpecialView>
)
}, [children, codeImageTools, language])
const renderHeader = useMemo(() => {
const langTag = '<' + language.toUpperCase() + '>'
@ -227,7 +291,7 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
const showSourceView = !specialView || viewMode !== 'special'
return (
<SplitViewWrapper className="split-view-wrapper">
<SplitViewWrapper className="split-view-wrapper" $viewMode={viewMode}>
{showSpecialView && specialView}
{showSourceView && sourceView}
</SplitViewWrapper>
@ -260,7 +324,7 @@ const CodeBlockWrapper = styled.div<{ $isInSpecialView: boolean }>`
position: relative;
width: 100%;
/* FIXME:
* CodePreview
* CodeViewer
* toolbar title
*/
min-width: 45ch;
@ -295,9 +359,10 @@ const CodeHeader = styled.div<{ $isInSpecialView: boolean }>`
border-top-right-radius: 8px;
margin-top: ${(props) => (props.$isInSpecialView ? '6px' : '0')};
height: ${(props) => (props.$isInSpecialView ? '16px' : '34px')};
background-color: ${(props) => (props.$isInSpecialView ? 'transparent' : 'var(--color-background-mute)')};
`
const SplitViewWrapper = styled.div`
const SplitViewWrapper = styled.div<{ $viewMode?: ViewMode }>`
display: flex;
> * {
@ -306,7 +371,27 @@ const SplitViewWrapper = styled.div`
}
&:not(:has(+ [class*='Container'])) {
border-radius: 0 0 8px 8px;
// 特殊视图的 header 会隐藏,所以全都使用圆角
border-radius: ${(props) => (props.$viewMode === 'special' ? '8px' : '0 0 8px 8px')};
overflow: hidden;
}
// 在 split 模式下添加中间分隔线
${(props) =>
props.$viewMode === 'split' &&
css`
position: relative;
&:before {
content: '';
position: absolute;
top: 0;
bottom: 0;
left: 50%;
width: 1px;
background-color: var(--color-background-mute);
transform: translateX(-50%);
z-index: 1;
}
`}
`

View File

@ -175,3 +175,26 @@ export function useBlurHandler({ onBlur }: UseBlurHandlerProps) {
})
}, [onBlur])
}
interface UseHeightListenerProps {
onHeightChange?: (scrollHeight: number) => void
}
/**
* CodeMirror
* @param onHeightChange
* @returns
*/
export function useHeightListener({ onHeightChange }: UseHeightListenerProps) {
return useMemo(() => {
if (!onHeightChange) {
return []
}
return EditorView.updateListener.of((update) => {
if (update.docChanged || update.heightChanged) {
onHeightChange(update.view.scrollDOM?.scrollHeight ?? 0)
}
})
}, [onHeightChange])
}

View File

@ -1,32 +1,29 @@
import { CodeTool, TOOL_SPECS, useCodeTool } from '@renderer/components/CodeToolbar'
import { MAX_COLLAPSED_CODE_HEIGHT } from '@renderer/config/constant'
import { useCodeStyle } from '@renderer/context/CodeStyleProvider'
import { useSettings } from '@renderer/hooks/useSettings'
import CodeMirror, { Annotation, BasicSetupOptions, EditorView, Extension } from '@uiw/react-codemirror'
import diff from 'fast-diff'
import {
ChevronsDownUp,
ChevronsUpDown,
Save as SaveIcon,
Text as UnWrapIcon,
WrapText as WrapIcon
} from 'lucide-react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useCallback, useEffect, useImperativeHandle, useMemo, useRef } from 'react'
import { memo } from 'react'
import { useTranslation } from 'react-i18next'
import { useBlurHandler, useLanguageExtensions, useSaveKeymap } from './hooks'
import { useBlurHandler, useHeightListener, useLanguageExtensions, useSaveKeymap } from './hooks'
// 标记非用户编辑的变更
const External = Annotation.define<boolean>()
interface Props {
export interface CodeEditorHandles {
save?: () => void
}
interface CodeEditorProps {
ref?: React.RefObject<CodeEditorHandles | null>
value: string
placeholder?: string | HTMLElement
language: string
onSave?: (newContent: string) => void
onChange?: (newContent: string) => void
onBlur?: (newContent: string) => void
setTools?: (value: React.SetStateAction<CodeTool[]>) => void
onHeightChange?: (scrollHeight: number) => void
height?: string
minHeight?: string
maxHeight?: string
@ -35,15 +32,16 @@ interface Props {
options?: {
stream?: boolean // 用于流式响应场景,默认 false
lint?: boolean
collapsible?: boolean
wrappable?: boolean
keymap?: boolean
} & BasicSetupOptions
/** 用于追加 extensions */
extensions?: Extension[]
/** 用于覆写编辑器的样式,会直接传给 CodeMirror 的 style 属性 */
style?: React.CSSProperties
className?: string
editable?: boolean
expanded?: boolean
unwrapped?: boolean
}
/**
@ -52,13 +50,14 @@ interface Props {
* CodeToolbar 使
*/
const CodeEditor = ({
ref,
value,
placeholder,
language,
onSave,
onChange,
onBlur,
setTools,
onHeightChange,
height,
minHeight,
maxHeight,
@ -66,17 +65,12 @@ const CodeEditor = ({
options,
extensions,
style,
editable = true
}: Props) => {
const {
fontSize: _fontSize,
codeShowLineNumbers: _lineNumbers,
codeCollapsible: _collapsible,
codeWrappable: _wrappable,
codeEditor
} = useSettings()
const collapsible = useMemo(() => options?.collapsible ?? _collapsible, [options?.collapsible, _collapsible])
const wrappable = useMemo(() => options?.wrappable ?? _wrappable, [options?.wrappable, _wrappable])
className,
editable = true,
expanded = true,
unwrapped = false
}: CodeEditorProps) => {
const { fontSize: _fontSize, codeShowLineNumbers: _lineNumbers, codeEditor } = useSettings()
const enableKeymap = useMemo(() => options?.keymap ?? codeEditor.keymap, [options?.keymap, codeEditor.keymap])
// 合并 codeEditor 和 options 的 basicSetupoptions 优先
@ -91,63 +85,16 @@ const CodeEditor = ({
const customFontSize = useMemo(() => fontSize ?? `${_fontSize - 1}px`, [fontSize, _fontSize])
const { activeCmTheme } = useCodeStyle()
const [isExpanded, setIsExpanded] = useState(!collapsible)
const [isUnwrapped, setIsUnwrapped] = useState(!wrappable)
const initialContent = useRef(options?.stream ? (value ?? '').trimEnd() : (value ?? ''))
const [editorReady, setEditorReady] = useState(false)
const editorViewRef = useRef<EditorView | null>(null)
const { t } = useTranslation()
const langExtensions = useLanguageExtensions(language, options?.lint)
const { registerTool, removeTool } = useCodeTool(setTools)
// 展开/折叠工具
useEffect(() => {
registerTool({
...TOOL_SPECS.expand,
icon: isExpanded ? <ChevronsDownUp className="icon" /> : <ChevronsUpDown className="icon" />,
tooltip: isExpanded ? t('code_block.collapse') : t('code_block.expand'),
visible: () => {
const scrollHeight = editorViewRef?.current?.scrollDOM?.scrollHeight
return collapsible && (scrollHeight ?? 0) > 350
},
onClick: () => setIsExpanded((prev) => !prev)
})
return () => removeTool(TOOL_SPECS.expand.id)
}, [collapsible, isExpanded, registerTool, removeTool, t, editorReady])
// 自动换行工具
useEffect(() => {
registerTool({
...TOOL_SPECS.wrap,
icon: isUnwrapped ? <WrapIcon className="icon" /> : <UnWrapIcon className="icon" />,
tooltip: isUnwrapped ? t('code_block.wrap.on') : t('code_block.wrap.off'),
visible: () => wrappable,
onClick: () => setIsUnwrapped((prev) => !prev)
})
return () => removeTool(TOOL_SPECS.wrap.id)
}, [wrappable, isUnwrapped, registerTool, removeTool, t])
const handleSave = useCallback(() => {
const currentDoc = editorViewRef.current?.state.doc.toString() ?? ''
onSave?.(currentDoc)
}, [onSave])
// 保存按钮
useEffect(() => {
registerTool({
...TOOL_SPECS.save,
icon: <SaveIcon className="icon" />,
tooltip: t('code_block.edit.save.label'),
onClick: handleSave
})
return () => removeTool(TOOL_SPECS.save.id)
}, [handleSave, registerTool, removeTool, t])
// 流式响应过程中计算 changes 来更新 EditorView
// 无法处理用户在流式响应过程中编辑代码的情况(应该也不必处理)
useEffect(() => {
@ -166,26 +113,24 @@ const CodeEditor = ({
}
}, [options?.stream, value])
useEffect(() => {
setIsExpanded(!collapsible)
}, [collapsible])
useEffect(() => {
setIsUnwrapped(!wrappable)
}, [wrappable])
const saveKeymapExtension = useSaveKeymap({ onSave, enabled: enableKeymap })
const blurExtension = useBlurHandler({ onBlur })
const heightListenerExtension = useHeightListener({ onHeightChange })
const customExtensions = useMemo(() => {
return [
...(extensions ?? []),
...langExtensions,
...(isUnwrapped ? [] : [EditorView.lineWrapping]),
...(unwrapped ? [] : [EditorView.lineWrapping]),
saveKeymapExtension,
blurExtension
blurExtension,
heightListenerExtension
].flat()
}, [extensions, langExtensions, isUnwrapped, saveKeymapExtension, blurExtension])
}, [extensions, langExtensions, unwrapped, saveKeymapExtension, blurExtension, heightListenerExtension])
useImperativeHandle(ref, () => ({
save: handleSave
}))
return (
<CodeMirror
@ -195,14 +140,14 @@ const CodeEditor = ({
width="100%"
height={height}
minHeight={minHeight}
maxHeight={collapsible && !isExpanded ? (maxHeight ?? '350px') : 'none'}
maxHeight={expanded ? 'none' : (maxHeight ?? `${MAX_COLLAPSED_CODE_HEIGHT}px`)}
editable={editable}
// @ts-ignore 强制使用,见 react-codemirror 的 Example.tsx
theme={activeCmTheme}
extensions={customExtensions}
onCreateEditor={(view: EditorView) => {
editorViewRef.current = view
setEditorReady(true)
onHeightChange?.(view.scrollDOM?.scrollHeight ?? 0)
}}
onChange={(value, viewUpdate) => {
if (onChange && viewUpdate.docChanged) onChange(value)
@ -230,6 +175,7 @@ const CodeEditor = ({
borderRadius: 'inherit',
...style
}}
className={`code-editor ${className ?? ''}`}
/>
)
}

View File

@ -0,0 +1,164 @@
import { ActionTool } from '@renderer/components/ActionTools'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import CodeToolButton from '../button'
// Mock Antd components
const mocks = vi.hoisted(() => ({
Tooltip: vi.fn(({ children, title }) => (
<div data-testid="tooltip" data-title={title}>
{children}
</div>
)),
Dropdown: vi.fn(({ children, menu }) => (
<div data-testid="dropdown" data-menu={JSON.stringify(menu)}>
{children}
</div>
))
}))
vi.mock('antd', () => ({
Tooltip: mocks.Tooltip,
Dropdown: mocks.Dropdown
}))
// Mock ToolWrapper
vi.mock('../styles', () => ({
ToolWrapper: ({ children, onClick }: { children: React.ReactNode; onClick?: () => void }) => (
<button type="button" data-testid="tool-wrapper" onClick={onClick}>
{children}
</button>
)
}))
// Helper function to create mock tools
const createMockTool = (overrides: Partial<ActionTool> = {}): ActionTool => ({
id: 'test-tool',
type: 'core',
order: 10,
icon: <span data-testid="test-icon">Test Icon</span>,
tooltip: 'Test Tool',
onClick: vi.fn(),
...overrides
})
const createMockChildTool = (id: string, tooltip: string): Omit<ActionTool, 'children'> => ({
id,
type: 'quick',
order: 10,
icon: <span data-testid={`${id}-icon`}>{tooltip} Icon</span>,
tooltip,
onClick: vi.fn()
})
describe('CodeToolButton', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('rendering modes', () => {
it('should render as simple button when no children', () => {
const tool = createMockTool()
render(<CodeToolButton tool={tool} />)
// Should render button with tooltip
expect(screen.getByTestId('tooltip')).toBeInTheDocument()
expect(screen.getByTestId('tool-wrapper')).toBeInTheDocument()
expect(screen.getByTestId('test-icon')).toBeInTheDocument()
// Should not render dropdown
expect(screen.queryByTestId('dropdown')).not.toBeInTheDocument()
})
it('should render as simple button when children array is empty', () => {
const tool = createMockTool({ children: [] })
render(<CodeToolButton tool={tool} />)
expect(screen.queryByTestId('dropdown')).not.toBeInTheDocument()
expect(screen.getByTestId('tooltip')).toBeInTheDocument()
})
it('should render as dropdown when has children', () => {
const children = [createMockChildTool('child1', 'Child 1')]
const tool = createMockTool({ children })
render(<CodeToolButton tool={tool} />)
// Should render dropdown containing the main button
expect(screen.getByTestId('dropdown')).toBeInTheDocument()
expect(screen.getByTestId('tooltip')).toBeInTheDocument()
expect(screen.getByTestId('tool-wrapper')).toBeInTheDocument()
})
})
describe('user interactions', () => {
it('should trigger onClick when simple button is clicked', () => {
const mockOnClick = vi.fn()
const tool = createMockTool({ onClick: mockOnClick })
render(<CodeToolButton tool={tool} />)
fireEvent.click(screen.getByTestId('tool-wrapper'))
expect(mockOnClick).toHaveBeenCalledTimes(1)
})
it('should handle missing onClick gracefully', () => {
const tool = createMockTool({ onClick: undefined })
render(<CodeToolButton tool={tool} />)
expect(() => {
fireEvent.click(screen.getByTestId('tool-wrapper'))
}).not.toThrow()
})
})
describe('dropdown functionality', () => {
it('should configure dropdown with correct menu structure', () => {
const mockOnClick1 = vi.fn()
const mockOnClick2 = vi.fn()
const children = [createMockChildTool('child1', 'Child 1'), createMockChildTool('child2', 'Child 2')]
children[0].onClick = mockOnClick1
children[1].onClick = mockOnClick2
const tool = createMockTool({ children })
render(<CodeToolButton tool={tool} />)
// Verify dropdown was called with correct menu structure
expect(mocks.Dropdown).toHaveBeenCalled()
const dropdownProps = mocks.Dropdown.mock.calls[0][0]
expect(dropdownProps.menu.items).toHaveLength(2)
expect(dropdownProps.menu.items[0].key).toBe('child1')
expect(dropdownProps.menu.items[0].label).toBe('Child 1')
expect(dropdownProps.menu.items[0].onClick).toBe(mockOnClick1)
expect(dropdownProps.trigger).toEqual(['click'])
})
})
describe('accessibility', () => {
it('should provide accessible button element with tooltip', () => {
const tool = createMockTool({ tooltip: 'Accessible Tool' })
render(<CodeToolButton tool={tool} />)
const button = screen.getByTestId('tool-wrapper')
expect(button.tagName).toBe('BUTTON')
expect(screen.getByTestId('tooltip')).toHaveAttribute('data-title', 'Accessible Tool')
})
})
describe('error handling', () => {
it('should render without crashing for minimal tool configuration', () => {
const minimalTool: ActionTool = {
id: 'minimal',
type: 'core',
order: 1,
icon: null,
tooltip: ''
}
expect(() => {
render(<CodeToolButton tool={minimalTool} />)
}).not.toThrow()
})
})
})

View File

@ -0,0 +1,262 @@
import { ActionTool } from '@renderer/components/ActionTools'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import CodeToolbar from '../toolbar'
// Test constants
const MORE_BUTTON_TOOLTIP = 'code_block.more'
// Mock components
const mocks = vi.hoisted(() => ({
CodeToolButton: vi.fn(({ tool }) => (
<div data-testid={`tool-button-${tool.id}`} data-tool-id={tool.id} data-tool-type={tool.type}>
{tool.icon}
</div>
)),
Tooltip: vi.fn(({ children, title }) => (
<div data-testid="tooltip" data-title={title}>
{children}
</div>
)),
HStack: vi.fn(({ children, className }) => (
<div data-testid="hstack" className={className}>
{children}
</div>
)),
ToolWrapper: vi.fn(({ children, onClick, className }) => (
<div data-testid="tool-wrapper" onClick={onClick} className={className} role="button" tabIndex={0}>
{children}
</div>
)),
EllipsisVertical: vi.fn(() => <div data-testid="ellipsis-icon" className="tool-icon" />),
useTranslation: vi.fn(() => ({
t: vi.fn((key: string) => key)
}))
}))
vi.mock('../button', () => ({
default: mocks.CodeToolButton
}))
vi.mock('antd', () => ({
Tooltip: mocks.Tooltip
}))
vi.mock('@renderer/components/Layout', () => ({
HStack: mocks.HStack
}))
vi.mock('./styles', () => ({
ToolWrapper: mocks.ToolWrapper
}))
vi.mock('lucide-react', () => ({
EllipsisVertical: mocks.EllipsisVertical
}))
vi.mock('react-i18next', () => ({
useTranslation: mocks.useTranslation
}))
// Helper function to create mock tools
const createMockTool = (overrides: Partial<ActionTool> = {}): ActionTool => ({
id: 'test-tool',
type: 'core',
order: 1,
icon: <div data-testid="test-icon">Icon</div>,
tooltip: 'Test Tool',
onClick: vi.fn(),
...overrides
})
// Common test data
const createMixedTools = () => [
createMockTool({ id: 'quick1', type: 'quick' }),
createMockTool({ id: 'quick2', type: 'quick' }),
createMockTool({ id: 'core1', type: 'core' })
]
const createCoreOnlyTools = () => [
createMockTool({ id: 'core1', type: 'core' }),
createMockTool({ id: 'core2', type: 'core' })
]
// Helper function to click more button
const clickMoreButton = () => {
const tooltip = screen.getByTestId('tooltip')
fireEvent.click(tooltip.firstChild as Element)
}
describe('CodeToolbar', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('basic rendering', () => {
it('should match snapshot with mixed tools', () => {
const { container } = render(<CodeToolbar tools={createMixedTools()} />)
expect(container).toMatchSnapshot()
})
it('should match snapshot with only core tools', () => {
const { container } = render(<CodeToolbar tools={[createMockTool({ id: 'core1', type: 'core' })]} />)
expect(container).toMatchSnapshot()
})
})
describe('empty state', () => {
it('should render nothing when no tools provided', () => {
const { container } = render(<CodeToolbar tools={[]} />)
expect(container.firstChild).toBeNull()
})
it('should render nothing when all tools are not visible', () => {
const tools = [
createMockTool({ id: 'tool1', visible: () => false }),
createMockTool({ id: 'tool2', visible: () => false })
]
const { container } = render(<CodeToolbar tools={tools} />)
expect(container.firstChild).toBeNull()
})
})
describe('tool visibility filtering', () => {
it('should only render visible tools', () => {
const tools = [
createMockTool({ id: 'visible-tool', visible: () => true }),
createMockTool({ id: 'hidden-tool', visible: () => false }),
createMockTool({ id: 'no-visible-prop' }) // Should be visible by default
]
render(<CodeToolbar tools={tools} />)
expect(screen.getByTestId('tool-button-visible-tool')).toBeInTheDocument()
expect(screen.getByTestId('tool-button-no-visible-prop')).toBeInTheDocument()
expect(screen.queryByTestId('tool-button-hidden-tool')).not.toBeInTheDocument()
})
it('should show tools without visible function by default', () => {
const tools = [createMockTool({ id: 'default-visible' })]
render(<CodeToolbar tools={tools} />)
expect(screen.getByTestId('tool-button-default-visible')).toBeInTheDocument()
})
})
describe('tool type grouping and quick tools behavior', () => {
it('should separate core and quick tools - show quick tools when expanded', () => {
const tools = [
createMockTool({ id: 'core1', type: 'core' }),
createMockTool({ id: 'quick1', type: 'quick' }),
createMockTool({ id: 'core2', type: 'core' }),
createMockTool({ id: 'quick2', type: 'quick' })
]
render(<CodeToolbar tools={tools} />)
// Initial state: core tools visible, quick tools hidden
expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument()
expect(screen.getByTestId('tool-button-core2')).toBeInTheDocument()
expect(screen.queryByTestId('tool-button-quick1')).not.toBeInTheDocument()
expect(screen.queryByTestId('tool-button-quick2')).not.toBeInTheDocument()
// After clicking more button, quick tools should be visible
clickMoreButton()
expect(screen.getByTestId('tool-button-quick1')).toBeInTheDocument()
expect(screen.getByTestId('tool-button-quick2')).toBeInTheDocument()
})
it('should render only core tools when no quick tools exist', () => {
render(<CodeToolbar tools={createCoreOnlyTools()} />)
expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument()
expect(screen.getByTestId('tool-button-core2')).toBeInTheDocument()
expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument() // No more button
})
it('should show single quick tool directly without more button', () => {
const tools = [createMockTool({ id: 'quick1', type: 'quick' }), createMockTool({ id: 'core1', type: 'core' })]
render(<CodeToolbar tools={tools} />)
expect(screen.getByTestId('tool-button-quick1')).toBeInTheDocument()
expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument()
expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument() // No more button
})
it('should show more button when multiple quick tools exist', () => {
render(<CodeToolbar tools={createMixedTools()} />)
// Initially quick tools should be hidden
expect(screen.queryByTestId('tool-button-quick1')).not.toBeInTheDocument()
expect(screen.queryByTestId('tool-button-quick2')).not.toBeInTheDocument()
expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument()
expect(screen.getByTestId('tooltip')).toBeInTheDocument() // More button exists
})
it('should toggle quick tools visibility when more button is clicked', () => {
render(<CodeToolbar tools={createMixedTools()} />)
// Initial state: quick tools hidden
expect(screen.queryByTestId('tool-button-quick1')).not.toBeInTheDocument()
expect(screen.queryByTestId('tool-button-quick2')).not.toBeInTheDocument()
// Click more button: quick tools visible
clickMoreButton()
expect(screen.getByTestId('tool-button-quick1')).toBeInTheDocument()
expect(screen.getByTestId('tool-button-quick2')).toBeInTheDocument()
// Click more button again: quick tools hidden
clickMoreButton()
expect(screen.queryByTestId('tool-button-quick1')).not.toBeInTheDocument()
expect(screen.queryByTestId('tool-button-quick2')).not.toBeInTheDocument()
})
it('should apply active class to more button when quick tools are shown', () => {
const tools = [createMockTool({ id: 'quick1', type: 'quick' }), createMockTool({ id: 'quick2', type: 'quick' })]
render(<CodeToolbar tools={tools} />)
const tooltip = screen.getByTestId('tooltip')
const moreButton = tooltip.firstChild as Element
// Initial state: no active class
expect(moreButton).not.toHaveClass('active')
// After click: has active class
fireEvent.click(moreButton)
expect(moreButton).toHaveClass('active')
// After second click: no active class
fireEvent.click(moreButton)
expect(moreButton).not.toHaveClass('active')
})
it('should display correct tooltip and icon for more button', () => {
render(<CodeToolbar tools={createMixedTools()} />)
const tooltip = screen.getByTestId('tooltip')
expect(tooltip).toHaveAttribute('data-title', MORE_BUTTON_TOOLTIP)
expect(screen.getByTestId('ellipsis-icon')).toBeInTheDocument()
expect(screen.getByTestId('ellipsis-icon')).toHaveClass('tool-icon')
})
it('should render core tools regardless of quick tools state', () => {
const tools = [
createMockTool({ id: 'quick1', type: 'quick' }),
createMockTool({ id: 'quick2', type: 'quick' }),
createMockTool({ id: 'core1', type: 'core' }),
createMockTool({ id: 'core2', type: 'core' })
]
render(<CodeToolbar tools={tools} />)
// Core tools always visible
expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument()
expect(screen.getByTestId('tool-button-core2')).toBeInTheDocument()
// After clicking more button, core tools still visible
clickMoreButton()
expect(screen.getByTestId('tool-button-core1')).toBeInTheDocument()
expect(screen.getByTestId('tool-button-core2')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,129 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
exports[`CodeToolbar > basic rendering > should match snapshot with mixed tools 1`] = `
.c2 {
display: flex;
align-items: center;
justify-content: center;
width: 24px;
height: 24px;
border-radius: 4px;
cursor: pointer;
user-select: none;
transition: all 0.2s ease;
color: var(--color-text-3);
}
.c2:hover {
background-color: var(--color-background-soft);
}
.c2:hover .tool-icon {
color: var(--color-text-1);
}
.c2.active {
color: var(--color-primary);
}
.c2.active .tool-icon {
color: var(--color-primary);
}
.c2 .tool-icon {
width: 14px;
height: 14px;
color: var(--color-text-3);
}
.c0 {
position: sticky;
top: 28px;
z-index: 10;
}
.c1 {
position: absolute;
align-items: center;
bottom: 0.3rem;
right: 0.5rem;
height: 24px;
gap: 4px;
}
<div>
<div
class="c0"
>
<div
class="c1 code-toolbar"
data-testid="hstack"
>
<div
data-testid="tooltip"
data-title="code_block.more"
>
<div
class="c2"
>
<div
class="tool-icon"
data-testid="ellipsis-icon"
/>
</div>
</div>
<div
data-testid="tool-button-core1"
data-tool-id="core1"
data-tool-type="core"
>
<div
data-testid="test-icon"
>
Icon
</div>
</div>
</div>
</div>
</div>
`;
exports[`CodeToolbar > basic rendering > should match snapshot with only core tools 1`] = `
.c0 {
position: sticky;
top: 28px;
z-index: 10;
}
.c1 {
position: absolute;
align-items: center;
bottom: 0.3rem;
right: 0.5rem;
height: 24px;
gap: 4px;
}
<div>
<div
class="c0"
>
<div
class="c1 code-toolbar"
data-testid="hstack"
>
<div
data-testid="tool-button-core1"
data-tool-id="core1"
data-tool-type="core"
>
<div
data-testid="test-icon"
>
Icon
</div>
</div>
</div>
</div>
</div>
`;

View File

@ -0,0 +1,251 @@
import { useCopyTool } from '@renderer/components/CodeToolbar/hooks/useCopyTool'
import { BasicPreviewHandles } from '@renderer/components/Preview'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock dependencies
const mocks = vi.hoisted(() => ({
i18n: {
t: vi.fn((key: string) => key)
},
useTemporaryValue: vi.fn(),
useToolManager: vi.fn(),
TOOL_SPECS: {
copy: {
id: 'copy',
type: 'core',
order: 11
},
'copy-image': {
id: 'copy-image',
type: 'quick',
order: 30
}
}
}))
vi.mock('lucide-react', () => ({
Check: () => <div data-testid="check-icon" />,
Image: () => <div data-testid="image-icon" />
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: mocks.i18n.t
})
}))
vi.mock('@renderer/components/Icons', () => ({
CopyIcon: () => <div data-testid="copy-icon" />
}))
vi.mock('@renderer/components/ActionTools', () => ({
TOOL_SPECS: mocks.TOOL_SPECS,
useToolManager: mocks.useToolManager
}))
vi.mock('@renderer/hooks/useTemporaryValue', () => ({
useTemporaryValue: mocks.useTemporaryValue
}))
// Mock useToolManager
const mockRegisterTool = vi.fn()
const mockRemoveTool = vi.fn()
mocks.useToolManager.mockImplementation(() => ({
registerTool: mockRegisterTool,
removeTool: mockRemoveTool
}))
// Mock useTemporaryValue setters
const mockSetCopiedTemporarily = vi.fn()
const mockSetCopiedImageTemporarily = vi.fn()
describe('useCopyTool', () => {
beforeEach(() => {
vi.clearAllMocks()
// Reset mocks for each test to ensure isolation
mocks.useTemporaryValue
.mockImplementationOnce(() => [false, mockSetCopiedTemporarily])
.mockImplementationOnce(() => [false, mockSetCopiedImageTemporarily])
})
// Helper function to create mock props
const createMockProps = (overrides: Partial<Parameters<typeof useCopyTool>[0]> = {}) => ({
showPreviewTools: false,
previewRef: { current: null },
onCopySource: vi.fn(),
setTools: vi.fn(),
...overrides
})
const createMockPreviewHandles = (): BasicPreviewHandles => ({
pan: vi.fn(),
zoom: vi.fn(),
copy: vi.fn(),
download: vi.fn()
})
describe('tool registration', () => {
it('should register only the copy-source tool when showPreviewTools is false', () => {
const props = createMockProps({ showPreviewTools: false })
renderHook(() => useCopyTool(props))
expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools)
expect(mockRegisterTool).toHaveBeenCalledTimes(1)
expect(mockRegisterTool).toHaveBeenCalledWith(
expect.objectContaining({
id: 'copy',
tooltip: 'code_block.copy.source'
})
)
})
it('should register only the copy-source tool when previewRef is null', () => {
const props = createMockProps({ showPreviewTools: true, previewRef: { current: null } })
renderHook(() => useCopyTool(props))
expect(mockRegisterTool).toHaveBeenCalledTimes(1)
expect(mockRegisterTool).toHaveBeenCalledWith(
expect.objectContaining({
id: 'copy'
})
)
})
it('should register both copy-source and copy-image tools when preview is available', () => {
const props = createMockProps({
showPreviewTools: true,
previewRef: { current: createMockPreviewHandles() }
})
renderHook(() => useCopyTool(props))
expect(mockRegisterTool).toHaveBeenCalledTimes(2)
// Check first tool: copy source
expect(mockRegisterTool).toHaveBeenCalledWith(
expect.objectContaining({
id: 'copy',
tooltip: 'code_block.copy.source',
onClick: expect.any(Function)
})
)
// Check second tool: copy image
expect(mockRegisterTool).toHaveBeenCalledWith(
expect.objectContaining({
id: 'copy-image',
tooltip: 'preview.copy.image',
onClick: expect.any(Function)
})
)
})
})
describe('copy functionality', () => {
it('should execute copy source behavior when copy-source tool is clicked', () => {
const mockOnCopySource = vi.fn()
const props = createMockProps({ onCopySource: mockOnCopySource })
renderHook(() => useCopyTool(props))
const copySourceTool = mockRegisterTool.mock.calls[0][0]
act(() => {
copySourceTool.onClick()
})
expect(mockOnCopySource).toHaveBeenCalledTimes(1)
expect(mockSetCopiedTemporarily).toHaveBeenCalledWith(true)
})
it('should execute copy image behavior when copy-image tool is clicked', () => {
const mockPreviewHandles = createMockPreviewHandles()
const props = createMockProps({
showPreviewTools: true,
previewRef: { current: mockPreviewHandles }
})
renderHook(() => useCopyTool(props))
// The copy-image tool is the second one registered
const copyImageTool = mockRegisterTool.mock.calls[1][0]
act(() => {
copyImageTool.onClick()
})
expect(mockPreviewHandles.copy).toHaveBeenCalledTimes(1)
expect(mockSetCopiedImageTemporarily).toHaveBeenCalledWith(true)
})
})
describe('cleanup', () => {
it('should remove both tools on unmount when both are registered', () => {
const props = createMockProps({
showPreviewTools: true,
previewRef: { current: createMockPreviewHandles() }
})
const { unmount } = renderHook(() => useCopyTool(props))
unmount()
expect(mockRemoveTool).toHaveBeenCalledTimes(2)
expect(mockRemoveTool).toHaveBeenCalledWith('copy')
expect(mockRemoveTool).toHaveBeenCalledWith('copy-image')
})
it('should attempt to remove both tools on unmount even if only one is registered', () => {
const props = createMockProps({ showPreviewTools: false })
const { unmount } = renderHook(() => useCopyTool(props))
unmount()
// The cleanup function is static and always tries to remove both
expect(mockRemoveTool).toHaveBeenCalledTimes(2)
expect(mockRemoveTool).toHaveBeenCalledWith('copy')
expect(mockRemoveTool).toHaveBeenCalledWith('copy-image')
})
})
describe('edge cases', () => {
it('should handle copy source failure gracefully', () => {
const mockOnCopySource = vi.fn().mockImplementation(() => {
throw new Error('Copy failed')
})
const props = createMockProps({ onCopySource: mockOnCopySource })
renderHook(() => useCopyTool(props))
const copySourceTool = mockRegisterTool.mock.calls[0][0]
expect(() => {
act(() => {
copySourceTool.onClick()
})
}).toThrow('Copy failed')
expect(mockOnCopySource).toHaveBeenCalledTimes(1)
expect(mockSetCopiedTemporarily).toHaveBeenCalledWith(false)
})
it('should handle copy image failure gracefully', () => {
const mockPreviewHandles = createMockPreviewHandles()
mockPreviewHandles.copy = vi.fn().mockImplementation(() => {
throw new Error('Image copy failed')
})
const props = createMockProps({
showPreviewTools: true,
previewRef: { current: mockPreviewHandles }
})
renderHook(() => useCopyTool(props))
const copyImageTool = mockRegisterTool.mock.calls[1][0]
expect(() => {
act(() => {
copyImageTool.onClick()
})
}).toThrow('Image copy failed')
expect(mockPreviewHandles.copy).toHaveBeenCalledTimes(1)
expect(mockSetCopiedImageTemporarily).toHaveBeenCalledWith(false)
})
})
})

View File

@ -0,0 +1,348 @@
import { useDownloadTool } from '@renderer/components/CodeToolbar/hooks/useDownloadTool'
import { BasicPreviewHandles } from '@renderer/components/Preview'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock dependencies
const mocks = vi.hoisted(() => ({
i18n: {
t: vi.fn((key: string) => key)
},
useToolManager: vi.fn(),
TOOL_SPECS: {
download: {
id: 'download',
type: 'core',
order: 10
},
'download-svg': {
id: 'download-svg',
type: 'quick',
order: 31
},
'download-png': {
id: 'download-png',
type: 'quick',
order: 32
}
}
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: mocks.i18n.t
})
}))
vi.mock('@renderer/components/Icons', () => ({
FilePngIcon: () => <div data-testid="file-png-icon" />,
FileSvgIcon: () => <div data-testid="file-svg-icon" />
}))
vi.mock('@renderer/components/ActionTools', () => ({
TOOL_SPECS: mocks.TOOL_SPECS,
useToolManager: mocks.useToolManager
}))
// Mock useToolManager
const mockRegisterTool = vi.fn()
const mockRemoveTool = vi.fn()
mocks.useToolManager.mockImplementation(() => ({
registerTool: mockRegisterTool,
removeTool: mockRemoveTool
}))
describe('useDownloadTool', () => {
beforeEach(() => {
vi.clearAllMocks()
// Note: mock implementations are already set in vi.hoisted() above
})
// Helper function to create mock props
const createMockProps = (overrides: Partial<Parameters<typeof useDownloadTool>[0]> = {}) => {
const defaultProps = {
showPreviewTools: false,
previewRef: { current: null },
onDownloadSource: vi.fn(),
setTools: vi.fn()
}
return { ...defaultProps, ...overrides }
}
// Helper function to create mock preview handles
const createMockPreviewHandles = (): BasicPreviewHandles => ({
pan: vi.fn(),
zoom: vi.fn(),
copy: vi.fn(),
download: vi.fn()
})
// Helper function for tool registration assertions
const expectToolRegistration = (times: number, toolConfig?: object) => {
expect(mockRegisterTool).toHaveBeenCalledTimes(times)
if (times > 0 && toolConfig) {
expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig))
}
}
const expectNoChildren = () => {
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(registeredTool).not.toHaveProperty('children')
}
describe('tool registration', () => {
it('should register single download tool when showPreviewTools is false', () => {
const props = createMockProps({ showPreviewTools: false })
renderHook(() => useDownloadTool(props))
expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools)
expectToolRegistration(1, {
id: 'download',
type: 'core',
order: 10,
tooltip: 'code_block.download.source',
onClick: expect.any(Function),
icon: expect.any(Object)
})
expectNoChildren()
})
it('should register single download tool when showPreviewTools is true but previewRef.current is null', () => {
const props = createMockProps({ showPreviewTools: true, previewRef: { current: null } })
renderHook(() => useDownloadTool(props))
expectToolRegistration(1, {
id: 'download',
type: 'core',
order: 10,
tooltip: 'code_block.download.source', // When previewRef.current is null, showPreviewTools is false
onClick: expect.any(Function),
icon: expect.any(Object)
})
expectNoChildren()
})
it('should register download tool with children when showPreviewTools is true and previewRef.current is not null', () => {
const mockPreviewHandles = createMockPreviewHandles()
const props = createMockProps({
showPreviewTools: true,
previewRef: { current: mockPreviewHandles }
})
renderHook(() => useDownloadTool(props))
expectToolRegistration(1, {
id: 'download',
type: 'core',
order: 10,
tooltip: undefined,
icon: expect.any(Object),
children: expect.arrayContaining([
expect.objectContaining({
id: 'download',
type: 'core',
order: 10,
tooltip: 'code_block.download.source',
onClick: expect.any(Function),
icon: expect.any(Object)
}),
expect.objectContaining({
id: 'download-svg',
type: 'quick',
order: 31,
tooltip: 'code_block.download.svg',
onClick: expect.any(Function),
icon: expect.any(Object)
}),
expect.objectContaining({
id: 'download-png',
type: 'quick',
order: 32,
tooltip: 'code_block.download.png',
onClick: expect.any(Function),
icon: expect.any(Object)
})
])
})
})
})
describe('download functionality', () => {
it('should execute download source behavior when tool is activated', () => {
const mockOnDownloadSource = vi.fn()
const props = createMockProps({ onDownloadSource: mockOnDownloadSource })
renderHook(() => useDownloadTool(props))
// Get the onClick handler from the registered tool
const registeredTool = mockRegisterTool.mock.calls[0][0]
act(() => {
registeredTool.onClick()
})
expect(mockOnDownloadSource).toHaveBeenCalledTimes(1)
})
it('should execute download SVG behavior when SVG download tool is activated', () => {
const mockPreviewHandles = createMockPreviewHandles()
const props = createMockProps({
showPreviewTools: true,
previewRef: { current: mockPreviewHandles }
})
renderHook(() => useDownloadTool(props))
// Get the download-svg child tool
const registeredTool = mockRegisterTool.mock.calls[0][0]
const downloadSvgTool = registeredTool.children?.find((child: any) => child.tooltip === 'code_block.download.svg')
expect(downloadSvgTool).toBeDefined()
act(() => {
downloadSvgTool.onClick()
})
expect(mockPreviewHandles.download).toHaveBeenCalledTimes(1)
expect(mockPreviewHandles.download).toHaveBeenCalledWith('svg')
})
it('should execute download PNG behavior when PNG download tool is activated', () => {
const mockPreviewHandles = createMockPreviewHandles()
const props = createMockProps({
showPreviewTools: true,
previewRef: { current: mockPreviewHandles }
})
renderHook(() => useDownloadTool(props))
// Get the download-png child tool
const registeredTool = mockRegisterTool.mock.calls[0][0]
const downloadPngTool = registeredTool.children?.find((child: any) => child.tooltip === 'code_block.download.png')
expect(downloadPngTool).toBeDefined()
act(() => {
downloadPngTool.onClick()
})
expect(mockPreviewHandles.download).toHaveBeenCalledTimes(1)
expect(mockPreviewHandles.download).toHaveBeenCalledWith('png')
})
it('should execute download source behavior from child tool', () => {
const mockOnDownloadSource = vi.fn()
const props = createMockProps({
showPreviewTools: true,
onDownloadSource: mockOnDownloadSource,
previewRef: { current: createMockPreviewHandles() }
})
renderHook(() => useDownloadTool(props))
// Get the download source child tool
const registeredTool = mockRegisterTool.mock.calls[0][0]
const downloadSourceTool = registeredTool.children?.find(
(child: any) => child.tooltip === 'code_block.download.source'
)
expect(downloadSourceTool).toBeDefined()
act(() => {
downloadSourceTool.onClick()
})
expect(mockOnDownloadSource).toHaveBeenCalledTimes(1)
})
})
describe('cleanup', () => {
it('should remove tool on unmount', () => {
const props = createMockProps()
const { unmount } = renderHook(() => useDownloadTool(props))
unmount()
expect(mockRemoveTool).toHaveBeenCalledWith('download')
})
})
describe('edge cases', () => {
it('should handle missing setTools gracefully', () => {
const props = createMockProps({ setTools: undefined })
expect(() => {
renderHook(() => useDownloadTool(props))
}).not.toThrow()
// Should still call useToolManager (but won't actually register)
expect(mocks.useToolManager).toHaveBeenCalledWith(undefined)
})
it('should handle missing previewRef.current gracefully', () => {
const props = createMockProps({
showPreviewTools: true,
previewRef: { current: null }
})
expect(() => {
renderHook(() => useDownloadTool(props))
}).not.toThrow()
// Should register single tool without children
expectToolRegistration(1)
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(registeredTool).not.toHaveProperty('children')
})
it('should handle download source operation failures gracefully', () => {
const mockOnDownloadSource = vi.fn().mockImplementation(() => {
throw new Error('Download failed')
})
const props = createMockProps({ onDownloadSource: mockOnDownloadSource })
renderHook(() => useDownloadTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
// Errors should be propagated up
expect(() => {
act(() => {
registeredTool.onClick()
})
}).toThrow('Download failed')
// Callback should still be called
expect(mockOnDownloadSource).toHaveBeenCalledTimes(1)
})
it('should handle download image operation failures gracefully', () => {
const mockPreviewHandles = createMockPreviewHandles()
mockPreviewHandles.download = vi.fn().mockImplementation(() => {
throw new Error('Image download failed')
})
const props = createMockProps({
showPreviewTools: true,
previewRef: { current: mockPreviewHandles }
})
renderHook(() => useDownloadTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
const downloadSvgTool = registeredTool.children?.find((child: any) => child.tooltip === 'code_block.download.svg')
expect(downloadSvgTool).toBeDefined()
// Errors should be propagated up
expect(() => {
act(() => {
downloadSvgTool.onClick()
})
}).toThrow('Image download failed')
// Callback should still be called
expect(mockPreviewHandles.download).toHaveBeenCalledTimes(1)
expect(mockPreviewHandles.download).toHaveBeenCalledWith('svg')
})
})
})

View File

@ -0,0 +1,190 @@
import { useExpandTool } from '@renderer/components/CodeToolbar/hooks/useExpandTool'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock dependencies
const mocks = vi.hoisted(() => ({
i18n: {
t: vi.fn((key: string) => key)
},
useToolManager: vi.fn(),
TOOL_SPECS: {
expand: {
id: 'expand',
type: 'quick',
order: 12
}
}
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: mocks.i18n.t
})
}))
vi.mock('@renderer/components/ActionTools', () => ({
TOOL_SPECS: mocks.TOOL_SPECS,
useToolManager: mocks.useToolManager
}))
// Mock useToolManager
const mockRegisterTool = vi.fn()
const mockRemoveTool = vi.fn()
mocks.useToolManager.mockImplementation(() => ({
registerTool: mockRegisterTool,
removeTool: mockRemoveTool
}))
vi.mock('lucide-react', () => ({
ChevronsDownUp: () => <div data-testid="chevrons-down-up" />,
ChevronsUpDown: () => <div data-testid="chevrons-up-down" />
}))
describe('useExpandTool', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Helper function to create mock props
const createMockProps = (overrides: Partial<Parameters<typeof useExpandTool>[0]> = {}) => {
const defaultProps = {
enabled: true,
expanded: false,
expandable: true,
toggle: vi.fn(),
setTools: vi.fn()
}
return { ...defaultProps, ...overrides }
}
// Helper function for tool registration assertions
const expectToolRegistration = (times: number, toolConfig?: object) => {
expect(mockRegisterTool).toHaveBeenCalledTimes(times)
if (times > 0 && toolConfig) {
expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig))
}
}
describe('tool registration', () => {
it('should register expand tool when enabled', () => {
const props = createMockProps({ enabled: true })
renderHook(() => useExpandTool(props))
expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools)
expectToolRegistration(1, {
id: 'expand',
type: 'quick',
order: 12,
tooltip: 'code_block.expand',
onClick: expect.any(Function),
visible: expect.any(Function)
})
})
it('should not register tool when disabled', () => {
const props = createMockProps({ enabled: false })
renderHook(() => useExpandTool(props))
expect(mockRegisterTool).not.toHaveBeenCalled()
})
it('should re-register tool when expanded changes', () => {
const props = createMockProps({ expanded: false })
const { rerender } = renderHook((hookProps) => useExpandTool(hookProps), {
initialProps: props
})
expect(mockRegisterTool).toHaveBeenCalledTimes(1)
const firstCall = mockRegisterTool.mock.calls[0][0]
expect(firstCall.tooltip).toBe('code_block.expand')
// Change expanded to true and rerender
const newProps = { ...props, expanded: true }
rerender(newProps)
expect(mockRegisterTool).toHaveBeenCalledTimes(2)
const secondCall = mockRegisterTool.mock.calls[1][0]
expect(secondCall.tooltip).toBe('code_block.collapse')
})
})
describe('visibility behavior', () => {
it('should be visible when expandable is true', () => {
const props = createMockProps({ expandable: true })
renderHook(() => useExpandTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(registeredTool.visible()).toBe(true)
})
it('should not be visible when expandable is false', () => {
const props = createMockProps({ expandable: false })
renderHook(() => useExpandTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(registeredTool.visible()).toBe(false)
})
it('should not be visible when expandable is undefined', () => {
const props = createMockProps({ expandable: undefined })
renderHook(() => useExpandTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(registeredTool.visible()).toBe(false)
})
})
describe('toggle functionality', () => {
it('should execute toggle function when tool is clicked', () => {
const mockToggle = vi.fn()
const props = createMockProps({ toggle: mockToggle })
renderHook(() => useExpandTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
act(() => {
registeredTool.onClick()
})
expect(mockToggle).toHaveBeenCalledTimes(1)
})
})
describe('cleanup', () => {
it('should remove tool on unmount', () => {
const props = createMockProps()
const { unmount } = renderHook(() => useExpandTool(props))
unmount()
expect(mockRemoveTool).toHaveBeenCalledWith('expand')
})
})
describe('edge cases', () => {
it('should handle missing setTools gracefully', () => {
const props = createMockProps({ setTools: undefined })
expect(() => {
renderHook(() => useExpandTool(props))
}).not.toThrow()
// Should still call useToolManager (but won't actually register)
expect(mocks.useToolManager).toHaveBeenCalledWith(undefined)
})
it('should not break when toggle is undefined', () => {
const props = createMockProps({ toggle: undefined })
renderHook(() => useExpandTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(() => {
act(() => {
registeredTool.onClick()
})
}).not.toThrow()
})
})
})

View File

@ -0,0 +1,165 @@
import { useRunTool } from '@renderer/components/CodeToolbar/hooks/useRunTool'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock dependencies
const mocks = vi.hoisted(() => ({
i18n: {
t: vi.fn((key: string) => key)
},
useToolManager: vi.fn(),
TOOL_SPECS: {
run: {
id: 'run',
type: 'quick',
order: 11
}
}
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: mocks.i18n.t
})
}))
vi.mock('lucide-react', () => ({
CirclePlay: () => <div>CirclePlay</div>
}))
vi.mock('@renderer/components/Icons', () => ({
LoadingIcon: () => <div>Loading</div>
}))
vi.mock('@renderer/components/ActionTools', () => ({
TOOL_SPECS: mocks.TOOL_SPECS,
useToolManager: mocks.useToolManager
}))
const mockRegisterTool = vi.fn()
const mockRemoveTool = vi.fn()
mocks.useToolManager.mockImplementation(() => ({
registerTool: mockRegisterTool,
removeTool: mockRemoveTool
}))
describe('useRunTool', () => {
beforeEach(() => {
vi.clearAllMocks()
})
const createMockProps = (overrides: Partial<Parameters<typeof useRunTool>[0]> = {}) => {
const defaultProps = {
enabled: true,
isRunning: false,
onRun: vi.fn(),
setTools: vi.fn()
}
return { ...defaultProps, ...overrides }
}
const expectToolRegistration = (times: number, toolConfig?: object) => {
expect(mockRegisterTool).toHaveBeenCalledTimes(times)
if (times > 0 && toolConfig) {
expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig))
}
}
describe('tool registration', () => {
it('should not register tool when disabled', () => {
const props = createMockProps({ enabled: false })
renderHook(() => useRunTool(props))
expect(mockRegisterTool).not.toHaveBeenCalled()
})
it('should register run tool when enabled', () => {
const props = createMockProps({ enabled: true })
renderHook(() => useRunTool(props))
expectToolRegistration(1, {
id: 'run',
type: 'quick',
order: 11,
tooltip: 'code_block.run'
})
})
it('should re-register tool when isRunning changes', () => {
const props = createMockProps({ isRunning: false })
const { rerender } = renderHook((hookProps) => useRunTool(hookProps), {
initialProps: props
})
expect(mockRegisterTool).toHaveBeenCalledTimes(1)
const newProps = { ...props, isRunning: true }
rerender(newProps)
expect(mockRegisterTool).toHaveBeenCalledTimes(2)
})
})
describe('run functionality', () => {
it('should execute onRun when tool is clicked and not running', () => {
const mockOnRun = vi.fn()
const props = createMockProps({ onRun: mockOnRun, isRunning: false })
renderHook(() => useRunTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
act(() => {
registeredTool.onClick()
})
expect(mockOnRun).toHaveBeenCalledTimes(1)
})
it('should not execute onRun when tool is clicked and already running', () => {
const mockOnRun = vi.fn()
const props = createMockProps({ onRun: mockOnRun, isRunning: true })
renderHook(() => useRunTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
act(() => {
registeredTool.onClick()
})
expect(mockOnRun).not.toHaveBeenCalled()
})
})
describe('cleanup', () => {
it('should remove tool on unmount', () => {
const props = createMockProps()
const { unmount } = renderHook(() => useRunTool(props))
unmount()
expect(mockRemoveTool).toHaveBeenCalledWith('run')
})
})
describe('edge cases', () => {
it('should handle missing setTools gracefully', () => {
const props = createMockProps({ setTools: undefined })
expect(() => {
renderHook(() => useRunTool(props))
}).not.toThrow()
})
it('should not break when onRun is undefined', () => {
const props = createMockProps({ onRun: undefined })
renderHook(() => useRunTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(() => {
act(() => {
registeredTool.onClick()
})
}).not.toThrow()
})
})
})

View File

@ -0,0 +1,193 @@
import { useSaveTool } from '@renderer/components/CodeToolbar/hooks/useSaveTool'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock dependencies
const mocks = vi.hoisted(() => ({
i18n: {
t: vi.fn((key: string) => key)
},
useToolManager: vi.fn(),
useTemporaryValue: vi.fn(),
TOOL_SPECS: {
save: {
id: 'save',
type: 'core',
order: 14
}
}
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: mocks.i18n.t
})
}))
vi.mock('@renderer/components/ActionTools', () => ({
TOOL_SPECS: mocks.TOOL_SPECS,
useToolManager: mocks.useToolManager
}))
// Mock useTemporaryValue
const mockSetTemporaryValue = vi.fn()
mocks.useTemporaryValue.mockImplementation(() => [false, mockSetTemporaryValue])
vi.mock('@renderer/hooks/useTemporaryValue', () => ({
useTemporaryValue: mocks.useTemporaryValue
}))
// Mock useToolManager
const mockRegisterTool = vi.fn()
const mockRemoveTool = vi.fn()
mocks.useToolManager.mockImplementation(() => ({
registerTool: mockRegisterTool,
removeTool: mockRemoveTool
}))
vi.mock('lucide-react', () => ({
Check: () => <div data-testid="check-icon" />,
SaveIcon: () => <div data-testid="save-icon" />
}))
describe('useSaveTool', () => {
beforeEach(() => {
vi.clearAllMocks()
// Reset to default values
mocks.useTemporaryValue.mockImplementation(() => [false, mockSetTemporaryValue])
})
// Helper function to create mock props
const createMockProps = (overrides: Partial<Parameters<typeof useSaveTool>[0]> = {}) => {
const defaultProps = {
enabled: true,
sourceViewRef: { current: null },
setTools: vi.fn()
}
return { ...defaultProps, ...overrides }
}
// Helper function for tool registration assertions
const expectToolRegistration = (times: number, toolConfig?: object) => {
expect(mockRegisterTool).toHaveBeenCalledTimes(times)
if (times > 0 && toolConfig) {
expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig))
}
}
describe('tool registration', () => {
it('should register save tool when enabled', () => {
const props = createMockProps({ enabled: true })
renderHook(() => useSaveTool(props))
expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools)
expectToolRegistration(1, {
id: 'save',
type: 'core',
order: 14,
tooltip: 'code_block.edit.save.label',
onClick: expect.any(Function)
})
})
it('should not register tool when disabled', () => {
const props = createMockProps({ enabled: false })
renderHook(() => useSaveTool(props))
expect(mockRegisterTool).not.toHaveBeenCalled()
})
it('should re-register tool when saved state changes', () => {
// Initially not saved
mocks.useTemporaryValue.mockImplementation(() => [false, mockSetTemporaryValue])
const props = createMockProps()
const { rerender } = renderHook(() => useSaveTool(props))
expect(mockRegisterTool).toHaveBeenCalledTimes(1)
// Change to saved state and rerender
mocks.useTemporaryValue.mockImplementation(() => [true, mockSetTemporaryValue])
rerender()
expect(mockRegisterTool).toHaveBeenCalledTimes(2)
})
})
describe('save functionality', () => {
it('should execute save behavior when tool is clicked', () => {
const mockSave = vi.fn()
const mockEditorHandles = { save: mockSave }
const props = createMockProps({
sourceViewRef: { current: mockEditorHandles }
})
renderHook(() => useSaveTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
act(() => {
registeredTool.onClick()
})
expect(mockSave).toHaveBeenCalledTimes(1)
expect(mockSetTemporaryValue).toHaveBeenCalledWith(true)
})
it('should handle when sourceViewRef.current is null', () => {
const props = createMockProps({
sourceViewRef: { current: null }
})
renderHook(() => useSaveTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(() => {
act(() => {
registeredTool.onClick()
})
}).not.toThrow()
expect(mockSetTemporaryValue).toHaveBeenCalledWith(true)
})
it('should handle when sourceViewRef.current.save is undefined', () => {
const props = createMockProps({
sourceViewRef: { current: {} }
})
renderHook(() => useSaveTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(() => {
act(() => {
registeredTool.onClick()
})
}).not.toThrow()
expect(mockSetTemporaryValue).toHaveBeenCalledWith(true)
})
})
describe('cleanup', () => {
it('should remove tool on unmount', () => {
const props = createMockProps()
const { unmount } = renderHook(() => useSaveTool(props))
unmount()
expect(mockRemoveTool).toHaveBeenCalledWith('save')
})
})
describe('edge cases', () => {
it('should handle missing setTools gracefully', () => {
const props = createMockProps({ setTools: undefined })
expect(() => {
renderHook(() => useSaveTool(props))
}).not.toThrow()
// Should still call useToolManager (but won't actually register)
expect(mocks.useToolManager).toHaveBeenCalledWith(undefined)
})
})
})

View File

@ -0,0 +1,180 @@
import { ViewMode } from '@renderer/components/CodeBlockView/types'
import { useSplitViewTool } from '@renderer/components/CodeToolbar/hooks/useSplitViewTool'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock dependencies
const mocks = vi.hoisted(() => ({
i18n: {
t: vi.fn((key: string) => key)
},
useToolManager: vi.fn(),
TOOL_SPECS: {
'split-view': {
id: 'split-view',
type: 'quick',
order: 10
}
}
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: mocks.i18n.t
})
}))
vi.mock('@renderer/components/ActionTools', () => ({
TOOL_SPECS: mocks.TOOL_SPECS,
useToolManager: mocks.useToolManager
}))
// Mock useToolManager
const mockRegisterTool = vi.fn()
const mockRemoveTool = vi.fn()
mocks.useToolManager.mockImplementation(() => ({
registerTool: mockRegisterTool,
removeTool: mockRemoveTool
}))
describe('useSplitViewTool', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Helper function to create mock props
const createMockProps = (overrides: Partial<Parameters<typeof useSplitViewTool>[0]> = {}) => {
const defaultProps = {
enabled: true,
viewMode: 'special' as ViewMode,
onToggleSplitView: vi.fn(),
setTools: vi.fn()
}
return { ...defaultProps, ...overrides }
}
// Helper function for tool registration assertions
const expectToolRegistration = (times: number, toolConfig?: object) => {
expect(mockRegisterTool).toHaveBeenCalledTimes(times)
if (times > 0 && toolConfig) {
expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig))
}
}
describe('tool registration', () => {
it('should not register tool when disabled', () => {
const props = createMockProps({ enabled: false })
renderHook(() => useSplitViewTool(props))
expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools)
expect(mockRegisterTool).not.toHaveBeenCalled()
})
it('should register split view tool when enabled', () => {
const props = createMockProps({ enabled: true })
renderHook(() => useSplitViewTool(props))
expectToolRegistration(1, {
id: 'split-view',
type: 'quick',
order: 10,
tooltip: 'code_block.split.label',
onClick: expect.any(Function),
icon: expect.any(Object)
})
})
it('should show different tooltip when in split mode', () => {
const props = createMockProps({ viewMode: 'split' })
renderHook(() => useSplitViewTool(props))
expectToolRegistration(1, {
tooltip: 'code_block.split.restore'
})
})
it('should show different tooltip when not in split mode', () => {
const props = createMockProps({ viewMode: 'special' })
renderHook(() => useSplitViewTool(props))
expectToolRegistration(1, {
tooltip: 'code_block.split.label'
})
})
it('should re-register tool when viewMode changes', () => {
const props = createMockProps({ viewMode: 'special' })
const { rerender } = renderHook((hookProps) => useSplitViewTool(hookProps), {
initialProps: props
})
expect(mockRegisterTool).toHaveBeenCalledTimes(1)
// Change viewMode and rerender
const newProps = { ...props, viewMode: 'split' as ViewMode }
rerender(newProps)
// Should register tool again with updated state
expect(mockRegisterTool).toHaveBeenCalledTimes(2)
// Verify the new registration has correct tooltip
const secondRegistration = mockRegisterTool.mock.calls[1][0]
expect(secondRegistration.tooltip).toBe('code_block.split.restore')
})
})
describe('view mode switching', () => {
it('should call onToggleSplitView when tool is clicked', () => {
const mockOnToggleSplitView = vi.fn()
const props = createMockProps({
onToggleSplitView: mockOnToggleSplitView
})
renderHook(() => useSplitViewTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
act(() => {
registeredTool.onClick()
})
expect(mockOnToggleSplitView).toHaveBeenCalledTimes(1)
})
})
describe('cleanup', () => {
it('should remove tool on unmount', () => {
const props = createMockProps()
const { unmount } = renderHook(() => useSplitViewTool(props))
unmount()
expect(mockRemoveTool).toHaveBeenCalledWith('split-view')
})
})
describe('edge cases', () => {
it('should handle missing setTools gracefully', () => {
const props = createMockProps({ setTools: undefined })
expect(() => {
renderHook(() => useSplitViewTool(props))
}).not.toThrow()
// Should still call useToolManager (but won't actually register)
expect(mocks.useToolManager).toHaveBeenCalledWith(undefined)
})
it('should not break when onToggleSplitView is undefined', () => {
const props = createMockProps({ onToggleSplitView: undefined })
renderHook(() => useSplitViewTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(() => {
act(() => {
registeredTool.onClick()
})
}).not.toThrow()
})
})
})

View File

@ -0,0 +1,226 @@
import { ViewMode } from '@renderer/components/CodeBlockView/types'
import { useViewSourceTool } from '@renderer/components/CodeToolbar/hooks/useViewSourceTool'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock dependencies
const mocks = vi.hoisted(() => ({
i18n: {
t: vi.fn((key: string) => key)
},
useToolManager: vi.fn(),
TOOL_SPECS: {
edit: {
id: 'edit',
type: 'core',
order: 12
},
'view-source': {
id: 'view-source',
type: 'core',
order: 12
}
}
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: mocks.i18n.t
})
}))
vi.mock('@renderer/components/ActionTools', () => ({
TOOL_SPECS: mocks.TOOL_SPECS,
useToolManager: mocks.useToolManager
}))
const mockRegisterTool = vi.fn()
const mockRemoveTool = vi.fn()
mocks.useToolManager.mockImplementation(() => ({
registerTool: mockRegisterTool,
removeTool: mockRemoveTool
}))
describe('useViewSourceTool', () => {
beforeEach(() => {
vi.clearAllMocks()
})
const createMockProps = (overrides: Partial<Parameters<typeof useViewSourceTool>[0]> = {}) => {
const defaultProps = {
enabled: true,
editable: false,
viewMode: 'special' as ViewMode,
onViewModeChange: vi.fn(),
setTools: vi.fn()
}
return { ...defaultProps, ...overrides }
}
const expectToolRegistration = (times: number, toolConfig?: object) => {
expect(mockRegisterTool).toHaveBeenCalledTimes(times)
if (times > 0 && toolConfig) {
expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig))
}
}
describe('tool registration', () => {
it('should not register tool when disabled', () => {
const props = createMockProps({ enabled: false })
renderHook(() => useViewSourceTool(props))
expect(mockRegisterTool).not.toHaveBeenCalled()
})
it('should not register tool when in split mode', () => {
const props = createMockProps({ viewMode: 'split' })
renderHook(() => useViewSourceTool(props))
expect(mockRegisterTool).not.toHaveBeenCalled()
})
it('should register view-source tool when not editable', () => {
const props = createMockProps({ editable: false })
renderHook(() => useViewSourceTool(props))
expectToolRegistration(1, {
id: 'view-source',
type: 'core',
order: 12
})
})
it('should register edit tool when editable', () => {
const props = createMockProps({ editable: true })
renderHook(() => useViewSourceTool(props))
expectToolRegistration(1, {
id: 'edit',
type: 'core',
order: 12
})
})
it('should re-register tool when editable changes', () => {
const props = createMockProps({ editable: false })
const { rerender } = renderHook((hookProps) => useViewSourceTool(hookProps), {
initialProps: props
})
expect(mockRegisterTool).toHaveBeenCalledTimes(1)
const newProps = { ...props, editable: true }
rerender(newProps)
expect(mockRegisterTool).toHaveBeenCalledTimes(2)
expect(mockRemoveTool).toHaveBeenCalledWith('view-source')
})
})
describe('tooltip variations', () => {
it('should show correct tooltips for edit mode', () => {
const props = createMockProps({ editable: true, viewMode: 'source' })
renderHook(() => useViewSourceTool(props))
expectToolRegistration(1, {
tooltip: 'preview.label'
})
vi.clearAllMocks()
const propsSpecial = createMockProps({ editable: true, viewMode: 'special' })
renderHook(() => useViewSourceTool(propsSpecial))
expectToolRegistration(1, {
tooltip: 'code_block.edit.label'
})
})
it('should show correct tooltips for view-source mode', () => {
const props = createMockProps({ editable: false, viewMode: 'source' })
renderHook(() => useViewSourceTool(props))
expectToolRegistration(1, {
tooltip: 'preview.label'
})
vi.clearAllMocks()
const propsSpecial = createMockProps({ editable: false, viewMode: 'special' })
renderHook(() => useViewSourceTool(propsSpecial))
expectToolRegistration(1, {
tooltip: 'preview.source'
})
})
})
describe('view mode switching', () => {
it('should switch from special to source when tool is clicked', () => {
const mockOnViewModeChange = vi.fn()
const props = createMockProps({
viewMode: 'special',
onViewModeChange: mockOnViewModeChange
})
renderHook(() => useViewSourceTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
act(() => {
registeredTool.onClick()
})
expect(mockOnViewModeChange).toHaveBeenCalledWith('source')
})
it('should switch from source to special when tool is clicked', () => {
const mockOnViewModeChange = vi.fn()
const props = createMockProps({
viewMode: 'source',
onViewModeChange: mockOnViewModeChange
})
renderHook(() => useViewSourceTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
act(() => {
registeredTool.onClick()
})
expect(mockOnViewModeChange).toHaveBeenCalledWith('special')
})
})
describe('cleanup', () => {
it('should remove tool on unmount', () => {
const props = createMockProps()
const { unmount } = renderHook(() => useViewSourceTool(props))
unmount()
expect(mockRemoveTool).toHaveBeenCalledWith('view-source')
})
})
describe('edge cases', () => {
it('should handle missing setTools gracefully', () => {
const props = createMockProps({ setTools: undefined })
expect(() => {
renderHook(() => useViewSourceTool(props))
}).not.toThrow()
})
it('should not break when onViewModeChange is undefined', () => {
const props = createMockProps({ onViewModeChange: undefined })
renderHook(() => useViewSourceTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(() => {
act(() => {
registeredTool.onClick()
})
}).not.toThrow()
})
})
})

View File

@ -0,0 +1,190 @@
import { useWrapTool } from '@renderer/components/CodeToolbar/hooks/useWrapTool'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock dependencies
const mocks = vi.hoisted(() => ({
i18n: {
t: vi.fn((key: string) => key)
},
useToolManager: vi.fn(),
TOOL_SPECS: {
wrap: {
id: 'wrap',
type: 'quick',
order: 13
}
}
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: mocks.i18n.t
})
}))
vi.mock('@renderer/components/ActionTools', () => ({
TOOL_SPECS: mocks.TOOL_SPECS,
useToolManager: mocks.useToolManager
}))
// Mock useToolManager
const mockRegisterTool = vi.fn()
const mockRemoveTool = vi.fn()
mocks.useToolManager.mockImplementation(() => ({
registerTool: mockRegisterTool,
removeTool: mockRemoveTool
}))
vi.mock('lucide-react', () => ({
Text: () => <div data-testid="text-icon" />,
WrapText: () => <div data-testid="wrap-text-icon" />
}))
describe('useWrapTool', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Helper function to create mock props
const createMockProps = (overrides: Partial<Parameters<typeof useWrapTool>[0]> = {}) => {
const defaultProps = {
enabled: true,
unwrapped: false,
wrappable: true,
toggle: vi.fn(),
setTools: vi.fn()
}
return { ...defaultProps, ...overrides }
}
// Helper function for tool registration assertions
const expectToolRegistration = (times: number, toolConfig?: object) => {
expect(mockRegisterTool).toHaveBeenCalledTimes(times)
if (times > 0 && toolConfig) {
expect(mockRegisterTool).toHaveBeenCalledWith(expect.objectContaining(toolConfig))
}
}
describe('tool registration', () => {
it('should register wrap tool when enabled', () => {
const props = createMockProps({ enabled: true })
renderHook(() => useWrapTool(props))
expect(mocks.useToolManager).toHaveBeenCalledWith(props.setTools)
expectToolRegistration(1, {
id: 'wrap',
type: 'quick',
order: 13,
tooltip: 'code_block.wrap.off',
onClick: expect.any(Function),
visible: expect.any(Function)
})
})
it('should not register tool when disabled', () => {
const props = createMockProps({ enabled: false })
renderHook(() => useWrapTool(props))
expect(mockRegisterTool).not.toHaveBeenCalled()
})
it('should re-register tool when unwrapped changes', () => {
const props = createMockProps({ unwrapped: false })
const { rerender } = renderHook((hookProps) => useWrapTool(hookProps), {
initialProps: props
})
expect(mockRegisterTool).toHaveBeenCalledTimes(1)
const firstCall = mockRegisterTool.mock.calls[0][0]
expect(firstCall.tooltip).toBe('code_block.wrap.off')
// Change unwrapped to true and rerender
const newProps = { ...props, unwrapped: true }
rerender(newProps)
expect(mockRegisterTool).toHaveBeenCalledTimes(2)
const secondCall = mockRegisterTool.mock.calls[1][0]
expect(secondCall.tooltip).toBe('code_block.wrap.on')
})
})
describe('visibility behavior', () => {
it('should be visible when wrappable is true', () => {
const props = createMockProps({ wrappable: true })
renderHook(() => useWrapTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(registeredTool.visible()).toBe(true)
})
it('should not be visible when wrappable is false', () => {
const props = createMockProps({ wrappable: false })
renderHook(() => useWrapTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(registeredTool.visible()).toBe(false)
})
it('should not be visible when wrappable is undefined', () => {
const props = createMockProps({ wrappable: undefined })
renderHook(() => useWrapTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(registeredTool.visible()).toBe(false)
})
})
describe('toggle functionality', () => {
it('should execute toggle function when tool is clicked', () => {
const mockToggle = vi.fn()
const props = createMockProps({ toggle: mockToggle })
renderHook(() => useWrapTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
act(() => {
registeredTool.onClick()
})
expect(mockToggle).toHaveBeenCalledTimes(1)
})
})
describe('cleanup', () => {
it('should remove tool on unmount', () => {
const props = createMockProps()
const { unmount } = renderHook(() => useWrapTool(props))
unmount()
expect(mockRemoveTool).toHaveBeenCalledWith('wrap')
})
})
describe('edge cases', () => {
it('should handle missing setTools gracefully', () => {
const props = createMockProps({ setTools: undefined })
expect(() => {
renderHook(() => useWrapTool(props))
}).not.toThrow()
// Should still call useToolManager (but won't actually register)
expect(mocks.useToolManager).toHaveBeenCalledWith(undefined)
})
it('should not break when toggle is undefined', () => {
const props = createMockProps({ toggle: undefined })
renderHook(() => useWrapTool(props))
const registeredTool = mockRegisterTool.mock.calls[0][0]
expect(() => {
act(() => {
registeredTool.onClick()
})
}).not.toThrow()
})
})
})

View File

@ -0,0 +1,41 @@
import { ActionTool } from '@renderer/components/ActionTools'
import { Dropdown, Tooltip } from 'antd'
import { memo, useMemo } from 'react'
import { ToolWrapper } from './styles'
interface CodeToolButtonProps {
tool: ActionTool
}
const CodeToolButton = ({ tool }: CodeToolButtonProps) => {
const mainTool = useMemo(
() => (
<Tooltip key={tool.id} title={tool.tooltip} mouseEnterDelay={0.5} mouseLeaveDelay={0}>
<ToolWrapper onClick={tool.onClick}>{tool.icon}</ToolWrapper>
</Tooltip>
),
[tool]
)
if (tool.children?.length && tool.children.length > 0) {
return (
<Dropdown
menu={{
items: tool.children.map((child) => ({
key: child.id,
label: child.tooltip,
icon: child.icon,
onClick: child.onClick
}))
}}
trigger={['click']}>
{mainTool}
</Dropdown>
)
}
return mainTool
}
export default memo(CodeToolButton)

View File

@ -0,0 +1,8 @@
export * from './useCopyTool'
export * from './useDownloadTool'
export * from './useExpandTool'
export * from './useRunTool'
export * from './useSaveTool'
export * from './useSplitViewTool'
export * from './useViewSourceTool'
export * from './useWrapTool'

View File

@ -0,0 +1,89 @@
import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools'
import { CopyIcon } from '@renderer/components/Icons'
import { BasicPreviewHandles } from '@renderer/components/Preview'
import { useTemporaryValue } from '@renderer/hooks/useTemporaryValue'
import { Check, Image } from 'lucide-react'
import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
interface UseCopyToolProps {
showPreviewTools?: boolean
previewRef: React.RefObject<BasicPreviewHandles | null>
onCopySource: () => void
setTools: React.Dispatch<React.SetStateAction<ActionTool[]>>
}
export const useCopyTool = ({ showPreviewTools, previewRef, onCopySource, setTools }: UseCopyToolProps) => {
const [copied, setCopiedTemporarily] = useTemporaryValue(false)
const [copiedImage, setCopiedImageTemporarily] = useTemporaryValue(false)
const { t } = useTranslation()
const { registerTool, removeTool } = useToolManager(setTools)
const handleCopySource = useCallback(() => {
try {
onCopySource()
setCopiedTemporarily(true)
} catch (error) {
setCopiedTemporarily(false)
throw error
}
}, [onCopySource, setCopiedTemporarily])
const handleCopyImage = useCallback(() => {
try {
previewRef.current?.copy()
setCopiedImageTemporarily(true)
} catch (error) {
setCopiedImageTemporarily(false)
throw error
}
}, [previewRef, setCopiedImageTemporarily])
useEffect(() => {
const includePreviewTools = showPreviewTools && previewRef.current !== null
const baseTool = {
...TOOL_SPECS.copy,
icon: copied ? (
<Check className="tool-icon" color="var(--color-status-success)" />
) : (
<CopyIcon className="tool-icon" />
),
tooltip: t('code_block.copy.source'),
onClick: handleCopySource
}
const copyImageTool = {
...TOOL_SPECS['copy-image'],
icon: copiedImage ? (
<Check className="tool-icon" color="var(--color-status-success)" />
) : (
<Image className="tool-icon" />
),
tooltip: t('preview.copy.image'),
onClick: handleCopyImage
}
registerTool(baseTool)
if (includePreviewTools) {
registerTool(copyImageTool)
}
return () => {
removeTool(TOOL_SPECS.copy.id)
removeTool(TOOL_SPECS['copy-image'].id)
}
}, [
onCopySource,
registerTool,
removeTool,
t,
copied,
copiedImage,
handleCopySource,
handleCopyImage,
showPreviewTools,
previewRef
])
}

View File

@ -0,0 +1,61 @@
import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools'
import { FilePngIcon, FileSvgIcon } from '@renderer/components/Icons'
import { BasicPreviewHandles } from '@renderer/components/Preview'
import { Download, FileCode } from 'lucide-react'
import { useEffect } from 'react'
import { useTranslation } from 'react-i18next'
interface UseDownloadToolProps {
showPreviewTools?: boolean
previewRef: React.RefObject<BasicPreviewHandles | null>
onDownloadSource: () => void
setTools: React.Dispatch<React.SetStateAction<ActionTool[]>>
}
export const useDownloadTool = ({ showPreviewTools, previewRef, onDownloadSource, setTools }: UseDownloadToolProps) => {
const { t } = useTranslation()
const { registerTool, removeTool } = useToolManager(setTools)
useEffect(() => {
const includePreviewTools = showPreviewTools && previewRef.current !== null
const baseTool = {
...TOOL_SPECS.download,
icon: <Download className="tool-icon" />,
tooltip: includePreviewTools ? undefined : t('code_block.download.source')
}
if (includePreviewTools) {
registerTool({
...baseTool,
children: [
{
...TOOL_SPECS.download,
icon: <FileCode size={'1rem'} />,
tooltip: t('code_block.download.source'),
onClick: onDownloadSource
},
{
...TOOL_SPECS['download-svg'],
icon: <FileSvgIcon size={'1rem'} className="lucide" />,
tooltip: t('code_block.download.svg'),
onClick: () => previewRef.current?.download('svg')
},
{
...TOOL_SPECS['download-png'],
icon: <FilePngIcon size={'1rem'} className="lucide" />,
tooltip: t('code_block.download.png'),
onClick: () => previewRef.current?.download('png')
}
]
})
} else {
registerTool({
...baseTool,
onClick: onDownloadSource
})
}
return () => removeTool(TOOL_SPECS.download.id)
}, [onDownloadSource, registerTool, removeTool, t, showPreviewTools, previewRef])
}

View File

@ -0,0 +1,35 @@
import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools'
import { ChevronsDownUp, ChevronsUpDown } from 'lucide-react'
import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
interface UseExpandToolProps {
enabled?: boolean
expanded?: boolean
expandable?: boolean
toggle: () => void
setTools: React.Dispatch<React.SetStateAction<ActionTool[]>>
}
export const useExpandTool = ({ enabled, expanded, expandable, toggle, setTools }: UseExpandToolProps) => {
const { t } = useTranslation()
const { registerTool, removeTool } = useToolManager(setTools)
const handleToggle = useCallback(() => {
toggle?.()
}, [toggle])
useEffect(() => {
if (enabled) {
registerTool({
...TOOL_SPECS.expand,
icon: expanded ? <ChevronsDownUp className="tool-icon" /> : <ChevronsUpDown className="tool-icon" />,
tooltip: expanded ? t('code_block.collapse') : t('code_block.expand'),
visible: () => expandable ?? false,
onClick: handleToggle
})
}
return () => removeTool(TOOL_SPECS.expand.id)
}, [enabled, expandable, expanded, handleToggle, registerTool, removeTool, t])
}

View File

@ -0,0 +1,30 @@
import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools'
import { LoadingIcon } from '@renderer/components/Icons'
import { CirclePlay } from 'lucide-react'
import { useEffect } from 'react'
import { useTranslation } from 'react-i18next'
interface UseRunToolProps {
enabled: boolean
isRunning: boolean
onRun: () => void
setTools: React.Dispatch<React.SetStateAction<ActionTool[]>>
}
export const useRunTool = ({ enabled, isRunning, onRun, setTools }: UseRunToolProps) => {
const { t } = useTranslation()
const { registerTool, removeTool } = useToolManager(setTools)
useEffect(() => {
if (!enabled) return
registerTool({
...TOOL_SPECS.run,
icon: isRunning ? <LoadingIcon className="tool-icon" /> : <CirclePlay className="tool-icon" />,
tooltip: t('code_block.run'),
onClick: () => !isRunning && onRun?.()
})
return () => removeTool(TOOL_SPECS.run.id)
}, [enabled, isRunning, onRun, registerTool, removeTool, t])
}

View File

@ -0,0 +1,40 @@
import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools'
import { CodeEditorHandles } from '@renderer/components/CodeEditor'
import { useTemporaryValue } from '@renderer/hooks/useTemporaryValue'
import { Check, SaveIcon } from 'lucide-react'
import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
interface UseSaveToolProps {
enabled?: boolean
sourceViewRef: React.RefObject<CodeEditorHandles | null>
setTools: React.Dispatch<React.SetStateAction<ActionTool[]>>
}
export const useSaveTool = ({ enabled, sourceViewRef, setTools }: UseSaveToolProps) => {
const [saved, setSavedTemporarily] = useTemporaryValue(false)
const { t } = useTranslation()
const { registerTool, removeTool } = useToolManager(setTools)
const handleSave = useCallback(() => {
sourceViewRef.current?.save?.()
setSavedTemporarily(true)
}, [sourceViewRef, setSavedTemporarily])
useEffect(() => {
if (enabled) {
registerTool({
...TOOL_SPECS.save,
icon: saved ? (
<Check className="tool-icon" color="var(--color-status-success)" />
) : (
<SaveIcon className="tool-icon" />
),
tooltip: t('code_block.edit.save.label'),
onClick: handleSave
})
}
return () => removeTool(TOOL_SPECS.save.id)
}, [enabled, handleSave, registerTool, removeTool, saved, t])
}

View File

@ -0,0 +1,34 @@
import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools'
import { ViewMode } from '@renderer/components/CodeBlockView/types'
import { Square, SquareSplitHorizontal } from 'lucide-react'
import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
interface UseSplitViewToolProps {
enabled: boolean
viewMode: ViewMode
onToggleSplitView: () => void
setTools: React.Dispatch<React.SetStateAction<ActionTool[]>>
}
export const useSplitViewTool = ({ enabled, viewMode, onToggleSplitView, setTools }: UseSplitViewToolProps) => {
const { t } = useTranslation()
const { registerTool, removeTool } = useToolManager(setTools)
const handleToggleSplitView = useCallback(() => {
onToggleSplitView?.()
}, [onToggleSplitView])
useEffect(() => {
if (!enabled) return
registerTool({
...TOOL_SPECS['split-view'],
icon: viewMode === 'split' ? <Square className="tool-icon" /> : <SquareSplitHorizontal className="tool-icon" />,
tooltip: viewMode === 'split' ? t('code_block.split.restore') : t('code_block.split.label'),
onClick: handleToggleSplitView
})
return () => removeTool(TOOL_SPECS['split-view'].id)
}, [enabled, viewMode, registerTool, removeTool, t, handleToggleSplitView])
}

View File

@ -0,0 +1,53 @@
import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools'
import { ViewMode } from '@renderer/components/CodeBlockView/types'
import { CodeXml, Eye, SquarePen } from 'lucide-react'
import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
interface UseViewSourceToolProps {
enabled: boolean
editable: boolean
viewMode: ViewMode
onViewModeChange: (mode: ViewMode) => void
setTools: React.Dispatch<React.SetStateAction<ActionTool[]>>
}
export const useViewSourceTool = ({
enabled,
editable,
viewMode,
onViewModeChange,
setTools
}: UseViewSourceToolProps) => {
const { t } = useTranslation()
const { registerTool, removeTool } = useToolManager(setTools)
const handleToggleViewMode = useCallback(() => {
const newMode = viewMode === 'source' ? 'special' : 'source'
onViewModeChange?.(newMode)
}, [viewMode, onViewModeChange])
useEffect(() => {
if (!enabled || viewMode === 'split') return
const toolSpec = editable ? TOOL_SPECS.edit : TOOL_SPECS['view-source']
if (editable) {
registerTool({
...toolSpec,
icon: viewMode === 'source' ? <Eye className="tool-icon" /> : <SquarePen className="tool-icon" />,
tooltip: viewMode === 'source' ? t('preview.label') : t('code_block.edit.label'),
onClick: handleToggleViewMode
})
} else {
registerTool({
...toolSpec,
icon: viewMode === 'source' ? <Eye className="tool-icon" /> : <CodeXml className="tool-icon" />,
tooltip: viewMode === 'source' ? t('preview.label') : t('preview.source'),
onClick: handleToggleViewMode
})
}
return () => removeTool(toolSpec.id)
}, [enabled, editable, viewMode, registerTool, removeTool, t, handleToggleViewMode])
}

View File

@ -0,0 +1,35 @@
import { ActionTool, TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools'
import { Text as UnWrapIcon, WrapText as WrapIcon } from 'lucide-react'
import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
interface UseWrapToolProps {
enabled?: boolean
unwrapped?: boolean
wrappable?: boolean
toggle: () => void
setTools: React.Dispatch<React.SetStateAction<ActionTool[]>>
}
export const useWrapTool = ({ enabled, unwrapped, wrappable, toggle, setTools }: UseWrapToolProps) => {
const { t } = useTranslation()
const { registerTool, removeTool } = useToolManager(setTools)
const handleToggle = useCallback(() => {
toggle?.()
}, [toggle])
useEffect(() => {
if (enabled) {
registerTool({
...TOOL_SPECS.wrap,
icon: unwrapped ? <WrapIcon className="tool-icon" /> : <UnWrapIcon className="tool-icon" />,
tooltip: unwrapped ? t('code_block.wrap.on') : t('code_block.wrap.off'),
visible: () => wrappable ?? false,
onClick: handleToggle
})
}
return () => removeTool(TOOL_SPECS.wrap.id)
}, [enabled, handleToggle, registerTool, removeTool, t, unwrapped, wrappable])
}

View File

@ -1,5 +1,3 @@
export * from './constants'
export * from './hook'
export * from './toolbar'
export * from './types'
export * from './usePreviewTools'
export { default as CodeToolButton } from './button'
export * from './hooks'
export { default as CodeToolbar } from './toolbar'

View File

@ -0,0 +1,35 @@
import styled from 'styled-components'
export const ToolWrapper = styled.div`
display: flex;
align-items: center;
justify-content: center;
width: 24px;
height: 24px;
border-radius: 4px;
cursor: pointer;
user-select: none;
transition: all 0.2s ease;
color: var(--color-text-3);
&:hover {
background-color: var(--color-background-soft);
.tool-icon {
color: var(--color-text-1);
}
}
&.active {
color: var(--color-primary);
.tool-icon {
color: var(--color-primary);
}
}
/* For Lucide icons */
.tool-icon {
width: 14px;
height: 14px;
color: var(--color-text-3);
}
`

View File

@ -1,25 +1,15 @@
import { ActionTool } from '@renderer/components/ActionTools'
import { HStack } from '@renderer/components/Layout'
import { Tooltip } from 'antd'
import { EllipsisVertical } from 'lucide-react'
import React, { memo, useMemo, useState } from 'react'
import { memo, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import { CodeTool } from './types'
import CodeToolButton from './button'
import { ToolWrapper } from './styles'
interface CodeToolButtonProps {
tool: CodeTool
}
const CodeToolButton: React.FC<CodeToolButtonProps> = memo(({ tool }) => {
return (
<Tooltip key={tool.id} title={tool.tooltip} mouseEnterDelay={0.5}>
<ToolWrapper onClick={() => tool.onClick()}>{tool.icon}</ToolWrapper>
</Tooltip>
)
})
export const CodeToolbar: React.FC<{ tools: CodeTool[] }> = memo(({ tools }) => {
const CodeToolbar = ({ tools }: { tools: ActionTool[] }) => {
const [showQuickTools, setShowQuickTools] = useState(false)
const { t } = useTranslation()
@ -51,7 +41,7 @@ export const CodeToolbar: React.FC<{ tools: CodeTool[] }> = memo(({ tools }) =>
{quickTools.length > 1 && (
<Tooltip title={t('code_block.more')} mouseEnterDelay={0.5}>
<ToolWrapper onClick={() => setShowQuickTools(!showQuickTools)} className={showQuickTools ? 'active' : ''}>
<EllipsisVertical className="icon" />
<EllipsisVertical className="tool-icon" />
</ToolWrapper>
</Tooltip>
)}
@ -63,7 +53,7 @@ export const CodeToolbar: React.FC<{ tools: CodeTool[] }> = memo(({ tools }) =>
</ToolbarWrapper>
</StickyWrapper>
)
})
}
const StickyWrapper = styled.div`
position: sticky;
@ -80,36 +70,4 @@ const ToolbarWrapper = styled(HStack)`
gap: 4px;
`
const ToolWrapper = styled.div`
display: flex;
align-items: center;
justify-content: center;
width: 24px;
height: 24px;
border-radius: 4px;
cursor: pointer;
user-select: none;
transition: all 0.2s ease;
color: var(--color-text-3);
&:hover {
background-color: var(--color-background-soft);
.icon {
color: var(--color-text-1);
}
}
&.active {
color: var(--color-primary);
.icon {
color: var(--color-primary);
}
}
/* For Lucide icons */
.icon {
width: 14px;
height: 14px;
color: var(--color-text-3);
}
`
export default memo(CodeToolbar)

View File

@ -1,25 +0,0 @@
/**
*
*/
export interface CodeToolSpec {
id: string
type: 'core' | 'quick'
order: number
}
/**
*
* @param id
* @param type
* @param icon
* @param tooltip
* @param condition
* @param onClick
* @param order
*/
export interface CodeTool extends CodeToolSpec {
icon: React.ReactNode
tooltip: string
visible?: () => boolean
onClick: () => void
}

View File

@ -1,363 +0,0 @@
import { loggerService } from '@logger'
import { download } from '@renderer/utils/download'
import { FileImage, ZoomIn, ZoomOut } from 'lucide-react'
import { RefObject, useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { DownloadPngIcon, DownloadSvgIcon } from '../Icons/DownloadIcons'
import { TOOL_SPECS } from './constants'
import { useCodeTool } from './hook'
import { CodeTool } from './types'
const logger = loggerService.withContext('usePreviewToolHandlers')
// 预编译正则表达式用于查询位置
const TRANSFORM_REGEX = /translate\((-?\d+\.?\d*)px,\s*(-?\d+\.?\d*)px\)/
/**
* 使Hook
*
*/
export const usePreviewToolHandlers = (
containerRef: RefObject<HTMLDivElement | null>,
options: {
prefix: string
imgSelector: string
enableWheelZoom?: boolean
customDownloader?: (format: 'svg' | 'png') => void
}
) => {
const transformRef = useRef({ scale: 1, x: 0, y: 0 }) // 管理变换状态
const [renderTrigger, setRenderTrigger] = useState(0) // 仅用于触发组件重渲染的状态
const { imgSelector, prefix, customDownloader, enableWheelZoom } = options
const { t } = useTranslation()
// 创建选择器函数
const getImgElement = useCallback(() => {
if (!containerRef.current) return null
// 优先尝试从 Shadow DOM 中查找
const shadowRoot = containerRef.current.shadowRoot
if (shadowRoot) {
return shadowRoot.querySelector(imgSelector) as SVGElement | null
}
// 降级到常规 DOM 查找
return containerRef.current.querySelector(imgSelector) as SVGElement | null
}, [containerRef, imgSelector])
// 查询当前位置
const getCurrentPosition = useCallback(() => {
const imgElement = getImgElement()
if (!imgElement) return { x: transformRef.current.x, y: transformRef.current.y }
const transform = imgElement.style.transform
if (!transform || transform === 'none') return { x: transformRef.current.x, y: transformRef.current.y }
const match = transform.match(TRANSFORM_REGEX)
if (match && match.length >= 3) {
return {
x: parseFloat(match[1]),
y: parseFloat(match[2])
}
}
return { x: transformRef.current.x, y: transformRef.current.y }
}, [getImgElement])
// 平移缩放变换
const applyTransform = useCallback((element: SVGElement | null, x: number, y: number, scale: number) => {
if (!element) return
element.style.transformOrigin = 'top left'
element.style.transform = `translate(${x}px, ${y}px) scale(${scale})`
}, [])
// 拖拽平移支持
useEffect(() => {
const container = containerRef.current
if (!container) return
let isDragging = false
const startPos = { x: 0, y: 0 }
const startOffset = { x: 0, y: 0 }
const onMouseDown = (e: MouseEvent) => {
if (e.button !== 0) return // 只响应左键
// 更新当前实际位置
const position = getCurrentPosition()
transformRef.current.x = position.x
transformRef.current.y = position.y
isDragging = true
startPos.x = e.clientX
startPos.y = e.clientY
startOffset.x = position.x
startOffset.y = position.y
container.style.cursor = 'grabbing'
e.preventDefault()
}
const onMouseMove = (e: MouseEvent) => {
if (!isDragging) return
const dx = e.clientX - startPos.x
const dy = e.clientY - startPos.y
const newX = startOffset.x + dx
const newY = startOffset.y + dy
const imgElement = getImgElement()
applyTransform(imgElement, newX, newY, transformRef.current.scale)
e.preventDefault()
}
const stopDrag = () => {
if (!isDragging) return
// 更新位置但不立即触发状态变更
const position = getCurrentPosition()
transformRef.current.x = position.x
transformRef.current.y = position.y
// 只触发一次渲染以保持组件状态同步
setRenderTrigger((prev) => prev + 1)
isDragging = false
container.style.cursor = 'default'
}
// 绑定到document以确保拖拽可以在鼠标离开容器后继续
container.addEventListener('mousedown', onMouseDown)
document.addEventListener('mousemove', onMouseMove)
document.addEventListener('mouseup', stopDrag)
return () => {
container.removeEventListener('mousedown', onMouseDown)
document.removeEventListener('mousemove', onMouseMove)
document.removeEventListener('mouseup', stopDrag)
}
}, [containerRef, getCurrentPosition, getImgElement, applyTransform])
// 缩放处理函数
const handleZoom = useCallback(
(delta: number) => {
const newScale = Math.max(0.1, Math.min(3, transformRef.current.scale + delta))
transformRef.current.scale = newScale
const imgElement = getImgElement()
applyTransform(imgElement, transformRef.current.x, transformRef.current.y, newScale)
// 触发重渲染以保持组件状态同步
setRenderTrigger((prev) => prev + 1)
},
[getImgElement, applyTransform]
)
// 滚轮缩放支持
useEffect(() => {
if (!enableWheelZoom || !containerRef.current) return
const container = containerRef.current
const handleWheel = (e: WheelEvent) => {
if ((e.ctrlKey || e.metaKey) && e.target) {
// 确认事件发生在容器内部
if (container.contains(e.target as Node)) {
const delta = e.deltaY < 0 ? 0.1 : -0.1
handleZoom(delta)
}
}
}
container.addEventListener('wheel', handleWheel, { passive: true })
return () => container.removeEventListener('wheel', handleWheel)
}, [containerRef, handleZoom, enableWheelZoom])
// 复制图像处理函数
const handleCopyImage = useCallback(async () => {
try {
const imgElement = getImgElement()
if (!imgElement) return
const canvas = document.createElement('canvas')
const ctx = canvas.getContext('2d')
const img = new Image()
img.crossOrigin = 'anonymous'
const viewBox = imgElement.getAttribute('viewBox')?.split(' ').map(Number) || []
const width = viewBox[2] || imgElement.clientWidth || imgElement.getBoundingClientRect().width
const height = viewBox[3] || imgElement.clientHeight || imgElement.getBoundingClientRect().height
const svgData = new XMLSerializer().serializeToString(imgElement)
const svgBase64 = `data:image/svg+xml;base64,${btoa(unescape(encodeURIComponent(svgData)))}`
img.onload = async () => {
const scale = 3
canvas.width = width * scale
canvas.height = height * scale
if (ctx) {
ctx.scale(scale, scale)
ctx.drawImage(img, 0, 0, width, height)
const blob = await new Promise<Blob>((resolve) => canvas.toBlob((b) => resolve(b!), 'image/png'))
await navigator.clipboard.write([new ClipboardItem({ 'image/png': blob })])
window.message.success(t('message.copy.success'))
}
}
img.src = svgBase64
} catch (error) {
logger.error('Copy failed:', error as Error)
window.message.error(t('message.copy.failed'))
}
}, [getImgElement, t])
// 下载处理函数
const handleDownload = useCallback(
(format: 'svg' | 'png') => {
// 如果有自定义下载器,使用自定义实现
if (customDownloader) {
customDownloader(format)
return
}
try {
const imgElement = getImgElement()
if (!imgElement) return
const timestamp = Date.now()
if (format === 'svg') {
const svgData = new XMLSerializer().serializeToString(imgElement)
const blob = new Blob([svgData], { type: 'image/svg+xml' })
const url = URL.createObjectURL(blob)
download(url, `${prefix}-${timestamp}.svg`)
URL.revokeObjectURL(url)
} else if (format === 'png') {
const canvas = document.createElement('canvas')
const ctx = canvas.getContext('2d')
const img = new Image()
img.crossOrigin = 'anonymous'
const viewBox = imgElement.getAttribute('viewBox')?.split(' ').map(Number) || []
const width = viewBox[2] || imgElement.clientWidth || imgElement.getBoundingClientRect().width
const height = viewBox[3] || imgElement.clientHeight || imgElement.getBoundingClientRect().height
const svgData = new XMLSerializer().serializeToString(imgElement)
const svgBase64 = `data:image/svg+xml;base64,${btoa(unescape(encodeURIComponent(svgData)))}`
img.onload = () => {
const scale = 3
canvas.width = width * scale
canvas.height = height * scale
if (ctx) {
ctx.scale(scale, scale)
ctx.drawImage(img, 0, 0, width, height)
}
canvas.toBlob((blob) => {
if (blob) {
const pngUrl = URL.createObjectURL(blob)
download(pngUrl, `${prefix}-${timestamp}.png`)
URL.revokeObjectURL(pngUrl)
}
}, 'image/png')
}
img.src = svgBase64
}
} catch (error) {
logger.error('Download failed:', error as Error)
}
},
[getImgElement, prefix, customDownloader]
)
return {
scale: transformRef.current.scale,
handleZoom,
handleCopyImage,
handleDownload,
renderTrigger // 导出渲染触发器,万一要用
}
}
export interface PreviewToolsOptions {
setTools?: (value: React.SetStateAction<CodeTool[]>) => void
handleZoom?: (delta: number) => void
handleCopyImage?: () => Promise<void>
handleDownload?: (format: 'svg' | 'png') => void
}
/**
* Hook
*/
export const usePreviewTools = ({ setTools, handleZoom, handleCopyImage, handleDownload }: PreviewToolsOptions) => {
const { t } = useTranslation()
const { registerTool, removeTool } = useCodeTool(setTools)
useEffect(() => {
// 根据提供的功能有选择性地注册工具
if (handleZoom) {
// 放大工具
registerTool({
...TOOL_SPECS['zoom-in'],
icon: <ZoomIn className="icon" />,
tooltip: t('code_block.preview.zoom_in'),
onClick: () => handleZoom(0.1)
})
// 缩小工具
registerTool({
...TOOL_SPECS['zoom-out'],
icon: <ZoomOut className="icon" />,
tooltip: t('code_block.preview.zoom_out'),
onClick: () => handleZoom(-0.1)
})
}
if (handleCopyImage) {
// 复制图片工具
registerTool({
...TOOL_SPECS['copy-image'],
icon: <FileImage className="icon" />,
tooltip: t('code_block.preview.copy.image'),
onClick: handleCopyImage
})
}
if (handleDownload) {
// 下载 SVG 工具
registerTool({
...TOOL_SPECS['download-svg'],
icon: <DownloadSvgIcon />,
tooltip: t('code_block.download.svg'),
onClick: () => handleDownload('svg')
})
// 下载 PNG 工具
registerTool({
...TOOL_SPECS['download-png'],
icon: <DownloadPngIcon />,
tooltip: t('code_block.download.png'),
onClick: () => handleDownload('png')
})
}
// 清理函数
return () => {
if (handleZoom) {
removeTool(TOOL_SPECS['zoom-in'].id)
removeTool(TOOL_SPECS['zoom-out'].id)
}
if (handleCopyImage) {
removeTool(TOOL_SPECS['copy-image'].id)
}
if (handleDownload) {
removeTool(TOOL_SPECS['download-svg'].id)
removeTool(TOOL_SPECS['download-png'].id)
}
}
}, [handleCopyImage, handleDownload, handleZoom, registerTool, removeTool, t])
}

View File

@ -1,4 +1,4 @@
import { TOOL_SPECS, useCodeTool } from '@renderer/components/CodeToolbar'
import { MAX_COLLAPSED_CODE_HEIGHT } from '@renderer/config/constant'
import { useCodeStyle } from '@renderer/context/CodeStyleProvider'
import { useCodeHighlight } from '@renderer/hooks/useCodeHighlight'
import { useSettings } from '@renderer/hooks/useSettings'
@ -6,82 +6,34 @@ import { uuid } from '@renderer/utils'
import { getReactStyleFromToken } from '@renderer/utils/shiki'
import { useVirtualizer } from '@tanstack/react-virtual'
import { debounce } from 'lodash'
import { ChevronsDownUp, ChevronsUpDown, Text as UnWrapIcon, WrapText as WrapIcon } from 'lucide-react'
import React, { memo, useCallback, useEffect, useLayoutEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import React, { memo, useCallback, useEffect, useLayoutEffect, useMemo, useRef } from 'react'
import { ThemedToken } from 'shiki/core'
import styled from 'styled-components'
import { BasicPreviewProps } from './types'
interface CodePreviewProps extends BasicPreviewProps {
interface CodeViewerProps {
language: string
children: string
expanded?: boolean
unwrapped?: boolean
onHeightChange?: (scrollHeight: number) => void
className?: string
}
const MAX_COLLAPSE_HEIGHT = 350
/**
* Shiki
* - shiki tokenizer
* - 使
* -
*/
const CodePreview = ({ children, language, setTools }: CodePreviewProps) => {
const { codeShowLineNumbers, fontSize, codeCollapsible, codeWrappable } = useSettings()
const CodeViewer = ({ children, language, expanded, unwrapped, onHeightChange, className }: CodeViewerProps) => {
const { codeShowLineNumbers, fontSize } = useSettings()
const { getShikiPreProperties, isShikiThemeDark } = useCodeStyle()
const [expandOverride, setExpandOverride] = useState(!codeCollapsible)
const [unwrapOverride, setUnwrapOverride] = useState(!codeWrappable)
const shikiThemeRef = useRef<HTMLDivElement>(null)
const scrollerRef = useRef<HTMLDivElement>(null)
const callerId = useRef(`${Date.now()}-${uuid()}`).current
const rawLines = useMemo(() => (typeof children === 'string' ? children.trimEnd().split('\n') : []), [children])
const { t } = useTranslation()
const { registerTool, removeTool } = useCodeTool(setTools)
// 展开/折叠工具
useEffect(() => {
registerTool({
...TOOL_SPECS.expand,
icon: expandOverride ? <ChevronsDownUp className="icon" /> : <ChevronsUpDown className="icon" />,
tooltip: expandOverride ? t('code_block.collapse') : t('code_block.expand'),
visible: () => {
const scrollHeight = scrollerRef.current?.scrollHeight
return codeCollapsible && (scrollHeight ?? 0) > MAX_COLLAPSE_HEIGHT
},
onClick: () => setExpandOverride((prev) => !prev)
})
return () => removeTool(TOOL_SPECS.expand.id)
}, [codeCollapsible, expandOverride, registerTool, removeTool, t])
// 自动换行工具
useEffect(() => {
registerTool({
...TOOL_SPECS.wrap,
icon: unwrapOverride ? <WrapIcon className="icon" /> : <UnWrapIcon className="icon" />,
tooltip: unwrapOverride ? t('code_block.wrap.on') : t('code_block.wrap.off'),
visible: () => codeWrappable,
onClick: () => setUnwrapOverride((prev) => !prev)
})
return () => removeTool(TOOL_SPECS.wrap.id)
}, [codeWrappable, unwrapOverride, registerTool, removeTool, t])
// 重置用户操作(可以考虑移除,保持用户操作结果)
useEffect(() => {
setExpandOverride(!codeCollapsible)
}, [codeCollapsible])
// 重置用户操作(可以考虑移除,保持用户操作结果)
useEffect(() => {
setUnwrapOverride(!codeWrappable)
}, [codeWrappable])
const shouldCollapse = useMemo(() => codeCollapsible && !expandOverride, [codeCollapsible, expandOverride])
const shouldWrap = useMemo(() => codeWrappable && !unwrapOverride, [codeWrappable, unwrapOverride])
// 计算行号数字位数
const gutterDigits = useMemo(
() => (codeShowLineNumbers ? Math.max(rawLines.length.toString().length, 1) : 0),
@ -90,10 +42,12 @@ const CodePreview = ({ children, language, setTools }: CodePreviewProps) => {
// 设置 pre 标签属性
useLayoutEffect(() => {
let mounted = true
getShikiPreProperties(language).then((properties) => {
if (!mounted) return
const shikiTheme = shikiThemeRef.current
if (shikiTheme) {
shikiTheme.className = `${properties.class || 'shiki'}`
shikiTheme.className = `${properties.class || 'shiki'} code-viewer ${className ?? ''}`
// 滚动条适应 shiki 主题变化而非应用主题
shikiTheme.classList.add(isShikiThemeDark ? 'shiki-dark' : 'shiki-light')
@ -103,7 +57,10 @@ const CodePreview = ({ children, language, setTools }: CodePreviewProps) => {
shikiTheme.tabIndex = properties.tabindex
}
})
}, [language, getShikiPreProperties, isShikiThemeDark])
return () => {
mounted = false
}
}, [language, getShikiPreProperties, isShikiThemeDark, className])
// Virtualizer 配置
const getScrollElement = useCallback(() => scrollerRef.current, [])
@ -140,19 +97,25 @@ const CodePreview = ({ children, language, setTools }: CodePreviewProps) => {
}
}, [virtualItems, debouncedHighlightLines])
// Report scrollHeight when it might change
useLayoutEffect(() => {
onHeightChange?.(scrollerRef.current?.scrollHeight ?? 0)
}, [rawLines.length, onHeightChange])
return (
<div ref={shikiThemeRef}>
<ScrollContainer
ref={scrollerRef}
className="shiki-scroller"
$wrap={shouldWrap}
$wrap={!unwrapped}
$expanded={expanded}
$lineHeight={estimateSize()}
style={
{
'--gutter-width': `${gutterDigits}ch`,
fontSize: `${fontSize - 1}px`,
maxHeight: shouldCollapse ? MAX_COLLAPSE_HEIGHT : undefined,
overflowY: shouldCollapse ? 'auto' : 'hidden'
maxHeight: expanded ? undefined : MAX_COLLAPSED_CODE_HEIGHT,
overflowY: expanded ? 'hidden' : 'auto'
} as React.CSSProperties
}>
<div
@ -170,7 +133,7 @@ const CodePreview = ({ children, language, setTools }: CodePreviewProps) => {
width: '100%',
transform: `translateY(${virtualItems[0]?.start ?? 0}px)`
}}>
{virtualizer.getVirtualItems().map((virtualItem) => (
{virtualItems.map((virtualItem) => (
<div key={virtualItem.key} data-index={virtualItem.index} ref={virtualizer.measureElement}>
<VirtualizedRow
rawLine={rawLines[virtualItem.index]}
@ -187,7 +150,7 @@ const CodePreview = ({ children, language, setTools }: CodePreviewProps) => {
)
}
CodePreview.displayName = 'CodePreview'
CodeViewer.displayName = 'CodeViewer'
const plainTokenStyle = {
color: 'inherit',
@ -259,20 +222,24 @@ VirtualizedRow.displayName = 'VirtualizedRow'
const ScrollContainer = styled.div<{
$wrap?: boolean
$expanded?: boolean
$lineHeight?: number
}>`
display: block;
overflow-x: auto;
position: relative;
border-radius: inherit;
padding: 0.5em 1em;
/* padding right 下沉到 line-content 中 */
padding: 0.5em 0 0.5em 1em;
.line {
display: flex;
align-items: flex-start;
width: 100%;
line-height: ${(props) => props.$lineHeight}px;
contain: content;
/* contain 优化 wrap 时滚动性能will-change 优化 unwrap 时滚动性能 */
contain: ${(props) => (props.$wrap ? 'content' : 'none')};
will-change: ${(props) => (!props.$wrap && !props.$expanded ? 'transform' : 'auto')};
.line-number {
width: var(--gutter-width, 1.2ch);
@ -288,6 +255,7 @@ const ScrollContainer = styled.div<{
.line-content {
flex: 1;
padding-right: 1em;
* {
white-space: ${(props) => (props.$wrap ? 'pre-wrap' : 'pre')};
overflow-wrap: ${(props) => (props.$wrap ? 'break-word' : 'normal')};
@ -296,4 +264,4 @@ const ScrollContainer = styled.div<{
}
`
export default memo(CodePreview)
export default memo(CodeViewer)

View File

@ -78,7 +78,7 @@ const CustomCollapse: FC<CustomCollapseProps> = ({
style={collapseStyle}
defaultActiveKey={defaultActiveKey}
activeKey={activeKey}
destroyInactivePanel={destroyInactivePanel}
destroyOnHidden={destroyInactivePanel}
collapsible={collapsible}
onChange={(keys) => {
setActiveKeys(keys)

View File

@ -3,7 +3,7 @@
import { render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import { DraggableList } from '../DraggableList'
import { DraggableList } from '../'
// mock @hello-pangea/dnd 组件
vi.mock('@hello-pangea/dnd', () => {

View File

@ -3,7 +3,7 @@
import { render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import DraggableVirtualList from '../DraggableList/virtual-list'
import { DraggableVirtualList } from '../'
// Mock 依赖项
vi.mock('@hello-pangea/dnd', () => ({

View File

@ -0,0 +1,151 @@
import { DropResult } from '@hello-pangea/dnd'
import { act, renderHook } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import { useDraggableReorder } from '../useDraggableReorder'
// 辅助函数和模拟数据
const createMockItem = (id: number) => ({ id: `item-${id}`, name: `Item ${id}` })
const mockOriginalList = [createMockItem(1), createMockItem(2), createMockItem(3), createMockItem(4), createMockItem(5)]
/**
* DropResult
* @param sourceIndex -
* @param destIndex -
* @param draggableId - ID itemKey
*/
const createMockDropResult = (sourceIndex: number, destIndex: number | null, draggableId: string): DropResult => ({
reason: 'DROP',
source: { index: sourceIndex, droppableId: 'droppable' },
destination: destIndex !== null ? { index: destIndex, droppableId: 'droppable' } : null,
combine: null,
mode: 'FLUID',
draggableId,
type: 'DEFAULT'
})
describe('useDraggableReorder', () => {
describe('reorder', () => {
it('should correctly reorder the list when it is not filtered', () => {
const onUpdate = vi.fn()
const { result } = renderHook(() =>
useDraggableReorder({
originalList: mockOriginalList,
filteredList: mockOriginalList, // 列表未过滤
onUpdate,
idKey: 'id'
})
)
// 模拟将第一项 (视图索引 0, 原始索引 0) 拖到第三项的位置 (视图索引 2)
// 在未过滤列表中itemKey(0) 返回 0
const dropResult = createMockDropResult(0, 2, '0')
act(() => {
result.current.onDragEnd(dropResult)
})
expect(onUpdate).toHaveBeenCalledTimes(1)
const newList = onUpdate.mock.calls[0][0]
// 原始: [1, 2, 3, 4, 5] -> 拖拽后预期: [2, 3, 1, 4, 5]
expect(newList.map((i) => i.id)).toEqual(['item-2', 'item-3', 'item-1', 'item-4', 'item-5'])
})
it('should correctly reorder the original list when the list is filtered', () => {
const onUpdate = vi.fn()
// 过滤后只剩下奇数项: [item-1, item-3, item-5]
const filteredList = [mockOriginalList[0], mockOriginalList[2], mockOriginalList[4]]
const { result } = renderHook(() =>
useDraggableReorder({
originalList: mockOriginalList,
filteredList,
onUpdate,
idKey: 'id'
})
)
// 在过滤后的列表中,将最后一项 'item-5' (视图索引 2) 拖到第一项 'item-1' (视图索引 0) 的位置
// 'item-5' 的原始索引是 4, 所以 itemKey(2) 返回 4
const dropResult = createMockDropResult(2, 0, '4')
act(() => {
result.current.onDragEnd(dropResult)
})
expect(onUpdate).toHaveBeenCalledTimes(1)
const newList = onUpdate.mock.calls[0][0]
// 原始: [1, 2, 3, 4, 5]
// 拖拽后预期: 'item-5' 移动到 'item-1' 的位置 -> [5, 1, 2, 3, 4]
expect(newList.map((i) => i.id)).toEqual(['item-5', 'item-1', 'item-2', 'item-3', 'item-4'])
})
})
describe('onUpdate', () => {
it('should not call onUpdate if destination is null', () => {
const onUpdate = vi.fn()
const { result } = renderHook(() =>
useDraggableReorder({
originalList: mockOriginalList,
filteredList: mockOriginalList,
onUpdate,
idKey: 'id'
})
)
// 模拟拖拽到列表外
const dropResult = createMockDropResult(0, null, '0')
act(() => {
result.current.onDragEnd(dropResult)
})
expect(onUpdate).not.toHaveBeenCalled()
})
it('should not call onUpdate if source and destination are the same', () => {
const onUpdate = vi.fn()
const { result } = renderHook(() =>
useDraggableReorder({
originalList: mockOriginalList,
filteredList: mockOriginalList,
onUpdate,
idKey: 'id'
})
)
// 模拟拖拽后放回原位
const dropResult = createMockDropResult(1, 1, '1')
act(() => {
result.current.onDragEnd(dropResult)
})
expect(onUpdate).not.toHaveBeenCalled()
})
})
describe('itemKey', () => {
it('should return the correct original index from a filtered list index', () => {
const onUpdate = vi.fn()
// 过滤后只剩下奇数项: [item-1, item-3, item-5]
const filteredList = [mockOriginalList[0], mockOriginalList[2], mockOriginalList[4]]
const { result } = renderHook(() =>
useDraggableReorder({
originalList: mockOriginalList,
filteredList,
onUpdate,
idKey: 'id'
})
)
// 视图索引 0 -> 'item-1' -> 原始索引 0
expect(result.current.itemKey(0)).toBe(0)
// 视图索引 1 -> 'item-3' -> 原始索引 2
expect(result.current.itemKey(1)).toBe(2)
// 视图索引 2 -> 'item-5' -> 原始索引 4
expect(result.current.itemKey(2)).toBe(4)
})
})
})

View File

@ -1,2 +1,3 @@
export { default as DraggableList } from './list'
export { useDraggableReorder } from './useDraggableReorder'
export { default as DraggableVirtualList } from './virtual-list'

View File

@ -0,0 +1,70 @@
import { DropResult } from '@hello-pangea/dnd'
import { Key, useCallback, useMemo } from 'react'
interface UseDraggableReorderParams<T> {
/** 原始的、完整的数据列表 */
originalList: T[]
/** 当前在界面上渲染的、可能被过滤的列表 */
filteredList: T[]
/** 用于更新原始列表状态的函数 */
onUpdate: (newList: T[]) => void
/** 用于从列表项中获取唯一ID的属性名或函数 */
idKey: keyof T | ((item: T) => Key)
}
/**
*
*
* @template T
* @param params - { originalList, filteredList, onUpdate, idKey }
* @returns DraggableVirtualList props: { onDragEnd, itemKey }
*/
export function useDraggableReorder<T>({ originalList, filteredList, onUpdate, idKey }: UseDraggableReorderParams<T>) {
const getId = useCallback((item: T) => (typeof idKey === 'function' ? idKey(item) : (item[idKey] as Key)), [idKey])
// 创建从 item ID 到其在 *原始列表* 中索引的映射
const itemIndexMap = useMemo(() => {
const map = new Map<Key, number>()
originalList.forEach((item, index) => {
map.set(getId(item), index)
})
return map
}, [originalList, getId])
// 创建一个函数,将 *过滤后列表* 的视图索引转换为 *原始列表* 的数据索引
const getItemKey = useCallback(
(index: number): Key => {
const item = filteredList[index]
// 如果找不到item返回视图索引兜底
if (!item) return index
const originalIndex = itemIndexMap.get(getId(item))
return originalIndex ?? index
},
[filteredList, itemIndexMap, getId]
)
// 创建 onDragEnd 回调,封装了所有重排逻辑
const onDragEnd = useCallback(
(result: DropResult) => {
if (!result.destination) return
// 使用 getItemKey 将视图索引转换为数据索引
const sourceOriginalIndex = getItemKey(result.source.index) as number
const destOriginalIndex = getItemKey(result.destination.index) as number
if (sourceOriginalIndex === destOriginalIndex) return
// 操作原始列表的副本
const newList = [...originalList]
const [movedItem] = newList.splice(sourceOriginalIndex, 1)
newList.splice(destOriginalIndex, 0, movedItem)
// 调用外部更新函数
onUpdate(newList)
},
[originalList, onUpdate, getItemKey]
)
return { onDragEnd, itemKey: getItemKey }
}

View File

@ -22,7 +22,7 @@ import { type Key, memo, useCallback, useRef } from 'react'
* @property {React.CSSProperties} [itemStyle]
* @property {React.CSSProperties} [itemContainerStyle]
* @property {Partial<DroppableProps>} [droppableProps] Droppable
* @property {(list: T[]) => void} onUpdate
* @property {(list: T[]) => void} [onUpdate] useDraggableReorder
* @property {OnDragStartResponder} [onDragStart]
* @property {OnDragEndResponder} [onDragEnd]
* @property {T[]} list
@ -39,7 +39,7 @@ interface DraggableVirtualListProps<T> {
itemStyle?: React.CSSProperties
itemContainerStyle?: React.CSSProperties
droppableProps?: Partial<DroppableProps>
onUpdate: (list: T[]) => void
onUpdate?: (list: T[]) => void
onDragStart?: OnDragStartResponder
onDragEnd?: OnDragEndResponder
list: T[]
@ -48,6 +48,7 @@ interface DraggableVirtualListProps<T> {
overscan?: number
header?: React.ReactNode
children: (item: T, index: number) => React.ReactNode
disabled?: boolean
}
/**
@ -73,11 +74,12 @@ function DraggableVirtualList<T>({
estimateSize: _estimateSize,
overscan = 5,
header,
children
children,
disabled
}: DraggableVirtualListProps<T>): React.ReactElement {
const _onDragEnd = (result: DropResult, provided: ResponderProvided) => {
onDragEnd?.(result, provided)
if (result.destination) {
if (onUpdate && result.destination) {
const sourceIndex = result.source.index
const destIndex = result.destination.index
const reorderAgents = droppableReorder(list, sourceIndex, destIndex)
@ -157,6 +159,7 @@ function DraggableVirtualList<T>({
itemContainerStyle={itemContainerStyle}
virtualizer={virtualizer}
children={children}
disabled={disabled}
/>
))}
</div>
@ -172,53 +175,56 @@ function DraggableVirtualList<T>({
/**
*
*/
const VirtualRow = memo(({ virtualItem, list, children, itemStyle, itemContainerStyle, virtualizer }: any) => {
const item = list[virtualItem.index]
const draggableId = String(virtualItem.key)
return (
<Draggable
key={`draggable_${draggableId}_${virtualItem.index}`}
draggableId={draggableId}
index={virtualItem.index}>
{(provided) => {
const setDragRefs = (el: HTMLElement | null) => {
provided.innerRef(el)
virtualizer.measureElement(el)
}
const VirtualRow = memo(
({ virtualItem, list, children, itemStyle, itemContainerStyle, virtualizer, disabled }: any) => {
const item = list[virtualItem.index]
const draggableId = String(virtualItem.key)
return (
<Draggable
key={`draggable_${draggableId}_${virtualItem.index}`}
draggableId={draggableId}
isDragDisabled={disabled}
index={virtualItem.index}>
{(provided) => {
const setDragRefs = (el: HTMLElement | null) => {
provided.innerRef(el)
virtualizer.measureElement(el)
}
const dndStyle = provided.draggableProps.style
const virtualizerTransform = `translateY(${virtualItem.start}px)`
const dndStyle = provided.draggableProps.style
const virtualizerTransform = `translateY(${virtualItem.start}px)`
// dnd 的 transform 负责拖拽时的位移和让位动画,
// virtualizer 的 translateY 负责将项定位到虚拟列表的正确位置,
// 它们拼接起来可以同时实现拖拽视觉效果和虚拟化定位。
const combinedTransform = dndStyle?.transform
? `${dndStyle.transform} ${virtualizerTransform}`
: virtualizerTransform
// dnd 的 transform 负责拖拽时的位移和让位动画,
// virtualizer 的 translateY 负责将项定位到虚拟列表的正确位置,
// 它们拼接起来可以同时实现拖拽视觉效果和虚拟化定位。
const combinedTransform = dndStyle?.transform
? `${dndStyle.transform} ${virtualizerTransform}`
: virtualizerTransform
return (
<div
{...provided.draggableProps}
ref={setDragRefs}
className="draggable-item"
data-index={virtualItem.index}
style={{
...itemContainerStyle,
...dndStyle,
position: 'absolute',
top: 0,
left: 0,
width: '100%',
transform: combinedTransform
}}>
<div {...provided.dragHandleProps} className="draggable-content" style={{ ...itemStyle }}>
{item && children(item, virtualItem.index)}
return (
<div
{...provided.draggableProps}
ref={setDragRefs}
className="draggable-item"
data-index={virtualItem.index}
style={{
...itemContainerStyle,
...dndStyle,
position: 'absolute',
top: 0,
left: 0,
width: '100%',
transform: combinedTransform
}}>
<div {...provided.dragHandleProps} className="draggable-content" style={itemStyle}>
{item && children(item, virtualItem.index)}
</div>
</div>
</div>
)
}}
</Draggable>
)
})
)
}}
</Draggable>
)
}
)
export default DraggableVirtualList

View File

@ -1,68 +0,0 @@
import { SVGProps } from 'react'
// 基础下载图标
export const DownloadIcon = (props: SVGProps<SVGSVGElement>) => (
<svg
xmlns="http://www.w3.org/2000/svg"
width="1.1em"
height="1.1em"
fill="none"
stroke="currentColor"
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
viewBox="0 0 24 24"
{...props}>
<path d="M21 15v4a2 2 0 01-2 2H5a2 2 0 01-2-2v-4" />
<path d="M12 15V3" />
<polygon points="12,15 9,11 15,11" fill="currentColor" stroke="none" />
</svg>
)
// 带有文件类型的下载图标基础组件
const DownloadTypeIconBase = ({ type, ...props }: SVGProps<SVGSVGElement> & { type: string }) => (
<svg
xmlns="http://www.w3.org/2000/svg"
width="1.1em"
height="1.1em"
fill="none"
stroke="currentColor"
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
viewBox="0 0 24 24"
{...props}>
<text
x="12"
y="7"
fontSize="8"
textAnchor="middle"
fill="currentColor"
stroke="currentColor"
strokeWidth="0.3"
letterSpacing="1"
fontFamily="Arial Black, sans-serif"
style={{
paintOrder: 'stroke',
fontStretch: 'expanded',
userSelect: 'none',
WebkitUserSelect: 'none',
MozUserSelect: 'none',
msUserSelect: 'none'
}}>
{type}
</text>
<path d="M21 16v3a2 2 0 01-2 2H5a2 2 0 01-2-2v-3" />
<path d="M12 17V10" />
<polygon points="12,17 9.5,14 14.5,14" fill="currentColor" stroke="none" />
</svg>
)
// JPG 文件下载图标
export const DownloadJpgIcon = (props: SVGProps<SVGSVGElement>) => <DownloadTypeIconBase type="JPG" {...props} />
// PNG 文件下载图标
export const DownloadPngIcon = (props: SVGProps<SVGSVGElement>) => <DownloadTypeIconBase type="PNG" {...props} />
// SVG 文件下载图标
export const DownloadSvgIcon = (props: SVGProps<SVGSVGElement>) => <DownloadTypeIconBase type="SVG" {...props} />

View File

@ -0,0 +1,70 @@
import { CSSProperties, SVGProps } from 'react'
interface BaseFileIconProps extends SVGProps<SVGSVGElement> {
size?: string
text?: string
}
const textStyle: CSSProperties = {
fontStyle: 'italic',
fontSize: '7.70985px',
lineHeight: 0.8,
fontFamily: "'Times New Roman'",
textAlign: 'center',
writingMode: 'horizontal-tb',
direction: 'ltr',
textAnchor: 'middle',
fill: 'none',
stroke: '#000000',
strokeWidth: '0.289119',
strokeLinejoin: 'round',
strokeDasharray: 'none'
}
const tspanStyle: CSSProperties = {
fontStyle: 'normal',
fontVariant: 'normal',
fontWeight: 'normal',
fontStretch: 'condensed',
fontSize: '7.70985px',
lineHeight: 0.8,
fontFamily: 'Arial',
fill: '#000000',
fillOpacity: 1,
strokeWidth: '0.289119',
strokeDasharray: 'none'
}
const BaseFileIcon = ({ size = '1.1em', text = 'SVG', ...props }: BaseFileIconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
version="1.1"
id="svg4"
xmlns="http://www.w3.org/2000/svg"
{...props}>
<defs id="defs4" />
<path d="m 14,2 v 4 a 2,2 0 0 0 2,2 h 4" id="path3" />
<path d="M 15,2 H 6 A 2,2 0 0 0 4,4 v 16 a 2,2 0 0 0 2,2 h 12 a 2,2 0 0 0 2,-2 V 7 Z" id="path4" />
<text
xmlSpace="preserve"
style={textStyle}
x="12.478625"
y="17.170216"
id="text4"
transform="scale(0.96196394,1.03954)">
<tspan id="tspan4" x="12.478625" y="17.170216" style={tspanStyle}>
{text}
</tspan>
</text>
</svg>
)
export const FileSvgIcon = (props: Omit<BaseFileIconProps, 'text'>) => <BaseFileIcon text="SVG" {...props} />
export const FilePngIcon = (props: Omit<BaseFileIconProps, 'text'>) => <BaseFileIcon text="PNG" {...props} />

View File

@ -1,8 +1,8 @@
export { default as CopyIcon } from './CopyIcon'
export { default as DeleteIcon } from './DeleteIcon'
export * from './DownloadIcons'
export { default as EditIcon } from './EditIcon'
export { default as FallbackFavicon } from './FallbackFavicon'
export * from './FileIcons'
export { default as MinAppIcon } from './MinAppIcon'
export * from './NutstoreIcons'
export { default as OcrIcon } from './OcrIcon'

View File

@ -86,7 +86,7 @@ const ImageViewer: React.FC<ImageViewerProps> = ({ src, style, ...props }) => {
},
{
key: 'copy-image',
label: t('code_block.preview.copy.image'),
label: t('preview.copy.image'),
icon: <FileImageOutlined />,
onClick: () => handleCopyImage(src)
}
@ -101,6 +101,7 @@ const ImageViewer: React.FC<ImageViewerProps> = ({ src, style, ...props }) => {
{...props}
preview={{
mask: typeof props.preview === 'object' ? props.preview.mask : false,
...(typeof props.preview === 'object' ? props.preview : {}),
toolbarRender: (
_,
{

View File

@ -26,7 +26,7 @@ const ModelIdWithTags = ({
maxWidth: '500px'
}
}}
destroyTooltipOnHide
destroyOnHidden
title={
<Typography.Text style={{ color: 'white' }} copyable={{ text: model.id }}>
{model.id}

View File

@ -1,4 +1,3 @@
import { EyeOutlined, GlobalOutlined, ToolOutlined } from '@ant-design/icons'
import {
isEmbeddingModel,
isFunctionCallingModel,
@ -14,7 +13,15 @@ import { FC, memo, useLayoutEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import CustomTag from './CustomTag'
import CustomTag from './Tags/CustomTag'
import {
EmbeddingTag,
ReasoningTag,
RerankerTag,
ToolsCallingTag,
VisionTag,
WebSearchTag
} from './Tags/ModelCapabilities'
interface ModelTagsProps {
model: Model
@ -70,45 +77,17 @@ const ModelTagsWithLabel: FC<ModelTagsProps> = ({
return (
<Container ref={containerRef} style={style}>
{isVisionModel(model) && (
<CustomTag
size={size}
color="#00b96b"
icon={<EyeOutlined style={{ fontSize: size }} />}
tooltip={showTooltip ? t('models.type.vision') : undefined}>
{shouldShowLabel ? t('models.type.vision') : ''}
</CustomTag>
)}
{isWebSearchModel(model) && (
<CustomTag
size={size}
color="#1677ff"
icon={<GlobalOutlined style={{ fontSize: size }} />}
tooltip={showTooltip ? t('models.type.websearch') : undefined}>
{shouldShowLabel ? t('models.type.websearch') : ''}
</CustomTag>
)}
{isVisionModel(model) && <VisionTag size={size} showTooltip={showTooltip} showLabel={shouldShowLabel} />}
{isWebSearchModel(model) && <WebSearchTag size={size} showTooltip={showTooltip} showLabel={shouldShowLabel} />}
{showReasoning && isReasoningModel(model) && (
<CustomTag
size={size}
color="#6372bd"
icon={<i className="iconfont icon-thinking" />}
tooltip={showTooltip ? t('models.type.reasoning') : undefined}>
{shouldShowLabel ? t('models.type.reasoning') : ''}
</CustomTag>
<ReasoningTag size={size} showTooltip={showTooltip} showLabel={shouldShowLabel} />
)}
{showToolsCalling && isFunctionCallingModel(model) && (
<CustomTag
size={size}
color="#f18737"
icon={<ToolOutlined style={{ fontSize: size }} />}
tooltip={showTooltip ? t('models.type.function_calling') : undefined}>
{shouldShowLabel ? t('models.type.function_calling') : ''}
</CustomTag>
<ToolsCallingTag size={size} showTooltip={showTooltip} showLabel={shouldShowLabel} />
)}
{isEmbeddingModel(model) && <CustomTag size={size} color="#FFA500" icon={t('models.type.embedding')} />}
{isEmbeddingModel(model) && <EmbeddingTag size={size} />}
{showFree && isFreeModel(model) && <CustomTag size={size} color="#7cb305" icon={t('models.type.free')} />}
{isRerankModel(model) && <CustomTag size={size} color="#6495ED" icon={t('models.type.rerank')} />}
{isRerankModel(model) && <RerankerTag size={size} />}
</Container>
)
}

View File

@ -1,6 +1,12 @@
import { getProviderLabel } from '@renderer/i18n/label'
import { Provider } from '@renderer/types'
import { oauthWithAihubmix, oauthWithPPIO, oauthWithSiliconFlow, oauthWithTokenFlux } from '@renderer/utils/oauth'
import {
oauthWith302AI,
oauthWithAihubmix,
oauthWithPPIO,
oauthWithSiliconFlow,
oauthWithTokenFlux
} from '@renderer/utils/oauth'
import { Button, ButtonProps } from 'antd'
import { FC } from 'react'
import { useTranslation } from 'react-i18next'
@ -36,6 +42,10 @@ const OAuthButton: FC<Props> = ({ provider, onSuccess, ...buttonProps }) => {
if (provider.id === 'tokenflux') {
oauthWithTokenFlux()
}
if (provider.id === '302ai') {
oauthWith302AI(handleSuccess)
}
}
return (

View File

@ -1,14 +1,18 @@
import { loggerService } from '@logger'
import CustomTag from '@renderer/components/CustomTag'
import CustomTag from '@renderer/components/Tags/CustomTag'
import { TopView } from '@renderer/components/TopView'
import { useKnowledge, useKnowledgeBases } from '@renderer/hooks/useKnowledge'
import { Topic } from '@renderer/types'
import { Message } from '@renderer/types/newMessage'
import {
analyzeMessageContent,
analyzeTopicContent,
CONTENT_TYPES,
ContentType,
MessageContentStats,
processMessageContent
processMessageContent,
processTopicContent,
TopicContentStats
} from '@renderer/utils/knowledge'
import { Flex, Form, Modal, Select, Tooltip, Typography } from 'antd'
import { Check, CircleHelp } from 'lucide-react'
@ -20,11 +24,12 @@ const logger = loggerService.withContext('SaveToKnowledgePopup')
const { Text } = Typography
// 内容类型配置
// Base Content Type Config
const CONTENT_TYPE_CONFIG = {
[CONTENT_TYPES.TEXT]: {
label: 'chat.save.knowledge.content.maintext.title',
description: 'chat.save.knowledge.content.maintext.description'
description: 'chat.save.knowledge.content.maintext.description',
topicDescription: 'chat.save.topic.knowledge.content.maintext.description'
},
[CONTENT_TYPES.CODE]: {
label: 'chat.save.knowledge.content.code.title',
@ -62,16 +67,20 @@ const TAG_COLORS = {
UNSELECTED: '#8c8c8c'
} as const
type ContentStats = MessageContentStats | TopicContentStats
interface ContentTypeOption {
type: ContentType
label: string
count: number
enabled: boolean
description?: string
label: string
description: string
}
type ContentSource = { type: 'message'; data: Message } | { type: 'topic'; data: Topic }
interface ShowParams {
message: Message
source: ContentSource
title?: string
}
@ -84,35 +93,73 @@ interface Props extends ShowParams {
resolve: (data: SaveResult | null) => void
}
const PopupContainer: React.FC<Props> = ({ message, title, resolve }) => {
const PopupContainer: React.FC<Props> = ({ source, title, resolve }) => {
const [open, setOpen] = useState(true)
const [loading, setLoading] = useState(false)
const [analysisLoading, setAnalysisLoading] = useState(true)
const [selectedBaseId, setSelectedBaseId] = useState<string>()
const [selectedTypes, setSelectedTypes] = useState<ContentType[]>([])
const [hasInitialized, setHasInitialized] = useState(false)
const [contentStats, setContentStats] = useState<ContentStats | null>(null)
const { bases } = useKnowledgeBases()
const { addNote, addFiles } = useKnowledge(selectedBaseId || '')
const { t } = useTranslation()
// 分析消息内容统计
const contentStats = useMemo(() => analyzeMessageContent(message), [message])
const isTopicMode = source?.type === 'topic'
// 生成内容类型选项(只显示有内容的类型)
// 异步分析内容统计
useEffect(() => {
const analyze = async () => {
setAnalysisLoading(true)
setContentStats(null)
try {
const stats = isTopicMode
? await analyzeTopicContent(source?.data as Topic)
: analyzeMessageContent(source?.data as Message)
setContentStats(stats)
} catch (error) {
logger.error('analyze content failed:', error as Error)
setContentStats({
text: 0,
code: 0,
thinking: 0,
images: 0,
files: 0,
tools: 0,
citations: 0,
translations: 0,
errors: 0,
...(isTopicMode && { messages: 0 })
})
} finally {
setAnalysisLoading(false)
}
}
analyze()
}, [source, isTopicMode])
// 生成内容类型选项
const contentTypeOptions: ContentTypeOption[] = useMemo(() => {
if (!contentStats) return []
return Object.entries(CONTENT_TYPE_CONFIG)
.map(([type, config]) => {
const contentType = type as ContentType
const count = contentStats[contentType as keyof MessageContentStats] || 0
const count = contentStats[contentType as keyof ContentStats] || 0
const descriptionKey =
isTopicMode && 'topicDescription' in config && config.topicDescription
? config.topicDescription
: config.description
return {
type: contentType,
count,
enabled: count > 0,
label: t(config.label),
description: t(config.description)
description: t(descriptionKey)
}
})
.filter((option) => option.enabled) // 只显示有内容的类型
}, [contentStats, t])
.filter((option) => option.enabled)
}, [contentStats, t, isTopicMode])
// 知识库选项
const knowledgeBaseOptions = useMemo(
@ -120,12 +167,12 @@ const PopupContainer: React.FC<Props> = ({ message, title, resolve }) => {
bases.map((base) => ({
label: base.name,
value: base.id,
disabled: !base.version // 如果知识库没有配置好就禁用
disabled: !base.version
})),
[bases]
)
// 合并状态计算
// 表单状态
const formState = useMemo(() => {
const hasValidBase = selectedBaseId && bases.find((base) => base.id === selectedBaseId)?.version
const hasContent = contentTypeOptions.length > 0
@ -142,7 +189,7 @@ const PopupContainer: React.FC<Props> = ({ message, title, resolve }) => {
}
}, [selectedBaseId, bases, contentTypeOptions, selectedTypes])
// 默认选择第一个可用知识库
// 默认选择第一个可用知识库
useEffect(() => {
if (!selectedBaseId) {
const firstAvailableBase = bases.find((base) => base.version)
@ -152,49 +199,51 @@ const PopupContainer: React.FC<Props> = ({ message, title, resolve }) => {
}
}, [bases, selectedBaseId])
// 默认选择所有可用内容类型(仅在初始化时)
// 默认选择所有可用内容类型
useEffect(() => {
if (!hasInitialized && contentTypeOptions.length > 0) {
const availableTypes = contentTypeOptions.map((option) => option.type)
setSelectedTypes(availableTypes)
setSelectedTypes(contentTypeOptions.map((option) => option.type))
setHasInitialized(true)
}
}, [contentTypeOptions, hasInitialized])
// 计算UI状态
// UI状态
const uiState = useMemo(() => {
if (analysisLoading) {
return { type: 'loading', message: t('chat.save.topic.knowledge.loading') }
}
if (!formState.hasContent) {
return { type: 'empty', message: t('chat.save.knowledge.empty.no_content') }
return {
type: 'empty',
message: t(isTopicMode ? 'chat.save.topic.knowledge.empty.no_content' : 'chat.save.knowledge.empty.no_content')
}
}
if (bases.length === 0) {
return { type: 'empty', message: t('chat.save.knowledge.empty.no_knowledge_base') }
}
return { type: 'form' }
}, [formState.hasContent, bases.length, t])
}, [analysisLoading, formState.hasContent, bases.length, t, isTopicMode])
// 处理内容类型选择切换
const handleContentTypeToggle = (type: ContentType) => {
setSelectedTypes((prev) => (prev.includes(type) ? prev.filter((t) => t !== type) : [...prev, type]))
}
const onOk = async () => {
if (!formState.canSubmit) {
return
}
if (!formState.canSubmit) return
setLoading(true)
let savedCount = 0
try {
const result = processMessageContent(message, selectedTypes)
const result = isTopicMode
? await processTopicContent(source?.data as Topic, selectedTypes)
: processMessageContent(source?.data as Message, selectedTypes)
// 保存文本内容
if (result.text.trim() && selectedTypes.some((type) => type !== CONTENT_TYPES.FILE)) {
await addNote(result.text)
savedCount++
}
// 保存文件
if (result.files.length > 0 && selectedTypes.includes(CONTENT_TYPES.FILE)) {
addFiles(result.files)
savedCount += result.files.length
@ -204,27 +253,22 @@ const PopupContainer: React.FC<Props> = ({ message, title, resolve }) => {
resolve({ success: true, savedCount })
} catch (error) {
logger.error('save failed:', error as Error)
window.message.error(t('chat.save.knowledge.error.save_failed'))
window.message.error(
t(isTopicMode ? 'chat.save.topic.knowledge.error.save_failed' : 'chat.save.knowledge.error.save_failed')
)
setLoading(false)
}
}
const onCancel = () => {
setOpen(false)
}
const onCancel = () => setOpen(false)
const onClose = () => resolve(null)
const onClose = () => {
resolve(null)
}
// 渲染空状态
const renderEmptyState = () => (
<EmptyContainer>
<Text type="secondary">{uiState.message}</Text>
</EmptyContainer>
)
// 渲染表单内容
const renderFormContent = () => (
<>
<Form layout="vertical">
@ -241,7 +285,10 @@ const PopupContainer: React.FC<Props> = ({ message, title, resolve }) => {
/>
</Form.Item>
<Form.Item label={t('chat.save.knowledge.select.content.title')}>
<Form.Item
label={t(
isTopicMode ? 'chat.save.topic.knowledge.select.content.label' : 'chat.save.knowledge.select.content.title'
)}>
<Flex gap={8} style={{ flexDirection: 'column' }}>
{contentTypeOptions.map((option) => (
<ContentTypeItem
@ -267,27 +314,37 @@ const PopupContainer: React.FC<Props> = ({ message, title, resolve }) => {
</Form.Item>
</Form>
{formState.selectedCount > 0 && (
<InfoContainer>
<InfoContainer>
{formState.selectedCount > 0 && (
<Text type="secondary" style={{ fontSize: '12px' }}>
{t('chat.save.knowledge.select.content.tip', { count: formState.selectedCount })}
{t(
isTopicMode
? 'chat.save.topic.knowledge.select.content.selected_tip'
: 'chat.save.knowledge.select.content.tip',
{
count: formState.selectedCount,
...(isTopicMode && { messages: (contentStats as TopicContentStats)?.messages || 0 })
}
)}
</Text>
</InfoContainer>
)}
{formState.hasNoSelection && (
<InfoContainer>
)}
{formState.hasNoSelection && (
<Text type="warning" style={{ fontSize: '12px' }}>
{t('chat.save.knowledge.error.no_content_selected')}
</Text>
</InfoContainer>
)}
)}
{!formState.hasNoSelection && formState.selectedCount === 0 && (
<Text type="secondary" style={{ fontSize: '12px', opacity: 0 }}>
&nbsp;
</Text>
)}
</InfoContainer>
</>
)
return (
<Modal
title={title || t('chat.save.knowledge.title')}
title={title || t(isTopicMode ? 'chat.save.topic.knowledge.title' : 'chat.save.knowledge.title')}
open={open}
onOk={onOk}
onCancel={onCancel}
@ -297,11 +354,8 @@ const PopupContainer: React.FC<Props> = ({ message, title, resolve }) => {
width={500}
okText={t('common.save')}
cancelText={t('common.cancel')}
okButtonProps={{
loading,
disabled: !formState.canSubmit
}}>
{uiState.type === 'empty' ? renderEmptyState() : renderFormContent()}
okButtonProps={{ loading, disabled: !formState.canSubmit || analysisLoading }}>
{uiState.type === 'form' ? renderFormContent() : renderEmptyState()}
</Modal>
)
}
@ -327,11 +381,22 @@ export default class SaveToKnowledgePopup {
)
})
}
static showForMessage(message: Message, title?: string): Promise<SaveResult | null> {
return this.show({ source: { type: 'message', data: message }, title })
}
static showForTopic(topic: Topic, title?: string): Promise<SaveResult | null> {
return this.show({ source: { type: 'topic', data: topic }, title })
}
}
const EmptyContainer = styled.div`
display: flex;
justify-content: center;
align-items: center;
min-height: 100px;
text-align: center;
padding: 40px 20px;
`
const ContentTypeItem = styled(Flex)`
@ -352,4 +417,7 @@ const InfoContainer = styled.div`
padding: 12px;
border-radius: 6px;
margin-top: 16px;
min-height: 40px; /* To avoid layout shift */
display: flex;
align-items: center;
`

View File

@ -1,9 +1,7 @@
import { LoadingOutlined } from '@ant-design/icons'
import { loggerService } from '@logger'
import { useDefaultModel } from '@renderer/hooks/useAssistant'
import { useSettings } from '@renderer/hooks/useSettings'
import { fetchTranslate } from '@renderer/services/ApiService'
import { getDefaultTranslateAssistant } from '@renderer/services/AssistantService'
import { translateText } from '@renderer/services/TranslateService'
import { getLanguageByLangcode } from '@renderer/utils/translate'
import { Modal, ModalProps } from 'antd'
import TextArea from 'antd/es/input/TextArea'
@ -43,7 +41,6 @@ const PopupContainer: React.FC<Props> = ({
const [textValue, setTextValue] = useState(text)
const [isTranslating, setIsTranslating] = useState(false)
const textareaRef = useRef<TextAreaRef>(null)
const { translateModel } = useDefaultModel()
const { targetLanguage, showTranslateConfirm } = useSettings()
const isMounted = useRef(true)
@ -103,21 +100,12 @@ const PopupContainer: React.FC<Props> = ({
if (!confirmed) return
}
if (!translateModel) {
window.message.error({
content: t('translate.error.not_configured'),
key: 'translate-message'
})
return
}
if (isMounted.current) {
setIsTranslating(true)
}
try {
const assistant = getDefaultTranslateAssistant(getLanguageByLangcode(targetLanguage), textValue)
const translatedText = await fetchTranslate({ content: textValue, assistant })
const translatedText = await translateText(textValue, getLanguageByLangcode(targetLanguage))
if (isMounted.current) {
setTextValue(translatedText)
}

View File

@ -0,0 +1,56 @@
import { AsyncInitializer } from '@renderer/utils/asyncInitializer'
import React, { memo, useCallback } from 'react'
import styled from 'styled-components'
import { useDebouncedRender } from './hooks/useDebouncedRender'
import ImagePreviewLayout from './ImagePreviewLayout'
import { BasicPreviewHandles, BasicPreviewProps } from './types'
import { renderSvgInShadowHost } from './utils'
// 管理 viz 实例
const vizInitializer = new AsyncInitializer(async () => {
const module = await import('@viz-js/viz')
return await module.instance()
})
/** Graphviz
* 使 usePreviewRenderer hook
*/
const GraphvizPreview = ({
children,
enableToolbar = false,
ref
}: BasicPreviewProps & { ref?: React.RefObject<BasicPreviewHandles | null> }) => {
// 定义渲染函数
const renderGraphviz = useCallback(async (content: string, container: HTMLDivElement) => {
const viz = await vizInitializer.get()
const svg = viz.renderString(content, { format: 'svg' })
renderSvgInShadowHost(svg, container)
}, [])
// 使用预览渲染器 hook
const { containerRef, error, isLoading } = useDebouncedRender(children, renderGraphviz, {
debounceDelay: 300
})
return (
<ImagePreviewLayout
loading={isLoading}
error={error}
enableToolbar={enableToolbar}
ref={ref}
imageRef={containerRef}
source="graphviz">
<StyledGraphviz ref={containerRef} className="graphviz special-preview" />
</ImagePreviewLayout>
)
}
const StyledGraphviz = styled.div`
overflow: auto;
position: relative;
width: 100%;
height: 100%;
`
export default memo(GraphvizPreview)

View File

@ -0,0 +1,60 @@
import { useImageTools } from '@renderer/components/ActionTools/hooks/useImageTools'
import { LoadingIcon } from '@renderer/components/Icons'
import { Spin } from 'antd'
import { memo, useImperativeHandle } from 'react'
import ImageToolbar from './ImageToolbar'
import { PreviewContainer, PreviewError } from './styles'
import { BasicPreviewHandles } from './types'
interface ImagePreviewLayoutProps {
children: React.ReactNode
ref?: React.RefObject<BasicPreviewHandles | null>
imageRef: React.RefObject<HTMLDivElement | null>
source: string
loading?: boolean
error?: string | null
enableToolbar?: boolean
className?: string
}
const ImagePreviewLayout = ({
children,
ref,
imageRef,
source,
loading,
error,
enableToolbar,
className
}: ImagePreviewLayoutProps) => {
// 使用通用图像工具
const { pan, zoom, copy, download, dialog } = useImageTools(imageRef, {
imgSelector: 'svg',
prefix: source ?? 'svg',
enableDrag: true,
enableWheelZoom: true
})
useImperativeHandle(ref, () => {
return {
pan,
zoom,
copy,
download,
dialog
}
})
return (
<Spin spinning={loading} indicator={<LoadingIcon color="var(--color-text-2)" />}>
<PreviewContainer vertical className={`image-preview-layout ${className ?? ''}`}>
{error && <PreviewError>{error}</PreviewError>}
{children}
{!error && enableToolbar && <ImageToolbar pan={pan} zoom={zoom} dialog={dialog} />}
</PreviewContainer>
</Spin>
)
}
export default memo(ImagePreviewLayout)

View File

@ -0,0 +1,18 @@
import { Button, Tooltip } from 'antd'
import { memo } from 'react'
interface ImageToolButtonProps {
tooltip: string
icon: React.ReactNode
onClick: () => void
}
const ImageToolButton = ({ tooltip, icon, onClick }: ImageToolButtonProps) => {
return (
<Tooltip title={tooltip} mouseEnterDelay={0.5} mouseLeaveDelay={0}>
<Button shape="circle" icon={icon} onClick={onClick} role="button" aria-label={tooltip} />
</Tooltip>
)
}
export default memo(ImageToolButton)

View File

@ -0,0 +1,107 @@
import { ResetIcon } from '@renderer/components/Icons'
import { classNames } from '@renderer/utils'
import { ChevronDown, ChevronLeft, ChevronRight, ChevronUp, Scan, ZoomIn, ZoomOut } from 'lucide-react'
import { memo, useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import ImageToolButton from './ImageToolButton'
interface ImageToolbarProps {
pan: (dx: number, dy: number, absolute?: boolean) => void
zoom: (delta: number, absolute?: boolean) => void
dialog: () => void
className?: string
}
const ImageToolbar = ({ pan, zoom, dialog, className }: ImageToolbarProps) => {
const { t } = useTranslation()
// 定义平移距离
const panDistance = 20
// 定义缩放增量
const zoomDelta = 0.1
const handleReset = useCallback(() => {
pan(0, 0, true)
zoom(1, true)
}, [pan, zoom])
return (
<ToolbarWrapper className={classNames('preview-toolbar', className)} role="toolbar" aria-label={t('preview.label')}>
{/* Up */}
<ActionButtonRow>
<Spacer />
<ImageToolButton
tooltip={t('preview.pan_up')}
icon={<ChevronUp size={'1rem'} />}
onClick={() => pan(0, -panDistance)}
/>
<ImageToolButton tooltip={t('preview.dialog')} icon={<Scan size={'1rem'} />} onClick={dialog} />
</ActionButtonRow>
{/* Left, Reset, Right */}
<ActionButtonRow>
<ImageToolButton
tooltip={t('preview.pan_left')}
icon={<ChevronLeft size={'1rem'} />}
onClick={() => pan(-panDistance, 0)}
/>
<ImageToolButton tooltip={t('preview.reset')} icon={<ResetIcon size={'1rem'} />} onClick={handleReset} />
<ImageToolButton
tooltip={t('preview.pan_right')}
icon={<ChevronRight size={'1rem'} />}
onClick={() => pan(panDistance, 0)}
/>
</ActionButtonRow>
{/* Down, Zoom */}
<ActionButtonRow>
<ImageToolButton
tooltip={t('preview.zoom_out')}
icon={<ZoomOut size={'1rem'} />}
onClick={() => zoom(-zoomDelta)}
/>
<ImageToolButton
tooltip={t('preview.pan_down')}
icon={<ChevronDown size={'1rem'} />}
onClick={() => pan(0, panDistance)}
/>
<ImageToolButton
tooltip={t('preview.zoom_in')}
icon={<ZoomIn size={'1rem'} />}
onClick={() => zoom(zoomDelta)}
/>
</ActionButtonRow>
</ToolbarWrapper>
)
}
const ToolbarWrapper = styled.div`
display: flex;
flex-direction: column;
align-items: center;
position: absolute;
gap: 4px;
right: 1em;
bottom: 1em;
z-index: 5;
.ant-btn {
line-height: 0;
}
`
const ActionButtonRow = styled.div`
display: flex;
justify-content: center;
gap: 4px;
width: 100%;
`
const Spacer = styled.div`
flex: 1;
`
export default memo(ImageToolbar)

View File

@ -0,0 +1,120 @@
import { nanoid } from '@reduxjs/toolkit'
import { useMermaid } from '@renderer/hooks/useMermaid'
import React, { memo, useCallback, useEffect, useRef, useState } from 'react'
import styled from 'styled-components'
import { useDebouncedRender } from './hooks/useDebouncedRender'
import ImagePreviewLayout from './ImagePreviewLayout'
import { BasicPreviewHandles, BasicPreviewProps } from './types'
/** Mermaid
* 使 usePreviewRenderer hook
* FIXME: 等将来 mermaid-js
*/
const MermaidPreview = ({
children,
enableToolbar = false,
ref
}: BasicPreviewProps & { ref?: React.RefObject<BasicPreviewHandles | null> }) => {
const { mermaid, isLoading: isLoadingMermaid, error: mermaidError } = useMermaid()
const diagramId = useRef<string>(`mermaid-${nanoid(6)}`).current
const [isVisible, setIsVisible] = useState(true)
// 定义渲染函数
const renderMermaid = useCallback(
async (content: string, container: HTMLDivElement) => {
// 验证语法,提前抛出异常
await mermaid.parse(content)
const { svg } = await mermaid.render(diagramId, content, container)
// 避免不可见时产生 undefined 和 NaN
const fixedSvg = svg.replace(/translate\(undefined,\s*NaN\)/g, 'translate(0, 0)')
container.innerHTML = fixedSvg
},
[diagramId, mermaid]
)
// 可见性检测函数
const shouldRender = useCallback(() => {
return !isLoadingMermaid && isVisible
}, [isLoadingMermaid, isVisible])
// 使用预览渲染器 hook
const {
containerRef,
error: renderError,
isLoading: isRendering
} = useDebouncedRender(children, renderMermaid, {
debounceDelay: 300,
shouldRender
})
/**
*
* `MessageGroup` `fold` `display: none`
* `fold` className `MessageWrapper`
* FIXME: 将来 mermaid-js
*/
useEffect(() => {
if (!containerRef.current) return
const checkVisibility = () => {
const element = containerRef.current
if (!element) return
const currentlyVisible = element.offsetParent !== null
setIsVisible(currentlyVisible)
}
// 初始检查
checkVisibility()
const observer = new MutationObserver(() => {
checkVisibility()
})
let targetElement = containerRef.current.parentElement
while (targetElement) {
observer.observe(targetElement, {
attributes: true,
attributeFilter: ['class', 'style']
})
if (targetElement.className?.includes('fold')) {
break
}
targetElement = targetElement.parentElement
}
return () => {
observer.disconnect()
}
}, [containerRef])
// 合并加载状态和错误状态
const isLoading = isLoadingMermaid || isRendering
const error = mermaidError || renderError
return (
<ImagePreviewLayout
loading={isLoading}
error={error}
enableToolbar={enableToolbar}
ref={ref}
imageRef={containerRef}
source="mermaid">
<StyledMermaid ref={containerRef} className="mermaid special-preview" />
</ImagePreviewLayout>
)
}
const StyledMermaid = styled.div`
overflow: auto;
position: relative;
width: 100%;
height: 100%;
`
export default memo(MermaidPreview)

View File

@ -0,0 +1,136 @@
import { loggerService } from '@logger'
import pako from 'pako'
import React, { memo, useCallback, useEffect } from 'react'
import { useDebouncedRender } from './hooks/useDebouncedRender'
import ImagePreviewLayout from './ImagePreviewLayout'
import { BasicPreviewHandles, BasicPreviewProps } from './types'
import { renderSvgInShadowHost } from './utils'
const logger = loggerService.withContext('PlantUmlPreview')
const PlantUMLServer = 'https://www.plantuml.com/plantuml'
function encode64(data: Uint8Array) {
let r = ''
for (let i = 0; i < data.length; i += 3) {
if (i + 2 === data.length) {
r += append3bytes(data[i], data[i + 1], 0)
} else if (i + 1 === data.length) {
r += append3bytes(data[i], 0, 0)
} else {
r += append3bytes(data[i], data[i + 1], data[i + 2])
}
}
return r
}
function encode6bit(b: number) {
if (b < 10) {
return String.fromCharCode(48 + b)
}
b -= 10
if (b < 26) {
return String.fromCharCode(65 + b)
}
b -= 26
if (b < 26) {
return String.fromCharCode(97 + b)
}
b -= 26
if (b === 0) {
return '-'
}
if (b === 1) {
return '_'
}
return '?'
}
function append3bytes(b1: number, b2: number, b3: number) {
const c1 = b1 >> 2
const c2 = ((b1 & 0x3) << 4) | (b2 >> 4)
const c3 = ((b2 & 0xf) << 2) | (b3 >> 6)
const c4 = b3 & 0x3f
let r = ''
r += encode6bit(c1 & 0x3f)
r += encode6bit(c2 & 0x3f)
r += encode6bit(c3 & 0x3f)
r += encode6bit(c4 & 0x3f)
return r
}
/**
* https://plantuml.com/zh/code-javascript-synchronous
* To use PlantUML image generation, a text diagram description have to be :
1. Encoded in UTF-8
2. Compressed using Deflate algorithm
3. Reencoded in ASCII using a transformation _close_ to base64
*/
function encodeDiagram(diagram: string): string {
const utf8text = new TextEncoder().encode(diagram)
const compressed = pako.deflateRaw(utf8text)
return encode64(compressed)
}
function getPlantUMLImageUrl(format: 'png' | 'svg', diagram: string, isDark?: boolean) {
const encodedDiagram = encodeDiagram(diagram)
if (isDark) {
return `${PlantUMLServer}/d${format}/${encodedDiagram}`
}
return `${PlantUMLServer}/${format}/${encodedDiagram}`
}
const PlantUmlPreview = ({
children,
enableToolbar = false,
ref
}: BasicPreviewProps & { ref?: React.RefObject<BasicPreviewHandles | null> }) => {
// 定义渲染函数
const renderPlantUml = useCallback(async (content: string, container: HTMLDivElement) => {
const url = getPlantUMLImageUrl('svg', content, false)
const response = await fetch(url)
if (!response.ok) {
if (response.status === 400) {
throw new Error(
'Diagram rendering failed (400): This is likely due to a syntax error in the diagram. Please check your code.'
)
}
if (response.status >= 500) {
throw new Error(
`Diagram rendering failed (${response.status}): The PlantUML server is temporarily unavailable. Please try again later.`
)
}
throw new Error(`Diagram rendering failed, server returned: ${response.status} ${response.statusText}`)
}
const text = await response.text()
renderSvgInShadowHost(text, container)
}, [])
// 使用预览渲染器 hook
const { containerRef, error, isLoading } = useDebouncedRender(children, renderPlantUml, {
debounceDelay: 300
})
// 记录网络错误
useEffect(() => {
if (error && error.includes('Failed to fetch')) {
logger.warn('Network Error: Unable to connect to PlantUML server. Please check your network connection.')
} else if (error) {
logger.warn(error)
}
}, [error])
return (
<ImagePreviewLayout
loading={isLoading}
error={error}
enableToolbar={enableToolbar}
ref={ref}
imageRef={containerRef}
source="plantuml">
<div ref={containerRef} className="plantuml-preview special-preview" />
</ImagePreviewLayout>
)
}
export default memo(PlantUmlPreview)

View File

@ -0,0 +1,42 @@
import { memo, useCallback } from 'react'
import { useDebouncedRender } from './hooks/useDebouncedRender'
import ImagePreviewLayout from './ImagePreviewLayout'
import { BasicPreviewHandles } from './types'
import { renderSvgInShadowHost } from './utils'
interface SvgPreviewProps {
children: string
enableToolbar?: boolean
className?: string
ref?: React.RefObject<BasicPreviewHandles | null>
}
/**
* 使 Shadow DOM SVG
*/
const SvgPreview = ({ children, enableToolbar = false, className, ref }: SvgPreviewProps) => {
// 定义渲染函数
const renderSvg = useCallback(async (content: string, container: HTMLDivElement) => {
renderSvgInShadowHost(content, container)
}, [])
// 使用预览渲染器 hook
const { containerRef, error, isLoading } = useDebouncedRender(children, renderSvg, {
debounceDelay: 300
})
return (
<ImagePreviewLayout
loading={isLoading}
error={error}
enableToolbar={enableToolbar}
ref={ref}
imageRef={containerRef}
source="svg">
<div ref={containerRef} className={className ?? 'svg-preview special-preview'}></div>
</ImagePreviewLayout>
)
}
export default memo(SvgPreview)

View File

@ -0,0 +1,140 @@
import GraphvizPreview from '@renderer/components/Preview/GraphvizPreview'
import { render, screen } from '@testing-library/react'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
// Use vi.hoisted to manage mocks
const mocks = vi.hoisted(() => ({
vizInstance: {
renderSVGElement: vi.fn()
},
vizInitializer: {
get: vi.fn()
},
ImagePreviewLayout: vi.fn(({ children, loading, error, enableToolbar, source }) => (
<div data-testid="image-preview-layout" data-source={source}>
{enableToolbar && <div data-testid="toolbar">Toolbar</div>}
{loading && <div data-testid="loading">Loading...</div>}
{error && <div data-testid="error">{error}</div>}
<div data-testid="preview-content">{children}</div>
</div>
)),
useDebouncedRender: vi.fn()
}))
vi.mock('@renderer/components/Preview/ImagePreviewLayout', () => ({
default: mocks.ImagePreviewLayout
}))
vi.mock('@renderer/utils/asyncInitializer', () => ({
AsyncInitializer: class {
constructor() {
return mocks.vizInitializer
}
}
}))
vi.mock('@renderer/components/Preview/hooks/useDebouncedRender', () => ({
useDebouncedRender: mocks.useDebouncedRender
}))
describe('GraphvizPreview', () => {
const dotCode = 'digraph { a -> b }'
const mockContainerRef = { current: document.createElement('div') }
// Helper function to create mock useDebouncedRender return value
const createMockHookReturn = (overrides = {}) => ({
containerRef: mockContainerRef,
error: null,
isLoading: false,
triggerRender: vi.fn(),
cancelRender: vi.fn(),
clearError: vi.fn(),
setLoading: vi.fn(),
...overrides
})
beforeEach(() => {
// Setup default successful state
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn())
})
afterEach(() => {
vi.clearAllMocks()
})
describe('basic rendering', () => {
it('should match snapshot', () => {
const { container } = render(<GraphvizPreview enableToolbar>{dotCode}</GraphvizPreview>)
expect(container).toMatchSnapshot()
})
it('should handle valid dot code', () => {
render(<GraphvizPreview>{dotCode}</GraphvizPreview>)
// Component should render without throwing
expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument()
expect(mocks.useDebouncedRender).toHaveBeenCalledWith(
dotCode,
expect.any(Function),
expect.objectContaining({ debounceDelay: 300 })
)
})
it('should handle empty content', () => {
render(<GraphvizPreview>{''}</GraphvizPreview>)
// Component should render without throwing
expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument()
expect(mocks.useDebouncedRender).toHaveBeenCalledWith('', expect.any(Function), expect.any(Object))
})
})
describe('loading state', () => {
it('should show loading indicator when rendering', () => {
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: true }))
render(<GraphvizPreview>{dotCode}</GraphvizPreview>)
expect(screen.getByTestId('loading')).toBeInTheDocument()
})
it('should not show loading indicator when not rendering', () => {
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: false }))
render(<GraphvizPreview>{dotCode}</GraphvizPreview>)
expect(screen.queryByTestId('loading')).not.toBeInTheDocument()
})
})
describe('error handling', () => {
it('should show error message when rendering fails', () => {
const errorMessage = 'Invalid dot syntax'
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: errorMessage }))
render(<GraphvizPreview>{dotCode}</GraphvizPreview>)
const errorElement = screen.getByTestId('error')
expect(errorElement).toBeInTheDocument()
expect(errorElement).toHaveTextContent(errorMessage)
})
it('should not show error when rendering is successful', () => {
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: null }))
render(<GraphvizPreview>{dotCode}</GraphvizPreview>)
expect(screen.queryByTestId('error')).not.toBeInTheDocument()
})
})
describe('ref forwarding', () => {
it('should forward ref to ImagePreviewLayout', () => {
const ref = { current: null }
render(<GraphvizPreview ref={ref}>{dotCode}</GraphvizPreview>)
// The ref should be passed to ImagePreviewLayout
expect(mocks.ImagePreviewLayout).toHaveBeenCalledWith(expect.objectContaining({ ref }), undefined)
})
})
})

View File

@ -0,0 +1,122 @@
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import ImagePreviewLayout from '../ImagePreviewLayout'
const mocks = vi.hoisted(() => ({
useImageTools: vi.fn(() => ({
pan: vi.fn(),
zoom: vi.fn(),
copy: vi.fn(),
download: vi.fn(),
dialog: vi.fn()
}))
}))
// Mock antd components
vi.mock('antd', () => ({
Spin: ({ children, spinning }: any) => (
<div data-testid="spin" data-spinning={spinning}>
{children}
</div>
)
}))
vi.mock('@renderer/components/Icons', () => ({
LoadingIcon: () => <div data-testid="spinner">Spinner</div>
}))
// Mock ImageToolbar
vi.mock('../ImageToolbar', () => ({
default: () => <div data-testid="image-toolbar">ImageToolbar</div>
}))
// Mock styles
vi.mock('../styles', () => ({
PreviewContainer: ({ children, vertical }: any) => (
<div data-testid="preview-container" data-vertical={vertical}>
{children}
</div>
),
PreviewError: ({ children }: any) => <div data-testid="preview-error">{children}</div>
}))
// Mock useImageTools
vi.mock('@renderer/components/ActionTools/hooks/useImageTools', () => ({
useImageTools: mocks.useImageTools
}))
describe('ImagePreviewLayout', () => {
const mockImageRef = { current: null }
const defaultProps = {
imageRef: mockImageRef,
source: 'test-source',
children: <div>Test Content</div>
}
beforeEach(() => {
vi.clearAllMocks()
})
it('should match snapshot', () => {
const { container } = render(<ImagePreviewLayout {...defaultProps} />)
expect(container).toMatchSnapshot()
})
it('should render children correctly', () => {
render(<ImagePreviewLayout {...defaultProps} />)
expect(screen.getByText('Test Content')).toBeInTheDocument()
})
it('should show loading state when loading is true', () => {
render(<ImagePreviewLayout {...defaultProps} loading={true} />)
expect(screen.getByTestId('spin')).toHaveAttribute('data-spinning', 'true')
})
it('should not show loading state when loading is false', () => {
render(<ImagePreviewLayout {...defaultProps} loading={false} />)
expect(screen.getByTestId('spin')).toHaveAttribute('data-spinning', 'false')
})
it('should display error message when error is provided', () => {
const errorMessage = 'Test error message'
render(<ImagePreviewLayout {...defaultProps} error={errorMessage} />)
expect(screen.getByText(errorMessage)).toBeInTheDocument()
})
it('should not display error message when error is null', () => {
render(<ImagePreviewLayout {...defaultProps} error={null} />)
expect(screen.queryByText('preview-error')).not.toBeInTheDocument()
})
it('should render ImageToolbar when enableToolbar is true and no error', () => {
render(<ImagePreviewLayout {...defaultProps} enableToolbar={true} />)
expect(screen.getByTestId('image-toolbar')).toBeInTheDocument()
})
it('should not render ImageToolbar when enableToolbar is false', () => {
render(<ImagePreviewLayout {...defaultProps} enableToolbar={false} />)
expect(screen.queryByTestId('image-toolbar')).not.toBeInTheDocument()
})
it('should not render ImageToolbar when there is an error', () => {
render(<ImagePreviewLayout {...defaultProps} enableToolbar={true} error="Error occurred" />)
expect(screen.queryByTestId('image-toolbar')).not.toBeInTheDocument()
})
it('should call useImageTools with correct parameters', () => {
render(<ImagePreviewLayout {...defaultProps} />)
// Verify useImageTools was called with correct parameters
expect(mocks.useImageTools).toHaveBeenCalledWith(
mockImageRef,
expect.objectContaining({
imgSelector: 'svg',
prefix: 'test-source',
enableDrag: true,
enableWheelZoom: true
})
)
})
})

View File

@ -0,0 +1,31 @@
import { render } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import ImageToolButton from '../ImageToolButton'
// Mock antd components
vi.mock('antd', () => ({
Button: vi.fn(({ children, onClick, ...props }) => (
<button type="button" data-testid="custom-button" onClick={onClick} {...props}>
{children}
</button>
)),
Tooltip: vi.fn(({ children, title }) => <div title={title}>{children}</div>)
}))
describe('ImageToolButton', () => {
beforeEach(() => {
vi.clearAllMocks()
})
const defaultProps = {
tooltip: 'Test tooltip',
icon: <span data-testid="test-icon">Icon</span>,
onClick: vi.fn()
}
it('should match snapshot', () => {
const { asFragment } = render(<ImageToolButton {...defaultProps} />)
expect(asFragment()).toMatchSnapshot()
})
})

View File

@ -0,0 +1,96 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import ImageToolbar from '../ImageToolbar'
// Mock react-i18next
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key
})
}))
// Mock ImageToolButton
vi.mock('../ImageToolButton', () => ({
default: vi.fn(({ tooltip, onClick, icon }) => (
<button type="button" onClick={onClick} role="button" aria-label={tooltip}>
{icon}
</button>
))
}))
// Mock lucide-react icons
vi.mock('lucide-react', () => ({
ChevronUp: () => <span data-testid="chevron-up"></span>,
ChevronDown: () => <span data-testid="chevron-down"></span>,
ChevronLeft: () => <span data-testid="chevron-left"></span>,
ChevronRight: () => <span data-testid="chevron-right"></span>,
ZoomIn: () => <span data-testid="zoom-in">+</span>,
ZoomOut: () => <span data-testid="zoom-out">-</span>,
Scan: () => <span data-testid="scan"></span>
}))
vi.mock('@renderer/components/Icons', () => ({
ResetIcon: () => <span data-testid="reset"></span>
}))
// Mock utils
vi.mock('@renderer/utils', () => ({
classNames: (...args: any[]) => args.filter(Boolean).join(' ')
}))
describe('ImageToolbar', () => {
const mockPan = vi.fn()
const mockZoom = vi.fn()
const mockOpenDialog = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
})
it('should match snapshot', () => {
const { asFragment } = render(<ImageToolbar pan={mockPan} zoom={mockZoom} dialog={mockOpenDialog} />)
expect(asFragment()).toMatchSnapshot()
})
it('calls onPan with correct values when pan buttons are clicked', () => {
render(<ImageToolbar pan={mockPan} zoom={mockZoom} dialog={mockOpenDialog} />)
fireEvent.click(screen.getByRole('button', { name: 'preview.pan_up' }))
expect(mockPan).toHaveBeenCalledWith(0, -20)
fireEvent.click(screen.getByRole('button', { name: 'preview.pan_down' }))
expect(mockPan).toHaveBeenCalledWith(0, 20)
fireEvent.click(screen.getByRole('button', { name: 'preview.pan_left' }))
expect(mockPan).toHaveBeenCalledWith(-20, 0)
fireEvent.click(screen.getByRole('button', { name: 'preview.pan_right' }))
expect(mockPan).toHaveBeenCalledWith(20, 0)
})
it('calls onZoom with correct values when zoom buttons are clicked', () => {
render(<ImageToolbar pan={mockPan} zoom={mockZoom} dialog={mockOpenDialog} />)
fireEvent.click(screen.getByRole('button', { name: 'preview.zoom_in' }))
expect(mockZoom).toHaveBeenCalledWith(0.1)
fireEvent.click(screen.getByRole('button', { name: 'preview.zoom_out' }))
expect(mockZoom).toHaveBeenCalledWith(-0.1)
})
it('calls onReset with correct values when reset button is clicked', () => {
render(<ImageToolbar pan={mockPan} zoom={mockZoom} dialog={mockOpenDialog} />)
fireEvent.click(screen.getByRole('button', { name: 'preview.reset' }))
expect(mockPan).toHaveBeenCalledWith(0, 0, true)
expect(mockZoom).toHaveBeenCalledWith(1, true)
})
it('calls onOpenDialog when dialog button is clicked', () => {
render(<ImageToolbar pan={mockPan} zoom={mockZoom} dialog={mockOpenDialog} />)
fireEvent.click(screen.getByRole('button', { name: 'preview.dialog' }))
expect(mockOpenDialog).toHaveBeenCalled()
})
})

View File

@ -0,0 +1,259 @@
import { render, screen } from '@testing-library/react'
import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest'
import { MermaidPreview } from '..'
const mocks = vi.hoisted(() => ({
useMermaid: vi.fn(),
useDebouncedRender: vi.fn(),
ImagePreviewLayout: vi.fn(({ children, loading, error, enableToolbar, source }) => (
<div data-testid="image-preview-layout" data-source={source}>
{enableToolbar && <div data-testid="toolbar">Toolbar</div>}
{loading && <div data-testid="loading">Loading...</div>}
{error && <div data-testid="error">{error}</div>}
<div data-testid="preview-content">{children}</div>
</div>
))
}))
// Mock hooks
vi.mock('@renderer/hooks/useMermaid', () => ({
useMermaid: () => mocks.useMermaid()
}))
vi.mock('@renderer/components/Preview/ImagePreviewLayout', () => ({
default: mocks.ImagePreviewLayout
}))
vi.mock('@renderer/components/Preview/hooks/useDebouncedRender', () => ({
useDebouncedRender: mocks.useDebouncedRender
}))
// Mock nanoid
vi.mock('@reduxjs/toolkit', () => ({
nanoid: () => 'test-id-123456'
}))
describe('MermaidPreview', () => {
const mermaidCode = 'graph TD\nA-->B'
const mockContainerRef = { current: document.createElement('div') }
const mockMermaid = {
parse: vi.fn(),
render: vi.fn()
}
// Helper function to create mock useDebouncedRender return value
const createMockHookReturn = (overrides = {}) => ({
containerRef: mockContainerRef,
error: null,
isLoading: false,
triggerRender: vi.fn(),
cancelRender: vi.fn(),
clearError: vi.fn(),
setLoading: vi.fn(),
...overrides
})
beforeEach(() => {
// Setup default mocks
mocks.useMermaid.mockReturnValue({
mermaid: mockMermaid,
isLoading: false,
error: null
})
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn())
mockMermaid.parse.mockResolvedValue(true)
mockMermaid.render.mockResolvedValue({
svg: '<svg class="flowchart" viewBox="0 0 100 100"><g>test diagram</g></svg>'
})
// Mock MutationObserver
global.MutationObserver = vi.fn().mockImplementation(() => ({
observe: vi.fn(),
disconnect: vi.fn(),
takeRecords: vi.fn()
}))
})
afterEach(() => {
vi.clearAllMocks()
vi.restoreAllMocks()
})
describe('basic rendering', () => {
it('should match snapshot', () => {
const { container } = render(<MermaidPreview enableToolbar>{mermaidCode}</MermaidPreview>)
expect(container).toMatchSnapshot()
})
it('should handle valid mermaid content', () => {
render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument()
expect(mocks.useDebouncedRender).toHaveBeenCalledWith(
mermaidCode,
expect.any(Function),
expect.objectContaining({
debounceDelay: 300,
shouldRender: expect.any(Function)
})
)
})
it('should handle empty content', () => {
render(<MermaidPreview>{''}</MermaidPreview>)
expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument()
expect(mocks.useDebouncedRender).toHaveBeenCalledWith('', expect.any(Function), expect.any(Object))
})
})
describe('loading state', () => {
it('should show loading when useMermaid is loading', () => {
mocks.useMermaid.mockReturnValue({
mermaid: mockMermaid,
isLoading: true,
error: null
})
render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
expect(screen.getByTestId('loading')).toBeInTheDocument()
})
it('should show loading when useDebouncedRender is loading', () => {
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: true }))
render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
expect(screen.getByTestId('loading')).toBeInTheDocument()
})
it('should not show loading when both are not loading', () => {
render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
expect(screen.queryByTestId('loading')).not.toBeInTheDocument()
})
})
describe('error handling', () => {
it('should show error from useMermaid', () => {
const mermaidError = 'Mermaid initialization failed'
mocks.useMermaid.mockReturnValue({
mermaid: mockMermaid,
isLoading: false,
error: mermaidError
})
render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
const errorElement = screen.getByTestId('error')
expect(errorElement).toBeInTheDocument()
expect(errorElement).toHaveTextContent(mermaidError)
})
it('should show error from useDebouncedRender', () => {
const renderError = 'Diagram rendering failed'
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: renderError }))
render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
const errorElement = screen.getByTestId('error')
expect(errorElement).toBeInTheDocument()
expect(errorElement).toHaveTextContent(renderError)
})
it('should prioritize useMermaid error over render error', () => {
const mermaidError = 'Mermaid initialization failed'
const renderError = 'Diagram rendering failed'
mocks.useMermaid.mockReturnValue({
mermaid: mockMermaid,
isLoading: false,
error: mermaidError
})
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: renderError }))
render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
const errorElement = screen.getByTestId('error')
expect(errorElement).toHaveTextContent(mermaidError)
})
})
describe('ref forwarding', () => {
it('should forward ref to ImagePreviewLayout', () => {
const ref = { current: null }
render(<MermaidPreview ref={ref}>{mermaidCode}</MermaidPreview>)
expect(mocks.ImagePreviewLayout).toHaveBeenCalledWith(expect.objectContaining({ ref }), undefined)
})
})
describe('visibility detection', () => {
it('should observe parent elements up to fold className', () => {
// Create a DOM structure that simulates MessageGroup fold layout
const foldContainer = document.createElement('div')
foldContainer.className = 'fold selected'
const messageWrapper = document.createElement('div')
messageWrapper.className = 'message-wrapper'
const codeBlock = document.createElement('div')
codeBlock.className = 'code-block'
foldContainer.appendChild(messageWrapper)
messageWrapper.appendChild(codeBlock)
document.body.appendChild(foldContainer)
try {
render(<MermaidPreview>{mermaidCode}</MermaidPreview>, {
container: codeBlock
})
const observerInstance = (global.MutationObserver as Mock).mock.results[0]?.value
expect(observerInstance.observe).toHaveBeenCalled()
} finally {
// Cleanup
document.body.removeChild(foldContainer)
}
})
it('should handle visibility changes and trigger re-render', () => {
const mockTriggerRender = vi.fn()
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ triggerRender: mockTriggerRender }))
const { container } = render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
// Get the MutationObserver callback
const observerCallback = (global.MutationObserver as Mock).mock.calls[0][0]
// Mock the container element to be initially hidden
const mermaidElement = container.querySelector('.mermaid')
Object.defineProperty(mermaidElement, 'offsetParent', {
get: () => null, // Hidden
configurable: true
})
// Simulate MutationObserver detecting visibility change
observerCallback([])
// Now make it visible
Object.defineProperty(mermaidElement, 'offsetParent', {
get: () => document.body, // Visible
configurable: true
})
// Simulate another MutationObserver callback for visibility change
observerCallback([])
// The visibility change should have been detected and component should be ready to re-render
// We verify the component structure is correct for potential re-rendering
expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument()
expect(mermaidElement).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,169 @@
import PlantUmlPreview from '@renderer/components/Preview/PlantUmlPreview'
import { render, screen } from '@testing-library/react'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
// Use vi.hoisted to manage mocks
const mocks = vi.hoisted(() => ({
ImagePreviewLayout: vi.fn(({ children, loading, error, enableToolbar, source }) => (
<div data-testid="image-preview-layout" data-source={source}>
{enableToolbar && <div data-testid="toolbar">Toolbar</div>}
{loading && <div data-testid="loading">Loading...</div>}
{error && <div data-testid="error">{error}</div>}
<div data-testid="preview-content">{children}</div>
</div>
)),
renderSvgInShadowHost: vi.fn(),
useDebouncedRender: vi.fn(),
logger: {
warn: vi.fn()
}
}))
vi.mock('@renderer/components/Preview/ImagePreviewLayout', () => ({
default: mocks.ImagePreviewLayout
}))
vi.mock('@renderer/components/Preview/utils', () => ({
renderSvgInShadowHost: mocks.renderSvgInShadowHost
}))
vi.mock('@renderer/components/Preview/hooks/useDebouncedRender', () => ({
useDebouncedRender: mocks.useDebouncedRender
}))
describe('PlantUmlPreview', () => {
const diagram = '@startuml\nA -> B\n@enduml'
const mockContainerRef = { current: document.createElement('div') }
// Helper function to create mock useDebouncedRender return value
const createMockHookReturn = (overrides = {}) => ({
containerRef: mockContainerRef,
error: null,
isLoading: false,
triggerRender: vi.fn(),
cancelRender: vi.fn(),
clearError: vi.fn(),
setLoading: vi.fn(),
...overrides
})
beforeEach(() => {
// Setup default successful state
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn())
})
afterEach(() => {
vi.clearAllMocks()
})
describe('basic rendering', () => {
it('should match snapshot', () => {
const { container } = render(<PlantUmlPreview enableToolbar>{diagram}</PlantUmlPreview>)
expect(container).toMatchSnapshot()
})
it('should handle valid plantuml diagram', () => {
render(<PlantUmlPreview>{diagram}</PlantUmlPreview>)
// Component should render without throwing
expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument()
expect(mocks.useDebouncedRender).toHaveBeenCalledWith(
diagram,
expect.any(Function),
expect.objectContaining({ debounceDelay: 300 })
)
})
it('should handle empty content', () => {
render(<PlantUmlPreview>{''}</PlantUmlPreview>)
// Component should render without throwing
expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument()
expect(mocks.useDebouncedRender).toHaveBeenCalledWith('', expect.any(Function), expect.any(Object))
})
})
describe('loading state', () => {
it('should show loading indicator when rendering', () => {
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: true }))
render(<PlantUmlPreview>{diagram}</PlantUmlPreview>)
expect(screen.getByTestId('loading')).toBeInTheDocument()
})
it('should not show loading indicator when not rendering', () => {
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: false }))
render(<PlantUmlPreview>{diagram}</PlantUmlPreview>)
expect(screen.queryByTestId('loading')).not.toBeInTheDocument()
})
})
describe('error handling', () => {
it('should show network error message', () => {
const networkError = 'Network Error: Unable to connect to PlantUML server. Please check your network connection.'
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: networkError }))
render(<PlantUmlPreview>{diagram}</PlantUmlPreview>)
const errorElement = screen.getByTestId('error')
expect(errorElement).toBeInTheDocument()
expect(errorElement).toHaveTextContent(networkError)
})
it('should show syntax error message for invalid diagram', () => {
const syntaxError =
'Diagram rendering failed (400): This is likely due to a syntax error in the diagram. Please check your code.'
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: syntaxError }))
render(<PlantUmlPreview>{diagram}</PlantUmlPreview>)
const errorElement = screen.getByTestId('error')
expect(errorElement).toBeInTheDocument()
expect(errorElement).toHaveTextContent(syntaxError)
})
it('should show server error message', () => {
const serverError =
'Diagram rendering failed (503): The PlantUML server is temporarily unavailable. Please try again later.'
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: serverError }))
render(<PlantUmlPreview>{diagram}</PlantUmlPreview>)
const errorElement = screen.getByTestId('error')
expect(errorElement).toBeInTheDocument()
expect(errorElement).toHaveTextContent(serverError)
})
it('should show generic error message for other errors', () => {
const genericError = "Diagram rendering failed, server returned: 418 I'm a teapot"
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: genericError }))
render(<PlantUmlPreview>{diagram}</PlantUmlPreview>)
const errorElement = screen.getByTestId('error')
expect(errorElement).toBeInTheDocument()
expect(errorElement).toHaveTextContent(genericError)
})
it('should not show error when rendering is successful', () => {
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: null }))
render(<PlantUmlPreview>{diagram}</PlantUmlPreview>)
expect(screen.queryByTestId('error')).not.toBeInTheDocument()
})
})
describe('ref forwarding', () => {
it('should forward ref to ImagePreviewLayout', () => {
const ref = { current: null }
render(<PlantUmlPreview ref={ref}>{diagram}</PlantUmlPreview>)
// The ref should be passed to ImagePreviewLayout
expect(mocks.ImagePreviewLayout).toHaveBeenCalledWith(expect.objectContaining({ ref }), undefined)
})
})
})

View File

@ -0,0 +1,149 @@
import SvgPreview from '@renderer/components/Preview/SvgPreview'
import { render, screen } from '@testing-library/react'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
// Use vi.hoisted to manage mocks
const mocks = vi.hoisted(() => ({
ImagePreviewLayout: vi.fn(({ children, loading, error, enableToolbar, source }) => (
<div data-testid="image-preview-layout" data-source={source}>
{enableToolbar && <div data-testid="toolbar">Toolbar</div>}
{loading && <div data-testid="loading">Loading...</div>}
{error && <div data-testid="error">{error}</div>}
<div data-testid="preview-content">{children}</div>
</div>
)),
renderSvgInShadowHost: vi.fn(),
useDebouncedRender: vi.fn()
}))
vi.mock('@renderer/components/Preview/ImagePreviewLayout', () => ({
default: mocks.ImagePreviewLayout
}))
vi.mock('@renderer/components/Preview/utils', () => ({
renderSvgInShadowHost: mocks.renderSvgInShadowHost
}))
vi.mock('@renderer/components/Preview/hooks/useDebouncedRender', () => ({
useDebouncedRender: mocks.useDebouncedRender
}))
describe('SvgPreview', () => {
const svgContent = '<svg><rect width="100" height="100" /></svg>'
const mockContainerRef = { current: document.createElement('div') }
// Helper function to create mock useDebouncedRender return value
const createMockHookReturn = (overrides = {}) => ({
containerRef: mockContainerRef,
error: null,
isLoading: false,
triggerRender: vi.fn(),
cancelRender: vi.fn(),
clearError: vi.fn(),
setLoading: vi.fn(),
...overrides
})
beforeEach(() => {
// Setup default successful state
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn())
})
afterEach(() => {
vi.clearAllMocks()
})
describe('basic rendering', () => {
it('should match snapshot', () => {
const { container } = render(<SvgPreview enableToolbar>{svgContent}</SvgPreview>)
expect(container).toMatchSnapshot()
})
it('should handle valid svg content', () => {
render(<SvgPreview>{svgContent}</SvgPreview>)
// Component should render without throwing
expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument()
expect(mocks.useDebouncedRender).toHaveBeenCalledWith(
svgContent,
expect.any(Function),
expect.objectContaining({ debounceDelay: 300 })
)
})
it('should handle empty content', () => {
render(<SvgPreview>{''}</SvgPreview>)
// Component should render without throwing
expect(screen.getByTestId('image-preview-layout')).toBeInTheDocument()
expect(mocks.useDebouncedRender).toHaveBeenCalledWith('', expect.any(Function), expect.any(Object))
})
})
describe('loading state', () => {
it('should show loading indicator when rendering', () => {
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: true }))
render(<SvgPreview>{svgContent}</SvgPreview>)
expect(screen.getByTestId('loading')).toBeInTheDocument()
})
it('should not show loading indicator when not rendering', () => {
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ isLoading: false }))
render(<SvgPreview>{svgContent}</SvgPreview>)
expect(screen.queryByTestId('loading')).not.toBeInTheDocument()
})
})
describe('error handling', () => {
it('should show error message when rendering fails', () => {
const errorMessage = 'Invalid SVG content'
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: errorMessage }))
render(<SvgPreview>{svgContent}</SvgPreview>)
const errorElement = screen.getByTestId('error')
expect(errorElement).toBeInTheDocument()
expect(errorElement).toHaveTextContent(errorMessage)
})
it('should not show error when rendering is successful', () => {
mocks.useDebouncedRender.mockReturnValue(createMockHookReturn({ error: null }))
render(<SvgPreview>{svgContent}</SvgPreview>)
expect(screen.queryByTestId('error')).not.toBeInTheDocument()
})
})
describe('custom styling', () => {
it('should use custom className when provided', () => {
render(<SvgPreview className="custom-svg-class">{svgContent}</SvgPreview>)
const content = screen.getByTestId('preview-content')
const svgContainer = content.querySelector('.custom-svg-class')
expect(svgContainer).toBeInTheDocument()
})
it('should use default className when not provided', () => {
render(<SvgPreview>{svgContent}</SvgPreview>)
const content = screen.getByTestId('preview-content')
const svgContainer = content.querySelector('.svg-preview.special-preview')
expect(svgContainer).toBeInTheDocument()
})
})
describe('ref forwarding', () => {
it('should forward ref to ImagePreviewLayout', () => {
const ref = { current: null }
render(<SvgPreview ref={ref}>{svgContent}</SvgPreview>)
// The ref should be passed to ImagePreviewLayout
expect(mocks.ImagePreviewLayout).toHaveBeenCalledWith(expect.objectContaining({ ref }), undefined)
})
})
})

View File

@ -0,0 +1,30 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
exports[`GraphvizPreview > basic rendering > should match snapshot 1`] = `
.c0 {
overflow: auto;
position: relative;
width: 100%;
height: 100%;
}
<div>
<div
data-source="graphviz"
data-testid="image-preview-layout"
>
<div
data-testid="toolbar"
>
Toolbar
</div>
<div
data-testid="preview-content"
>
<div
class="c0 graphviz special-preview"
/>
</div>
</div>
</div>
`;

Some files were not shown because too many files have changed in this diff Show More