third_party_pusher.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. """
  2. 第三方平台推送模块
  3. 将批次信息推送到第三方平台接口
  4. """
  5. import os
  6. import time
  7. import json
  8. import logging
  9. import threading
  10. import queue
  11. import cv2
  12. import requests
  13. from typing import Optional, Dict, Any, List, Callable
  14. from dataclasses import dataclass
  15. from datetime import datetime
  16. from pathlib import Path
  17. logger = logging.getLogger(__name__)
  18. def _normalize_timestamp(ts: float) -> float:
  19. """统一时间戳为秒。CaptureUploader 使用毫秒,其他可能使用秒。"""
  20. if ts > 1e12:
  21. return ts / 1000.0
  22. return ts
  23. def _convert_to_legacy_batch_info(new_info: Dict[str, Any]) -> Dict[str, Any]:
  24. """
  25. 把新版 CaptureUploader 生成的 batch_info 转成业务平台要求的 PairedImageSaver 老格式。
  26. 参考字段:
  27. - batch_id / device_id / project_id / timestamp / datetime
  28. - total_persons / ptz_images_count
  29. - panorama: local_path / local_path_original / oss_url / oss_url_original
  30. - persons: person_index / position(x,y 归一化) / bbox / confidence /
  31. ptz_position / ptz_bbox / ptz_image_saved /
  32. ptz_image_path / ptz_image_original_path /
  33. ptz_oss_url / ptz_oss_url_original
  34. - upload_status
  35. """
  36. normalized_ts = _normalize_timestamp(new_info.get("timestamp", time.time()))
  37. urls = new_info.get("image_urls") or {}
  38. image_paths = new_info.get("image_paths") or []
  39. camera_type = new_info.get("camera_type", "panorama")
  40. is_ptz = camera_type == "ptz"
  41. # image_paths 约定:[original, marked]
  42. original_path = image_paths[0] if len(image_paths) > 0 else None
  43. marked_path = image_paths[1] if len(image_paths) > 1 else original_path
  44. oss_url_original = urls.get("original") or None
  45. oss_url_marked = urls.get("marked") or oss_url_original
  46. # 读取原图尺寸,用于把人体中心坐标归一化到 0~1
  47. img_w, img_h = 0, 0
  48. if original_path and os.path.exists(original_path):
  49. try:
  50. img = cv2.imread(original_path)
  51. if img is not None:
  52. img_h, img_w = img.shape[:2]
  53. except Exception:
  54. pass
  55. ptz_position = new_info.get("ptz_position") or {}
  56. persons = []
  57. for i, det in enumerate(new_info.get("detections") or []):
  58. bbox = det.get("bbox", [0, 0, 0, 0])
  59. x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
  60. cx = (x1 + x2) / 2.0
  61. cy = (y1 + y2) / 2.0
  62. position = {
  63. "x": round(cx / img_w, 4) if img_w else round(cx, 4),
  64. "y": round(cy / img_h, 4) if img_h else round(cy, 4),
  65. }
  66. person = {
  67. "person_index": i,
  68. "position": position,
  69. "bbox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
  70. "confidence": float(det.get("confidence", 0.0)),
  71. }
  72. # PTZ 流检测时,把同一张 PTZ 图作为每个人的特写图复用
  73. # 始终包含 ptz_position 字段,第三方平台要求必须有 pan/tilt/zoom 数值
  74. if is_ptz:
  75. # 使用实际检测时的 PTZ 位置(若无实际位置,用默认值 0/0/1)
  76. ptz_pan = ptz_position.get("pan") if isinstance(ptz_position, dict) else 0
  77. ptz_tilt = ptz_position.get("tilt") if isinstance(ptz_position, dict) else 0
  78. ptz_zoom = ptz_position.get("zoom") if isinstance(ptz_position, dict) else 1
  79. person["ptz_position"] = {
  80. "pan": ptz_pan if ptz_pan is not None else 0,
  81. "tilt": ptz_tilt if ptz_tilt is not None else 0,
  82. "zoom": ptz_zoom if ptz_zoom is not None else 1,
  83. }
  84. person["ptz_bbox"] = {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
  85. person["ptz_image_saved"] = bool(marked_path and os.path.exists(marked_path))
  86. person["ptz_image_path"] = marked_path
  87. person["ptz_image_original_path"] = original_path
  88. person["ptz_oss_url"] = oss_url_marked
  89. person["ptz_oss_url_original"] = oss_url_original
  90. else:
  91. # 全景相机检测时,复用当前检测框作为 ptz_bbox
  92. person["ptz_position"] = {
  93. "pan": 0, "tilt": 0, "zoom": 1,
  94. }
  95. person["ptz_bbox"] = {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
  96. person["ptz_image_saved"] = bool(marked_path and os.path.exists(marked_path))
  97. person["ptz_image_path"] = marked_path
  98. person["ptz_image_original_path"] = original_path
  99. person["ptz_oss_url"] = oss_url_marked
  100. person["ptz_oss_url_original"] = oss_url_original
  101. persons.append(person)
  102. ptz_uploaded = is_ptz and bool(oss_url_marked)
  103. legacy = {
  104. "batch_id": new_info.get("batch_id", ""),
  105. "device_id": new_info.get("device_id", ""),
  106. "project_id": new_info.get("project_id", ""),
  107. "timestamp": normalized_ts,
  108. "datetime": datetime.fromtimestamp(normalized_ts).isoformat(),
  109. "total_persons": len(persons),
  110. "ptz_images_count": len(persons),
  111. "panorama": {
  112. "local_path": marked_path,
  113. "local_path_original": original_path,
  114. "oss_url": oss_url_marked,
  115. "oss_url_original": oss_url_original,
  116. },
  117. "persons": persons,
  118. "upload_status": {
  119. "panorama_uploaded": bool(oss_url_marked or oss_url_original),
  120. "panorama_original_uploaded": bool(oss_url_original),
  121. "all_ptz_uploaded": ptz_uploaded,
  122. },
  123. }
  124. return legacy
  125. @dataclass
  126. class BatchReport:
  127. """批次上报数据"""
  128. batch_id: str
  129. device_id: str
  130. project_id: str
  131. timestamp: float
  132. batch_info: Dict[str, Any] # batch_info.json 的完整内容
  133. local_path: Optional[str] = None # batch_info.json 本地路径
  134. class ThirdPartyPusher:
  135. """
  136. 第三方平台推送器
  137. 负责将批次信息推送到配置的第三方平台接口
  138. """
  139. def __init__(self, config: Dict[str, Any] = None):
  140. """
  141. 初始化第三方平台推送器
  142. Args:
  143. config: 第三方平台配置字典
  144. """
  145. from config import THIRD_PARTY_CONFIG, DEVICE_CONFIG
  146. self.config = config or THIRD_PARTY_CONFIG
  147. self.device_config = DEVICE_CONFIG
  148. # 功能开关
  149. self.enabled = self.config.get('enabled', False)
  150. # 平台配置
  151. self.platform_type = self.config.get('platform_type', 'custom')
  152. self.base_url = self.config.get('base_url', '')
  153. self.api_version = self.config.get('api_version', 'v1')
  154. # 认证配置
  155. self.auth_type = self.config.get('auth_type', 'none')
  156. self.api_key = self.config.get('api_key', '')
  157. self.api_secret = self.config.get('api_secret', '')
  158. self.oauth2_config = self.config.get('oauth2', {})
  159. # 接口路径
  160. self.endpoints = self.config.get('endpoints', {})
  161. self.batch_report_url = self.endpoints.get('batch_report', '/api/batch/report')
  162. self.heartbeat_url = self.endpoints.get('heartbeat', '/api/device/heartbeat')
  163. # 推送控制
  164. self.push_interval = self.config.get('push_interval', 1.0)
  165. self.retry_count = self.config.get('retry_count', 3)
  166. self.retry_delay = self.config.get('retry_delay', 2.0)
  167. self.timeout = self.config.get('timeout', 10)
  168. self.data_format = self.config.get('data_format', 'json')
  169. self.include_images = self.config.get('include_images', False)
  170. # OAuth2 Token
  171. self._access_token = None
  172. self._token_expires_at = 0
  173. # 上报队列
  174. self.report_queue = queue.Queue()
  175. # 工作线程
  176. self.running = False
  177. self.worker_thread = None
  178. # 统计
  179. self.stats = {
  180. 'total_reports': 0,
  181. 'success_reports': 0,
  182. 'failed_reports': 0,
  183. }
  184. self.stats_lock = threading.Lock()
  185. # 回调
  186. self.on_report_success: Optional[Callable] = None
  187. self.on_report_failed: Optional[Callable] = None
  188. # 最后上报时间
  189. self.last_report_time = 0
  190. if self.enabled:
  191. logger.info(f"[第三方平台] 推送器初始化完成: {self.base_url}")
  192. def start(self):
  193. """启动推送器"""
  194. if not self.enabled:
  195. logger.info("[第三方平台] 推送器未启用")
  196. return
  197. if self.running:
  198. return
  199. self.running = True
  200. self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
  201. self.worker_thread.start()
  202. logger.info("[第三方平台] 推送器已启动")
  203. def stop(self):
  204. """停止推送器"""
  205. self.running = False
  206. if self.worker_thread:
  207. self.worker_thread.join(timeout=5)
  208. logger.info("[第三方平台] 推送器已停止")
  209. def _worker_loop(self):
  210. """工作线程循环"""
  211. while self.running:
  212. try:
  213. report = self.report_queue.get(timeout=1.0)
  214. self._process_report(report)
  215. except queue.Empty:
  216. continue
  217. except Exception as e:
  218. logger.error(f"[第三方平台] 处理上报错误: {e}")
  219. def _get_auth_headers(self) -> Dict[str, str]:
  220. """获取认证请求头(当前第三方接口不需要自定义 header,返回空避免 422)"""
  221. return {}
  222. def _get_oauth2_token(self) -> Optional[str]:
  223. """获取 OAuth2 Token"""
  224. # 检查现有 token 是否有效
  225. if self._access_token and time.time() < self._token_expires_at - 60:
  226. return self._access_token
  227. # 重新获取 token
  228. token_url = self.oauth2_config.get('token_url', '')
  229. client_id = self.oauth2_config.get('client_id', '')
  230. client_secret = self.oauth2_config.get('client_secret', '')
  231. scope = self.oauth2_config.get('scope', '')
  232. if not all([token_url, client_id, client_secret]):
  233. logger.error("[第三方平台] OAuth2 配置不完整")
  234. return None
  235. try:
  236. data = {
  237. 'grant_type': 'client_credentials',
  238. 'client_id': client_id,
  239. 'client_secret': client_secret,
  240. }
  241. if scope:
  242. data['scope'] = scope
  243. response = requests.post(token_url, data=data, timeout=self.timeout)
  244. if response.status_code == 200:
  245. result = response.json()
  246. self._access_token = result.get('access_token')
  247. expires_in = result.get('expires_in', 3600)
  248. self._token_expires_at = time.time() + expires_in
  249. logger.info("[第三方平台] OAuth2 Token 获取成功")
  250. return self._access_token
  251. else:
  252. logger.error(f"[第三方平台] OAuth2 Token 获取失败: {response.status_code}")
  253. return None
  254. except Exception as e:
  255. logger.error(f"[第三方平台] OAuth2 Token 请求异常: {e}")
  256. return None
  257. def _process_report(self, report: BatchReport):
  258. """处理单个上报任务"""
  259. # 检查推送间隔
  260. current_time = time.time()
  261. time_since_last = current_time - self.last_report_time
  262. if time_since_last < self.push_interval:
  263. time.sleep(self.push_interval - time_since_last)
  264. success = self._send_batch_report(report)
  265. with self.stats_lock:
  266. self.stats['total_reports'] += 1
  267. if success:
  268. self.stats['success_reports'] += 1
  269. else:
  270. self.stats['failed_reports'] += 1
  271. self.last_report_time = time.time()
  272. # 触发回调
  273. if success and self.on_report_success:
  274. try:
  275. self.on_report_success(report)
  276. except Exception as e:
  277. logger.error(f"[第三方平台] 成功回调执行错误: {e}")
  278. elif not success and self.on_report_failed:
  279. try:
  280. self.on_report_failed(report)
  281. except Exception as e:
  282. logger.error(f"[第三方平台] 失败回调执行错误: {e}")
  283. def _send_batch_report(self, report: BatchReport) -> bool:
  284. """
  285. 发送批次上报请求
  286. Args:
  287. report: 批次上报数据
  288. Returns:
  289. bool: 是否成功
  290. """
  291. if not self.base_url:
  292. logger.error("[第三方平台] 未配置 base_url")
  293. return False
  294. url = f"{self.base_url}{self.batch_report_url}"
  295. # 构建请求数据
  296. payload = self._build_payload(report)
  297. headers = self._get_auth_headers()
  298. for attempt in range(self.retry_count):
  299. try:
  300. if self.data_format == 'json':
  301. response = requests.post(
  302. url,
  303. json=payload,
  304. headers=headers,
  305. timeout=self.timeout,
  306. verify=False
  307. )
  308. else:
  309. response = requests.post(
  310. url,
  311. data=payload,
  312. headers=headers,
  313. timeout=self.timeout,
  314. verify=False
  315. )
  316. if response.status_code == 200:
  317. result = response.json()
  318. status = result.get('status', '')
  319. message = result.get('message', '')
  320. if (result.get('code') == 200 or
  321. result.get('success') == True or
  322. status in ('pending', 'success', 'accepted') or
  323. '请求已接收' in message or
  324. message == 'accepted'):
  325. logger.info(f"[第三方平台] 批次上报成功: {report.batch_id}, task_id={result.get('task_id')}")
  326. return True
  327. else:
  328. logger.warning(f"[第三方平台] 批次上报失败: {result.get('msg', '未知错误')}")
  329. try:
  330. logger.warning(f"[第三方平台] 响应内容: {str(result)[:500]}")
  331. except Exception:
  332. pass
  333. else:
  334. logger.warning(f"[第三方平台] 批次上报失败: HTTP {response.status_code}")
  335. try:
  336. logger.warning(f"[第三方平台] 响应内容: {response.text[:500]}")
  337. except Exception:
  338. pass
  339. if attempt < self.retry_count - 1:
  340. time.sleep(self.retry_delay)
  341. except requests.exceptions.Timeout:
  342. logger.warning(f"[第三方平台] 请求超时 (尝试 {attempt + 1}/{self.retry_count})")
  343. if attempt < self.retry_count - 1:
  344. time.sleep(self.retry_delay)
  345. except Exception as e:
  346. logger.error(f"[第三方平台] 请求异常 (尝试 {attempt + 1}/{self.retry_count}): {e}")
  347. if attempt < self.retry_count - 1:
  348. time.sleep(self.retry_delay)
  349. logger.error(f"[第三方平台] 批次上报最终失败: {report.batch_id}")
  350. return False
  351. def _build_payload(self, report: BatchReport) -> Dict[str, Any]:
  352. """
  353. 构建上报请求体
  354. Args:
  355. report: 批次上报数据
  356. Returns:
  357. Dict: 请求体字典
  358. """
  359. batch_info = report.batch_info
  360. # 根据平台类型调整格式
  361. normalized_ts = _normalize_timestamp(report.timestamp)
  362. if self.platform_type == 'jtjai':
  363. # 优先取 OSS URL,否则用本地路径
  364. urls = batch_info.get('image_urls') or {}
  365. image_url = urls.get('original') or urls.get('marked') or (batch_info.get('image_paths') or [None])[0]
  366. payload = {
  367. 'createTime': datetime.fromtimestamp(normalized_ts).strftime("%Y-%m-%d %H:%M:%S"),
  368. 'addr': f"设备{report.device_id}批次上报",
  369. 'ext1': json.dumps([image_url]),
  370. 'ext2': json.dumps({
  371. 'batchId': report.batch_id,
  372. 'deviceId': report.device_id,
  373. 'projectId': report.project_id,
  374. 'totalPersons': len(batch_info.get('detections', [])),
  375. 'ptzImagesCount': 1 if batch_info.get('camera_type') == 'ptz' else 0,
  376. 'persons': batch_info.get('detections', []),
  377. 'imageUrls': urls,
  378. })
  379. }
  380. else:
  381. # custom / 其他平台:把新版 batch_info 转回老字段名后上报,
  382. # 兼容原人体分析平台对 panorama / total_persons / persons 的解析。
  383. payload = _convert_to_legacy_batch_info(batch_info)
  384. # 统一时间戳单位为秒,避免第三方解析错误
  385. payload['timestamp'] = normalized_ts
  386. return payload
  387. def report_batch(self, batch_info: Dict[str, Any], local_path: Optional[str] = None):
  388. """
  389. 上报批次信息
  390. Args:
  391. batch_info: batch_info.json 的字典内容
  392. local_path: batch_info.json 的本地文件路径(可选)
  393. """
  394. if not self.enabled:
  395. return
  396. # 接受所有相机类型的检测上报(panorama 或 ptz)
  397. # 业务流程:检测到人 → 上传 OSS → 推送第三方平台
  398. report = BatchReport(
  399. batch_id=batch_info.get('batch_id', ''),
  400. device_id=batch_info.get('device_id', ''),
  401. project_id=batch_info.get('project_id', ''),
  402. timestamp=batch_info.get('timestamp', time.time()),
  403. batch_info=batch_info,
  404. local_path=local_path
  405. )
  406. self.report_queue.put(report)
  407. def report_batch_sync(self, batch_info: Dict[str, Any],
  408. local_path: Optional[str] = None) -> bool:
  409. """
  410. 同步上报批次信息
  411. Args:
  412. batch_info: batch_info.json 的字典内容
  413. local_path: batch_info.json 的本地文件路径(可选)
  414. Returns:
  415. bool: 是否成功
  416. """
  417. if not self.enabled:
  418. return False
  419. report = BatchReport(
  420. batch_id=batch_info.get('batch_id', ''),
  421. device_id=batch_info.get('device_id', ''),
  422. project_id=batch_info.get('project_id', ''),
  423. timestamp=batch_info.get('timestamp', time.time()),
  424. batch_info=batch_info,
  425. local_path=local_path
  426. )
  427. return self._send_batch_report(report)
  428. def send_heartbeat(self) -> bool:
  429. """
  430. 发送心跳
  431. Returns:
  432. bool: 是否成功
  433. """
  434. if not self.enabled or not self.heartbeat_url:
  435. return False
  436. url = f"{self.base_url}{self.heartbeat_url}"
  437. payload = {
  438. 'deviceId': self.device_config.get('device_id', ''),
  439. 'projectId': self.device_config.get('project_id', ''),
  440. 'timestamp': time.time(),
  441. 'status': 'online',
  442. }
  443. headers = self._get_auth_headers()
  444. try:
  445. response = requests.post(
  446. url,
  447. json=payload,
  448. headers=headers,
  449. timeout=self.timeout,
  450. verify=False
  451. )
  452. if response.status_code == 200:
  453. logger.debug("[第三方平台] 心跳发送成功")
  454. return True
  455. else:
  456. logger.warning(f"[第三方平台] 心跳发送失败: HTTP {response.status_code}")
  457. return False
  458. except Exception as e:
  459. logger.error(f"[第三方平台] 心跳发送异常: {e}")
  460. return False
  461. def set_callbacks(self, on_success: Callable = None, on_failed: Callable = None):
  462. """
  463. 设置回调函数
  464. Args:
  465. on_success: 上报成功回调
  466. on_failed: 上报失败回调
  467. """
  468. self.on_report_success = on_success
  469. self.on_report_failed = on_failed
  470. def get_stats(self) -> Dict[str, int]:
  471. """获取统计信息"""
  472. with self.stats_lock:
  473. return self.stats.copy()
  474. def is_enabled(self) -> bool:
  475. """检查是否启用"""
  476. return self.enabled
  477. # 全局单例
  478. _third_party_pusher_instance: Optional[ThirdPartyPusher] = None
  479. _third_party_pusher_lock = threading.Lock()
  480. def get_third_party_pusher(config: Dict[str, Any] = None) -> ThirdPartyPusher:
  481. """
  482. 获取第三方平台推送器实例(单例模式,线程安全)
  483. Args:
  484. config: 第三方平台配置
  485. Returns:
  486. ThirdPartyPusher 实例
  487. """
  488. global _third_party_pusher_instance
  489. if _third_party_pusher_instance is None:
  490. with _third_party_pusher_lock:
  491. if _third_party_pusher_instance is None:
  492. _third_party_pusher_instance = ThirdPartyPusher(config)
  493. return _third_party_pusher_instance
  494. def reset_third_party_pusher():
  495. """重置第三方平台推送器实例"""
  496. global _third_party_pusher_instance
  497. with _third_party_pusher_lock:
  498. if _third_party_pusher_instance is not None:
  499. _third_party_pusher_instance.stop()
  500. _third_party_pusher_instance = None