"""OSS/URL CSV source (v2). - Validates incoming URL to reduce SSRF risk (allowlist + IP checks) - Downloads CSV to a local temporary file for analysis This module is intentionally small and dependency-light. """ from __future__ import annotations import ipaddress import logging import os import socket import tempfile from dataclasses import dataclass from typing import Optional from urllib.parse import urlsplit import requests from app.core.config import settings logger = logging.getLogger(__name__) @dataclass(frozen=True) class DownloadedCsv: local_path: str source_host: str source_name: str etag: Optional[str] = None last_modified: Optional[str] = None class UrlValidationError(ValueError): pass def _is_ip_allowed(ip_str: str) -> bool: ip = ipaddress.ip_address(ip_str) if settings.V2_ALLOW_PRIVATE_NETWORKS: return True # Block loopback/link-local/private/multicast/unspecified/reserved if ( ip.is_loopback or ip.is_private or ip.is_link_local or ip.is_multicast or ip.is_unspecified or ip.is_reserved ): return False return True def validate_source_url(source_url: str) -> tuple[str, str]: """Validate URL and return (host, source_name).""" if not source_url or not isinstance(source_url, str): raise UrlValidationError("source_url 不能为空") parts = urlsplit(source_url) if parts.scheme not in {"https", "http"}: raise UrlValidationError("仅支持 http/https URL") if parts.scheme == "http" and not settings.V2_ALLOW_HTTP: raise UrlValidationError("不允许 http;请使用 https 或开启 V2_ALLOW_HTTP") if not parts.netloc: raise UrlValidationError("URL 缺少 host") # Disallow URLs with userinfo if "@" in parts.netloc: raise UrlValidationError("URL 不允许包含用户名/密码") host = parts.hostname if not host: raise UrlValidationError("无法解析 URL host") # Optional allowlist if settings.V2_ALLOWED_HOSTS: allowed = {h.lower() for h in settings.V2_ALLOWED_HOSTS} if host.lower() not in allowed: raise UrlValidationError(f"host 不在白名单: {host}") # Resolve host -> IP and block private/loopback, unless explicitly allowed. try: addr_info = socket.getaddrinfo(host, None) except socket.gaierror as e: raise UrlValidationError(f"DNS 解析失败: {host} ({e})") from e for family, _type, _proto, _canonname, sockaddr in addr_info: ip_str = None if family == socket.AF_INET: ip_str = str(sockaddr[0]) elif family == socket.AF_INET6: ip_str = str(sockaddr[0]) if ip_str and not _is_ip_allowed(ip_str): raise UrlValidationError(f"host 解析到不允许的 IP: {ip_str}") source_name = os.path.basename(parts.path) or "data.csv" return host, source_name def download_csv_to_tempfile(source_url: str, *, suffix: str = ".csv") -> DownloadedCsv: """Download URL content to a temp file and return local path + meta.""" host, source_name = validate_source_url(source_url) # Create temp file inside configured TEMP_DIR for easier ops/observability settings.TEMP_DIR.mkdir(exist_ok=True) tmp = tempfile.NamedTemporaryFile( mode="wb", suffix=suffix, dir=str(settings.TEMP_DIR), delete=False, ) try: timeout = (settings.V2_CONNECT_TIMEOUT_SECONDS, settings.V2_DOWNLOAD_TIMEOUT_SECONDS) with requests.get(source_url, stream=True, timeout=timeout) as resp: resp.raise_for_status() etag = resp.headers.get("ETag") last_modified = resp.headers.get("Last-Modified") for chunk in resp.iter_content(chunk_size=1024 * 1024): if not chunk: continue tmp.write(chunk) tmp.flush() tmp.close() if os.path.getsize(tmp.name) <= 0: raise UrlValidationError("下载内容为空") return DownloadedCsv( local_path=tmp.name, source_host=host, source_name=source_name, etag=etag, last_modified=last_modified, ) except Exception: try: tmp.close() except Exception: pass try: os.unlink(tmp.name) except Exception: pass raise