Json-Python-Server/app/services/oss_csv_source.py

160 lines
4.5 KiB
Python
Raw Normal View History

2026-01-29 18:18:32 +08:00
"""OSS/URL CSV source (v2).
- Validates incoming URL to reduce SSRF risk (allowlist + IP checks)
- Downloads CSV to a local temporary file for analysis
This module is intentionally small and dependency-light.
"""
from __future__ import annotations
import ipaddress
import logging
import os
import socket
import tempfile
from dataclasses import dataclass
from typing import Optional
from urllib.parse import urlsplit
import requests
from app.core.config import settings
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class DownloadedCsv:
local_path: str
source_host: str
source_name: str
etag: Optional[str] = None
last_modified: Optional[str] = None
class UrlValidationError(ValueError):
pass
def _is_ip_allowed(ip_str: str) -> bool:
ip = ipaddress.ip_address(ip_str)
if settings.V2_ALLOW_PRIVATE_NETWORKS:
return True
# Block loopback/link-local/private/multicast/unspecified/reserved
if (
ip.is_loopback
or ip.is_private
or ip.is_link_local
or ip.is_multicast
or ip.is_unspecified
or ip.is_reserved
):
return False
return True
def validate_source_url(source_url: str) -> tuple[str, str]:
"""Validate URL and return (host, source_name)."""
if not source_url or not isinstance(source_url, str):
raise UrlValidationError("source_url 不能为空")
parts = urlsplit(source_url)
if parts.scheme not in {"https", "http"}:
raise UrlValidationError("仅支持 http/https URL")
if parts.scheme == "http" and not settings.V2_ALLOW_HTTP:
raise UrlValidationError("不允许 http请使用 https 或开启 V2_ALLOW_HTTP")
if not parts.netloc:
raise UrlValidationError("URL 缺少 host")
# Disallow URLs with userinfo
if "@" in parts.netloc:
raise UrlValidationError("URL 不允许包含用户名/密码")
host = parts.hostname
if not host:
raise UrlValidationError("无法解析 URL host")
# Optional allowlist
if settings.V2_ALLOWED_HOSTS:
allowed = {h.lower() for h in settings.V2_ALLOWED_HOSTS}
if host.lower() not in allowed:
raise UrlValidationError(f"host 不在白名单: {host}")
# Resolve host -> IP and block private/loopback, unless explicitly allowed.
try:
addr_info = socket.getaddrinfo(host, None)
except socket.gaierror as e:
raise UrlValidationError(f"DNS 解析失败: {host} ({e})") from e
for family, _type, _proto, _canonname, sockaddr in addr_info:
ip_str = None
if family == socket.AF_INET:
ip_str = str(sockaddr[0])
elif family == socket.AF_INET6:
ip_str = str(sockaddr[0])
if ip_str and not _is_ip_allowed(ip_str):
raise UrlValidationError(f"host 解析到不允许的 IP: {ip_str}")
source_name = os.path.basename(parts.path) or "data.csv"
return host, source_name
def download_csv_to_tempfile(source_url: str, *, suffix: str = ".csv") -> DownloadedCsv:
"""Download URL content to a temp file and return local path + meta."""
host, source_name = validate_source_url(source_url)
# Create temp file inside configured TEMP_DIR for easier ops/observability
settings.TEMP_DIR.mkdir(exist_ok=True)
tmp = tempfile.NamedTemporaryFile(
mode="wb",
suffix=suffix,
dir=str(settings.TEMP_DIR),
delete=False,
)
try:
timeout = (settings.V2_CONNECT_TIMEOUT_SECONDS, settings.V2_DOWNLOAD_TIMEOUT_SECONDS)
with requests.get(source_url, stream=True, timeout=timeout) as resp:
resp.raise_for_status()
etag = resp.headers.get("ETag")
last_modified = resp.headers.get("Last-Modified")
for chunk in resp.iter_content(chunk_size=1024 * 1024):
if not chunk:
continue
tmp.write(chunk)
tmp.flush()
tmp.close()
if os.path.getsize(tmp.name) <= 0:
raise UrlValidationError("下载内容为空")
return DownloadedCsv(
local_path=tmp.name,
source_host=host,
source_name=source_name,
etag=etag,
last_modified=last_modified,
)
except Exception:
try:
tmp.close()
except Exception:
pass
try:
os.unlink(tmp.name)
except Exception:
pass
raise