fix(api): refactors the SQL LIKE pattern escaping logic to use a centralized utility function, ensuring consistent and secure handling of special characters across all database queries. (#30450)

Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
NeatGuyCoding 2026-01-06 09:56:30 +08:00 committed by GitHub
parent de6262784c
commit 615c313f80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 648 additions and 36 deletions

View File

@ -348,10 +348,13 @@ class CompletionConversationApi(Resource):
)
if args.keyword:
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(args.keyword)
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike(f"%{args.keyword}%"),
Message.answer.ilike(f"%{args.keyword}%"),
Message.query.ilike(f"%{escaped_keyword}%", escape="\\"),
Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"),
)
)
@ -460,7 +463,10 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args.keyword:
keyword_filter = f"%{args.keyword}%"
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(args.keyword)
keyword_filter = f"%{escaped_keyword}%"
query = (
query.join(
Message,
@ -469,11 +475,11 @@ class ChatConversationApi(Resource):
.join(subquery, subquery.c.conversation_id == Conversation.id)
.where(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter),
Message.query.ilike(keyword_filter, escape="\\"),
Message.answer.ilike(keyword_filter, escape="\\"),
Conversation.name.ilike(keyword_filter, escape="\\"),
Conversation.introduction.ilike(keyword_filter, escape="\\"),
subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"),
),
)
.group_by(Conversation.id)

View File

@ -30,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
@ -145,6 +146,8 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
escaped_keyword = escape_like_pattern(keyword)
# Search in both content and keywords fields
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
@ -156,15 +159,15 @@ class DatasetDocumentSegmentListApi(Resource):
.scalar_subquery()
),
",",
).ilike(f"%{keyword}%")
).ilike(f"%{escaped_keyword}%", escape="\\")
else:
# MySQL: Cast JSON to string for pattern matching
# MySQL stores Chinese text directly in JSON without Unicode escaping
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\")
query = query.where(
or_(
DocumentSegment.content.ilike(f"%{keyword}%"),
DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"),
keywords_condition,
)
)

View File

@ -984,9 +984,11 @@ class ClickzettaVector(BaseVector):
# No need for dataset_id filter since each dataset has its own table
# Use simple quote escaping for LIKE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern
escaped_query = escape_like_pattern(query).replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
where_clause = " AND ".join(filter_clauses)
search_sql = f"""

View File

@ -287,11 +287,15 @@ class IrisVector(BaseVector):
cursor.execute(sql, (query,))
else:
# Fallback to LIKE search (inefficient for large datasets)
query_pattern = f"%{query}%"
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern
escaped_query = escape_like_pattern(query)
query_pattern = f"%{escaped_query}%"
sql = f"""
SELECT TOP {top_k} id, text, meta
FROM {self.schema}.{self.table_name}
WHERE text LIKE ?
WHERE text LIKE ? ESCAPE '\\'
"""
cursor.execute(sql, (query_pattern,))

View File

@ -1198,18 +1198,24 @@ class DatasetRetrieval:
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
from libs.helper import escape_like_pattern
match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
case "start with":
filters.append(json_field.like(f"{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
case "end with":
filters.append(json_field.like(f"%{value}"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
case "is" | "=":
if isinstance(value, str):

View File

@ -32,6 +32,38 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def escape_like_pattern(pattern: str) -> str:
"""
Escape special characters in a string for safe use in SQL LIKE patterns.
This function escapes the special characters used in SQL LIKE patterns:
- Backslash (\\) -> \\
- Percent (%) -> \\%
- Underscore (_) -> \\_
The escaped pattern can then be safely used in SQL LIKE queries with the
ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards.
Args:
pattern: The string pattern to escape
Returns:
Escaped string safe for use in SQL LIKE queries
Examples:
>>> escape_like_pattern("50% discount")
'50\\% discount'
>>> escape_like_pattern("test_data")
'test\\_data'
>>> escape_like_pattern("path\\to\\file")
'path\\\\to\\\\file'
"""
if not pattern:
return pattern
# Escape backslash first, then percent and underscore
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
"""
Extract tenant_id from Account or EndUser object.

View File

@ -137,13 +137,16 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
if keyword:
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(keyword)
stmt = (
select(MessageAnnotation)
.where(MessageAnnotation.app_id == app_id)
.where(
or_(
MessageAnnotation.question.ilike(f"%{keyword}%"),
MessageAnnotation.content.ilike(f"%{keyword}%"),
MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"),
MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"),
)
)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())

View File

@ -55,8 +55,11 @@ class AppService:
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):
from libs.helper import escape_like_pattern
name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%"))
escaped_name = escape_like_pattern(name)
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
# Check if tag_ids is not empty to avoid WHERE false condition
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])

View File

@ -218,7 +218,9 @@ class ConversationService:
# Apply variable_name filter if provided
if variable_name:
# Filter using JSON extraction to match variable names case-insensitively
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
from libs.helper import escape_like_pattern
escaped_variable_name = escape_like_pattern(variable_name)
# Filter using JSON extraction to match variable names case-insensitively
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
stmt = stmt.where(

View File

@ -144,7 +144,8 @@ class DatasetService:
query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
if search:
query = query.where(Dataset.name.ilike(f"%{search}%"))
escaped_search = helper.escape_like_pattern(search)
query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\"))
# Check if tag_ids is not empty to avoid WHERE false condition
if tag_ids and len(tag_ids) > 0:
@ -3423,7 +3424,8 @@ class SegmentService:
.order_by(ChildChunk.position.asc())
)
if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
escaped_keyword = helper.escape_like_pattern(keyword)
query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\"))
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
@ -3456,7 +3458,8 @@ class SegmentService:
query = query.where(DocumentSegment.status.in_(status_list))
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
escaped_keyword = helper.escape_like_pattern(keyword)
query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"))
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)

View File

@ -35,7 +35,10 @@ class ExternalDatasetService:
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
from libs.helper import escape_like_pattern
escaped_search = escape_like_pattern(search)
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\"))
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False

View File

@ -19,7 +19,10 @@ class TagService:
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(keyword)
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
return results

View File

@ -86,12 +86,19 @@ class WorkflowAppService:
# Join to workflow run for filtering when needed.
if keyword:
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
from libs.helper import escape_like_pattern
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
escaped_keyword = escape_like_pattern(keyword[:30])
keyword_like_val = f"%{escaped_keyword}%"
keyword_conditions = [
WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val),
WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"),
WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"),
# filter keyword by end user session id if created by end user role
and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
and_(
WorkflowRun.created_by_role == "end_user",
EndUser.session_id.ilike(keyword_like_val, escape="\\"),
),
]
# filter keyword by workflow run id

View File

@ -444,6 +444,78 @@ class TestAnnotationService:
assert total == 1
assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content
def test_get_annotation_list_by_app_id_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention.
This test verifies:
- Special characters (%, _, \) in keyword are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotations with special characters in content
annotation_with_percent = {
"question": "Question with 50% discount",
"answer": "Answer about 50% discount offer",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id)
annotation_with_underscore = {
"question": "Question with test_data",
"answer": "Answer about test_data value",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id)
annotation_with_backslash = {
"question": "Question with path\\to\\file",
"answer": "Answer about path\\to\\file location",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id)
# Create annotation that should NOT match (contains % but as part of different text)
annotation_no_match = {
"question": "Question with 100% different",
"answer": "Answer about 100% different content",
}
AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id)
# Test 1: Search with % character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="50%"
)
assert total == 1
assert len(annotation_list) == 1
assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content
# Test 2: Search with _ character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="test_data"
)
assert total == 1
assert len(annotation_list) == 1
assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content
# Test 3: Search with \ character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="path\\to\\file"
)
assert total == 1
assert len(annotation_list) == 1
assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content
# Test 4: Search with % should NOT match 100% (verifies escaping works)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="50%"
)
# Should only find the 50% annotation, not the 100% one
assert total == 1
assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list)
def test_get_annotation_list_by_app_id_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -7,7 +7,9 @@ from constants.model_template import default_app_templates
from models import Account
from models.model import App, Site
from services.account_service import AccountService, TenantService
from services.app_service import AppService
# Delay import of AppService to avoid circular dependency
# from services.app_service import AppService
class TestAppService:
@ -71,6 +73,9 @@ class TestAppService:
}
# Create app
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -109,6 +114,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Test different app modes
@ -159,6 +167,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
created_app = app_service.create_app(tenant.id, app_args, account)
@ -194,6 +205,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create multiple apps
@ -245,6 +259,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create apps with different modes
@ -315,6 +332,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create an app
@ -392,6 +412,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -458,6 +481,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -508,6 +534,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -562,6 +591,9 @@ class TestAppService:
"icon_background": "#74B9FF",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -617,6 +649,9 @@ class TestAppService:
"icon_background": "#A29BFE",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -672,6 +707,9 @@ class TestAppService:
"icon_background": "#FD79A8",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -720,6 +758,9 @@ class TestAppService:
"icon_background": "#E17055",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -768,6 +809,9 @@ class TestAppService:
"icon_background": "#00B894",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -826,6 +870,9 @@ class TestAppService:
"icon_background": "#6C5CE7",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -862,6 +909,9 @@ class TestAppService:
"icon_background": "#FDCB6E",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -899,6 +949,9 @@ class TestAppService:
"icon_background": "#E84393",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -947,8 +1000,132 @@ class TestAppService:
"icon_background": "#D63031",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Attempt to create app with invalid mode
with pytest.raises(ValueError, match="invalid mode value"):
app_service.create_app(tenant.id, app_args, account)
def test_get_apps_with_special_characters_in_name(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test app retrieval with special characters in name search to verify SQL injection prevention.
This test verifies:
- Special characters (%, _, \) in name search are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create apps with special characters in names
app_with_percent = app_service.create_app(
tenant.id,
{
"name": "App with 50% discount",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
app_with_underscore = app_service.create_app(
tenant.id,
{
"name": "test_data_app",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
app_with_backslash = app_service.create_app(
tenant.id,
{
"name": "path\\to\\app",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
# Create app that should NOT match
app_no_match = app_service.create_app(
tenant.id,
{
"name": "100% different",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
# Test 1: Search with % character
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "App with 50% discount"
# Test 2: Search with _ character
args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "test_data_app"
# Test 3: Search with \ character
args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "path\\to\\app"
# Test 4: Search with % should NOT match 100% (verifies escaping works)
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert all("50%" in app.name for app in paginated_apps.items)

View File

@ -1,3 +1,4 @@
import uuid
from unittest.mock import create_autospec, patch
import pytest
@ -312,6 +313,85 @@ class TestTagService:
result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent")
assert len(result_no_match) == 0
def test_get_tags_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test tag retrieval with special characters in keyword to verify SQL injection prevention.
This test verifies:
- Special characters (%, _, \) in keyword are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
from extensions.ext_database import db
# Create tags with special characters in names
tag_with_percent = Tag(
name="50% discount",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_with_percent.id = str(uuid.uuid4())
db.session.add(tag_with_percent)
tag_with_underscore = Tag(
name="test_data_tag",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_with_underscore.id = str(uuid.uuid4())
db.session.add(tag_with_underscore)
tag_with_backslash = Tag(
name="path\\to\\tag",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_with_backslash.id = str(uuid.uuid4())
db.session.add(tag_with_backslash)
# Create tag that should NOT match
tag_no_match = Tag(
name="100% different",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_no_match.id = str(uuid.uuid4())
db.session.add(tag_no_match)
db.session.commit()
# Act & Assert: Test 1 - Search with % character
result = TagService.get_tags("app", tenant.id, keyword="50%")
assert len(result) == 1
assert result[0].name == "50% discount"
# Test 2 - Search with _ character
result = TagService.get_tags("app", tenant.id, keyword="test_data")
assert len(result) == 1
assert result[0].name == "test_data_tag"
# Test 3 - Search with \ character
result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag")
assert len(result) == 1
assert result[0].name == "path\\to\\tag"
# Test 4 - Search with % should NOT match 100% (verifies escaping works)
result = TagService.get_tags("app", tenant.id, keyword="50%")
assert len(result) == 1
assert all("50%" in item.name for item in result)
def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test tag retrieval when no tags exist.

View File

@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
from services.account_service import AccountService, TenantService
from services.app_service import AppService
# Delay import of AppService to avoid circular dependency
# from services.app_service import AppService
from services.workflow_app_service import WorkflowAppService
@ -86,6 +88,9 @@ class TestWorkflowAppService:
"api_rpm": 10,
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -147,6 +152,9 @@ class TestWorkflowAppService:
"api_rpm": 10,
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -308,6 +316,156 @@ class TestWorkflowAppService:
assert result_no_match["total"] == 0
assert len(result_no_match["data"]) == 0
def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention.
This test verifies:
- Special characters (%, _) in keyword are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account)
from extensions.ext_database import db
service = WorkflowAppService()
# Test 1: Search with % character
workflow_run_1 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
type="workflow",
triggered_from="app-run",
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
status="succeeded",
inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}),
outputs=json.dumps({"result": "50% discount applied", "status": "success"}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_1)
db.session.flush()
workflow_app_log_1 = WorkflowAppLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_1.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
workflow_app_log_1.id = str(uuid.uuid4())
workflow_app_log_1.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_1)
db.session.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
)
# Should find the workflow_run_1 entry
assert result["total"] >= 1
assert len(result["data"]) >= 1
assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"])
# Test 2: Search with _ character
workflow_run_2 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
type="workflow",
triggered_from="app-run",
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
status="succeeded",
inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}),
outputs=json.dumps({"result": "test_data_value found", "status": "success"}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_2)
db.session.flush()
workflow_app_log_2 = WorkflowAppLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_2.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
workflow_app_log_2.id = str(uuid.uuid4())
workflow_app_log_2.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_2)
db.session.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20
)
# Should find the workflow_run_2 entry
assert result["total"] >= 1
assert len(result["data"]) >= 1
assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"])
# Test 3: Search with % should NOT match 100% (verifies escaping works correctly)
workflow_run_4 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
type="workflow",
triggered_from="app-run",
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
status="succeeded",
inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}),
outputs=json.dumps({"result": "100% different result", "status": "success"}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_4)
db.session.flush()
workflow_app_log_4 = WorkflowAppLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_4.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
workflow_app_log_4.id = str(uuid.uuid4())
workflow_app_log_4.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_4)
db.session.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
)
# Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4)
# This verifies that escaping works correctly - 50% should not match 100%
assert result["total"] >= 1
assert len(result["data"]) >= 1
# Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different)
found_run_ids = [log.workflow_run_id for log in result["data"]]
assert workflow_run_1.id in found_run_ids
assert workflow_run_4.id not in found_run_ids
def test_get_paginate_workflow_app_logs_with_status_filter(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -1,6 +1,6 @@
import pytest
from libs.helper import extract_tenant_id
from libs.helper import escape_like_pattern, extract_tenant_id
from models.account import Account
from models.model import EndUser
@ -63,3 +63,51 @@ class TestExtractTenantId:
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
extract_tenant_id(dict_user)
class TestEscapeLikePattern:
"""Test cases for the escape_like_pattern utility function."""
def test_escape_percent_character(self):
"""Test escaping percent character."""
result = escape_like_pattern("50% discount")
assert result == "50\\% discount"
def test_escape_underscore_character(self):
"""Test escaping underscore character."""
result = escape_like_pattern("test_data")
assert result == "test\\_data"
def test_escape_backslash_character(self):
"""Test escaping backslash character."""
result = escape_like_pattern("path\\to\\file")
assert result == "path\\\\to\\\\file"
def test_escape_combined_special_characters(self):
"""Test escaping multiple special characters together."""
result = escape_like_pattern("file_50%\\path")
assert result == "file\\_50\\%\\\\path"
def test_escape_empty_string(self):
"""Test escaping empty string returns empty string."""
result = escape_like_pattern("")
assert result == ""
def test_escape_none_handling(self):
"""Test escaping None returns None (falsy check handles it)."""
# The function checks `if not pattern`, so None is falsy and returns as-is
result = escape_like_pattern(None)
assert result is None
def test_escape_normal_string_no_change(self):
"""Test that normal strings without special characters are unchanged."""
result = escape_like_pattern("normal text")
assert result == "normal text"
def test_escape_order_matters(self):
"""Test that backslash is escaped first to prevent double escaping."""
# If we escape % first, then escape \, we might get wrong results
# This test ensures the order is correct: \ first, then % and _
result = escape_like_pattern("test\\%_value")
# Should be: test\\\%\_value
assert result == "test\\\\\\%\\_value"