#!/usr/bin/env python3 import cv2 import numpy as np import argparse import requests import json import time import os from PIL import Image, ImageDraw, ImageFont from rknnlite.api import RKNNLite import onnxruntime as ort from dataclasses import dataclass from typing import List, Tuple, Optional import subprocess os.system("taskset -p 0xff0 %d" % os.getpid()) @dataclass class Detection: class_id: int class_name: str confidence: float bbox: Tuple[int, int, int, int] def nms(dets, iou_threshold=0.45): if len(dets) == 0: return [] boxes = np.array([[d.bbox[0], d.bbox[1], d.bbox[2], d.bbox[3], d.confidence] for d in dets]) x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] scores = boxes[:, 4] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) inds = np.where(ovr <= iou_threshold)[0] order = order[inds + 1] return [dets[i] for i in keep] class BaseDetector: LABEL_MAP = {0: '安全帽', 4: '安全衣', 3: '人'} def __init__(self): self.input_size = (640, 640) self.num_classes = 5 def letterbox(self, image): h0, w0 = image.shape[:2] ih, iw = self.input_size scale = min(iw / w0, ih / h0) new_w, new_h = int(w0 * scale), int(h0 * scale) pad_w = (iw - new_w) // 2 pad_h = (ih - new_h) // 2 resized = cv2.resize(image, (new_w, new_h)) canvas = np.full((ih, iw, 3), 114, dtype=np.uint8) canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized return canvas, scale, pad_w, pad_h, h0, w0 def postprocess(self, outputs, scale, pad_w, pad_h, h0, w0, conf_threshold_map): dets = [] if not outputs: return dets output = outputs[0] if len(output.shape) == 3: output = output[0] num_boxes = output.shape[1] for i in range(num_boxes): x_center = float(output[0, i]) y_center = float(output[1, i]) width = float(output[2, i]) height = float(output[3, i]) class_probs = output[4:4+self.num_classes, i] best_class = int(np.argmax(class_probs)) confidence = float(class_probs[best_class]) if best_class not in self.LABEL_MAP: continue conf_threshold = conf_threshold_map.get(best_class, 0.5) if confidence < conf_threshold: continue # Remove padding and scale to original image x1 = int(((x_center - width / 2) - pad_w) / scale) y1 = int(((y_center - height / 2) - pad_h) / scale) x2 = int(((x_center + width / 2) - pad_w) / scale) y2 = int(((y_center + height / 2) - pad_h) / scale) x1 = max(0, min(w0, x1)) y1 = max(0, min(h0, y1)) x2 = max(0, min(w0, x2)) y2 = max(0, min(h0, y2)) det = Detection( class_id=best_class, class_name=self.LABEL_MAP[best_class], confidence=confidence, bbox=(x1, y1, x2, y2) ) dets.append(det) dets = nms(dets, iou_threshold=0.45) return dets output = outputs[0] if len(output.shape) == 3: output = output[0] # Output shape: (4+nc, num_anchors) = (9, 8400) # Row 0-3: x_center, y_center, width, height (in pixel space 0-640) # Row 4-8: class scores (already sigmoid'd, 0-1 range) # NO objectness column in YOLO v8/v11 num_boxes = output.shape[1] for i in range(num_boxes): # Coordinates are already in pixel space (0-640), NO sigmoid needed x_center = float(output[0, i]) y_center = float(output[1, i]) width = float(output[2, i]) height = float(output[3, i]) # Class scores are already sigmoid'd class_probs = output[4:4+self.num_classes, i] # Find best class and its confidence best_class = int(np.argmax(class_probs)) confidence = float(class_probs[best_class]) if best_class not in self.LABEL_MAP: continue conf_threshold = conf_threshold_map.get(best_class, 0.5) if confidence < conf_threshold: continue # Convert from center format to corner format and scale to original image x1 = int((x_center - width/2) * (w0/640)) y1 = int((y_center - height/2) * (h0/640)) x2 = int((x_center + width/2) * (w0/640)) y2 = int((y_center + height/2) * (h0/640)) x1 = max(0, x1) y1 = max(0, y1) x2 = min(w0, x2) y2 = min(h0, y2) det = Detection( class_id=best_class, class_name=self.LABEL_MAP[best_class], confidence=confidence, bbox=(x1, y1, x2, y2) ) dets.append(det) dets = nms(dets, iou_threshold=0.45) return dets def detect(self, image, conf_threshold_map): raise NotImplementedError def release(self): pass class RKNNDetector(BaseDetector): """RKNN detector - uses NHWC input format (1, H, W, C)""" def __init__(self, model_path: str): super().__init__() self.rknn = RKNNLite() ret = self.rknn.load_rknn(model_path) if ret != 0: print("[ERROR] load_rknn failed") exit(-1) ret = self.rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0_1_2) if ret != 0: print("[ERROR] init_runtime failed") exit(-1) def detect(self, image, conf_threshold_map): canvas, scale, pad_w, pad_h, h0, w0 = self.letterbox(image) # RKNN expects NHWC (1, H, W, C), RGB, normalized 0-1 img = canvas[..., ::-1].astype(np.float32) / 255.0 blob = img[None, ...] # (1, 640, 640, 3) outs = self.rknn.inference(inputs=[blob]) return self.postprocess(outs, scale, pad_w, pad_h, h0, w0, conf_threshold_map) def release(self): self.rknn.release() class ONNXDetector(BaseDetector): """ONNX detector - uses NCHW input format (1, C, H, W)""" def __init__(self, model_path: str): super().__init__() self.session = ort.InferenceSession(model_path) self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name def detect(self, image, conf_threshold_map): canvas, scale, pad_w, pad_h, h0, w0 = self.letterbox(image) # ONNX expects NCHW (1, C, H, W), RGB, normalized 0-1 img = canvas[..., ::-1].astype(np.float32) / 255.0 img = img.transpose(2, 0, 1) blob = img[None, ...] # (1, 3, 640, 640) outs = self.session.run([self.output_name], {self.input_name: blob}) return self.postprocess(outs, scale, pad_w, pad_h, h0, w0, conf_threshold_map) def create_detector(model_path: str): ext = os.path.splitext(model_path)[1].lower() if ext == '.rknn': print("使用 RKNN 模型") return RKNNDetector(model_path) elif ext == '.onnx': print("使用 ONNX 模型") return ONNXDetector(model_path) else: print("不支持的模型格式") exit(-1) def put_text_chinese(img, text, position, font_size=20, color=(255, 0, 0)): img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(img_pil) font_path = "Alibaba_PuHuiTi_2.0_35_Thin_35_Thin.ttf" try: font = ImageFont.truetype(font_path, font_size) except: try: font = ImageFont.truetype("MiSans-Thin.ttf", font_size) except: font = ImageFont.load_default() color_rgb = (color[2], color[1], color[0]) draw.text(position, text, font=font, fill=color_rgb) img_cv2 = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) x, y = position text_width = draw.textlength(text, font=font) text_height = font_size img[y:y+text_height, x:x+int(text_width)] = img_cv2[y:y+text_height, x:x+int(text_width)] def upload_image(image_path): try: import http.client import mimetypes from codecs import encode filename = os.path.basename(image_path) conn = http.client.HTTPSConnection("jtjai.device.wenhq.top", 8583) boundary = 'wL36Yn8afVp8Ag7AmP8qZ0SA4n1v9T' dataList = [] dataList.append(encode('--' + boundary)) dataList.append(encode('Content-Disposition: form-data; name=file; filename={0}'.format(filename))) fileType = mimetypes.guess_type(image_path)[0] or 'application/octet-stream' dataList.append(encode('Content-Type: {}'.format(fileType))) dataList.append(encode('')) with open(image_path, 'rb') as f: dataList.append(f.read()) dataList.append(encode('--'+boundary+'--')) dataList.append(encode('')) body = b'\r\n'.join(dataList) headers = { 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)', 'Accept': '*/*', 'Host': 'jtjai.device.wenhq.top:8583', 'Connection': 'keep-alive', 'Content-Type': 'multipart/form-data; boundary={}'.format(boundary) } conn.request("POST", "/api/resource/oss/upload", body, headers) res = conn.getresponse() data = res.read() if res.status == 200: result = json.loads(data.decode("utf-8")) if result.get('code') == 200: return result.get('data', {}).get('purl') print(f"上传图片失败: {data.decode('utf-8')}") except Exception as e: print(f"上传图片异常: {e}") return None def create_event(addr, purl): try: url = "https://jtjai.device.wenhq.top:8583/api/system/event" create_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) data = { "createTime": create_time, "addr": addr, "ext1": json.dumps([purl]), "ext2": json.dumps({"lx":"工地安全"}) } response = requests.post(url, json=data, verify=False) if response.status_code == 200: result = response.json() if result.get('code') == 200: print(f"事件创建成功: {addr}") return True print(f"创建事件失败: {response.text}") except Exception as e: print(f"创建事件异常: {e}") return False def check_safety_equipment(detections): person_detections = [] helmet_detections = [] safety_clothes_detections = [] for det in detections: x1, y1, x2, y2 = det.bbox if det.class_id == 3: person_detections.append((x1, y1, x2, y2, det.confidence)) elif det.class_id == 0: helmet_detections.append((x1, y1, x2, y2, det.confidence)) elif det.class_id == 4: safety_clothes_detections.append((x1, y1, x2, y2, det.confidence)) need_alert = False alert_addr = None for person_x1, person_y1, person_x2, person_y2, person_conf in person_detections: has_helmet = False for helmet_x1, helmet_y1, helmet_x2, helmet_y2, helmet_conf in helmet_detections: helmet_center_x = (helmet_x1 + helmet_x2) / 2 helmet_center_y = (helmet_y1 + helmet_y2) / 2 if (helmet_center_x >= person_x1 and helmet_center_x <= person_x2 and helmet_center_y >= person_y1 and helmet_center_y <= person_y2): has_helmet = True break has_safety_clothes = False for clothes_x1, clothes_y1, clothes_x2, clothes_y2, clothes_conf in safety_clothes_detections: overlap_x1 = max(person_x1, clothes_x1) overlap_y1 = max(person_y1, clothes_y1) overlap_x2 = min(person_x2, clothes_x2) overlap_y2 = min(person_y2, clothes_y2) if overlap_x1 < overlap_x2 and overlap_y1 < overlap_y2: has_safety_clothes = True break if not has_helmet or not has_safety_clothes: need_alert = True if not has_helmet and not has_safety_clothes: alert_addr = "反光衣和安全帽都没戴" elif not has_helmet: alert_addr = "未戴安全帽" else: alert_addr = "未穿反光衣" print(f"警告: {alert_addr},置信度: {person_conf:.2f}") return need_alert, alert_addr, person_detections class RTSPCapture: def __init__(self, rtsp_url, model_path, rtmp_url, fps=2): self.rtsp_url = rtsp_url self.rtmp_url = rtmp_url self.det = create_detector(model_path) self.cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG) self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) self.rtmp_pipe = None self.process_fps = fps self.conf_threshold_map = {3: 0.8, 0: 0.5, 4: 0.5} self.last_upload_time = 0 self.upload_interval = 2 def start_rtmp(self): w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = self.cap.get(cv2.CAP_PROP_FPS) or 25 command = [ 'ffmpeg', '-y', '-f', 'rawvideo', '-pix_fmt', 'bgr24', '-s', f'{w}x{h}', '-r', str(fps), '-i', '-', '-c:v', 'libx264', '-preset', 'ultrafast', '-tune', 'zerolatency', '-f', 'flv', self.rtmp_url ] self.rtmp_pipe = subprocess.Popen( command, stdin=subprocess.PIPE, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL ) def run(self): self.start_rtmp() frame_count = 0 fps = self.cap.get(cv2.CAP_PROP_FPS) or 25 frame_interval = int(round(fps / self.process_fps)) if fps > 0 else 1 print(f"帧间隔: {frame_interval} 帧") last_dets = [] last_need_alert = False last_alert_addr = None last_person_detections = [] while True: ret, frame = self.cap.read() if not ret: break frame_count += 1 if frame_count % frame_interval == 0: try: last_dets = self.det.detect(frame, self.conf_threshold_map) print(last_dets) last_need_alert, last_alert_addr, last_person_detections = check_safety_equipment(last_dets) if last_dets: print(f"[Frame {frame_count}] 检测到 {len(last_dets)} 个目标") for d in last_dets: print(f" {d.class_name}: conf={d.confidence:.2f}, box={d.bbox}") except Exception as e: print(f"检测过程中出错: {e}") for d in last_dets: x1, y1, x2, y2 = d.bbox cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2) text = f"{d.class_name}: {d.confidence:.2f}" text_y = max(15, y1 - 20) put_text_chinese(frame, text, (x1, text_y), font_size=20, color=(255, 0, 0)) if last_person_detections and last_need_alert and last_alert_addr: current_time = time.time() if current_time - self.last_upload_time >= self.upload_interval: print(f"检测到人,触发告警上传") temp_image_path = f"alert_frame_{frame_count}.jpg" cv2.imwrite(temp_image_path, frame) purl = upload_image(temp_image_path) if purl: create_event(last_alert_addr, purl) self.last_upload_time = current_time if os.path.exists(temp_image_path): os.remove(temp_image_path) if self.rtmp_pipe: try: self.rtmp_pipe.stdin.write(frame.tobytes()) except: pass cv2.imshow("RK3588 工地安全检测", frame) if cv2.waitKey(1) & 0xFF == ord('q'): break self.cap.release() self.det.release() cv2.destroyAllWindows() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--rtsp", required=True) parser.add_argument("--model", default="yolo11m_safety.rknn") parser.add_argument("--rtmp", required=True) parser.add_argument("--fps", type=int, default=2, help="每秒处理的帧数") args = parser.parse_args() cap = RTSPCapture(args.rtsp, args.model, args.rtmp, args.fps) cap.run()