third_party_pusher.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. """
  2. 第三方平台推送模块
  3. 将批次信息推送到第三方平台接口
  4. """
  5. import os
  6. import time
  7. import json
  8. import logging
  9. import threading
  10. import queue
  11. import requests
  12. from typing import Optional, Dict, Any, List, Callable
  13. from dataclasses import dataclass
  14. from datetime import datetime
  15. from pathlib import Path
  16. logger = logging.getLogger(__name__)
  17. @dataclass
  18. class BatchReport:
  19. """批次上报数据"""
  20. batch_id: str
  21. device_id: str
  22. project_id: str
  23. timestamp: float
  24. batch_info: Dict[str, Any] # batch_info.json 的完整内容
  25. local_path: Optional[str] = None # batch_info.json 本地路径
  26. class ThirdPartyPusher:
  27. """
  28. 第三方平台推送器
  29. 负责将批次信息推送到配置的第三方平台接口
  30. """
  31. def __init__(self, config: Dict[str, Any] = None):
  32. """
  33. 初始化第三方平台推送器
  34. Args:
  35. config: 第三方平台配置字典
  36. """
  37. from config import THIRD_PARTY_CONFIG, DEVICE_CONFIG
  38. self.config = config or THIRD_PARTY_CONFIG
  39. self.device_config = DEVICE_CONFIG
  40. # 功能开关
  41. self.enabled = self.config.get('enabled', False)
  42. # 平台配置
  43. self.platform_type = self.config.get('platform_type', 'custom')
  44. self.base_url = self.config.get('base_url', '')
  45. self.api_version = self.config.get('api_version', 'v1')
  46. # 认证配置
  47. self.auth_type = self.config.get('auth_type', 'none')
  48. self.api_key = self.config.get('api_key', '')
  49. self.api_secret = self.config.get('api_secret', '')
  50. self.oauth2_config = self.config.get('oauth2', {})
  51. # 接口路径
  52. self.endpoints = self.config.get('endpoints', {})
  53. self.batch_report_url = self.endpoints.get('batch_report', '/api/batch/report')
  54. self.heartbeat_url = self.endpoints.get('heartbeat', '/api/device/heartbeat')
  55. # 推送控制
  56. self.push_interval = self.config.get('push_interval', 1.0)
  57. self.retry_count = self.config.get('retry_count', 3)
  58. self.retry_delay = self.config.get('retry_delay', 2.0)
  59. self.timeout = self.config.get('timeout', 10)
  60. self.data_format = self.config.get('data_format', 'json')
  61. self.include_images = self.config.get('include_images', False)
  62. # OAuth2 Token
  63. self._access_token = None
  64. self._token_expires_at = 0
  65. # 上报队列
  66. self.report_queue = queue.Queue()
  67. # 工作线程
  68. self.running = False
  69. self.worker_thread = None
  70. # 统计
  71. self.stats = {
  72. 'total_reports': 0,
  73. 'success_reports': 0,
  74. 'failed_reports': 0,
  75. }
  76. self.stats_lock = threading.Lock()
  77. # 回调
  78. self.on_report_success: Optional[Callable] = None
  79. self.on_report_failed: Optional[Callable] = None
  80. # 最后上报时间
  81. self.last_report_time = 0
  82. if self.enabled:
  83. logger.info(f"[第三方平台] 推送器初始化完成: {self.base_url}")
  84. def start(self):
  85. """启动推送器"""
  86. if not self.enabled:
  87. logger.info("[第三方平台] 推送器未启用")
  88. return
  89. if self.running:
  90. return
  91. self.running = True
  92. self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
  93. self.worker_thread.start()
  94. logger.info("[第三方平台] 推送器已启动")
  95. def stop(self):
  96. """停止推送器"""
  97. self.running = False
  98. if self.worker_thread:
  99. self.worker_thread.join(timeout=5)
  100. logger.info("[第三方平台] 推送器已停止")
  101. def _worker_loop(self):
  102. """工作线程循环"""
  103. while self.running:
  104. try:
  105. report = self.report_queue.get(timeout=1.0)
  106. self._process_report(report)
  107. except queue.Empty:
  108. continue
  109. except Exception as e:
  110. logger.error(f"[第三方平台] 处理上报错误: {e}")
  111. def _get_auth_headers(self) -> Dict[str, str]:
  112. """获取认证请求头"""
  113. headers = {
  114. 'Content-Type': 'application/json',
  115. 'Accept': 'application/json',
  116. }
  117. if self.auth_type == 'api_key':
  118. headers['X-API-Key'] = self.api_key
  119. if self.api_secret:
  120. headers['X-API-Secret'] = self.api_secret
  121. elif self.auth_type == 'oauth2':
  122. token = self._get_oauth2_token()
  123. if token:
  124. headers['Authorization'] = f'Bearer {token}'
  125. elif self.auth_type == 'basic':
  126. import base64
  127. credentials = base64.b64encode(f"{self.api_key}:{self.api_secret}".encode()).decode()
  128. headers['Authorization'] = f'Basic {credentials}'
  129. return headers
  130. def _get_oauth2_token(self) -> Optional[str]:
  131. """获取 OAuth2 Token"""
  132. # 检查现有 token 是否有效
  133. if self._access_token and time.time() < self._token_expires_at - 60:
  134. return self._access_token
  135. # 重新获取 token
  136. token_url = self.oauth2_config.get('token_url', '')
  137. client_id = self.oauth2_config.get('client_id', '')
  138. client_secret = self.oauth2_config.get('client_secret', '')
  139. scope = self.oauth2_config.get('scope', '')
  140. if not all([token_url, client_id, client_secret]):
  141. logger.error("[第三方平台] OAuth2 配置不完整")
  142. return None
  143. try:
  144. data = {
  145. 'grant_type': 'client_credentials',
  146. 'client_id': client_id,
  147. 'client_secret': client_secret,
  148. }
  149. if scope:
  150. data['scope'] = scope
  151. response = requests.post(token_url, data=data, timeout=self.timeout)
  152. if response.status_code == 200:
  153. result = response.json()
  154. self._access_token = result.get('access_token')
  155. expires_in = result.get('expires_in', 3600)
  156. self._token_expires_at = time.time() + expires_in
  157. logger.info("[第三方平台] OAuth2 Token 获取成功")
  158. return self._access_token
  159. else:
  160. logger.error(f"[第三方平台] OAuth2 Token 获取失败: {response.status_code}")
  161. return None
  162. except Exception as e:
  163. logger.error(f"[第三方平台] OAuth2 Token 请求异常: {e}")
  164. return None
  165. def _process_report(self, report: BatchReport):
  166. """处理单个上报任务"""
  167. # 检查推送间隔
  168. current_time = time.time()
  169. time_since_last = current_time - self.last_report_time
  170. if time_since_last < self.push_interval:
  171. time.sleep(self.push_interval - time_since_last)
  172. success = self._send_batch_report(report)
  173. with self.stats_lock:
  174. self.stats['total_reports'] += 1
  175. if success:
  176. self.stats['success_reports'] += 1
  177. else:
  178. self.stats['failed_reports'] += 1
  179. self.last_report_time = time.time()
  180. # 触发回调
  181. if success and self.on_report_success:
  182. try:
  183. self.on_report_success(report)
  184. except Exception as e:
  185. logger.error(f"[第三方平台] 成功回调执行错误: {e}")
  186. elif not success and self.on_report_failed:
  187. try:
  188. self.on_report_failed(report)
  189. except Exception as e:
  190. logger.error(f"[第三方平台] 失败回调执行错误: {e}")
  191. def _send_batch_report(self, report: BatchReport) -> bool:
  192. """
  193. 发送批次上报请求
  194. Args:
  195. report: 批次上报数据
  196. Returns:
  197. bool: 是否成功
  198. """
  199. if not self.base_url:
  200. logger.error("[第三方平台] 未配置 base_url")
  201. return False
  202. url = f"{self.base_url}{self.batch_report_url}"
  203. # 构建请求数据
  204. payload = self._build_payload(report)
  205. headers = self._get_auth_headers()
  206. for attempt in range(self.retry_count):
  207. try:
  208. if self.data_format == 'json':
  209. response = requests.post(
  210. url,
  211. json=payload,
  212. headers=headers,
  213. timeout=self.timeout,
  214. verify=False
  215. )
  216. else:
  217. response = requests.post(
  218. url,
  219. data=payload,
  220. headers=headers,
  221. timeout=self.timeout,
  222. verify=False
  223. )
  224. if response.status_code == 200:
  225. result = response.json()
  226. if result.get('code') == 200 or result.get('success') == True:
  227. logger.info(f"[第三方平台] 批次上报成功: {report.batch_id}")
  228. return True
  229. else:
  230. logger.warning(f"[第三方平台] 批次上报失败: {result.get('msg', '未知错误')}")
  231. else:
  232. logger.warning(f"[第三方平台] 批次上报失败: HTTP {response.status_code}")
  233. if attempt < self.retry_count - 1:
  234. time.sleep(self.retry_delay)
  235. except requests.exceptions.Timeout:
  236. logger.warning(f"[第三方平台] 请求超时 (尝试 {attempt + 1}/{self.retry_count})")
  237. if attempt < self.retry_count - 1:
  238. time.sleep(self.retry_delay)
  239. except Exception as e:
  240. logger.error(f"[第三方平台] 请求异常 (尝试 {attempt + 1}/{self.retry_count}): {e}")
  241. if attempt < self.retry_count - 1:
  242. time.sleep(self.retry_delay)
  243. logger.error(f"[第三方平台] 批次上报最终失败: {report.batch_id}")
  244. return False
  245. def _build_payload(self, report: BatchReport) -> Dict[str, Any]:
  246. """
  247. 构建上报请求体
  248. Args:
  249. report: 批次上报数据
  250. Returns:
  251. Dict: 请求体字典
  252. """
  253. batch_info = report.batch_info
  254. # 标准上报格式
  255. payload = {
  256. 'deviceId': report.device_id,
  257. 'projectId': report.project_id,
  258. 'batchId': report.batch_id,
  259. 'timestamp': report.timestamp,
  260. 'datetime': datetime.fromtimestamp(report.timestamp).isoformat(),
  261. 'totalPersons': batch_info.get('total_persons', 0),
  262. 'ptzImagesCount': batch_info.get('ptz_images_count', 0),
  263. 'panorama': batch_info.get('panorama', {}),
  264. 'persons': batch_info.get('persons', []),
  265. 'uploadStatus': batch_info.get('upload_status', {}),
  266. }
  267. # 根据平台类型调整格式
  268. if self.platform_type == 'jtjai':
  269. # jtjai 平台特定格式
  270. payload = {
  271. 'createTime': datetime.fromtimestamp(report.timestamp).strftime("%Y-%m-%d %H:%M:%S"),
  272. 'addr': f"设备{report.device_id}批次上报",
  273. 'ext1': json.dumps([batch_info.get('panorama', {}).get('oss_url')]),
  274. 'ext2': json.dumps({
  275. 'batchId': report.batch_id,
  276. 'deviceId': report.device_id,
  277. 'projectId': report.project_id,
  278. 'totalPersons': batch_info.get('total_persons', 0),
  279. 'ptzImagesCount': batch_info.get('ptz_images_count', 0),
  280. 'persons': batch_info.get('persons', []),
  281. })
  282. }
  283. return payload
  284. def report_batch(self, batch_info: Dict[str, Any], local_path: Optional[str] = None):
  285. """
  286. 上报批次信息
  287. Args:
  288. batch_info: batch_info.json 的字典内容
  289. local_path: batch_info.json 的本地文件路径(可选)
  290. """
  291. if not self.enabled:
  292. return
  293. report = BatchReport(
  294. batch_id=batch_info.get('batch_id', ''),
  295. device_id=batch_info.get('device_id', ''),
  296. project_id=batch_info.get('project_id', ''),
  297. timestamp=batch_info.get('timestamp', time.time()),
  298. batch_info=batch_info,
  299. local_path=local_path
  300. )
  301. self.report_queue.put(report)
  302. with self.stats_lock:
  303. self.stats['total_reports'] += 1
  304. def report_batch_sync(self, batch_info: Dict[str, Any],
  305. local_path: Optional[str] = None) -> bool:
  306. """
  307. 同步上报批次信息
  308. Args:
  309. batch_info: batch_info.json 的字典内容
  310. local_path: batch_info.json 的本地文件路径(可选)
  311. Returns:
  312. bool: 是否成功
  313. """
  314. if not self.enabled:
  315. return False
  316. report = BatchReport(
  317. batch_id=batch_info.get('batch_id', ''),
  318. device_id=batch_info.get('device_id', ''),
  319. project_id=batch_info.get('project_id', ''),
  320. timestamp=batch_info.get('timestamp', time.time()),
  321. batch_info=batch_info,
  322. local_path=local_path
  323. )
  324. return self._send_batch_report(report)
  325. def send_heartbeat(self) -> bool:
  326. """
  327. 发送心跳
  328. Returns:
  329. bool: 是否成功
  330. """
  331. if not self.enabled or not self.heartbeat_url:
  332. return False
  333. url = f"{self.base_url}{self.heartbeat_url}"
  334. payload = {
  335. 'deviceId': self.device_config.get('device_id', ''),
  336. 'projectId': self.device_config.get('project_id', ''),
  337. 'timestamp': time.time(),
  338. 'status': 'online',
  339. }
  340. headers = self._get_auth_headers()
  341. try:
  342. response = requests.post(
  343. url,
  344. json=payload,
  345. headers=headers,
  346. timeout=self.timeout,
  347. verify=False
  348. )
  349. if response.status_code == 200:
  350. logger.debug("[第三方平台] 心跳发送成功")
  351. return True
  352. else:
  353. logger.warning(f"[第三方平台] 心跳发送失败: HTTP {response.status_code}")
  354. return False
  355. except Exception as e:
  356. logger.error(f"[第三方平台] 心跳发送异常: {e}")
  357. return False
  358. def set_callbacks(self, on_success: Callable = None, on_failed: Callable = None):
  359. """
  360. 设置回调函数
  361. Args:
  362. on_success: 上报成功回调
  363. on_failed: 上报失败回调
  364. """
  365. self.on_report_success = on_success
  366. self.on_report_failed = on_failed
  367. def get_stats(self) -> Dict[str, int]:
  368. """获取统计信息"""
  369. with self.stats_lock:
  370. return self.stats.copy()
  371. def is_enabled(self) -> bool:
  372. """检查是否启用"""
  373. return self.enabled
  374. # 全局单例
  375. _third_party_pusher_instance: Optional[ThirdPartyPusher] = None
  376. def get_third_party_pusher(config: Dict[str, Any] = None) -> ThirdPartyPusher:
  377. """
  378. 获取第三方平台推送器实例(单例模式)
  379. Args:
  380. config: 第三方平台配置
  381. Returns:
  382. ThirdPartyPusher 实例
  383. """
  384. global _third_party_pusher_instance
  385. if _third_party_pusher_instance is None:
  386. _third_party_pusher_instance = ThirdPartyPusher(config)
  387. return _third_party_pusher_instance
  388. def reset_third_party_pusher():
  389. """重置第三方平台推送器实例"""
  390. global _third_party_pusher_instance
  391. if _third_party_pusher_instance is not None:
  392. _third_party_pusher_instance.stop()
  393. _third_party_pusher_instance = None