diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 120fb73cdb..c0fefef3d0 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -104,6 +104,8 @@ def download(f: File, /): ): return _download_file_content(f.storage_key) elif f.transfer_method == FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() return response.content @@ -134,6 +136,8 @@ def _download_file_content(path: str, /): def _get_encoded_string(f: File, /): match f.transfer_method: case FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() data = response.content diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 128c64ff2c..ddccfbaf45 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -4,8 +4,10 @@ Proxy requests to avoid SSRF import logging import time +from typing import Any, TypeAlias import httpx +from pydantic import TypeAdapter, ValidationError from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client @@ -18,6 +20,9 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] +Headers: TypeAlias = dict[str, str] +_HEADERS_ADAPTER = TypeAdapter(Headers) + _SSL_VERIFIED_POOL_KEY = "ssrf:verified" _SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified" _SSRF_CLIENT_LIMITS = httpx.Limits( @@ -76,7 +81,7 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client: ) -def _get_user_provided_host_header(headers: dict | None) -> str | None: +def _get_user_provided_host_header(headers: Headers | None) -> str | None: """ Extract the user-provided Host header from the headers dict. @@ -92,7 +97,7 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None: return None -def _inject_trace_headers(headers: dict | None) -> dict: +def _inject_trace_headers(headers: Headers | None) -> Headers: """ Inject W3C traceparent header for distributed tracing. @@ -125,7 +130,7 @@ def _inject_trace_headers(headers: dict | None) -> dict: return headers -def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: # Convert requests-style allow_redirects to httpx-style follow_redirects if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") @@ -142,10 +147,15 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): # prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY) + if not isinstance(verify_option, bool): + raise ValueError("ssl_verify must be a boolean") client = _get_ssrf_client(verify_option) # Inject traceparent header for distributed tracing (when OTEL is not enabled) - headers = kwargs.get("headers") or {} + try: + headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {}) + except ValidationError as e: + raise ValueError("headers must be a mapping of string keys to string values") from e headers = _inject_trace_headers(headers) kwargs["headers"] = headers @@ -198,25 +208,25 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") -def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("GET", url, max_retries=max_retries, **kwargs) -def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("POST", url, max_retries=max_retries, **kwargs) -def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("PUT", url, max_retries=max_retries, **kwargs) -def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("PATCH", url, max_retries=max_retries, **kwargs) -def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("DELETE", url, max_retries=max_retries, **kwargs) -def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("HEAD", url, max_retries=max_retries, **kwargs) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 511f5a698d..1ddbfc5864 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,4 +1,7 @@ -"""Abstract interface for document loader implementations.""" +"""Word (.docx) document extractor used for RAG ingestion. + +Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). +""" import logging import mimetypes @@ -8,7 +11,6 @@ import tempfile import uuid from urllib.parse import urlparse -import httpx from docx import Document as DocxDocument from docx.oxml.ns import qn from docx.text.run import Run @@ -44,7 +46,7 @@ class WordExtractor(BaseExtractor): # If the file is a web path, download it to a temporary file, and use that if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): - response = httpx.get(self.file_path, timeout=None) + response = ssrf_proxy.get(self.file_path) if response.status_code != 200: response.close() @@ -55,6 +57,7 @@ class WordExtractor(BaseExtractor): self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115 try: self.temp_file.write(response.content) + self.temp_file.flush() finally: response.close() self.file_path = self.temp_file.name diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index f9e59a5f05..0792ada194 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,7 +1,9 @@ """Primarily used for testing merged cell scenarios""" +import io import os import tempfile +from pathlib import Path from types import SimpleNamespace from docx import Document @@ -56,6 +58,42 @@ def test_parse_row(): assert extractor._parse_row(row, {}, 3) == gt[idx] +def test_init_downloads_via_ssrf_proxy(monkeypatch): + doc = Document() + doc.add_paragraph("hello") + buf = io.BytesIO() + doc.save(buf) + docx_bytes = buf.getvalue() + + calls: list[tuple[str, object]] = [] + + class FakeResponse: + status_code = 200 + content = docx_bytes + + def close(self) -> None: + calls.append(("close", None)) + + def fake_get(url: str, **kwargs): + calls.append(("get", (url, kwargs))) + return FakeResponse() + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + + extractor = WordExtractor("https://example.com/test.docx", "tenant_id", "user_id") + try: + assert calls + assert calls[0][0] == "get" + url, kwargs = calls[0][1] + assert url == "https://example.com/test.docx" + assert kwargs.get("timeout") is None + assert extractor.web_path == "https://example.com/test.docx" + assert extractor.file_path != extractor.web_path + assert Path(extractor.file_path).read_bytes() == docx_bytes + finally: + extractor.temp_file.close() + + def test_extract_images_from_docx(monkeypatch): external_bytes = b"ext-bytes" internal_bytes = b"int-bytes"