160 lines
4.5 KiB
Python
160 lines
4.5 KiB
Python
|
|
"""OSS/URL CSV source (v2).
|
|||
|
|
|
|||
|
|
- Validates incoming URL to reduce SSRF risk (allowlist + IP checks)
|
|||
|
|
- Downloads CSV to a local temporary file for analysis
|
|||
|
|
|
|||
|
|
This module is intentionally small and dependency-light.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import ipaddress
|
|||
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
import socket
|
|||
|
|
import tempfile
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Optional
|
|||
|
|
from urllib.parse import urlsplit
|
|||
|
|
|
|||
|
|
import requests
|
|||
|
|
|
|||
|
|
from app.core.config import settings
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(frozen=True)
|
|||
|
|
class DownloadedCsv:
|
|||
|
|
local_path: str
|
|||
|
|
source_host: str
|
|||
|
|
source_name: str
|
|||
|
|
etag: Optional[str] = None
|
|||
|
|
last_modified: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UrlValidationError(ValueError):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _is_ip_allowed(ip_str: str) -> bool:
|
|||
|
|
ip = ipaddress.ip_address(ip_str)
|
|||
|
|
|
|||
|
|
if settings.V2_ALLOW_PRIVATE_NETWORKS:
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
# Block loopback/link-local/private/multicast/unspecified/reserved
|
|||
|
|
if (
|
|||
|
|
ip.is_loopback
|
|||
|
|
or ip.is_private
|
|||
|
|
or ip.is_link_local
|
|||
|
|
or ip.is_multicast
|
|||
|
|
or ip.is_unspecified
|
|||
|
|
or ip.is_reserved
|
|||
|
|
):
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
def validate_source_url(source_url: str) -> tuple[str, str]:
|
|||
|
|
"""Validate URL and return (host, source_name)."""
|
|||
|
|
|
|||
|
|
if not source_url or not isinstance(source_url, str):
|
|||
|
|
raise UrlValidationError("source_url 不能为空")
|
|||
|
|
|
|||
|
|
parts = urlsplit(source_url)
|
|||
|
|
|
|||
|
|
if parts.scheme not in {"https", "http"}:
|
|||
|
|
raise UrlValidationError("仅支持 http/https URL")
|
|||
|
|
|
|||
|
|
if parts.scheme == "http" and not settings.V2_ALLOW_HTTP:
|
|||
|
|
raise UrlValidationError("不允许 http;请使用 https 或开启 V2_ALLOW_HTTP")
|
|||
|
|
|
|||
|
|
if not parts.netloc:
|
|||
|
|
raise UrlValidationError("URL 缺少 host")
|
|||
|
|
|
|||
|
|
# Disallow URLs with userinfo
|
|||
|
|
if "@" in parts.netloc:
|
|||
|
|
raise UrlValidationError("URL 不允许包含用户名/密码")
|
|||
|
|
|
|||
|
|
host = parts.hostname
|
|||
|
|
if not host:
|
|||
|
|
raise UrlValidationError("无法解析 URL host")
|
|||
|
|
|
|||
|
|
# Optional allowlist
|
|||
|
|
if settings.V2_ALLOWED_HOSTS:
|
|||
|
|
allowed = {h.lower() for h in settings.V2_ALLOWED_HOSTS}
|
|||
|
|
if host.lower() not in allowed:
|
|||
|
|
raise UrlValidationError(f"host 不在白名单: {host}")
|
|||
|
|
|
|||
|
|
# Resolve host -> IP and block private/loopback, unless explicitly allowed.
|
|||
|
|
try:
|
|||
|
|
addr_info = socket.getaddrinfo(host, None)
|
|||
|
|
except socket.gaierror as e:
|
|||
|
|
raise UrlValidationError(f"DNS 解析失败: {host} ({e})") from e
|
|||
|
|
|
|||
|
|
for family, _type, _proto, _canonname, sockaddr in addr_info:
|
|||
|
|
ip_str = None
|
|||
|
|
if family == socket.AF_INET:
|
|||
|
|
ip_str = str(sockaddr[0])
|
|||
|
|
elif family == socket.AF_INET6:
|
|||
|
|
ip_str = str(sockaddr[0])
|
|||
|
|
if ip_str and not _is_ip_allowed(ip_str):
|
|||
|
|
raise UrlValidationError(f"host 解析到不允许的 IP: {ip_str}")
|
|||
|
|
|
|||
|
|
source_name = os.path.basename(parts.path) or "data.csv"
|
|||
|
|
return host, source_name
|
|||
|
|
|
|||
|
|
|
|||
|
|
def download_csv_to_tempfile(source_url: str, *, suffix: str = ".csv") -> DownloadedCsv:
|
|||
|
|
"""Download URL content to a temp file and return local path + meta."""
|
|||
|
|
|
|||
|
|
host, source_name = validate_source_url(source_url)
|
|||
|
|
|
|||
|
|
# Create temp file inside configured TEMP_DIR for easier ops/observability
|
|||
|
|
settings.TEMP_DIR.mkdir(exist_ok=True)
|
|||
|
|
tmp = tempfile.NamedTemporaryFile(
|
|||
|
|
mode="wb",
|
|||
|
|
suffix=suffix,
|
|||
|
|
dir=str(settings.TEMP_DIR),
|
|||
|
|
delete=False,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
timeout = (settings.V2_CONNECT_TIMEOUT_SECONDS, settings.V2_DOWNLOAD_TIMEOUT_SECONDS)
|
|||
|
|
with requests.get(source_url, stream=True, timeout=timeout) as resp:
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
etag = resp.headers.get("ETag")
|
|||
|
|
last_modified = resp.headers.get("Last-Modified")
|
|||
|
|
|
|||
|
|
for chunk in resp.iter_content(chunk_size=1024 * 1024):
|
|||
|
|
if not chunk:
|
|||
|
|
continue
|
|||
|
|
tmp.write(chunk)
|
|||
|
|
|
|||
|
|
tmp.flush()
|
|||
|
|
tmp.close()
|
|||
|
|
|
|||
|
|
if os.path.getsize(tmp.name) <= 0:
|
|||
|
|
raise UrlValidationError("下载内容为空")
|
|||
|
|
|
|||
|
|
return DownloadedCsv(
|
|||
|
|
local_path=tmp.name,
|
|||
|
|
source_host=host,
|
|||
|
|
source_name=source_name,
|
|||
|
|
etag=etag,
|
|||
|
|
last_modified=last_modified,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception:
|
|||
|
|
try:
|
|||
|
|
tmp.close()
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
try:
|
|||
|
|
os.unlink(tmp.name)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
raise
|