# main.py import os import json import logging from pathlib import Path from typing import Dict, Any, List import asyncio import httpx import aiofiles from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import uvicorn # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 常量定义 DOWNLOADS_DIR = "downloads" MAX_FILENAME_LENGTH = 100 INVALID_FILENAME_CHARS = '<>:"/\\|?*' MAX_CONCURRENT_DOWNLOADS = 5 # 最大并发下载数 DOWNLOAD_TIMEOUT = 30 # 下载超时时间(秒) # FastAPI应用 app = FastAPI(title="eh-v2") # 全局变量用于跟踪下载状态 download_status: Dict[str, Dict[str, Any]] = {} # 数据模型 class SaveDataRequest(BaseModel): url: str title: str all_images: Dict[str, str] total_images: int class GalleryInfo(BaseModel): title: str path: str total_images: int downloaded_images: int class DownloadStatusResponse(BaseModel): status: str message: str downloaded: int total: int current_progress: float # 工具函数 def setup_downloads_directory() -> Path: """创建并返回下载目录路径""" downloads_path = Path(DOWNLOADS_DIR) downloads_path.mkdir(exist_ok=True) logger.info(f"下载目录已准备: {downloads_path.absolute()}") return downloads_path def sanitize_filename(filename: str) -> str: """清理文件名,移除非法字符并限制长度""" sanitized = filename for char in INVALID_FILENAME_CHARS: sanitized = sanitized.replace(char, '_') # 限制文件名长度 if len(sanitized) > MAX_FILENAME_LENGTH: sanitized = sanitized[:MAX_FILENAME_LENGTH] return sanitized def create_title_directory(base_path: Path, title: str) -> Path: """创建标题对应的目录""" safe_title = sanitize_filename(title) title_dir = base_path / safe_title title_dir.mkdir(exist_ok=True) logger.info(f"创建标题目录: {title_dir}") return title_dir async def save_data_to_file(file_path: Path, data: Dict[str, Any]) -> None: """异步保存数据到JSON文件""" async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: await f.write(json.dumps(data, ensure_ascii=False, indent=2)) def get_all_galleries() -> List[GalleryInfo]: """获取所有画廊信息""" galleries = [] downloads_path = Path(DOWNLOADS_DIR) if not downloads_path.exists(): return galleries for gallery_dir in downloads_path.iterdir(): if gallery_dir.is_dir(): data_file = gallery_dir / "data.json" if data_file.exists(): try: with open(data_file, 'r', encoding='utf-8') as f: data = json.load(f) # 计算已下载的图片数量 downloaded_count = 0 if 'all_images' in data: for filename, url in data['all_images'].items(): image_path = gallery_dir / filename if image_path.exists(): downloaded_count += 1 galleries.append(GalleryInfo( title=data.get('title', gallery_dir.name), path=str(gallery_dir), total_images=data.get('total_images', 0), downloaded_images=downloaded_count )) except Exception as e: logger.error(f"读取画廊数据失败 {gallery_dir}: {e}") return galleries async def download_single_image(client: httpx.AsyncClient, url: str, file_path: Path, semaphore: asyncio.Semaphore) -> bool: """下载单张图片 - 精简版""" async with semaphore: try: logger.info(f"开始下载: {url}") if file_path.exists(): logger.info(f"文件已存在: {file_path}") return True # 第一步:获取中间页面 response = await client.get(url, timeout=DOWNLOAD_TIMEOUT) response.raise_for_status() # 第二步:提取真实图片URL import re match = re.search(r'img id="img" src="(.*?)"', response.text) if not match: logger.error(f"无法提取图片URL: {url}") return False real_img_url = match.group(1) logger.info(f"真实URL: {real_img_url}") # 第三步:下载图片 img_response = await client.get(real_img_url, timeout=DOWNLOAD_TIMEOUT) img_response.raise_for_status() # 保存图片 async with aiofiles.open(file_path, 'wb') as f: await f.write(img_response.content) logger.info(f"下载完成: {file_path}") return True except Exception as e: logger.error(f"下载失败 {url}: {e}") return False async def download_gallery_images(title: str) -> DownloadStatusResponse: """下载指定画廊的所有图片""" safe_title = sanitize_filename(title) gallery_path = downloads_path / safe_title data_file = gallery_path / "data.json" if not data_file.exists(): return DownloadStatusResponse( status="error", message="画廊数据文件不存在", downloaded=0, total=0, current_progress=0.0 ) try: # 读取画廊数据 async with aiofiles.open(data_file, 'r', encoding='utf-8') as f: content = await f.read() data = json.loads(content) all_images = data.get('all_images', {}) total_images = len(all_images) if total_images == 0: return DownloadStatusResponse( status="error", message="没有可下载的图片", downloaded=0, total=0, current_progress=0.0 ) # 初始化下载状态 download_status[title] = { "downloaded": 0, "total": total_images, "status": "downloading" } logger.info(f"开始下载画廊 '{title}',共 {total_images} 张图片") # 创建信号量限制并发数 semaphore = asyncio.Semaphore(MAX_CONCURRENT_DOWNLOADS) # 使用异步HTTP客户端 async with httpx.AsyncClient( headers={ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' }, follow_redirects=True ) as client: # 准备下载任务 tasks = [] for filename, url in all_images.items(): image_path = gallery_path / filename # 如果图片已存在,跳过下载但计入完成数量 if image_path.exists(): download_status[title]["downloaded"] += 1 continue task = download_single_image(client, url, image_path, semaphore) tasks.append(task) # 批量执行下载任务 if tasks: results = await asyncio.gather(*tasks, return_exceptions=True) # 统计成功下载的数量 successful_downloads = sum(1 for result in results if result is True) download_status[title]["downloaded"] += successful_downloads # 更新最终状态 downloaded_count = download_status[title]["downloaded"] progress = (downloaded_count / total_images) * 100 if downloaded_count == total_images: download_status[title]["status"] = "completed" message = f"下载完成!共下载 {downloaded_count}/{total_images} 张图片" logger.info(f"画廊 '{title}' {message}") else: download_status[title]["status"] = "partial" message = f"部分完成!下载 {downloaded_count}/{total_images} 张图片" logger.warning(f"画廊 '{title}' {message}") return DownloadStatusResponse( status="success", message=message, downloaded=downloaded_count, total=total_images, current_progress=progress ) except Exception as e: logger.error(f"下载画廊 '{title}' 时发生错误: {e}") download_status[title] = { "status": "error", "message": str(e) } return DownloadStatusResponse( status="error", message=f"下载失败: {str(e)}", downloaded=0, total=0, current_progress=0.0 ) async def download_all_pending_galleries(): """下载所有未完成的画廊""" galleries = get_all_galleries() pending_galleries = [g for g in galleries if g.downloaded_images < g.total_images] if not pending_galleries: logger.info("没有待下载的画廊") return logger.info(f"开始批量下载 {len(pending_galleries)} 个画廊") for gallery in pending_galleries: if gallery.downloaded_images < gallery.total_images: logger.info(f"开始下载画廊: {gallery.title}") result = await download_gallery_images(gallery.title) if result.status == "success": logger.info(f"画廊 '{gallery.title}' 下载完成: {result.message}") else: logger.error(f"画廊 '{gallery.title}' 下载失败: {result.message}") # 添加延迟避免请求过于频繁 await asyncio.sleep(1) logger.info("批量下载任务完成") # 初始化 downloads_path = setup_downloads_directory() # API路由 @app.get("/", response_class=HTMLResponse) async def read_gallery_manager(): """画廊管理页面""" return """
管理您的画廊下载任务
点击"读取文件夹"按钮加载数据