| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605 |
- """
- 第三方平台推送模块
- 将批次信息推送到第三方平台接口
- """
- import os
- import time
- import json
- import logging
- import threading
- import queue
- import cv2
- import requests
- from typing import Optional, Dict, Any, List, Callable
- from dataclasses import dataclass
- from datetime import datetime
- from pathlib import Path
- logger = logging.getLogger(__name__)
- def _normalize_timestamp(ts: float) -> float:
- """统一时间戳为秒。CaptureUploader 使用毫秒,其他可能使用秒。"""
- if ts > 1e12:
- return ts / 1000.0
- return ts
- def _convert_to_legacy_batch_info(new_info: Dict[str, Any]) -> Dict[str, Any]:
- """
- 把新版 CaptureUploader 生成的 batch_info 转成业务平台要求的 PairedImageSaver 老格式。
- 参考字段:
- - batch_id / device_id / project_id / timestamp / datetime
- - total_persons / ptz_images_count
- - panorama: local_path / local_path_original / oss_url / oss_url_original
- - persons: person_index / position(x,y 归一化) / bbox / confidence /
- ptz_position / ptz_bbox / ptz_image_saved /
- ptz_image_path / ptz_image_original_path /
- ptz_oss_url / ptz_oss_url_original
- - upload_status
- """
- normalized_ts = _normalize_timestamp(new_info.get("timestamp", time.time()))
- urls = new_info.get("image_urls") or {}
- image_paths = new_info.get("image_paths") or []
- camera_type = new_info.get("camera_type", "panorama")
- is_ptz = camera_type == "ptz"
- # image_paths 约定:[original, marked]
- original_path = image_paths[0] if len(image_paths) > 0 else None
- marked_path = image_paths[1] if len(image_paths) > 1 else original_path
- oss_url_original = urls.get("original") or None
- oss_url_marked = urls.get("marked") or oss_url_original
- # 读取原图尺寸,用于把人体中心坐标归一化到 0~1
- img_w, img_h = 0, 0
- if original_path and os.path.exists(original_path):
- try:
- img = cv2.imread(original_path)
- if img is not None:
- img_h, img_w = img.shape[:2]
- except Exception:
- pass
- ptz_position = new_info.get("ptz_position") or {}
- persons = []
- for i, det in enumerate(new_info.get("detections") or []):
- bbox = det.get("bbox", [0, 0, 0, 0])
- x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
- cx = (x1 + x2) / 2.0
- cy = (y1 + y2) / 2.0
- position = {
- "x": round(cx / img_w, 4) if img_w else round(cx, 4),
- "y": round(cy / img_h, 4) if img_h else round(cy, 4),
- }
- person = {
- "person_index": i,
- "position": position,
- "bbox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
- "confidence": float(det.get("confidence", 0.0)),
- }
- # PTZ 流检测时,把同一张 PTZ 图作为每个人的特写图复用
- # 始终包含 ptz_position 字段,第三方平台要求必须有 pan/tilt/zoom 数值
- if is_ptz:
- # 使用实际检测时的 PTZ 位置(若无实际位置,用默认值 0/0/1)
- ptz_pan = ptz_position.get("pan") if isinstance(ptz_position, dict) else 0
- ptz_tilt = ptz_position.get("tilt") if isinstance(ptz_position, dict) else 0
- ptz_zoom = ptz_position.get("zoom") if isinstance(ptz_position, dict) else 1
- person["ptz_position"] = {
- "pan": ptz_pan if ptz_pan is not None else 0,
- "tilt": ptz_tilt if ptz_tilt is not None else 0,
- "zoom": ptz_zoom if ptz_zoom is not None else 1,
- }
- person["ptz_bbox"] = {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
- person["ptz_image_saved"] = bool(marked_path and os.path.exists(marked_path))
- person["ptz_image_path"] = marked_path
- person["ptz_image_original_path"] = original_path
- person["ptz_oss_url"] = oss_url_marked
- person["ptz_oss_url_original"] = oss_url_original
- else:
- # 全景相机检测时,复用当前检测框作为 ptz_bbox
- person["ptz_position"] = {
- "pan": 0, "tilt": 0, "zoom": 1,
- }
- person["ptz_bbox"] = {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
- person["ptz_image_saved"] = bool(marked_path and os.path.exists(marked_path))
- person["ptz_image_path"] = marked_path
- person["ptz_image_original_path"] = original_path
- person["ptz_oss_url"] = oss_url_marked
- person["ptz_oss_url_original"] = oss_url_original
- persons.append(person)
- ptz_uploaded = is_ptz and bool(oss_url_marked)
- legacy = {
- "batch_id": new_info.get("batch_id", ""),
- "device_id": new_info.get("device_id", ""),
- "project_id": new_info.get("project_id", ""),
- "timestamp": normalized_ts,
- "datetime": datetime.fromtimestamp(normalized_ts).isoformat(),
- "total_persons": len(persons),
- "ptz_images_count": len(persons),
- "panorama": {
- "local_path": marked_path,
- "local_path_original": original_path,
- "oss_url": oss_url_marked,
- "oss_url_original": oss_url_original,
- },
- "persons": persons,
- "upload_status": {
- "panorama_uploaded": bool(oss_url_marked or oss_url_original),
- "panorama_original_uploaded": bool(oss_url_original),
- "all_ptz_uploaded": ptz_uploaded,
- },
- }
- return legacy
- @dataclass
- class BatchReport:
- """批次上报数据"""
- batch_id: str
- device_id: str
- project_id: str
- timestamp: float
- batch_info: Dict[str, Any] # batch_info.json 的完整内容
- local_path: Optional[str] = None # batch_info.json 本地路径
- class ThirdPartyPusher:
- """
- 第三方平台推送器
- 负责将批次信息推送到配置的第三方平台接口
- """
-
- def __init__(self, config: Dict[str, Any] = None):
- """
- 初始化第三方平台推送器
-
- Args:
- config: 第三方平台配置字典
- """
- from config import THIRD_PARTY_CONFIG, DEVICE_CONFIG
-
- self.config = config or THIRD_PARTY_CONFIG
- self.device_config = DEVICE_CONFIG
-
- # 功能开关
- self.enabled = self.config.get('enabled', False)
-
- # 平台配置
- self.platform_type = self.config.get('platform_type', 'custom')
- self.base_url = self.config.get('base_url', '')
- self.api_version = self.config.get('api_version', 'v1')
-
- # 认证配置
- self.auth_type = self.config.get('auth_type', 'none')
- self.api_key = self.config.get('api_key', '')
- self.api_secret = self.config.get('api_secret', '')
- self.oauth2_config = self.config.get('oauth2', {})
-
- # 接口路径
- self.endpoints = self.config.get('endpoints', {})
- self.batch_report_url = self.endpoints.get('batch_report', '/api/batch/report')
- self.heartbeat_url = self.endpoints.get('heartbeat', '/api/device/heartbeat')
-
- # 推送控制
- self.push_interval = self.config.get('push_interval', 1.0)
- self.retry_count = self.config.get('retry_count', 3)
- self.retry_delay = self.config.get('retry_delay', 2.0)
- self.timeout = self.config.get('timeout', 10)
- self.data_format = self.config.get('data_format', 'json')
- self.include_images = self.config.get('include_images', False)
-
- # OAuth2 Token
- self._access_token = None
- self._token_expires_at = 0
-
- # 上报队列
- self.report_queue = queue.Queue()
-
- # 工作线程
- self.running = False
- self.worker_thread = None
-
- # 统计
- self.stats = {
- 'total_reports': 0,
- 'success_reports': 0,
- 'failed_reports': 0,
- }
- self.stats_lock = threading.Lock()
-
- # 回调
- self.on_report_success: Optional[Callable] = None
- self.on_report_failed: Optional[Callable] = None
-
- # 最后上报时间
- self.last_report_time = 0
-
- if self.enabled:
- logger.info(f"[第三方平台] 推送器初始化完成: {self.base_url}")
-
- def start(self):
- """启动推送器"""
- if not self.enabled:
- logger.info("[第三方平台] 推送器未启用")
- return
-
- if self.running:
- return
-
- self.running = True
- self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
- self.worker_thread.start()
- logger.info("[第三方平台] 推送器已启动")
-
- def stop(self):
- """停止推送器"""
- self.running = False
- if self.worker_thread:
- self.worker_thread.join(timeout=5)
- logger.info("[第三方平台] 推送器已停止")
-
- def _worker_loop(self):
- """工作线程循环"""
- while self.running:
- try:
- report = self.report_queue.get(timeout=1.0)
- self._process_report(report)
- except queue.Empty:
- continue
- except Exception as e:
- logger.error(f"[第三方平台] 处理上报错误: {e}")
-
- def _get_auth_headers(self) -> Dict[str, str]:
- """获取认证请求头(当前第三方接口不需要自定义 header,返回空避免 422)"""
- return {}
-
- def _get_oauth2_token(self) -> Optional[str]:
- """获取 OAuth2 Token"""
- # 检查现有 token 是否有效
- if self._access_token and time.time() < self._token_expires_at - 60:
- return self._access_token
-
- # 重新获取 token
- token_url = self.oauth2_config.get('token_url', '')
- client_id = self.oauth2_config.get('client_id', '')
- client_secret = self.oauth2_config.get('client_secret', '')
- scope = self.oauth2_config.get('scope', '')
-
- if not all([token_url, client_id, client_secret]):
- logger.error("[第三方平台] OAuth2 配置不完整")
- return None
-
- try:
- data = {
- 'grant_type': 'client_credentials',
- 'client_id': client_id,
- 'client_secret': client_secret,
- }
- if scope:
- data['scope'] = scope
-
- response = requests.post(token_url, data=data, timeout=self.timeout)
-
- if response.status_code == 200:
- result = response.json()
- self._access_token = result.get('access_token')
- expires_in = result.get('expires_in', 3600)
- self._token_expires_at = time.time() + expires_in
- logger.info("[第三方平台] OAuth2 Token 获取成功")
- return self._access_token
- else:
- logger.error(f"[第三方平台] OAuth2 Token 获取失败: {response.status_code}")
- return None
-
- except Exception as e:
- logger.error(f"[第三方平台] OAuth2 Token 请求异常: {e}")
- return None
-
- def _process_report(self, report: BatchReport):
- """处理单个上报任务"""
- # 检查推送间隔
- current_time = time.time()
- time_since_last = current_time - self.last_report_time
- if time_since_last < self.push_interval:
- time.sleep(self.push_interval - time_since_last)
-
- success = self._send_batch_report(report)
-
- with self.stats_lock:
- self.stats['total_reports'] += 1
- if success:
- self.stats['success_reports'] += 1
- else:
- self.stats['failed_reports'] += 1
-
- self.last_report_time = time.time()
-
- # 触发回调
- if success and self.on_report_success:
- try:
- self.on_report_success(report)
- except Exception as e:
- logger.error(f"[第三方平台] 成功回调执行错误: {e}")
- elif not success and self.on_report_failed:
- try:
- self.on_report_failed(report)
- except Exception as e:
- logger.error(f"[第三方平台] 失败回调执行错误: {e}")
-
- def _send_batch_report(self, report: BatchReport) -> bool:
- """
- 发送批次上报请求
-
- Args:
- report: 批次上报数据
-
- Returns:
- bool: 是否成功
- """
- if not self.base_url:
- logger.error("[第三方平台] 未配置 base_url")
- return False
-
- url = f"{self.base_url}{self.batch_report_url}"
-
- # 构建请求数据
- payload = self._build_payload(report)
-
- headers = self._get_auth_headers()
-
- for attempt in range(self.retry_count):
- try:
- if self.data_format == 'json':
- response = requests.post(
- url,
- json=payload,
- headers=headers,
- timeout=self.timeout,
- verify=False
- )
- else:
- response = requests.post(
- url,
- data=payload,
- headers=headers,
- timeout=self.timeout,
- verify=False
- )
-
- if response.status_code == 200:
- result = response.json()
- status = result.get('status', '')
- message = result.get('message', '')
- if (result.get('code') == 200 or
- result.get('success') == True or
- status in ('pending', 'success', 'accepted') or
- '请求已接收' in message or
- message == 'accepted'):
- logger.info(f"[第三方平台] 批次上报成功: {report.batch_id}, task_id={result.get('task_id')}")
- return True
- else:
- logger.warning(f"[第三方平台] 批次上报失败: {result.get('msg', '未知错误')}")
- try:
- logger.warning(f"[第三方平台] 响应内容: {str(result)[:500]}")
- except Exception:
- pass
- else:
- logger.warning(f"[第三方平台] 批次上报失败: HTTP {response.status_code}")
- try:
- logger.warning(f"[第三方平台] 响应内容: {response.text[:500]}")
- except Exception:
- pass
- if attempt < self.retry_count - 1:
- time.sleep(self.retry_delay)
-
- except requests.exceptions.Timeout:
- logger.warning(f"[第三方平台] 请求超时 (尝试 {attempt + 1}/{self.retry_count})")
- if attempt < self.retry_count - 1:
- time.sleep(self.retry_delay)
- except Exception as e:
- logger.error(f"[第三方平台] 请求异常 (尝试 {attempt + 1}/{self.retry_count}): {e}")
- if attempt < self.retry_count - 1:
- time.sleep(self.retry_delay)
-
- logger.error(f"[第三方平台] 批次上报最终失败: {report.batch_id}")
- return False
-
- def _build_payload(self, report: BatchReport) -> Dict[str, Any]:
- """
- 构建上报请求体
- Args:
- report: 批次上报数据
- Returns:
- Dict: 请求体字典
- """
- batch_info = report.batch_info
- # 根据平台类型调整格式
- normalized_ts = _normalize_timestamp(report.timestamp)
- if self.platform_type == 'jtjai':
- # 优先取 OSS URL,否则用本地路径
- urls = batch_info.get('image_urls') or {}
- image_url = urls.get('original') or urls.get('marked') or (batch_info.get('image_paths') or [None])[0]
- payload = {
- 'createTime': datetime.fromtimestamp(normalized_ts).strftime("%Y-%m-%d %H:%M:%S"),
- 'addr': f"设备{report.device_id}批次上报",
- 'ext1': json.dumps([image_url]),
- 'ext2': json.dumps({
- 'batchId': report.batch_id,
- 'deviceId': report.device_id,
- 'projectId': report.project_id,
- 'totalPersons': len(batch_info.get('detections', [])),
- 'ptzImagesCount': 1 if batch_info.get('camera_type') == 'ptz' else 0,
- 'persons': batch_info.get('detections', []),
- 'imageUrls': urls,
- })
- }
- else:
- # custom / 其他平台:把新版 batch_info 转回老字段名后上报,
- # 兼容原人体分析平台对 panorama / total_persons / persons 的解析。
- payload = _convert_to_legacy_batch_info(batch_info)
- # 统一时间戳单位为秒,避免第三方解析错误
- payload['timestamp'] = normalized_ts
- return payload
-
- def report_batch(self, batch_info: Dict[str, Any], local_path: Optional[str] = None):
- """
- 上报批次信息
-
- Args:
- batch_info: batch_info.json 的字典内容
- local_path: batch_info.json 的本地文件路径(可选)
- """
- if not self.enabled:
- return
-
- # 接受所有相机类型的检测上报(panorama 或 ptz)
- # 业务流程:检测到人 → 上传 OSS → 推送第三方平台
-
- report = BatchReport(
- batch_id=batch_info.get('batch_id', ''),
- device_id=batch_info.get('device_id', ''),
- project_id=batch_info.get('project_id', ''),
- timestamp=batch_info.get('timestamp', time.time()),
- batch_info=batch_info,
- local_path=local_path
- )
-
- self.report_queue.put(report)
-
- def report_batch_sync(self, batch_info: Dict[str, Any],
- local_path: Optional[str] = None) -> bool:
- """
- 同步上报批次信息
-
- Args:
- batch_info: batch_info.json 的字典内容
- local_path: batch_info.json 的本地文件路径(可选)
-
- Returns:
- bool: 是否成功
- """
- if not self.enabled:
- return False
-
- report = BatchReport(
- batch_id=batch_info.get('batch_id', ''),
- device_id=batch_info.get('device_id', ''),
- project_id=batch_info.get('project_id', ''),
- timestamp=batch_info.get('timestamp', time.time()),
- batch_info=batch_info,
- local_path=local_path
- )
-
- return self._send_batch_report(report)
-
- def send_heartbeat(self) -> bool:
- """
- 发送心跳
-
- Returns:
- bool: 是否成功
- """
- if not self.enabled or not self.heartbeat_url:
- return False
-
- url = f"{self.base_url}{self.heartbeat_url}"
-
- payload = {
- 'deviceId': self.device_config.get('device_id', ''),
- 'projectId': self.device_config.get('project_id', ''),
- 'timestamp': time.time(),
- 'status': 'online',
- }
-
- headers = self._get_auth_headers()
-
- try:
- response = requests.post(
- url,
- json=payload,
- headers=headers,
- timeout=self.timeout,
- verify=False
- )
-
- if response.status_code == 200:
- logger.debug("[第三方平台] 心跳发送成功")
- return True
- else:
- logger.warning(f"[第三方平台] 心跳发送失败: HTTP {response.status_code}")
- return False
-
- except Exception as e:
- logger.error(f"[第三方平台] 心跳发送异常: {e}")
- return False
-
- def set_callbacks(self, on_success: Callable = None, on_failed: Callable = None):
- """
- 设置回调函数
-
- Args:
- on_success: 上报成功回调
- on_failed: 上报失败回调
- """
- self.on_report_success = on_success
- self.on_report_failed = on_failed
-
- def get_stats(self) -> Dict[str, int]:
- """获取统计信息"""
- with self.stats_lock:
- return self.stats.copy()
-
- def is_enabled(self) -> bool:
- """检查是否启用"""
- return self.enabled
- # 全局单例
- _third_party_pusher_instance: Optional[ThirdPartyPusher] = None
- _third_party_pusher_lock = threading.Lock()
- def get_third_party_pusher(config: Dict[str, Any] = None) -> ThirdPartyPusher:
- """
- 获取第三方平台推送器实例(单例模式,线程安全)
- Args:
- config: 第三方平台配置
- Returns:
- ThirdPartyPusher 实例
- """
- global _third_party_pusher_instance
- if _third_party_pusher_instance is None:
- with _third_party_pusher_lock:
- if _third_party_pusher_instance is None:
- _third_party_pusher_instance = ThirdPartyPusher(config)
- return _third_party_pusher_instance
- def reset_third_party_pusher():
- """重置第三方平台推送器实例"""
- global _third_party_pusher_instance
- with _third_party_pusher_lock:
- if _third_party_pusher_instance is not None:
- _third_party_pusher_instance.stop()
- _third_party_pusher_instance = None
|