rtsp_person_detection.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. #!/usr/bin/env python3
  2. import cv2
  3. import numpy as np
  4. import argparse
  5. import requests
  6. import json
  7. import time
  8. import os
  9. from PIL import Image, ImageDraw, ImageFont
  10. from rknnlite.api import RKNNLite
  11. import onnxruntime as ort
  12. from dataclasses import dataclass
  13. from typing import List, Tuple, Optional
  14. import subprocess
  15. os.system("taskset -p 0xff0 %d" % os.getpid())
  16. @dataclass
  17. class Detection:
  18. class_id: int
  19. class_name: str
  20. confidence: float
  21. bbox: Tuple[int, int, int, int]
  22. def nms(dets, iou_threshold=0.45):
  23. if len(dets) == 0:
  24. return []
  25. boxes = np.array([[d.bbox[0], d.bbox[1], d.bbox[2], d.bbox[3], d.confidence] for d in dets])
  26. x1 = boxes[:, 0]
  27. y1 = boxes[:, 1]
  28. x2 = boxes[:, 2]
  29. y2 = boxes[:, 3]
  30. scores = boxes[:, 4]
  31. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  32. order = scores.argsort()[::-1]
  33. keep = []
  34. while order.size > 0:
  35. i = order[0]
  36. keep.append(i)
  37. xx1 = np.maximum(x1[i], x1[order[1:]])
  38. yy1 = np.maximum(y1[i], y1[order[1:]])
  39. xx2 = np.minimum(x2[i], x2[order[1:]])
  40. yy2 = np.minimum(y2[i], y2[order[1:]])
  41. w = np.maximum(0.0, xx2 - xx1 + 1)
  42. h = np.maximum(0.0, yy2 - yy1 + 1)
  43. inter = w * h
  44. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  45. inds = np.where(ovr <= iou_threshold)[0]
  46. order = order[inds + 1]
  47. return [dets[i] for i in keep]
  48. class BaseDetector:
  49. LABEL_MAP = {0: '安全帽', 4: '安全衣', 3: '人'}
  50. def __init__(self):
  51. self.input_size = (640, 640)
  52. self.num_classes = 5
  53. def letterbox(self, image):
  54. h0, w0 = image.shape[:2]
  55. ih, iw = self.input_size
  56. scale = min(iw / w0, ih / h0)
  57. new_w, new_h = int(w0 * scale), int(h0 * scale)
  58. pad_w = (iw - new_w) // 2
  59. pad_h = (ih - new_h) // 2
  60. resized = cv2.resize(image, (new_w, new_h))
  61. canvas = np.full((ih, iw, 3), 114, dtype=np.uint8)
  62. canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
  63. return canvas, scale, pad_w, pad_h, h0, w0
  64. def postprocess(self, outputs, scale, pad_w, pad_h, h0, w0, conf_threshold_map):
  65. dets = []
  66. if not outputs:
  67. return dets
  68. output = outputs[0]
  69. if len(output.shape) == 3:
  70. output = output[0]
  71. num_boxes = output.shape[1]
  72. for i in range(num_boxes):
  73. x_center = float(output[0, i])
  74. y_center = float(output[1, i])
  75. width = float(output[2, i])
  76. height = float(output[3, i])
  77. class_probs = output[4:4+self.num_classes, i]
  78. best_class = int(np.argmax(class_probs))
  79. confidence = float(class_probs[best_class])
  80. if best_class not in self.LABEL_MAP:
  81. continue
  82. conf_threshold = conf_threshold_map.get(best_class, 0.5)
  83. if confidence < conf_threshold:
  84. continue
  85. # Remove padding and scale to original image
  86. x1 = int(((x_center - width / 2) - pad_w) / scale)
  87. y1 = int(((y_center - height / 2) - pad_h) / scale)
  88. x2 = int(((x_center + width / 2) - pad_w) / scale)
  89. y2 = int(((y_center + height / 2) - pad_h) / scale)
  90. x1 = max(0, min(w0, x1))
  91. y1 = max(0, min(h0, y1))
  92. x2 = max(0, min(w0, x2))
  93. y2 = max(0, min(h0, y2))
  94. det = Detection(
  95. class_id=best_class,
  96. class_name=self.LABEL_MAP[best_class],
  97. confidence=confidence,
  98. bbox=(x1, y1, x2, y2)
  99. )
  100. dets.append(det)
  101. dets = nms(dets, iou_threshold=0.45)
  102. return dets
  103. output = outputs[0]
  104. if len(output.shape) == 3:
  105. output = output[0]
  106. # Output shape: (4+nc, num_anchors) = (9, 8400)
  107. # Row 0-3: x_center, y_center, width, height (in pixel space 0-640)
  108. # Row 4-8: class scores (already sigmoid'd, 0-1 range)
  109. # NO objectness column in YOLO v8/v11
  110. num_boxes = output.shape[1]
  111. for i in range(num_boxes):
  112. # Coordinates are already in pixel space (0-640), NO sigmoid needed
  113. x_center = float(output[0, i])
  114. y_center = float(output[1, i])
  115. width = float(output[2, i])
  116. height = float(output[3, i])
  117. # Class scores are already sigmoid'd
  118. class_probs = output[4:4+self.num_classes, i]
  119. # Find best class and its confidence
  120. best_class = int(np.argmax(class_probs))
  121. confidence = float(class_probs[best_class])
  122. if best_class not in self.LABEL_MAP:
  123. continue
  124. conf_threshold = conf_threshold_map.get(best_class, 0.5)
  125. if confidence < conf_threshold:
  126. continue
  127. # Convert from center format to corner format and scale to original image
  128. x1 = int((x_center - width/2) * (w0/640))
  129. y1 = int((y_center - height/2) * (h0/640))
  130. x2 = int((x_center + width/2) * (w0/640))
  131. y2 = int((y_center + height/2) * (h0/640))
  132. x1 = max(0, x1)
  133. y1 = max(0, y1)
  134. x2 = min(w0, x2)
  135. y2 = min(h0, y2)
  136. det = Detection(
  137. class_id=best_class,
  138. class_name=self.LABEL_MAP[best_class],
  139. confidence=confidence,
  140. bbox=(x1, y1, x2, y2)
  141. )
  142. dets.append(det)
  143. dets = nms(dets, iou_threshold=0.45)
  144. return dets
  145. def detect(self, image, conf_threshold_map):
  146. raise NotImplementedError
  147. def release(self):
  148. pass
  149. class RKNNDetector(BaseDetector):
  150. """RKNN detector - uses NHWC input format (1, H, W, C)"""
  151. def __init__(self, model_path: str):
  152. super().__init__()
  153. self.rknn = RKNNLite()
  154. ret = self.rknn.load_rknn(model_path)
  155. if ret != 0:
  156. print("[ERROR] load_rknn failed")
  157. exit(-1)
  158. ret = self.rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0_1_2)
  159. if ret != 0:
  160. print("[ERROR] init_runtime failed")
  161. exit(-1)
  162. def detect(self, image, conf_threshold_map):
  163. canvas, scale, pad_w, pad_h, h0, w0 = self.letterbox(image)
  164. # RKNN expects NHWC (1, H, W, C), RGB, normalized 0-1
  165. img = canvas[..., ::-1].astype(np.float32) / 255.0
  166. blob = img[None, ...] # (1, 640, 640, 3)
  167. outs = self.rknn.inference(inputs=[blob])
  168. return self.postprocess(outs, scale, pad_w, pad_h, h0, w0, conf_threshold_map)
  169. def release(self):
  170. self.rknn.release()
  171. class ONNXDetector(BaseDetector):
  172. """ONNX detector - uses NCHW input format (1, C, H, W)"""
  173. def __init__(self, model_path: str):
  174. super().__init__()
  175. self.session = ort.InferenceSession(model_path)
  176. self.input_name = self.session.get_inputs()[0].name
  177. self.output_name = self.session.get_outputs()[0].name
  178. def detect(self, image, conf_threshold_map):
  179. canvas, scale, pad_w, pad_h, h0, w0 = self.letterbox(image)
  180. # ONNX expects NCHW (1, C, H, W), RGB, normalized 0-1
  181. img = canvas[..., ::-1].astype(np.float32) / 255.0
  182. img = img.transpose(2, 0, 1)
  183. blob = img[None, ...] # (1, 3, 640, 640)
  184. outs = self.session.run([self.output_name], {self.input_name: blob})
  185. return self.postprocess(outs, scale, pad_w, pad_h, h0, w0, conf_threshold_map)
  186. def create_detector(model_path: str):
  187. ext = os.path.splitext(model_path)[1].lower()
  188. if ext == '.rknn':
  189. print("使用 RKNN 模型")
  190. return RKNNDetector(model_path)
  191. elif ext == '.onnx':
  192. print("使用 ONNX 模型")
  193. return ONNXDetector(model_path)
  194. else:
  195. print("不支持的模型格式")
  196. exit(-1)
  197. def put_text_chinese(img, text, position, font_size=20, color=(255, 0, 0)):
  198. img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
  199. draw = ImageDraw.Draw(img_pil)
  200. font_path = "Alibaba_PuHuiTi_2.0_35_Thin_35_Thin.ttf"
  201. try:
  202. font = ImageFont.truetype(font_path, font_size)
  203. except:
  204. try:
  205. font = ImageFont.truetype("MiSans-Thin.ttf", font_size)
  206. except:
  207. font = ImageFont.load_default()
  208. color_rgb = (color[2], color[1], color[0])
  209. draw.text(position, text, font=font, fill=color_rgb)
  210. img_cv2 = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
  211. x, y = position
  212. text_width = draw.textlength(text, font=font)
  213. text_height = font_size
  214. img[y:y+text_height, x:x+int(text_width)] = img_cv2[y:y+text_height, x:x+int(text_width)]
  215. def upload_image(image_path):
  216. try:
  217. import http.client
  218. import mimetypes
  219. from codecs import encode
  220. filename = os.path.basename(image_path)
  221. conn = http.client.HTTPSConnection("jtjai.device.wenhq.top", 8583)
  222. boundary = 'wL36Yn8afVp8Ag7AmP8qZ0SA4n1v9T'
  223. dataList = []
  224. dataList.append(encode('--' + boundary))
  225. dataList.append(encode('Content-Disposition: form-data; name=file; filename={0}'.format(filename)))
  226. fileType = mimetypes.guess_type(image_path)[0] or 'application/octet-stream'
  227. dataList.append(encode('Content-Type: {}'.format(fileType)))
  228. dataList.append(encode(''))
  229. with open(image_path, 'rb') as f:
  230. dataList.append(f.read())
  231. dataList.append(encode('--'+boundary+'--'))
  232. dataList.append(encode(''))
  233. body = b'\r\n'.join(dataList)
  234. headers = {
  235. 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
  236. 'Accept': '*/*',
  237. 'Host': 'jtjai.device.wenhq.top:8583',
  238. 'Connection': 'keep-alive',
  239. 'Content-Type': 'multipart/form-data; boundary={}'.format(boundary)
  240. }
  241. conn.request("POST", "/api/resource/oss/upload", body, headers)
  242. res = conn.getresponse()
  243. data = res.read()
  244. if res.status == 200:
  245. result = json.loads(data.decode("utf-8"))
  246. if result.get('code') == 200:
  247. return result.get('data', {}).get('purl')
  248. print(f"上传图片失败: {data.decode('utf-8')}")
  249. except Exception as e:
  250. print(f"上传图片异常: {e}")
  251. return None
  252. def create_event(addr, purl):
  253. try:
  254. url = "https://jtjai.device.wenhq.top:8583/api/system/event"
  255. create_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
  256. data = {
  257. "createTime": create_time,
  258. "addr": addr,
  259. "ext1": json.dumps([purl]),
  260. "ext2": json.dumps({"lx":"工地安全"})
  261. }
  262. response = requests.post(url, json=data, verify=False)
  263. if response.status_code == 200:
  264. result = response.json()
  265. if result.get('code') == 200:
  266. print(f"事件创建成功: {addr}")
  267. return True
  268. print(f"创建事件失败: {response.text}")
  269. except Exception as e:
  270. print(f"创建事件异常: {e}")
  271. return False
  272. def check_safety_equipment(detections):
  273. person_detections = []
  274. helmet_detections = []
  275. safety_clothes_detections = []
  276. for det in detections:
  277. x1, y1, x2, y2 = det.bbox
  278. if det.class_id == 3:
  279. person_detections.append((x1, y1, x2, y2, det.confidence))
  280. elif det.class_id == 0:
  281. helmet_detections.append((x1, y1, x2, y2, det.confidence))
  282. elif det.class_id == 4:
  283. safety_clothes_detections.append((x1, y1, x2, y2, det.confidence))
  284. need_alert = False
  285. alert_addr = None
  286. for person_x1, person_y1, person_x2, person_y2, person_conf in person_detections:
  287. has_helmet = False
  288. for helmet_x1, helmet_y1, helmet_x2, helmet_y2, helmet_conf in helmet_detections:
  289. helmet_center_x = (helmet_x1 + helmet_x2) / 2
  290. helmet_center_y = (helmet_y1 + helmet_y2) / 2
  291. if (helmet_center_x >= person_x1 and helmet_center_x <= person_x2 and
  292. helmet_center_y >= person_y1 and helmet_center_y <= person_y2):
  293. has_helmet = True
  294. break
  295. has_safety_clothes = False
  296. for clothes_x1, clothes_y1, clothes_x2, clothes_y2, clothes_conf in safety_clothes_detections:
  297. overlap_x1 = max(person_x1, clothes_x1)
  298. overlap_y1 = max(person_y1, clothes_y1)
  299. overlap_x2 = min(person_x2, clothes_x2)
  300. overlap_y2 = min(person_y2, clothes_y2)
  301. if overlap_x1 < overlap_x2 and overlap_y1 < overlap_y2:
  302. has_safety_clothes = True
  303. break
  304. if not has_helmet or not has_safety_clothes:
  305. need_alert = True
  306. if not has_helmet and not has_safety_clothes:
  307. alert_addr = "反光衣和安全帽都没戴"
  308. elif not has_helmet:
  309. alert_addr = "未戴安全帽"
  310. else:
  311. alert_addr = "未穿反光衣"
  312. print(f"警告: {alert_addr},置信度: {person_conf:.2f}")
  313. return need_alert, alert_addr, person_detections
  314. class RTSPCapture:
  315. def __init__(self, rtsp_url, model_path, rtmp_url, fps=2):
  316. self.rtsp_url = rtsp_url
  317. self.rtmp_url = rtmp_url
  318. self.det = create_detector(model_path)
  319. self.cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
  320. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
  321. self.rtmp_pipe = None
  322. self.process_fps = fps
  323. self.conf_threshold_map = {3: 0.8, 0: 0.5, 4: 0.5}
  324. self.last_upload_time = 0
  325. self.upload_interval = 2
  326. def start_rtmp(self):
  327. w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  328. h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  329. fps = self.cap.get(cv2.CAP_PROP_FPS) or 25
  330. command = [
  331. 'ffmpeg',
  332. '-y',
  333. '-f', 'rawvideo',
  334. '-pix_fmt', 'bgr24',
  335. '-s', f'{w}x{h}',
  336. '-r', str(fps),
  337. '-i', '-',
  338. '-c:v', 'libx264',
  339. '-preset', 'ultrafast',
  340. '-tune', 'zerolatency',
  341. '-f', 'flv',
  342. self.rtmp_url
  343. ]
  344. self.rtmp_pipe = subprocess.Popen(
  345. command,
  346. stdin=subprocess.PIPE,
  347. stdout=subprocess.DEVNULL,
  348. stderr=subprocess.DEVNULL
  349. )
  350. def run(self):
  351. self.start_rtmp()
  352. frame_count = 0
  353. fps = self.cap.get(cv2.CAP_PROP_FPS) or 25
  354. frame_interval = int(round(fps / self.process_fps)) if fps > 0 else 1
  355. print(f"帧间隔: {frame_interval} 帧")
  356. last_dets = []
  357. last_need_alert = False
  358. last_alert_addr = None
  359. last_person_detections = []
  360. while True:
  361. ret, frame = self.cap.read()
  362. if not ret:
  363. break
  364. frame_count += 1
  365. if frame_count % frame_interval == 0:
  366. try:
  367. last_dets = self.det.detect(frame, self.conf_threshold_map)
  368. print(last_dets)
  369. last_need_alert, last_alert_addr, last_person_detections = check_safety_equipment(last_dets)
  370. if last_dets:
  371. print(f"[Frame {frame_count}] 检测到 {len(last_dets)} 个目标")
  372. for d in last_dets:
  373. print(f" {d.class_name}: conf={d.confidence:.2f}, box={d.bbox}")
  374. except Exception as e:
  375. print(f"检测过程中出错: {e}")
  376. for d in last_dets:
  377. x1, y1, x2, y2 = d.bbox
  378. cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
  379. text = f"{d.class_name}: {d.confidence:.2f}"
  380. text_y = max(15, y1 - 20)
  381. put_text_chinese(frame, text, (x1, text_y), font_size=20, color=(255, 0, 0))
  382. if last_person_detections and last_need_alert and last_alert_addr:
  383. current_time = time.time()
  384. if current_time - self.last_upload_time >= self.upload_interval:
  385. print(f"检测到人,触发告警上传")
  386. temp_image_path = f"alert_frame_{frame_count}.jpg"
  387. cv2.imwrite(temp_image_path, frame)
  388. purl = upload_image(temp_image_path)
  389. if purl:
  390. create_event(last_alert_addr, purl)
  391. self.last_upload_time = current_time
  392. if os.path.exists(temp_image_path):
  393. os.remove(temp_image_path)
  394. if self.rtmp_pipe:
  395. try:
  396. self.rtmp_pipe.stdin.write(frame.tobytes())
  397. except:
  398. pass
  399. cv2.imshow("RK3588 工地安全检测", frame)
  400. if cv2.waitKey(1) & 0xFF == ord('q'):
  401. break
  402. self.cap.release()
  403. self.det.release()
  404. cv2.destroyAllWindows()
  405. if __name__ == "__main__":
  406. import argparse
  407. parser = argparse.ArgumentParser()
  408. parser.add_argument("--rtsp", required=True)
  409. parser.add_argument("--model", default="yolo11m_safety.rknn")
  410. parser.add_argument("--rtmp", required=True)
  411. parser.add_argument("--fps", type=int, default=2, help="每秒处理的帧数")
  412. args = parser.parse_args()
  413. cap = RTSPCapture(args.rtsp, args.model, args.rtmp, args.fps)
  414. cap.run()