mirror of
https://github.com/langgenius/dify.git
synced 2026-01-13 21:57:48 +08:00
fix(api): refactors the SQL LIKE pattern escaping logic to use a centralized utility function, ensuring consistent and secure handling of special characters across all database queries. (#30450)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
de6262784c
commit
615c313f80
@ -348,10 +348,13 @@ class CompletionConversationApi(Resource):
|
||||
)
|
||||
|
||||
if args.keyword:
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(args.keyword)
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||
or_(
|
||||
Message.query.ilike(f"%{args.keyword}%"),
|
||||
Message.answer.ilike(f"%{args.keyword}%"),
|
||||
Message.query.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
)
|
||||
)
|
||||
|
||||
@ -460,7 +463,10 @@ class ChatConversationApi(Resource):
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args.keyword:
|
||||
keyword_filter = f"%{args.keyword}%"
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(args.keyword)
|
||||
keyword_filter = f"%{escaped_keyword}%"
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
@ -469,11 +475,11 @@ class ChatConversationApi(Resource):
|
||||
.join(subquery, subquery.c.conversation_id == Conversation.id)
|
||||
.where(
|
||||
or_(
|
||||
Message.query.ilike(keyword_filter),
|
||||
Message.answer.ilike(keyword_filter),
|
||||
Conversation.name.ilike(keyword_filter),
|
||||
Conversation.introduction.ilike(keyword_filter),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter),
|
||||
Message.query.ilike(keyword_filter, escape="\\"),
|
||||
Message.answer.ilike(keyword_filter, escape="\\"),
|
||||
Conversation.name.ilike(keyword_filter, escape="\\"),
|
||||
Conversation.introduction.ilike(keyword_filter, escape="\\"),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"),
|
||||
),
|
||||
)
|
||||
.group_by(Conversation.id)
|
||||
|
||||
@ -30,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.helper import escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
@ -145,6 +146,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if keyword:
|
||||
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
# Search in both content and keywords fields
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
@ -156,15 +159,15 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
.scalar_subquery()
|
||||
),
|
||||
",",
|
||||
).ilike(f"%{keyword}%")
|
||||
).ilike(f"%{escaped_keyword}%", escape="\\")
|
||||
else:
|
||||
# MySQL: Cast JSON to string for pattern matching
|
||||
# MySQL stores Chinese text directly in JSON without Unicode escaping
|
||||
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
|
||||
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\")
|
||||
|
||||
query = query.where(
|
||||
or_(
|
||||
DocumentSegment.content.ilike(f"%{keyword}%"),
|
||||
DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
keywords_condition,
|
||||
)
|
||||
)
|
||||
|
||||
@ -984,9 +984,11 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use simple quote escaping for LIKE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_query = escape_like_pattern(query).replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
search_sql = f"""
|
||||
|
||||
@ -287,11 +287,15 @@ class IrisVector(BaseVector):
|
||||
cursor.execute(sql, (query,))
|
||||
else:
|
||||
# Fallback to LIKE search (inefficient for large datasets)
|
||||
query_pattern = f"%{query}%"
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_query = escape_like_pattern(query)
|
||||
query_pattern = f"%{escaped_query}%"
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE text LIKE ?
|
||||
WHERE text LIKE ? ESCAPE '\\'
|
||||
"""
|
||||
cursor.execute(sql, (query_pattern,))
|
||||
|
||||
|
||||
@ -1198,18 +1198,24 @@ class DatasetRetrieval:
|
||||
|
||||
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
|
||||
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
|
||||
|
||||
case "start with":
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
|
||||
|
||||
case "end with":
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
|
||||
@ -32,6 +32,38 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape special characters in a string for safe use in SQL LIKE patterns.
|
||||
|
||||
This function escapes the special characters used in SQL LIKE patterns:
|
||||
- Backslash (\\) -> \\
|
||||
- Percent (%) -> \\%
|
||||
- Underscore (_) -> \\_
|
||||
|
||||
The escaped pattern can then be safely used in SQL LIKE queries with the
|
||||
ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards.
|
||||
|
||||
Args:
|
||||
pattern: The string pattern to escape
|
||||
|
||||
Returns:
|
||||
Escaped string safe for use in SQL LIKE queries
|
||||
|
||||
Examples:
|
||||
>>> escape_like_pattern("50% discount")
|
||||
'50\\% discount'
|
||||
>>> escape_like_pattern("test_data")
|
||||
'test\\_data'
|
||||
>>> escape_like_pattern("path\\to\\file")
|
||||
'path\\\\to\\\\file'
|
||||
"""
|
||||
if not pattern:
|
||||
return pattern
|
||||
# Escape backslash first, then percent and underscore
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
|
||||
"""
|
||||
Extract tenant_id from Account or EndUser object.
|
||||
|
||||
@ -137,13 +137,16 @@ class AppAnnotationService:
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if keyword:
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
stmt = (
|
||||
select(MessageAnnotation)
|
||||
.where(MessageAnnotation.app_id == app_id)
|
||||
.where(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike(f"%{keyword}%"),
|
||||
MessageAnnotation.content.ilike(f"%{keyword}%"),
|
||||
MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
)
|
||||
)
|
||||
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
|
||||
|
||||
@ -55,8 +55,11 @@ class AppService:
|
||||
if args.get("is_created_by_me", False):
|
||||
filters.append(App.created_by == user_id)
|
||||
if args.get("name"):
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
name = args["name"][:30]
|
||||
filters.append(App.name.ilike(f"%{name}%"))
|
||||
escaped_name = escape_like_pattern(name)
|
||||
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
|
||||
|
||||
@ -218,7 +218,9 @@ class ConversationService:
|
||||
# Apply variable_name filter if provided
|
||||
if variable_name:
|
||||
# Filter using JSON extraction to match variable names case-insensitively
|
||||
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_variable_name = escape_like_pattern(variable_name)
|
||||
# Filter using JSON extraction to match variable names case-insensitively
|
||||
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
|
||||
stmt = stmt.where(
|
||||
|
||||
@ -144,7 +144,8 @@ class DatasetService:
|
||||
query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
|
||||
|
||||
if search:
|
||||
query = query.where(Dataset.name.ilike(f"%{search}%"))
|
||||
escaped_search = helper.escape_like_pattern(search)
|
||||
query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\"))
|
||||
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if tag_ids and len(tag_ids) > 0:
|
||||
@ -3423,7 +3424,8 @@ class SegmentService:
|
||||
.order_by(ChildChunk.position.asc())
|
||||
)
|
||||
if keyword:
|
||||
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
|
||||
escaped_keyword = helper.escape_like_pattern(keyword)
|
||||
query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\"))
|
||||
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
@classmethod
|
||||
@ -3456,7 +3458,8 @@ class SegmentService:
|
||||
query = query.where(DocumentSegment.status.in_(status_list))
|
||||
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
escaped_keyword = helper.escape_like_pattern(keyword)
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"))
|
||||
|
||||
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
|
||||
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
@ -35,7 +35,10 @@ class ExternalDatasetService:
|
||||
.order_by(ExternalKnowledgeApis.created_at.desc())
|
||||
)
|
||||
if search:
|
||||
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_search = escape_like_pattern(search)
|
||||
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\"))
|
||||
|
||||
external_knowledge_apis = db.paginate(
|
||||
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
|
||||
|
||||
@ -19,7 +19,10 @@ class TagService:
|
||||
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
|
||||
)
|
||||
if keyword:
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@ -86,12 +86,19 @@ class WorkflowAppService:
|
||||
# Join to workflow run for filtering when needed.
|
||||
|
||||
if keyword:
|
||||
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
|
||||
escaped_keyword = escape_like_pattern(keyword[:30])
|
||||
keyword_like_val = f"%{escaped_keyword}%"
|
||||
keyword_conditions = [
|
||||
WorkflowRun.inputs.ilike(keyword_like_val),
|
||||
WorkflowRun.outputs.ilike(keyword_like_val),
|
||||
WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"),
|
||||
WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"),
|
||||
# filter keyword by end user session id if created by end user role
|
||||
and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
|
||||
and_(
|
||||
WorkflowRun.created_by_role == "end_user",
|
||||
EndUser.session_id.ilike(keyword_like_val, escape="\\"),
|
||||
),
|
||||
]
|
||||
|
||||
# filter keyword by workflow run id
|
||||
|
||||
@ -444,6 +444,78 @@ class TestAnnotationService:
|
||||
assert total == 1
|
||||
assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content
|
||||
|
||||
def test_get_annotation_list_by_app_id_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _, \) in keyword are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create annotations with special characters in content
|
||||
annotation_with_percent = {
|
||||
"question": "Question with 50% discount",
|
||||
"answer": "Answer about 50% discount offer",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id)
|
||||
|
||||
annotation_with_underscore = {
|
||||
"question": "Question with test_data",
|
||||
"answer": "Answer about test_data value",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id)
|
||||
|
||||
annotation_with_backslash = {
|
||||
"question": "Question with path\\to\\file",
|
||||
"answer": "Answer about path\\to\\file location",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id)
|
||||
|
||||
# Create annotation that should NOT match (contains % but as part of different text)
|
||||
annotation_no_match = {
|
||||
"question": "Question with 100% different",
|
||||
"answer": "Answer about 100% different content",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id)
|
||||
|
||||
# Test 1: Search with % character - should find exact match only
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="50%"
|
||||
)
|
||||
assert total == 1
|
||||
assert len(annotation_list) == 1
|
||||
assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content
|
||||
|
||||
# Test 2: Search with _ character - should find exact match only
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="test_data"
|
||||
)
|
||||
assert total == 1
|
||||
assert len(annotation_list) == 1
|
||||
assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content
|
||||
|
||||
# Test 3: Search with \ character - should find exact match only
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="path\\to\\file"
|
||||
)
|
||||
assert total == 1
|
||||
assert len(annotation_list) == 1
|
||||
assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content
|
||||
|
||||
# Test 4: Search with % should NOT match 100% (verifies escaping works)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="50%"
|
||||
)
|
||||
# Should only find the 50% annotation, not the 100% one
|
||||
assert total == 1
|
||||
assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list)
|
||||
|
||||
def test_get_annotation_list_by_app_id_app_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
@ -7,7 +7,9 @@ from constants.model_template import default_app_templates
|
||||
from models import Account
|
||||
from models.model import App, Site
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
|
||||
|
||||
class TestAppService:
|
||||
@ -71,6 +73,9 @@ class TestAppService:
|
||||
}
|
||||
|
||||
# Create app
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -109,6 +114,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Test different app modes
|
||||
@ -159,6 +167,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
created_app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -194,6 +205,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create multiple apps
|
||||
@ -245,6 +259,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with different modes
|
||||
@ -315,6 +332,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create an app
|
||||
@ -392,6 +412,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -458,6 +481,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -508,6 +534,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -562,6 +591,9 @@ class TestAppService:
|
||||
"icon_background": "#74B9FF",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -617,6 +649,9 @@ class TestAppService:
|
||||
"icon_background": "#A29BFE",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -672,6 +707,9 @@ class TestAppService:
|
||||
"icon_background": "#FD79A8",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -720,6 +758,9 @@ class TestAppService:
|
||||
"icon_background": "#E17055",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -768,6 +809,9 @@ class TestAppService:
|
||||
"icon_background": "#00B894",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -826,6 +870,9 @@ class TestAppService:
|
||||
"icon_background": "#6C5CE7",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -862,6 +909,9 @@ class TestAppService:
|
||||
"icon_background": "#FDCB6E",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -899,6 +949,9 @@ class TestAppService:
|
||||
"icon_background": "#E84393",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -947,8 +1000,132 @@ class TestAppService:
|
||||
"icon_background": "#D63031",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Attempt to create app with invalid mode
|
||||
with pytest.raises(ValueError, match="invalid mode value"):
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
def test_get_apps_with_special_characters_in_name(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test app retrieval with special characters in name search to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _, \) in name search are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with special characters in names
|
||||
app_with_percent = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "App with 50% discount",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
app_with_underscore = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "test_data_app",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
app_with_backslash = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "path\\to\\app",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
# Create app that should NOT match
|
||||
app_no_match = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "100% different",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
# Test 1: Search with % character
|
||||
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "App with 50% discount"
|
||||
|
||||
# Test 2: Search with _ character
|
||||
args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "test_data_app"
|
||||
|
||||
# Test 3: Search with \ character
|
||||
args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "path\\to\\app"
|
||||
|
||||
# Test 4: Search with % should NOT match 100% (verifies escaping works)
|
||||
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert all("50%" in app.name for app in paginated_apps.items)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from unittest.mock import create_autospec, patch
|
||||
|
||||
import pytest
|
||||
@ -312,6 +313,85 @@ class TestTagService:
|
||||
result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent")
|
||||
assert len(result_no_match) == 0
|
||||
|
||||
def test_get_tags_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test tag retrieval with special characters in keyword to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _, \) in keyword are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create tags with special characters in names
|
||||
tag_with_percent = Tag(
|
||||
name="50% discount",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_percent.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_percent)
|
||||
|
||||
tag_with_underscore = Tag(
|
||||
name="test_data_tag",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_underscore.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_underscore)
|
||||
|
||||
tag_with_backslash = Tag(
|
||||
name="path\\to\\tag",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_backslash.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_backslash)
|
||||
|
||||
# Create tag that should NOT match
|
||||
tag_no_match = Tag(
|
||||
name="100% different",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_no_match.id = str(uuid.uuid4())
|
||||
db.session.add(tag_no_match)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Act & Assert: Test 1 - Search with % character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="50%")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "50% discount"
|
||||
|
||||
# Test 2 - Search with _ character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="test_data")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "test_data_tag"
|
||||
|
||||
# Test 3 - Search with \ character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "path\\to\\tag"
|
||||
|
||||
# Test 4 - Search with % should NOT match 100% (verifies escaping works)
|
||||
result = TagService.get_tags("app", tenant.id, keyword="50%")
|
||||
assert len(result) == 1
|
||||
assert all("50%" in item.name for item in result)
|
||||
|
||||
def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test tag retrieval when no tags exist.
|
||||
|
||||
@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatorUserRole
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
@ -86,6 +88,9 @@ class TestWorkflowAppService:
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -147,6 +152,9 @@ class TestWorkflowAppService:
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -308,6 +316,156 @@ class TestWorkflowAppService:
|
||||
assert result_no_match["total"] == 0
|
||||
assert len(result_no_match["data"]) == 0
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _) in keyword are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = WorkflowAppService()
|
||||
|
||||
# Test 1: Search with % character
|
||||
workflow_run_1 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
status="succeeded",
|
||||
inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}),
|
||||
outputs=json.dumps({"result": "50% discount applied", "status": "success"}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_1)
|
||||
db.session.flush()
|
||||
|
||||
workflow_app_log_1 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_1.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
workflow_app_log_1.id = str(uuid.uuid4())
|
||||
workflow_app_log_1.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_1)
|
||||
db.session.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
|
||||
)
|
||||
# Should find the workflow_run_1 entry
|
||||
assert result["total"] >= 1
|
||||
assert len(result["data"]) >= 1
|
||||
assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"])
|
||||
|
||||
# Test 2: Search with _ character
|
||||
workflow_run_2 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
status="succeeded",
|
||||
inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}),
|
||||
outputs=json.dumps({"result": "test_data_value found", "status": "success"}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_2)
|
||||
db.session.flush()
|
||||
|
||||
workflow_app_log_2 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_2.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
workflow_app_log_2.id = str(uuid.uuid4())
|
||||
workflow_app_log_2.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_2)
|
||||
db.session.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20
|
||||
)
|
||||
# Should find the workflow_run_2 entry
|
||||
assert result["total"] >= 1
|
||||
assert len(result["data"]) >= 1
|
||||
assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"])
|
||||
|
||||
# Test 3: Search with % should NOT match 100% (verifies escaping works correctly)
|
||||
workflow_run_4 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
status="succeeded",
|
||||
inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}),
|
||||
outputs=json.dumps({"result": "100% different result", "status": "success"}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_4)
|
||||
db.session.flush()
|
||||
|
||||
workflow_app_log_4 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_4.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
workflow_app_log_4.id = str(uuid.uuid4())
|
||||
workflow_app_log_4.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_4)
|
||||
db.session.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
|
||||
)
|
||||
# Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4)
|
||||
# This verifies that escaping works correctly - 50% should not match 100%
|
||||
assert result["total"] >= 1
|
||||
assert len(result["data"]) >= 1
|
||||
# Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different)
|
||||
found_run_ids = [log.workflow_run_id for log in result["data"]]
|
||||
assert workflow_run_1.id in found_run_ids
|
||||
assert workflow_run_4.id not in found_run_ids
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_status_filter(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from libs.helper import extract_tenant_id
|
||||
from libs.helper import escape_like_pattern, extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
@ -63,3 +63,51 @@ class TestExtractTenantId:
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(dict_user)
|
||||
|
||||
|
||||
class TestEscapeLikePattern:
|
||||
"""Test cases for the escape_like_pattern utility function."""
|
||||
|
||||
def test_escape_percent_character(self):
|
||||
"""Test escaping percent character."""
|
||||
result = escape_like_pattern("50% discount")
|
||||
assert result == "50\\% discount"
|
||||
|
||||
def test_escape_underscore_character(self):
|
||||
"""Test escaping underscore character."""
|
||||
result = escape_like_pattern("test_data")
|
||||
assert result == "test\\_data"
|
||||
|
||||
def test_escape_backslash_character(self):
|
||||
"""Test escaping backslash character."""
|
||||
result = escape_like_pattern("path\\to\\file")
|
||||
assert result == "path\\\\to\\\\file"
|
||||
|
||||
def test_escape_combined_special_characters(self):
|
||||
"""Test escaping multiple special characters together."""
|
||||
result = escape_like_pattern("file_50%\\path")
|
||||
assert result == "file\\_50\\%\\\\path"
|
||||
|
||||
def test_escape_empty_string(self):
|
||||
"""Test escaping empty string returns empty string."""
|
||||
result = escape_like_pattern("")
|
||||
assert result == ""
|
||||
|
||||
def test_escape_none_handling(self):
|
||||
"""Test escaping None returns None (falsy check handles it)."""
|
||||
# The function checks `if not pattern`, so None is falsy and returns as-is
|
||||
result = escape_like_pattern(None)
|
||||
assert result is None
|
||||
|
||||
def test_escape_normal_string_no_change(self):
|
||||
"""Test that normal strings without special characters are unchanged."""
|
||||
result = escape_like_pattern("normal text")
|
||||
assert result == "normal text"
|
||||
|
||||
def test_escape_order_matters(self):
|
||||
"""Test that backslash is escaped first to prevent double escaping."""
|
||||
# If we escape % first, then escape \, we might get wrong results
|
||||
# This test ensures the order is correct: \ first, then % and _
|
||||
result = escape_like_pattern("test\\%_value")
|
||||
# Should be: test\\\%\_value
|
||||
assert result == "test\\\\\\%\\_value"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user