From ba04fe42fcf861b408a2f1d5fb03ee3aab02c549 Mon Sep 17 00:00:00 2001 From: hello-dd-code Date: Fri, 3 Apr 2026 16:06:28 +0800 Subject: [PATCH] Add maxlaw PC spider and shared proxy limiter --- common_sites/dls.py | 17 +- common_sites/dls_fresh.py | 6 +- common_sites/dls_pc.py | 438 ++++++++++++++++++++++++++++++++++++++ common_sites/findlaw.py | 4 +- common_sites/hualv.py | 7 +- common_sites/lawtime.py | 4 +- common_sites/six4365.py | 7 +- common_sites/start.sh | 21 +- utils/rate_limiter.py | 225 +++++++++++++++----- 9 files changed, 651 insertions(+), 78 deletions(-) create mode 100644 common_sites/dls_pc.py diff --git a/common_sites/dls.py b/common_sites/dls.py index 06d4a01..4cca085 100644 --- a/common_sites/dls.py +++ b/common_sites/dls.py @@ -24,7 +24,7 @@ from request.proxy_config import get_proxies, report_proxy_status urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) from Db import Db -from utils.rate_limiter import wait_for_request +from utils.rate_limiter import request_slot DOMAIN = "大律师" LIST_TEMPLATE = "https://m.maxlaw.cn/law/{pinyin}?page={page}" @@ -108,17 +108,16 @@ class DlsSpider: def _get(self, url: str, max_retries: int = 3, headers: Optional[Dict[str, str]] = None) -> Optional[str]: """发送 GET 请求,带重试机制""" - wait_for_request() - for attempt in range(max_retries): try: # 使用更长的超时时间,分别设置连接和读取超时 - resp = self.session.get( - url, - timeout=(10, 30), # (connect_timeout, read_timeout) - verify=False, - headers=headers, - ) + with request_slot(): + resp = self.session.get( + url, + timeout=(10, 30), # (connect_timeout, read_timeout) + verify=False, + headers=headers, + ) status_code = resp.status_code content = resp.text resp.close() diff --git a/common_sites/dls_fresh.py b/common_sites/dls_fresh.py index d4a4347..0f7a36c 100644 --- a/common_sites/dls_fresh.py +++ b/common_sites/dls_fresh.py @@ -22,7 +22,7 @@ if project_root not in sys.path: sys.path.append(project_root) from request.requests_client import RequestClientError, RequestsClient -from utils.rate_limiter import wait_for_request +from utils.rate_limiter import request_slot from Db import Db urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @@ -107,9 +107,9 @@ class DlsFreshCrawler: def _get_text(self, url: str, timeout: int = 20, max_retries: int = 3) -> str: last_error: Optional[Exception] = None for attempt in range(max_retries): - wait_for_request() try: - resp = self.client.get_text(url, timeout=timeout, verify=False) + with request_slot(): + resp = self.client.get_text(url, timeout=timeout, verify=False) code = resp.status_code if code == 403: if attempt < max_retries - 1: diff --git a/common_sites/dls_pc.py b/common_sites/dls_pc.py new file mode 100644 index 0000000..1dfa113 --- /dev/null +++ b/common_sites/dls_pc.py @@ -0,0 +1,438 @@ +import json +import os +import random +import re +import sys +import time +from typing import Dict, List, Optional, Set, Tuple +from urllib.parse import urljoin + +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +request_dir = os.path.join(project_root, "request") +if request_dir not in sys.path: + sys.path.insert(0, request_dir) +if project_root not in sys.path: + sys.path.append(project_root) + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry +import urllib3 +from bs4 import BeautifulSoup + +from request.proxy_config import get_proxies, report_proxy_status +from utils.rate_limiter import request_slot +from Db import Db + +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +DOMAIN = "大律师" +SITE_BASE = "https://www.maxlaw.cn" +LIST_URL_TEMPLATE = SITE_BASE + "/law/{pinyin}?page={page}" +PROVINCE_API = "https://js.maxlaw.cn/js/ajax/common/getprovice.js" +CITY_API_TEMPLATE = "https://js.maxlaw.cn/js/ajax/common/getcity_{province_id}.js" + +PHONE_RE = re.compile(r"1[3-9]\d{9}") +REPLY_RE = re.compile(r"已回复[::]?\s*(\d+)") +AREA_PREFIX_RE = re.compile(r"^[A-Za-z]\s*") + + +def normalize_phone(text: str) -> str: + compact = re.sub(r"\D", "", text or "") + match = PHONE_RE.search(compact) + return match.group(0) if match else "" + + +def clean_area_name(text: str) -> str: + value = AREA_PREFIX_RE.sub("", (text or "").strip()) + return value.strip() + + +def normalize_region_text(text: str) -> str: + value = (text or "").strip() + value = value.replace("\xa0", " ") + value = value.replace("-", "-").replace("—", "-").replace("–", "-") + value = re.sub(r"\s*-\s*", "-", value) + value = re.sub(r"\s+", "", value) + return value + + +class DlsPcSpider: + def __init__(self, db_connection): + self.db = db_connection + self.session = self._build_session() + self.max_pages = int(os.getenv("MAXLAW_PC_MAX_PAGES", "100")) + self.areas = self._load_areas() + + def _build_session(self) -> requests.Session: + report_proxy_status() + session = requests.Session() + session.trust_env = False + proxies = get_proxies() + if proxies: + session.proxies.update(proxies) + else: + session.proxies.clear() + + retries = Retry( + total=3, + backoff_factor=1, + status_forcelist=(429, 500, 502, 503, 504), + allowed_methods=frozenset(["GET"]), + raise_on_status=False, + ) + adapter = HTTPAdapter(max_retries=retries) + session.mount("https://", adapter) + session.mount("http://", adapter) + session.headers.update({ + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/136.0.0.0 Safari/537.36" + ), + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "Connection": "close", + }) + return session + + def _refresh_session(self) -> None: + try: + self.session.close() + except Exception: + pass + self.session = self._build_session() + + def _get(self, url: str, max_retries: int = 3, headers: Optional[Dict[str, str]] = None) -> Optional[str]: + for attempt in range(max_retries): + try: + with request_slot(): + resp = self.session.get(url, timeout=(10, 25), verify=False, headers=headers) + status_code = resp.status_code + text = resp.text + resp.close() + if status_code == 403: + if attempt < max_retries - 1: + wait_time = 2 ** attempt + random.uniform(0.3, 1.0) + print(f"403被拦截,{wait_time:.1f}秒后重试 ({attempt + 1}/{max_retries}): {url}") + self._refresh_session() + time.sleep(wait_time) + continue + print(f"请求失败 {url}: 403 Forbidden") + return None + if status_code >= 400: + raise requests.exceptions.HTTPError(f"{status_code} Error: {url}") + return text + except requests.exceptions.RequestException as exc: + if attempt < max_retries - 1: + wait_time = 2 ** attempt + random.uniform(0.3, 1.0) + print(f"请求失败,{wait_time:.1f}秒后重试 ({attempt + 1}/{max_retries}): {url} -> {exc}") + time.sleep(wait_time) + continue + print(f"请求失败 {url}: {exc}") + return None + return None + + def _get_json(self, url: str) -> Optional[Dict]: + text = self._get(url) + if not text: + return None + try: + return json.loads(text.strip().lstrip("\ufeff")) + except ValueError as exc: + print(f"解析JSON失败 {url}: {exc}") + return None + + def _load_areas(self) -> List[Dict[str, str]]: + areas = self._load_areas_from_site() + if areas: + print(f"[大律师PC] 地区来源: site, 地区数: {len(areas)}") + return areas + + areas = self._load_areas_from_db() + if areas: + print(f"[大律师PC] 地区来源: db, 地区数: {len(areas)}") + return areas + + print("[大律师PC] 无地区数据") + return [] + + def _load_areas_from_site(self) -> List[Dict[str, str]]: + data = self._get_json(PROVINCE_API) + if not data or str(data.get("status")) != "1": + return [] + + result: List[Dict[str, str]] = [] + seen_pinyin: Set[str] = set() + + for province in data.get("ds", []) or []: + province_id = province.get("id") + province_name = clean_area_name(province.get("name", "")) + province_pinyin = (province.get("py_code") or "").strip() + + city_rows = [] + if province_id: + city_data = self._get_json(CITY_API_TEMPLATE.format(province_id=province_id)) + if city_data and str(city_data.get("status")) == "1": + city_rows = city_data.get("ds", []) or [] + + if not city_rows and province_pinyin and province_pinyin not in seen_pinyin: + seen_pinyin.add(province_pinyin) + result.append({ + "province": province_name, + "city": province_name, + "pinyin": province_pinyin, + }) + continue + + for city in city_rows: + city_name = clean_area_name(city.get("name", "")) + city_pinyin = (city.get("py_code") or "").strip() + if not city_pinyin or city_pinyin in seen_pinyin: + continue + seen_pinyin.add(city_pinyin) + result.append({ + "province": province_name, + "city": city_name or province_name, + "pinyin": city_pinyin, + }) + + return result + + def _load_areas_from_db(self) -> List[Dict[str, str]]: + tables = ("area_new", "area", "area2") + last_error = None + for table in tables: + try: + rows = self.db.select_data( + table, + "province, city, pinyin", + "domain='maxlaw' AND level=2", + ) or [] + except Exception as exc: + last_error = exc + continue + + if rows: + return rows + + if last_error: + print(f"[大律师PC] 加载数据库地区失败: {last_error}") + return [] + + def _existing_phones(self, phones: List[str]) -> Set[str]: + if not phones: + return set() + existing: Set[str] = set() + cur = self.db.db.cursor() + try: + chunk_size = 500 + for i in range(0, len(phones), chunk_size): + chunk = phones[i:i + chunk_size] + placeholders = ",".join(["%s"] * len(chunk)) + sql = f"SELECT phone FROM lawyer WHERE domain=%s AND phone IN ({placeholders})" + cur.execute(sql, [DOMAIN, *chunk]) + for row in cur.fetchall(): + existing.add(row[0]) + finally: + cur.close() + return existing + + def _build_list_url(self, pinyin: str, page: int) -> str: + return LIST_URL_TEMPLATE.format(pinyin=pinyin, page=page) + + def _parse_location_line( + self, + text: str, + fallback_province: str, + fallback_city: str, + ) -> Tuple[str, str, str]: + raw = (text or "").replace("\xa0", " ") + raw = re.sub(r"\s+", " ", raw).strip() + if not raw: + return fallback_province, fallback_city or fallback_province, "" + + parts = raw.split(" ", 1) + area_text = parts[0].strip() + law_firm = parts[1].strip() if len(parts) > 1 else "" + + province = fallback_province + city = fallback_city or fallback_province + if "-" in area_text: + area_parts = [item.strip() for item in area_text.split("-", 1)] + if area_parts[0]: + province = area_parts[0] + if len(area_parts) > 1 and area_parts[1]: + city = area_parts[1] + elif area_text: + province = area_text + city = area_text + + return province, city, law_firm + + def _extract_page_region(self, soup: BeautifulSoup) -> str: + button = soup.select_one(".filter .filter-btn") + if button: + return normalize_region_text(button.get_text(" ", strip=True)) + title = soup.select_one(".findLawyer-title h1") + if title: + return normalize_region_text(title.get_text(strip=True).replace("律师", "")) + return "" + + def _page_matches_area(self, soup: BeautifulSoup, province: str, city: str) -> Tuple[bool, str]: + current_region = self._extract_page_region(soup) + if not current_region: + return True, current_region + if "全国" in current_region: + return False, current_region + + norm_province = normalize_region_text(province) + norm_city = normalize_region_text(city or province) + + if norm_city and norm_city != norm_province: + matched = norm_province in current_region and norm_city in current_region + else: + matched = norm_province in current_region + + if matched: + return True, current_region + + title = soup.select_one(".findLawyer-title h1") + title_text = "" + if title: + title_text = normalize_region_text(title.get_text(strip=True).replace("律师", "")) + + if norm_city and norm_city != norm_province: + matched = norm_city in title_text + else: + matched = norm_province in title_text + + return matched, current_region or title_text + + def _parse_list(self, html: str, province: str, city: str, list_url: str, area_pinyin: str) -> Tuple[bool, int, int]: + soup = BeautifulSoup(html, "html.parser") + matched, current_region = self._page_matches_area(soup, province, city) + if not matched: + print(f" 页面地区不匹配,停止分页: 目标={province}-{city} 当前={current_region or '未知'}") + return False, 0, 0 + + cards = [] + seen_page_phone: Set[str] = set() + + for item in soup.select("ul.findLawyer-list > li.clearfix"): + name_link = item.select_one(".findLawyer-list-detail-name a[href]") + phone_tag = item.select_one(".findLawyer-list-detail-name span") + if not name_link or not phone_tag: + continue + + phone = normalize_phone(phone_tag.get_text(" ", strip=True)) + if not phone or phone in seen_page_phone: + continue + seen_page_phone.add(phone) + + name = name_link.get_text(strip=True) + detail_url = urljoin(SITE_BASE, name_link.get("href", "").strip()) + + location_tag = item.select_one(".findLawyer-list-detail-the") + card_province, card_city, law_firm = self._parse_location_line( + location_tag.get_text(" ", strip=True) if location_tag else "", + province, + city, + ) + + specialties = [] + for dd in item.select(".findLawyer-list-detail-fields dd"): + text = dd.get_text(strip=True) + if text: + specialties.append(text) + + reply_count = None + reply_tag = item.select_one(".findLawyer-list-detail-other a") + if reply_tag: + match = REPLY_RE.search(reply_tag.get_text(" ", strip=True)) + if match: + reply_count = int(match.group(1)) + + cards.append({ + "name": name, + "law_firm": law_firm, + "province": card_province or province, + "city": card_city or city or province, + "phone": phone, + "url": detail_url, + "domain": DOMAIN, + "create_time": int(time.time()), + "params": json.dumps({ + "area_pinyin": area_pinyin, + "source": list_url, + "specialties": specialties, + "reply_count": reply_count, + }, ensure_ascii=False), + }) + + if not cards: + return True, 0, 0 + + phones = [item["phone"] for item in cards if item.get("phone")] + existing = self._existing_phones(phones) + inserted = 0 + + for item in cards: + phone = item.get("phone") + if not phone: + continue + if phone in existing: + print(f" -- 已存在: {item['name']} ({phone})") + continue + try: + self.db.insert_data("lawyer", item) + inserted += 1 + print(f" -> 新增: {item['name']} ({phone})") + except Exception as exc: + print(f" 插入失败 {item.get('url')}: {exc}") + + return True, inserted, len(cards) + + def run(self): + print("启动大律师 PC 站采集...") + if not self.areas: + print("无地区数据") + return + + for area in self.areas: + province = (area.get("province") or "").strip() + city = (area.get("city") or province).strip() + pinyin = (area.get("pinyin") or "").strip() + if not province or not pinyin: + continue + + area_label = province if not city or city == province else f"{province}-{city}" + print(f"采集地区: {area_label} ({pinyin})") + + for page in range(1, self.max_pages + 1): + list_url = self._build_list_url(pinyin, page) + print(f" 第 {page} 页: {list_url}") + html = self._get(list_url, headers={"Referer": SITE_BASE + "/law"}) + if not html: + break + + page_ok, inserted, parsed_count = self._parse_list(html, province, city, list_url, pinyin) + if not page_ok: + break + if parsed_count == 0: + print(" 当前页无律师卡片,停止") + break + + if inserted == 0: + print(" 当前页无新增数据") + + time.sleep(0.5) + + print("大律师 PC 站采集完成") + + +if __name__ == "__main__": + with Db() as db: + spider = DlsPcSpider(db) + spider.run() diff --git a/common_sites/findlaw.py b/common_sites/findlaw.py index 2496037..750c311 100644 --- a/common_sites/findlaw.py +++ b/common_sites/findlaw.py @@ -16,6 +16,7 @@ if project_root not in sys.path: import requests from request.proxy_config import get_proxies, report_proxy_status from Db import Db +from utils.rate_limiter import request_slot DOMAIN = "找法网" LIST_TEMPLATE = "https://m.findlaw.cn/{pinyin}/q_lawyer/p{page}?ajax=1&order=0&sex=-1" @@ -59,7 +60,8 @@ class FindlawSpider: headers = {"Referer": referer} for attempt in range(max_retries): try: - resp = self.session.get(url, timeout=15, verify=verify, headers=headers) + with request_slot(): + resp = self.session.get(url, timeout=15, verify=verify, headers=headers) status_code = resp.status_code text = resp.text resp.close() diff --git a/common_sites/hualv.py b/common_sites/hualv.py index 8eaa30d..3b6180f 100644 --- a/common_sites/hualv.py +++ b/common_sites/hualv.py @@ -20,6 +20,7 @@ from request.proxy_config import get_proxies, report_proxy_status from Db import Db from config import HEADERS +from utils.rate_limiter import request_slot LIST_URL = "https://m.66law.cn/findlawyer/rpc/loadlawyerlist/" DOMAIN = "华律" @@ -100,7 +101,8 @@ class HualvSpider: def _post(self, data: Dict[str, str], max_retries: int = 3) -> Optional[Dict]: for attempt in range(max_retries): try: - resp = self.session.post(LIST_URL, data=data, timeout=20, verify=False) + with request_slot(): + resp = self.session.post(LIST_URL, data=data, timeout=20, verify=False) status_code = resp.status_code text = resp.text resp.close() @@ -272,7 +274,8 @@ class HualvSpider: def _get_detail(self, url: str, max_retries: int = 3) -> Optional[str]: for attempt in range(max_retries): try: - resp = self.session.get(url, timeout=15, verify=False) + with request_slot(): + resp = self.session.get(url, timeout=15, verify=False) status_code = resp.status_code text = resp.text resp.close() diff --git a/common_sites/lawtime.py b/common_sites/lawtime.py index 2ce89aa..e1618a2 100644 --- a/common_sites/lawtime.py +++ b/common_sites/lawtime.py @@ -26,6 +26,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) from Db import Db from config import LAWTIME_CONFIG +from utils.rate_limiter import request_slot LIST_BASE = "https://m.lawtime.cn/{pinyin}/lawyer/?page={page}" DETAIL_BASE = "https://m.lawtime.cn" @@ -123,7 +124,8 @@ class LawtimeSpider: def _get_with_session(self, session: requests.Session, url: str, max_retries: int = 3, is_thread: bool = False) -> Optional[str]: for attempt in range(max_retries): try: - resp = session.get(url, timeout=15, verify=False) + with request_slot(): + resp = session.get(url, timeout=15, verify=False) status_code = resp.status_code text = resp.text resp.close() diff --git a/common_sites/six4365.py b/common_sites/six4365.py index 9fb3651..470adc4 100644 --- a/common_sites/six4365.py +++ b/common_sites/six4365.py @@ -23,6 +23,7 @@ from request.proxy_config import get_proxies, report_proxy_status urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) from Db import Db +from utils.rate_limiter import request_slot DOMAIN = "律图" LIST_URL = "https://m.64365.com/findLawyer/rpc/FindLawyer/LawyerRecommend/" @@ -144,7 +145,8 @@ class Six4365Spider: def _post(self, payload: Dict[str, str], max_retries: int = 3) -> Optional[str]: for attempt in range(max_retries): try: - resp = self.session.post(LIST_URL, data=payload, timeout=10, verify=False) + with request_slot(): + resp = self.session.post(LIST_URL, data=payload, timeout=10, verify=False) status_code = resp.status_code text = resp.text resp.close() @@ -301,7 +303,8 @@ class Six4365Spider: session = self._get_thread_session() for attempt in range(max_retries): try: - resp = session.get(url, timeout=10, verify=False) + with request_slot(): + resp = session.get(url, timeout=10, verify=False) status_code = resp.status_code text = resp.text resp.close() diff --git a/common_sites/start.sh b/common_sites/start.sh index a4d2a7d..f117cbc 100755 --- a/common_sites/start.sh +++ b/common_sites/start.sh @@ -5,9 +5,20 @@ set -euo pipefail cd "$(dirname "$0")" echo "使用 request/proxy_settings.json 读取代理配置" +export PROXY_MAX_REQUESTS_PER_SECOND="${PROXY_MAX_REQUESTS_PER_SECOND:-5}" -nohup python ../common_sites/dls.py > dls.log 2>&1 & # 大律师 -nohup python ../common_sites/findlaw.py > findlaw.log 2>&1 & # 找法网 -nohup python ../common_sites/lawtime.py > lawtime.log 2>&1 & # 法律快车 -nohup python ../common_sites/six4365.py > six4365.log 2>&1 & # 律图 -nohup python ../common_sites/hualv.py > hualv.log 2>&1 & # 华律 +start_job() { + local script="$1" + local log_file="$2" + local label="$3" + nohup python "../common_sites/${script}" > "${log_file}" 2>&1 & + echo "启动 ${label}: ${script} -> ${log_file}" + sleep 1 +} + +start_job "dls.py" "dls.log" "大律师" +start_job "dls_pc.py" "dls_pc.log" "大律师PC站" +start_job "findlaw.py" "findlaw.log" "找法网" +start_job "lawtime.py" "lawtime.log" "法律快车" +start_job "six4365.py" "six4365.log" "律图" +start_job "hualv.py" "hualv.log" "华律" diff --git a/utils/rate_limiter.py b/utils/rate_limiter.py index f3fd9ea..23ef0b5 100644 --- a/utils/rate_limiter.py +++ b/utils/rate_limiter.py @@ -1,76 +1,191 @@ """ 全局请求速率限制器 -确保代理每秒不超过5次请求 + +默认按“所有爬虫进程共享一个桶”来限流,避免 `bash start.sh` +同时启动多个进程时,每个进程各自 5 次/秒,叠加后把代理冲爆。 """ +from contextlib import contextmanager +import json +import os +import tempfile import time import threading -from collections import deque +from pathlib import Path +from uuid import uuid4 + +import fcntl class RateLimiter: """ - 令牌桶算法实现的速率限制器 + 基于文件锁的跨进程滑动窗口限流器。 + + - 同一台机器上的多个 Python 进程会共享同一个状态文件 + - 同一个进程内的多个线程也会一起走这个限流器 """ - def __init__(self, max_requests_per_second: int = 5): - """ - 初始化速率限制器 - - Args: - max_requests_per_second: 每秒最大请求数 - """ - self.max_requests = max_requests_per_second - self.requests = deque() - self.lock = threading.RLock() - - def acquire(self): - """ - 获取请求权限,如果需要则等待 - """ - with self.lock: - now = time.time() - - # 清理超过1秒的请求记录 - while self.requests and now - self.requests[0] >= 1.0: - self.requests.popleft() - - # 如果当前请求数已达上限,等待 - if len(self.requests) >= self.max_requests: - # 计算需要等待的时间 - wait_time = 1.0 - (now - self.requests[0]) - if wait_time > 0: - time.sleep(wait_time) - return self.acquire() # 递归调用以重新检查 - - # 记录这次请求 - self.requests.append(now) - + + def __init__( + self, + max_requests_per_second: int = 5, + window_seconds: float = 1.0, + state_file: str | None = None, + ): + self.max_requests = max(1, int(max_requests_per_second)) + self.max_concurrent = max( + 1, + int(os.getenv("PROXY_MAX_CONCURRENT_REQUESTS", str(self.max_requests))), + ) + self.window_seconds = max(0.1, float(window_seconds)) + self.lease_seconds = max( + 5.0, + float(os.getenv("PROXY_REQUEST_LEASE_SECONDS", "120")), + ) + default_state = os.path.join( + tempfile.gettempdir(), + "lawyers_proxy_rate_limiter.json", + ) + self.state_file = Path( + state_file or os.getenv("PROXY_RATE_LIMIT_FILE", default_state) + ) + self.lock_file = self.state_file.with_suffix(self.state_file.suffix + ".lock") + self._thread_lock = threading.RLock() + self.state_file.parent.mkdir(parents=True, exist_ok=True) + self.lock_file.parent.mkdir(parents=True, exist_ok=True) + + def _load_state(self) -> dict: + if not self.state_file.exists(): + return {"timestamps": [], "leases": {}} + try: + raw = self.state_file.read_text(encoding="utf-8").strip() + if not raw: + return {"timestamps": [], "leases": {}} + data = json.loads(raw) + if isinstance(data, list): + return { + "timestamps": [float(item) for item in data], + "leases": {}, + } + if not isinstance(data, dict): + return {"timestamps": [], "leases": {}} + timestamps = data.get("timestamps", []) or [] + leases = data.get("leases", {}) or {} + return { + "timestamps": [float(item) for item in timestamps], + "leases": {str(key): float(value) for key, value in leases.items()}, + } + except Exception: + return {"timestamps": [], "leases": {}} + + def _save_state(self, state: dict) -> None: + payload = json.dumps(state, ensure_ascii=False) + self.state_file.write_text(payload, encoding="utf-8") + + def _normalize_state(self, state: dict, now: float) -> dict: + timestamps = [ + float(ts) + for ts in (state.get("timestamps", []) or []) + if now - float(ts) < self.window_seconds + ] + leases = { + str(key): float(value) + for key, value in (state.get("leases", {}) or {}).items() + if now - float(value) < self.lease_seconds + } + return {"timestamps": timestamps, "leases": leases} + + def acquire(self) -> None: + token = None + while True: + token = self.try_acquire_slot() + if token: + self.release(token) + return + time.sleep(0.05) + + def try_acquire_slot(self) -> str | None: + while True: + wait_time = 0.0 + with self._thread_lock: + with open(self.lock_file, "a+", encoding="utf-8") as lock_fp: + fcntl.flock(lock_fp.fileno(), fcntl.LOCK_EX) + now = time.time() + state = self._normalize_state(self._load_state(), now) + timestamps = state["timestamps"] + leases = state["leases"] + + if len(timestamps) < self.max_requests and len(leases) < self.max_concurrent: + token = uuid4().hex + timestamps.append(now) + leases[token] = now + self._save_state(state) + fcntl.flock(lock_fp.fileno(), fcntl.LOCK_UN) + return token + + wait_candidates = [] + if len(timestamps) >= self.max_requests and timestamps: + wait_candidates.append(self.window_seconds - (now - timestamps[0])) + if len(leases) >= self.max_concurrent: + wait_candidates.append(0.05) + wait_time = max(0.05, min([item for item in wait_candidates if item > 0] or [0.05])) + fcntl.flock(lock_fp.fileno(), fcntl.LOCK_UN) + + time.sleep(wait_time) + + def release(self, token: str | None) -> None: + if not token: + return + with self._thread_lock: + with open(self.lock_file, "a+", encoding="utf-8") as lock_fp: + fcntl.flock(lock_fp.fileno(), fcntl.LOCK_EX) + now = time.time() + state = self._normalize_state(self._load_state(), now) + leases = state["leases"] + if token in leases: + leases.pop(token, None) + self._save_state(state) + else: + self._save_state(state) + fcntl.flock(lock_fp.fileno(), fcntl.LOCK_UN) + def can_make_request(self) -> bool: - """ - 检查是否可以立即发起请求(非阻塞) - """ - with self.lock: - now = time.time() - - # 清理超过1秒的请求记录 - while self.requests and now - self.requests[0] >= 1.0: - self.requests.popleft() - - return len(self.requests) < self.max_requests + with self._thread_lock: + with open(self.lock_file, "a+", encoding="utf-8") as lock_fp: + fcntl.flock(lock_fp.fileno(), fcntl.LOCK_EX) + now = time.time() + state = self._normalize_state(self._load_state(), now) + self._save_state(state) + allowed = ( + len(state["timestamps"]) < self.max_requests + and len(state["leases"]) < self.max_concurrent + ) + fcntl.flock(lock_fp.fileno(), fcntl.LOCK_UN) + return allowed -# 全局速率限制器实例 -global_rate_limiter = RateLimiter(max_requests_per_second=5) +global_rate_limiter = RateLimiter( + max_requests_per_second=int(os.getenv("PROXY_MAX_REQUESTS_PER_SECOND", "5")) +) def wait_for_request(): - """ - 等待直到可以发起请求 - """ + """等待直到可以发起请求。""" global_rate_limiter.acquire() def can_request_now() -> bool: - """ - 检查是否可以立即发起请求 - """ + """检查是否可以立即发起请求。""" return global_rate_limiter.can_make_request() + + +@contextmanager +def request_slot(): + """ + 申请一个跨进程共享的请求槽位,请求结束后自动释放。 + + 这样既能限制“每秒启动多少请求”,也能限制“同时在飞多少请求”。 + """ + token = global_rate_limiter.try_acquire_slot() + try: + yield + finally: + global_rate_limiter.release(token)