Просмотр исходного кода

fix(event_pusher): use base_url for uploads, add cleanup tests, improve types

wenhongquan 3 недель назад
Родитель
Сommit
8580f07510

+ 22 - 54
dual_camera_system/event_pusher.py

@@ -10,13 +10,11 @@ import threading
 import queue
 import tempfile
 import requests
-import http.client
 import mimetypes
 from typing import Optional, Dict, Any, List
 from dataclasses import dataclass
 from datetime import datetime
 from enum import Enum
-from codecs import encode
 
 import cv2
 import numpy as np
@@ -40,7 +38,7 @@ class SafetyEvent:
     confidence: float = 0.0               # 置信度
     location: str = ""                    # 位置信息
     timestamp: float = 0.0                # 时间戳
-    extra: Dict[str, Any] = None          # 额外信息
+    extra: Optional[Dict[str, Any]] = None  # 额外信息
     
     def __post_init__(self):
         if self.timestamp == 0.0:
@@ -68,7 +66,7 @@ class EventPusher:
     负责将安全事件推送到业务平台
     """
     
-    def __init__(self, config: Dict[str, Any] = None):
+    def __init__(self, config: Optional[Dict[str, Any]] = None):
         """
         初始化事件推送器
         
@@ -219,7 +217,7 @@ class EventPusher:
                 except Exception:
                     pass
     
-    def push_tracking_capture(self, batch_time: float, captures: List[dict]):
+    def push_tracking_capture(self, batch_time: float, captures: List[Dict[str, Any]]) -> Optional[requests.Response]:
         """
         推送一轮多目标跟踪抓拍事件
         
@@ -242,7 +240,7 @@ class EventPusher:
         url = f"{self.base_url}{self.event_url}"
         return self._post(url, payload)
     
-    def _post(self, url: str, json_data: dict):
+    def _post(self, url: str, json_data: Dict[str, Any]) -> Optional[requests.Response]:
         """
         发送 POST 请求
         
@@ -298,8 +296,8 @@ class EventPusher:
                 if os.path.exists(event.image_path):
                     try:
                         os.remove(event.image_path)
-                    except:
-                        pass
+                    except Exception as e:
+                        print(f"清理临时文件失败: {e}")
             
             # 创建事件
             success = self._create_event(event)
@@ -332,64 +330,34 @@ class EventPusher:
         if not os.path.exists(image_path):
             return None
         
+        filename = os.path.basename(image_path)
+        content_type = mimetypes.guess_type(image_path)[0] or 'image/jpeg'
+        url = f"{self.base_url}{self.upload_url}"
+        
         for attempt in range(self.retry_count):
-            conn = None
             try:
-                filename = os.path.basename(image_path)
-
-                # 创建连接
-                if self.use_https:
-                    conn = http.client.HTTPSConnection(self.api_host, self.api_port)
-                else:
-                    conn = http.client.HTTPConnection(self.api_host, self.api_port)
-                
-                # 准备 multipart/form-data
-                boundary = f'wL36Yn8afVp8Ag7AmP8qZ0SA4n1v9T{int(time.time())}'
-                dataList = []
-                dataList.append(encode(f'--{boundary}'))
-                dataList.append(encode(f'Content-Disposition: form-data; name=file; filename={filename}'))
-                
-                # 文件类型
-                fileType = mimetypes.guess_type(image_path)[0] or 'image/jpeg'
-                dataList.append(encode(f'Content-Type: {fileType}'))
-                dataList.append(encode(''))
-                
-                # 读取文件
                 with open(image_path, 'rb') as f:
-                    dataList.append(f.read())
+                    files = {'file': (filename, f, content_type)}
+                    response = requests.post(
+                        url,
+                        files=files,
+                        headers={'User-Agent': 'SafetySystem/1.0'},
+                        verify=False,
+                        timeout=10
+                    )
                 
-                dataList.append(encode(f'--{boundary}--'))
-                dataList.append(encode(''))
-                
-                body = b'\r\n'.join(dataList)
-                
-                headers = {
-                    'User-Agent': 'SafetySystem/1.0',
-                    'Accept': '*/*',
-                    'Host': f'{self.api_host}:{self.api_port}',
-                    'Connection': 'keep-alive',
-                    'Content-Type': f'multipart/form-data; boundary={boundary}'
-                }
-                
-                conn.request("POST", self.upload_url, body, headers)
-                res = conn.getresponse()
-                data = res.read()
-
-                if res.status == 200:
-                    result = json.loads(data.decode("utf-8"))
+                if response.status_code == 200:
+                    result = response.json()
                     if result.get('code') == 200:
                         return result.get('data', {}).get('purl')
                     else:
                         print(f"上传失败: {result.get('msg', '未知错误')}")
                 else:
-                    print(f"上传失败: HTTP {res.status}")
+                    print(f"上传失败: HTTP {response.status_code}")
             except Exception as e:
                 print(f"上传异常 (尝试 {attempt + 1}/{self.retry_count}): {e}")
                 if attempt < self.retry_count - 1:
                     time.sleep(self.retry_delay)
-            finally:
-                if conn:
-                    conn.close()
         
         return None
     
@@ -458,7 +426,7 @@ class EventListener:
     监听业务平台的指令(如语音播放指令)
     """
     
-    def __init__(self, config: Dict[str, Any] = None):
+    def __init__(self, config: Optional[Dict[str, Any]] = None):
         """
         初始化事件监听器
         

+ 41 - 0
dual_camera_system/tests/test_event_pusher_upload.py

@@ -48,5 +48,46 @@ def test_push_tracking_capture(monkeypatch):
         captures=[{"track_id": 1, "ptz_image_url": "url1"}]
     )
 
+    assert captured["url"] == "http://localhost/api/system/event"
     assert captured["json"]["eventType"] == "TRACKING_CAPTURE"
     assert captured["json"]["data"]["captureCount"] == 1
+
+
+def test_upload_numpy_image_cleans_temp_file_on_success(monkeypatch):
+    config = {"device_id": "test-device", "base_url": "http://localhost"}
+    pusher = EventPusher(config)
+    temp_paths = []
+
+    def fake_upload(path):
+        temp_paths.append(path)
+        assert os.path.exists(path)
+        return "http://example.com/image.jpg"
+
+    monkeypatch.setattr(pusher, "_upload_image", fake_upload)
+
+    img = np.zeros((100, 100, 3), dtype=np.uint8)
+    url = pusher.upload_numpy_image(img)
+
+    assert url == "http://example.com/image.jpg"
+    assert len(temp_paths) == 1
+    assert not os.path.exists(temp_paths[0])
+
+
+def test_upload_numpy_image_cleans_temp_file_on_upload_failure(monkeypatch):
+    config = {"device_id": "test-device", "base_url": "http://localhost"}
+    pusher = EventPusher(config)
+    temp_paths = []
+
+    def fake_upload(path):
+        temp_paths.append(path)
+        assert os.path.exists(path)
+        raise RuntimeError("upload failed")
+
+    monkeypatch.setattr(pusher, "_upload_image", fake_upload)
+
+    img = np.zeros((100, 100, 3), dtype=np.uint8)
+    url = pusher.upload_numpy_image(img)
+
+    assert url is None
+    assert len(temp_paths) == 1
+    assert not os.path.exists(temp_paths[0])