third_party_pusher.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. """
  2. 第三方平台推送模块
  3. 将批次信息推送到第三方平台接口
  4. """
  5. import os
  6. import time
  7. import json
  8. import logging
  9. import threading
  10. import queue
  11. import requests
  12. from typing import Optional, Dict, Any, List, Callable
  13. from dataclasses import dataclass
  14. from datetime import datetime
  15. from pathlib import Path
  16. logger = logging.getLogger(__name__)
  17. @dataclass
  18. class BatchReport:
  19. """批次上报数据"""
  20. batch_id: str
  21. device_id: str
  22. project_id: str
  23. timestamp: float
  24. batch_info: Dict[str, Any] # batch_info.json 的完整内容
  25. local_path: Optional[str] = None # batch_info.json 本地路径
  26. class ThirdPartyPusher:
  27. """
  28. 第三方平台推送器
  29. 负责将批次信息推送到配置的第三方平台接口
  30. """
  31. def __init__(self, config: Dict[str, Any] = None):
  32. """
  33. 初始化第三方平台推送器
  34. Args:
  35. config: 第三方平台配置字典
  36. """
  37. from config import THIRD_PARTY_CONFIG, DEVICE_CONFIG
  38. self.config = config or THIRD_PARTY_CONFIG
  39. self.device_config = DEVICE_CONFIG
  40. # 功能开关
  41. self.enabled = self.config.get('enabled', False)
  42. # 平台配置
  43. self.platform_type = self.config.get('platform_type', 'custom')
  44. self.base_url = self.config.get('base_url', '')
  45. self.api_version = self.config.get('api_version', 'v1')
  46. # 认证配置
  47. self.auth_type = self.config.get('auth_type', 'none')
  48. self.api_key = self.config.get('api_key', '')
  49. self.api_secret = self.config.get('api_secret', '')
  50. self.oauth2_config = self.config.get('oauth2', {})
  51. # 接口路径
  52. self.endpoints = self.config.get('endpoints', {})
  53. self.batch_report_url = self.endpoints.get('batch_report', '/api/batch/report')
  54. self.heartbeat_url = self.endpoints.get('heartbeat', '/api/device/heartbeat')
  55. # 推送控制
  56. self.push_interval = self.config.get('push_interval', 1.0)
  57. self.retry_count = self.config.get('retry_count', 3)
  58. self.retry_delay = self.config.get('retry_delay', 2.0)
  59. self.timeout = self.config.get('timeout', 10)
  60. self.data_format = self.config.get('data_format', 'json')
  61. self.include_images = self.config.get('include_images', False)
  62. # OAuth2 Token
  63. self._access_token = None
  64. self._token_expires_at = 0
  65. # 上报队列
  66. self.report_queue = queue.Queue()
  67. # 工作线程
  68. self.running = False
  69. self.worker_thread = None
  70. # 统计
  71. self.stats = {
  72. 'total_reports': 0,
  73. 'success_reports': 0,
  74. 'failed_reports': 0,
  75. }
  76. self.stats_lock = threading.Lock()
  77. # 回调
  78. self.on_report_success: Optional[Callable] = None
  79. self.on_report_failed: Optional[Callable] = None
  80. # 最后上报时间
  81. self.last_report_time = 0
  82. if self.enabled:
  83. logger.info(f"[第三方平台] 推送器初始化完成: {self.base_url}")
  84. def start(self):
  85. """启动推送器"""
  86. if not self.enabled:
  87. logger.info("[第三方平台] 推送器未启用")
  88. return
  89. if self.running:
  90. return
  91. self.running = True
  92. self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
  93. self.worker_thread.start()
  94. logger.info("[第三方平台] 推送器已启动")
  95. def stop(self):
  96. """停止推送器"""
  97. self.running = False
  98. if self.worker_thread:
  99. self.worker_thread.join(timeout=5)
  100. logger.info("[第三方平台] 推送器已停止")
  101. def _worker_loop(self):
  102. """工作线程循环"""
  103. while self.running:
  104. try:
  105. report = self.report_queue.get(timeout=1.0)
  106. self._process_report(report)
  107. except queue.Empty:
  108. continue
  109. except Exception as e:
  110. logger.error(f"[第三方平台] 处理上报错误: {e}")
  111. def _get_auth_headers(self) -> Dict[str, str]:
  112. """获取认证请求头(当前第三方接口不需要自定义 header,返回空避免 422)"""
  113. return {}
  114. def _get_oauth2_token(self) -> Optional[str]:
  115. """获取 OAuth2 Token"""
  116. # 检查现有 token 是否有效
  117. if self._access_token and time.time() < self._token_expires_at - 60:
  118. return self._access_token
  119. # 重新获取 token
  120. token_url = self.oauth2_config.get('token_url', '')
  121. client_id = self.oauth2_config.get('client_id', '')
  122. client_secret = self.oauth2_config.get('client_secret', '')
  123. scope = self.oauth2_config.get('scope', '')
  124. if not all([token_url, client_id, client_secret]):
  125. logger.error("[第三方平台] OAuth2 配置不完整")
  126. return None
  127. try:
  128. data = {
  129. 'grant_type': 'client_credentials',
  130. 'client_id': client_id,
  131. 'client_secret': client_secret,
  132. }
  133. if scope:
  134. data['scope'] = scope
  135. response = requests.post(token_url, data=data, timeout=self.timeout)
  136. if response.status_code == 200:
  137. result = response.json()
  138. self._access_token = result.get('access_token')
  139. expires_in = result.get('expires_in', 3600)
  140. self._token_expires_at = time.time() + expires_in
  141. logger.info("[第三方平台] OAuth2 Token 获取成功")
  142. return self._access_token
  143. else:
  144. logger.error(f"[第三方平台] OAuth2 Token 获取失败: {response.status_code}")
  145. return None
  146. except Exception as e:
  147. logger.error(f"[第三方平台] OAuth2 Token 请求异常: {e}")
  148. return None
  149. def _process_report(self, report: BatchReport):
  150. """处理单个上报任务"""
  151. # 检查推送间隔
  152. current_time = time.time()
  153. time_since_last = current_time - self.last_report_time
  154. if time_since_last < self.push_interval:
  155. time.sleep(self.push_interval - time_since_last)
  156. success = self._send_batch_report(report)
  157. with self.stats_lock:
  158. self.stats['total_reports'] += 1
  159. if success:
  160. self.stats['success_reports'] += 1
  161. else:
  162. self.stats['failed_reports'] += 1
  163. self.last_report_time = time.time()
  164. # 触发回调
  165. if success and self.on_report_success:
  166. try:
  167. self.on_report_success(report)
  168. except Exception as e:
  169. logger.error(f"[第三方平台] 成功回调执行错误: {e}")
  170. elif not success and self.on_report_failed:
  171. try:
  172. self.on_report_failed(report)
  173. except Exception as e:
  174. logger.error(f"[第三方平台] 失败回调执行错误: {e}")
  175. def _send_batch_report(self, report: BatchReport) -> bool:
  176. """
  177. 发送批次上报请求
  178. Args:
  179. report: 批次上报数据
  180. Returns:
  181. bool: 是否成功
  182. """
  183. if not self.base_url:
  184. logger.error("[第三方平台] 未配置 base_url")
  185. return False
  186. url = f"{self.base_url}{self.batch_report_url}"
  187. # 构建请求数据
  188. payload = self._build_payload(report)
  189. headers = self._get_auth_headers()
  190. for attempt in range(self.retry_count):
  191. try:
  192. if self.data_format == 'json':
  193. response = requests.post(
  194. url,
  195. json=payload,
  196. headers=headers,
  197. timeout=self.timeout,
  198. verify=False
  199. )
  200. else:
  201. response = requests.post(
  202. url,
  203. data=payload,
  204. headers=headers,
  205. timeout=self.timeout,
  206. verify=False
  207. )
  208. if response.status_code == 200:
  209. result = response.json()
  210. status = result.get('status', '')
  211. message = result.get('message', '')
  212. if (result.get('code') == 200 or
  213. result.get('success') == True or
  214. status in ('pending', 'success', 'accepted') or
  215. message == 'accepted'):
  216. logger.info(f"[第三方平台] 批次上报成功: {report.batch_id}, task_id={result.get('task_id')}")
  217. return True
  218. else:
  219. logger.warning(f"[第三方平台] 批次上报失败: {result.get('msg', '未知错误')}")
  220. try:
  221. logger.warning(f"[第三方平台] 响应内容: {str(result)[:500]}")
  222. except Exception:
  223. pass
  224. else:
  225. logger.warning(f"[第三方平台] 批次上报失败: HTTP {response.status_code}")
  226. try:
  227. logger.warning(f"[第三方平台] 响应内容: {response.text[:500]}")
  228. except Exception:
  229. pass
  230. if attempt < self.retry_count - 1:
  231. time.sleep(self.retry_delay)
  232. except requests.exceptions.Timeout:
  233. logger.warning(f"[第三方平台] 请求超时 (尝试 {attempt + 1}/{self.retry_count})")
  234. if attempt < self.retry_count - 1:
  235. time.sleep(self.retry_delay)
  236. except Exception as e:
  237. logger.error(f"[第三方平台] 请求异常 (尝试 {attempt + 1}/{self.retry_count}): {e}")
  238. if attempt < self.retry_count - 1:
  239. time.sleep(self.retry_delay)
  240. logger.error(f"[第三方平台] 批次上报最终失败: {report.batch_id}")
  241. return False
  242. def _build_payload(self, report: BatchReport) -> Dict[str, Any]:
  243. """
  244. 构建上报请求体
  245. Args:
  246. report: 批次上报数据
  247. Returns:
  248. Dict: 请求体字典
  249. """
  250. batch_info = report.batch_info
  251. # 根据平台类型调整格式
  252. if self.platform_type == 'jtjai':
  253. # jtjai 平台特定格式
  254. payload = {
  255. 'createTime': datetime.fromtimestamp(report.timestamp).strftime("%Y-%m-%d %H:%M:%S"),
  256. 'addr': f"设备{report.device_id}批次上报",
  257. 'ext1': json.dumps([batch_info.get('panorama', {}).get('oss_url')]),
  258. 'ext2': json.dumps({
  259. 'batchId': report.batch_id,
  260. 'deviceId': report.device_id,
  261. 'projectId': report.project_id,
  262. 'totalPersons': batch_info.get('total_persons', 0),
  263. 'ptzImagesCount': batch_info.get('ptz_images_count', 0),
  264. 'persons': batch_info.get('persons', []),
  265. })
  266. }
  267. else:
  268. # custom / 其他平台:原样发送 batch_info(snake_case)
  269. payload = dict(batch_info)
  270. return payload
  271. def report_batch(self, batch_info: Dict[str, Any], local_path: Optional[str] = None):
  272. """
  273. 上报批次信息
  274. Args:
  275. batch_info: batch_info.json 的字典内容
  276. local_path: batch_info.json 的本地文件路径(可选)
  277. """
  278. if not self.enabled:
  279. return
  280. report = BatchReport(
  281. batch_id=batch_info.get('batch_id', ''),
  282. device_id=batch_info.get('device_id', ''),
  283. project_id=batch_info.get('project_id', ''),
  284. timestamp=batch_info.get('timestamp', time.time()),
  285. batch_info=batch_info,
  286. local_path=local_path
  287. )
  288. self.report_queue.put(report)
  289. def report_batch_sync(self, batch_info: Dict[str, Any],
  290. local_path: Optional[str] = None) -> bool:
  291. """
  292. 同步上报批次信息
  293. Args:
  294. batch_info: batch_info.json 的字典内容
  295. local_path: batch_info.json 的本地文件路径(可选)
  296. Returns:
  297. bool: 是否成功
  298. """
  299. if not self.enabled:
  300. return False
  301. report = BatchReport(
  302. batch_id=batch_info.get('batch_id', ''),
  303. device_id=batch_info.get('device_id', ''),
  304. project_id=batch_info.get('project_id', ''),
  305. timestamp=batch_info.get('timestamp', time.time()),
  306. batch_info=batch_info,
  307. local_path=local_path
  308. )
  309. return self._send_batch_report(report)
  310. def send_heartbeat(self) -> bool:
  311. """
  312. 发送心跳
  313. Returns:
  314. bool: 是否成功
  315. """
  316. if not self.enabled or not self.heartbeat_url:
  317. return False
  318. url = f"{self.base_url}{self.heartbeat_url}"
  319. payload = {
  320. 'deviceId': self.device_config.get('device_id', ''),
  321. 'projectId': self.device_config.get('project_id', ''),
  322. 'timestamp': time.time(),
  323. 'status': 'online',
  324. }
  325. headers = self._get_auth_headers()
  326. try:
  327. response = requests.post(
  328. url,
  329. json=payload,
  330. headers=headers,
  331. timeout=self.timeout,
  332. verify=False
  333. )
  334. if response.status_code == 200:
  335. logger.debug("[第三方平台] 心跳发送成功")
  336. return True
  337. else:
  338. logger.warning(f"[第三方平台] 心跳发送失败: HTTP {response.status_code}")
  339. return False
  340. except Exception as e:
  341. logger.error(f"[第三方平台] 心跳发送异常: {e}")
  342. return False
  343. def set_callbacks(self, on_success: Callable = None, on_failed: Callable = None):
  344. """
  345. 设置回调函数
  346. Args:
  347. on_success: 上报成功回调
  348. on_failed: 上报失败回调
  349. """
  350. self.on_report_success = on_success
  351. self.on_report_failed = on_failed
  352. def get_stats(self) -> Dict[str, int]:
  353. """获取统计信息"""
  354. with self.stats_lock:
  355. return self.stats.copy()
  356. def is_enabled(self) -> bool:
  357. """检查是否启用"""
  358. return self.enabled
  359. # 全局单例
  360. _third_party_pusher_instance: Optional[ThirdPartyPusher] = None
  361. _third_party_pusher_lock = threading.Lock()
  362. def get_third_party_pusher(config: Dict[str, Any] = None) -> ThirdPartyPusher:
  363. """
  364. 获取第三方平台推送器实例(单例模式,线程安全)
  365. Args:
  366. config: 第三方平台配置
  367. Returns:
  368. ThirdPartyPusher 实例
  369. """
  370. global _third_party_pusher_instance
  371. if _third_party_pusher_instance is None:
  372. with _third_party_pusher_lock:
  373. if _third_party_pusher_instance is None:
  374. _third_party_pusher_instance = ThirdPartyPusher(config)
  375. return _third_party_pusher_instance
  376. def reset_third_party_pusher():
  377. """重置第三方平台推送器实例"""
  378. global _third_party_pusher_instance
  379. with _third_party_pusher_lock:
  380. if _third_party_pusher_instance is not None:
  381. _third_party_pusher_instance.stop()
  382. _third_party_pusher_instance = None