""" 第三方平台推送模块 将批次信息推送到第三方平台接口 """ import os import time import json import logging import threading import queue import requests from typing import Optional, Dict, Any, List, Callable from dataclasses import dataclass from datetime import datetime from pathlib import Path logger = logging.getLogger(__name__) @dataclass class BatchReport: """批次上报数据""" batch_id: str device_id: str project_id: str timestamp: float batch_info: Dict[str, Any] # batch_info.json 的完整内容 local_path: Optional[str] = None # batch_info.json 本地路径 class ThirdPartyPusher: """ 第三方平台推送器 负责将批次信息推送到配置的第三方平台接口 """ def __init__(self, config: Dict[str, Any] = None): """ 初始化第三方平台推送器 Args: config: 第三方平台配置字典 """ from config import THIRD_PARTY_CONFIG, DEVICE_CONFIG self.config = config or THIRD_PARTY_CONFIG self.device_config = DEVICE_CONFIG # 功能开关 self.enabled = self.config.get('enabled', False) # 平台配置 self.platform_type = self.config.get('platform_type', 'custom') self.base_url = self.config.get('base_url', '') self.api_version = self.config.get('api_version', 'v1') # 认证配置 self.auth_type = self.config.get('auth_type', 'none') self.api_key = self.config.get('api_key', '') self.api_secret = self.config.get('api_secret', '') self.oauth2_config = self.config.get('oauth2', {}) # 接口路径 self.endpoints = self.config.get('endpoints', {}) self.batch_report_url = self.endpoints.get('batch_report', '/api/batch/report') self.heartbeat_url = self.endpoints.get('heartbeat', '/api/device/heartbeat') # 推送控制 self.push_interval = self.config.get('push_interval', 1.0) self.retry_count = self.config.get('retry_count', 3) self.retry_delay = self.config.get('retry_delay', 2.0) self.timeout = self.config.get('timeout', 10) self.data_format = self.config.get('data_format', 'json') self.include_images = self.config.get('include_images', False) # OAuth2 Token self._access_token = None self._token_expires_at = 0 # 上报队列 self.report_queue = queue.Queue() # 工作线程 self.running = False self.worker_thread = None # 统计 self.stats = { 'total_reports': 0, 'success_reports': 0, 'failed_reports': 0, } self.stats_lock = threading.Lock() # 回调 self.on_report_success: Optional[Callable] = None self.on_report_failed: Optional[Callable] = None # 最后上报时间 self.last_report_time = 0 if self.enabled: logger.info(f"[第三方平台] 推送器初始化完成: {self.base_url}") def start(self): """启动推送器""" if not self.enabled: logger.info("[第三方平台] 推送器未启用") return if self.running: return self.running = True self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True) self.worker_thread.start() logger.info("[第三方平台] 推送器已启动") def stop(self): """停止推送器""" self.running = False if self.worker_thread: self.worker_thread.join(timeout=5) logger.info("[第三方平台] 推送器已停止") def _worker_loop(self): """工作线程循环""" while self.running: try: report = self.report_queue.get(timeout=1.0) self._process_report(report) except queue.Empty: continue except Exception as e: logger.error(f"[第三方平台] 处理上报错误: {e}") def _get_auth_headers(self) -> Dict[str, str]: """获取认证请求头""" headers = { 'Content-Type': 'application/json', 'Accept': 'application/json', } if self.auth_type == 'api_key': headers['X-API-Key'] = self.api_key if self.api_secret: headers['X-API-Secret'] = self.api_secret elif self.auth_type == 'oauth2': token = self._get_oauth2_token() if token: headers['Authorization'] = f'Bearer {token}' elif self.auth_type == 'basic': import base64 credentials = base64.b64encode(f"{self.api_key}:{self.api_secret}".encode()).decode() headers['Authorization'] = f'Basic {credentials}' return headers def _get_oauth2_token(self) -> Optional[str]: """获取 OAuth2 Token""" # 检查现有 token 是否有效 if self._access_token and time.time() < self._token_expires_at - 60: return self._access_token # 重新获取 token token_url = self.oauth2_config.get('token_url', '') client_id = self.oauth2_config.get('client_id', '') client_secret = self.oauth2_config.get('client_secret', '') scope = self.oauth2_config.get('scope', '') if not all([token_url, client_id, client_secret]): logger.error("[第三方平台] OAuth2 配置不完整") return None try: data = { 'grant_type': 'client_credentials', 'client_id': client_id, 'client_secret': client_secret, } if scope: data['scope'] = scope response = requests.post(token_url, data=data, timeout=self.timeout) if response.status_code == 200: result = response.json() self._access_token = result.get('access_token') expires_in = result.get('expires_in', 3600) self._token_expires_at = time.time() + expires_in logger.info("[第三方平台] OAuth2 Token 获取成功") return self._access_token else: logger.error(f"[第三方平台] OAuth2 Token 获取失败: {response.status_code}") return None except Exception as e: logger.error(f"[第三方平台] OAuth2 Token 请求异常: {e}") return None def _process_report(self, report: BatchReport): """处理单个上报任务""" # 检查推送间隔 current_time = time.time() time_since_last = current_time - self.last_report_time if time_since_last < self.push_interval: time.sleep(self.push_interval - time_since_last) success = self._send_batch_report(report) with self.stats_lock: self.stats['total_reports'] += 1 if success: self.stats['success_reports'] += 1 else: self.stats['failed_reports'] += 1 self.last_report_time = time.time() # 触发回调 if success and self.on_report_success: try: self.on_report_success(report) except Exception as e: logger.error(f"[第三方平台] 成功回调执行错误: {e}") elif not success and self.on_report_failed: try: self.on_report_failed(report) except Exception as e: logger.error(f"[第三方平台] 失败回调执行错误: {e}") def _send_batch_report(self, report: BatchReport) -> bool: """ 发送批次上报请求 Args: report: 批次上报数据 Returns: bool: 是否成功 """ if not self.base_url: logger.error("[第三方平台] 未配置 base_url") return False url = f"{self.base_url}{self.batch_report_url}" # 构建请求数据 payload = self._build_payload(report) headers = self._get_auth_headers() for attempt in range(self.retry_count): try: if self.data_format == 'json': response = requests.post( url, json=payload, headers=headers, timeout=self.timeout, verify=False ) else: response = requests.post( url, data=payload, headers=headers, timeout=self.timeout, verify=False ) if response.status_code == 200: result = response.json() if result.get('code') == 200 or result.get('success') == True: logger.info(f"[第三方平台] 批次上报成功: {report.batch_id}") return True else: logger.warning(f"[第三方平台] 批次上报失败: {result.get('msg', '未知错误')}") else: logger.warning(f"[第三方平台] 批次上报失败: HTTP {response.status_code}") if attempt < self.retry_count - 1: time.sleep(self.retry_delay) except requests.exceptions.Timeout: logger.warning(f"[第三方平台] 请求超时 (尝试 {attempt + 1}/{self.retry_count})") if attempt < self.retry_count - 1: time.sleep(self.retry_delay) except Exception as e: logger.error(f"[第三方平台] 请求异常 (尝试 {attempt + 1}/{self.retry_count}): {e}") if attempt < self.retry_count - 1: time.sleep(self.retry_delay) logger.error(f"[第三方平台] 批次上报最终失败: {report.batch_id}") return False def _build_payload(self, report: BatchReport) -> Dict[str, Any]: """ 构建上报请求体 Args: report: 批次上报数据 Returns: Dict: 请求体字典 """ batch_info = report.batch_info # 标准上报格式 payload = { 'deviceId': report.device_id, 'projectId': report.project_id, 'batchId': report.batch_id, 'timestamp': report.timestamp, 'datetime': datetime.fromtimestamp(report.timestamp).isoformat(), 'totalPersons': batch_info.get('total_persons', 0), 'ptzImagesCount': batch_info.get('ptz_images_count', 0), 'panorama': batch_info.get('panorama', {}), 'persons': batch_info.get('persons', []), 'uploadStatus': batch_info.get('upload_status', {}), } # 根据平台类型调整格式 if self.platform_type == 'jtjai': # jtjai 平台特定格式 payload = { 'createTime': datetime.fromtimestamp(report.timestamp).strftime("%Y-%m-%d %H:%M:%S"), 'addr': f"设备{report.device_id}批次上报", 'ext1': json.dumps([batch_info.get('panorama', {}).get('oss_url')]), 'ext2': json.dumps({ 'batchId': report.batch_id, 'deviceId': report.device_id, 'projectId': report.project_id, 'totalPersons': batch_info.get('total_persons', 0), 'ptzImagesCount': batch_info.get('ptz_images_count', 0), 'persons': batch_info.get('persons', []), }) } return payload def report_batch(self, batch_info: Dict[str, Any], local_path: Optional[str] = None): """ 上报批次信息 Args: batch_info: batch_info.json 的字典内容 local_path: batch_info.json 的本地文件路径(可选) """ if not self.enabled: return report = BatchReport( batch_id=batch_info.get('batch_id', ''), device_id=batch_info.get('device_id', ''), project_id=batch_info.get('project_id', ''), timestamp=batch_info.get('timestamp', time.time()), batch_info=batch_info, local_path=local_path ) self.report_queue.put(report) def report_batch_sync(self, batch_info: Dict[str, Any], local_path: Optional[str] = None) -> bool: """ 同步上报批次信息 Args: batch_info: batch_info.json 的字典内容 local_path: batch_info.json 的本地文件路径(可选) Returns: bool: 是否成功 """ if not self.enabled: return False report = BatchReport( batch_id=batch_info.get('batch_id', ''), device_id=batch_info.get('device_id', ''), project_id=batch_info.get('project_id', ''), timestamp=batch_info.get('timestamp', time.time()), batch_info=batch_info, local_path=local_path ) return self._send_batch_report(report) def send_heartbeat(self) -> bool: """ 发送心跳 Returns: bool: 是否成功 """ if not self.enabled or not self.heartbeat_url: return False url = f"{self.base_url}{self.heartbeat_url}" payload = { 'deviceId': self.device_config.get('device_id', ''), 'projectId': self.device_config.get('project_id', ''), 'timestamp': time.time(), 'status': 'online', } headers = self._get_auth_headers() try: response = requests.post( url, json=payload, headers=headers, timeout=self.timeout, verify=False ) if response.status_code == 200: logger.debug("[第三方平台] 心跳发送成功") return True else: logger.warning(f"[第三方平台] 心跳发送失败: HTTP {response.status_code}") return False except Exception as e: logger.error(f"[第三方平台] 心跳发送异常: {e}") return False def set_callbacks(self, on_success: Callable = None, on_failed: Callable = None): """ 设置回调函数 Args: on_success: 上报成功回调 on_failed: 上报失败回调 """ self.on_report_success = on_success self.on_report_failed = on_failed def get_stats(self) -> Dict[str, int]: """获取统计信息""" with self.stats_lock: return self.stats.copy() def is_enabled(self) -> bool: """检查是否启用""" return self.enabled # 全局单例 _third_party_pusher_instance: Optional[ThirdPartyPusher] = None _third_party_pusher_lock = threading.Lock() def get_third_party_pusher(config: Dict[str, Any] = None) -> ThirdPartyPusher: """ 获取第三方平台推送器实例(单例模式,线程安全) Args: config: 第三方平台配置 Returns: ThirdPartyPusher 实例 """ global _third_party_pusher_instance if _third_party_pusher_instance is None: with _third_party_pusher_lock: if _third_party_pusher_instance is None: _third_party_pusher_instance = ThirdPartyPusher(config) return _third_party_pusher_instance def reset_third_party_pusher(): """重置第三方平台推送器实例""" global _third_party_pusher_instance with _third_party_pusher_lock: if _third_party_pusher_instance is not None: _third_party_pusher_instance.stop() _third_party_pusher_instance = None