160 lines
4.5 KiB
Python
160 lines
4.5 KiB
Python
"""OSS/URL CSV source (v2).
|
||
|
||
- Validates incoming URL to reduce SSRF risk (allowlist + IP checks)
|
||
- Downloads CSV to a local temporary file for analysis
|
||
|
||
This module is intentionally small and dependency-light.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import ipaddress
|
||
import logging
|
||
import os
|
||
import socket
|
||
import tempfile
|
||
from dataclasses import dataclass
|
||
from typing import Optional
|
||
from urllib.parse import urlsplit
|
||
|
||
import requests
|
||
|
||
from app.core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class DownloadedCsv:
|
||
local_path: str
|
||
source_host: str
|
||
source_name: str
|
||
etag: Optional[str] = None
|
||
last_modified: Optional[str] = None
|
||
|
||
|
||
class UrlValidationError(ValueError):
|
||
pass
|
||
|
||
|
||
def _is_ip_allowed(ip_str: str) -> bool:
|
||
ip = ipaddress.ip_address(ip_str)
|
||
|
||
if settings.V2_ALLOW_PRIVATE_NETWORKS:
|
||
return True
|
||
|
||
# Block loopback/link-local/private/multicast/unspecified/reserved
|
||
if (
|
||
ip.is_loopback
|
||
or ip.is_private
|
||
or ip.is_link_local
|
||
or ip.is_multicast
|
||
or ip.is_unspecified
|
||
or ip.is_reserved
|
||
):
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
def validate_source_url(source_url: str) -> tuple[str, str]:
|
||
"""Validate URL and return (host, source_name)."""
|
||
|
||
if not source_url or not isinstance(source_url, str):
|
||
raise UrlValidationError("source_url 不能为空")
|
||
|
||
parts = urlsplit(source_url)
|
||
|
||
if parts.scheme not in {"https", "http"}:
|
||
raise UrlValidationError("仅支持 http/https URL")
|
||
|
||
if parts.scheme == "http" and not settings.V2_ALLOW_HTTP:
|
||
raise UrlValidationError("不允许 http;请使用 https 或开启 V2_ALLOW_HTTP")
|
||
|
||
if not parts.netloc:
|
||
raise UrlValidationError("URL 缺少 host")
|
||
|
||
# Disallow URLs with userinfo
|
||
if "@" in parts.netloc:
|
||
raise UrlValidationError("URL 不允许包含用户名/密码")
|
||
|
||
host = parts.hostname
|
||
if not host:
|
||
raise UrlValidationError("无法解析 URL host")
|
||
|
||
# Optional allowlist
|
||
if settings.V2_ALLOWED_HOSTS:
|
||
allowed = {h.lower() for h in settings.V2_ALLOWED_HOSTS}
|
||
if host.lower() not in allowed:
|
||
raise UrlValidationError(f"host 不在白名单: {host}")
|
||
|
||
# Resolve host -> IP and block private/loopback, unless explicitly allowed.
|
||
try:
|
||
addr_info = socket.getaddrinfo(host, None)
|
||
except socket.gaierror as e:
|
||
raise UrlValidationError(f"DNS 解析失败: {host} ({e})") from e
|
||
|
||
for family, _type, _proto, _canonname, sockaddr in addr_info:
|
||
ip_str = None
|
||
if family == socket.AF_INET:
|
||
ip_str = str(sockaddr[0])
|
||
elif family == socket.AF_INET6:
|
||
ip_str = str(sockaddr[0])
|
||
if ip_str and not _is_ip_allowed(ip_str):
|
||
raise UrlValidationError(f"host 解析到不允许的 IP: {ip_str}")
|
||
|
||
source_name = os.path.basename(parts.path) or "data.csv"
|
||
return host, source_name
|
||
|
||
|
||
def download_csv_to_tempfile(source_url: str, *, suffix: str = ".csv") -> DownloadedCsv:
|
||
"""Download URL content to a temp file and return local path + meta."""
|
||
|
||
host, source_name = validate_source_url(source_url)
|
||
|
||
# Create temp file inside configured TEMP_DIR for easier ops/observability
|
||
settings.TEMP_DIR.mkdir(exist_ok=True)
|
||
tmp = tempfile.NamedTemporaryFile(
|
||
mode="wb",
|
||
suffix=suffix,
|
||
dir=str(settings.TEMP_DIR),
|
||
delete=False,
|
||
)
|
||
|
||
try:
|
||
timeout = (settings.V2_CONNECT_TIMEOUT_SECONDS, settings.V2_DOWNLOAD_TIMEOUT_SECONDS)
|
||
with requests.get(source_url, stream=True, timeout=timeout) as resp:
|
||
resp.raise_for_status()
|
||
etag = resp.headers.get("ETag")
|
||
last_modified = resp.headers.get("Last-Modified")
|
||
|
||
for chunk in resp.iter_content(chunk_size=1024 * 1024):
|
||
if not chunk:
|
||
continue
|
||
tmp.write(chunk)
|
||
|
||
tmp.flush()
|
||
tmp.close()
|
||
|
||
if os.path.getsize(tmp.name) <= 0:
|
||
raise UrlValidationError("下载内容为空")
|
||
|
||
return DownloadedCsv(
|
||
local_path=tmp.name,
|
||
source_host=host,
|
||
source_name=source_name,
|
||
etag=etag,
|
||
last_modified=last_modified,
|
||
)
|
||
|
||
except Exception:
|
||
try:
|
||
tmp.close()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
os.unlink(tmp.name)
|
||
except Exception:
|
||
pass
|
||
raise
|