| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522 |
- #!/usr/bin/env python3
- import cv2
- import numpy as np
- import argparse
- import requests
- import json
- import time
- import os
- from PIL import Image, ImageDraw, ImageFont
- from rknnlite.api import RKNNLite
- import onnxruntime as ort
- from dataclasses import dataclass
- from typing import List, Tuple, Optional
- import subprocess
- os.system("taskset -p 0xff0 %d" % os.getpid())
- @dataclass
- class Detection:
- class_id: int
- class_name: str
- confidence: float
- bbox: Tuple[int, int, int, int]
- def nms(dets, iou_threshold=0.45):
- if len(dets) == 0:
- return []
-
- boxes = np.array([[d.bbox[0], d.bbox[1], d.bbox[2], d.bbox[3], d.confidence] for d in dets])
- x1 = boxes[:, 0]
- y1 = boxes[:, 1]
- x2 = boxes[:, 2]
- y2 = boxes[:, 3]
- scores = boxes[:, 4]
-
- areas = (x2 - x1 + 1) * (y2 - y1 + 1)
- order = scores.argsort()[::-1]
-
- keep = []
- while order.size > 0:
- i = order[0]
- keep.append(i)
-
- xx1 = np.maximum(x1[i], x1[order[1:]])
- yy1 = np.maximum(y1[i], y1[order[1:]])
- xx2 = np.minimum(x2[i], x2[order[1:]])
- yy2 = np.minimum(y2[i], y2[order[1:]])
-
- w = np.maximum(0.0, xx2 - xx1 + 1)
- h = np.maximum(0.0, yy2 - yy1 + 1)
- inter = w * h
- ovr = inter / (areas[i] + areas[order[1:]] - inter)
-
- inds = np.where(ovr <= iou_threshold)[0]
- order = order[inds + 1]
-
- return [dets[i] for i in keep]
- class BaseDetector:
- LABEL_MAP = {0: '安全帽', 4: '安全衣', 3: '人'}
-
- def __init__(self):
- self.input_size = (640, 640)
- self.num_classes = 5
-
- def letterbox(self, image):
- h0, w0 = image.shape[:2]
- ih, iw = self.input_size
- scale = min(iw / w0, ih / h0)
- new_w, new_h = int(w0 * scale), int(h0 * scale)
- pad_w = (iw - new_w) // 2
- pad_h = (ih - new_h) // 2
- resized = cv2.resize(image, (new_w, new_h))
- canvas = np.full((ih, iw, 3), 114, dtype=np.uint8)
- canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
- return canvas, scale, pad_w, pad_h, h0, w0
-
- def postprocess(self, outputs, scale, pad_w, pad_h, h0, w0, conf_threshold_map):
- dets = []
-
- if not outputs:
- return dets
-
- output = outputs[0]
-
- if len(output.shape) == 3:
- output = output[0]
-
- num_boxes = output.shape[1]
-
- for i in range(num_boxes):
- x_center = float(output[0, i])
- y_center = float(output[1, i])
- width = float(output[2, i])
- height = float(output[3, i])
-
- class_probs = output[4:4+self.num_classes, i]
- best_class = int(np.argmax(class_probs))
- confidence = float(class_probs[best_class])
-
- if best_class not in self.LABEL_MAP:
- continue
-
- conf_threshold = conf_threshold_map.get(best_class, 0.5)
-
- if confidence < conf_threshold:
- continue
-
- # Remove padding and scale to original image
- x1 = int(((x_center - width / 2) - pad_w) / scale)
- y1 = int(((y_center - height / 2) - pad_h) / scale)
- x2 = int(((x_center + width / 2) - pad_w) / scale)
- y2 = int(((y_center + height / 2) - pad_h) / scale)
-
- x1 = max(0, min(w0, x1))
- y1 = max(0, min(h0, y1))
- x2 = max(0, min(w0, x2))
- y2 = max(0, min(h0, y2))
-
- det = Detection(
- class_id=best_class,
- class_name=self.LABEL_MAP[best_class],
- confidence=confidence,
- bbox=(x1, y1, x2, y2)
- )
- dets.append(det)
-
- dets = nms(dets, iou_threshold=0.45)
- return dets
-
- output = outputs[0]
-
- if len(output.shape) == 3:
- output = output[0]
-
- # Output shape: (4+nc, num_anchors) = (9, 8400)
- # Row 0-3: x_center, y_center, width, height (in pixel space 0-640)
- # Row 4-8: class scores (already sigmoid'd, 0-1 range)
- # NO objectness column in YOLO v8/v11
- num_boxes = output.shape[1]
-
- for i in range(num_boxes):
- # Coordinates are already in pixel space (0-640), NO sigmoid needed
- x_center = float(output[0, i])
- y_center = float(output[1, i])
- width = float(output[2, i])
- height = float(output[3, i])
-
- # Class scores are already sigmoid'd
- class_probs = output[4:4+self.num_classes, i]
-
- # Find best class and its confidence
- best_class = int(np.argmax(class_probs))
- confidence = float(class_probs[best_class])
-
- if best_class not in self.LABEL_MAP:
- continue
-
- conf_threshold = conf_threshold_map.get(best_class, 0.5)
-
- if confidence < conf_threshold:
- continue
-
- # Convert from center format to corner format and scale to original image
- x1 = int((x_center - width/2) * (w0/640))
- y1 = int((y_center - height/2) * (h0/640))
- x2 = int((x_center + width/2) * (w0/640))
- y2 = int((y_center + height/2) * (h0/640))
-
- x1 = max(0, x1)
- y1 = max(0, y1)
- x2 = min(w0, x2)
- y2 = min(h0, y2)
-
- det = Detection(
- class_id=best_class,
- class_name=self.LABEL_MAP[best_class],
- confidence=confidence,
- bbox=(x1, y1, x2, y2)
- )
- dets.append(det)
-
- dets = nms(dets, iou_threshold=0.45)
- return dets
-
- def detect(self, image, conf_threshold_map):
- raise NotImplementedError
-
- def release(self):
- pass
- class RKNNDetector(BaseDetector):
- """RKNN detector - uses NHWC input format (1, H, W, C)"""
- def __init__(self, model_path: str):
- super().__init__()
- self.rknn = RKNNLite()
-
- ret = self.rknn.load_rknn(model_path)
- if ret != 0:
- print("[ERROR] load_rknn failed")
- exit(-1)
-
- ret = self.rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0_1_2)
- if ret != 0:
- print("[ERROR] init_runtime failed")
- exit(-1)
-
- def detect(self, image, conf_threshold_map):
- canvas, scale, pad_w, pad_h, h0, w0 = self.letterbox(image)
- # RKNN expects NHWC (1, H, W, C), RGB, normalized 0-1
- img = canvas[..., ::-1].astype(np.float32) / 255.0
- blob = img[None, ...] # (1, 640, 640, 3)
- outs = self.rknn.inference(inputs=[blob])
- return self.postprocess(outs, scale, pad_w, pad_h, h0, w0, conf_threshold_map)
-
- def release(self):
- self.rknn.release()
- class ONNXDetector(BaseDetector):
- """ONNX detector - uses NCHW input format (1, C, H, W)"""
- def __init__(self, model_path: str):
- super().__init__()
- self.session = ort.InferenceSession(model_path)
- self.input_name = self.session.get_inputs()[0].name
- self.output_name = self.session.get_outputs()[0].name
-
- def detect(self, image, conf_threshold_map):
- canvas, scale, pad_w, pad_h, h0, w0 = self.letterbox(image)
- # ONNX expects NCHW (1, C, H, W), RGB, normalized 0-1
- img = canvas[..., ::-1].astype(np.float32) / 255.0
- img = img.transpose(2, 0, 1)
- blob = img[None, ...] # (1, 3, 640, 640)
- outs = self.session.run([self.output_name], {self.input_name: blob})
- return self.postprocess(outs, scale, pad_w, pad_h, h0, w0, conf_threshold_map)
- def create_detector(model_path: str):
- ext = os.path.splitext(model_path)[1].lower()
- if ext == '.rknn':
- print("使用 RKNN 模型")
- return RKNNDetector(model_path)
- elif ext == '.onnx':
- print("使用 ONNX 模型")
- return ONNXDetector(model_path)
- else:
- print("不支持的模型格式")
- exit(-1)
- def put_text_chinese(img, text, position, font_size=20, color=(255, 0, 0)):
- img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
- draw = ImageDraw.Draw(img_pil)
- font_path = "Alibaba_PuHuiTi_2.0_35_Thin_35_Thin.ttf"
- try:
- font = ImageFont.truetype(font_path, font_size)
- except:
- try:
- font = ImageFont.truetype("MiSans-Thin.ttf", font_size)
- except:
- font = ImageFont.load_default()
-
- color_rgb = (color[2], color[1], color[0])
- draw.text(position, text, font=font, fill=color_rgb)
-
- img_cv2 = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
-
- x, y = position
- text_width = draw.textlength(text, font=font)
- text_height = font_size
- img[y:y+text_height, x:x+int(text_width)] = img_cv2[y:y+text_height, x:x+int(text_width)]
- def upload_image(image_path):
- try:
- import http.client
- import mimetypes
- from codecs import encode
-
- filename = os.path.basename(image_path)
-
- conn = http.client.HTTPSConnection("jtjai.device.wenhq.top", 8583)
-
- boundary = 'wL36Yn8afVp8Ag7AmP8qZ0SA4n1v9T'
- dataList = []
- dataList.append(encode('--' + boundary))
- dataList.append(encode('Content-Disposition: form-data; name=file; filename={0}'.format(filename)))
-
- fileType = mimetypes.guess_type(image_path)[0] or 'application/octet-stream'
- dataList.append(encode('Content-Type: {}'.format(fileType)))
- dataList.append(encode(''))
-
- with open(image_path, 'rb') as f:
- dataList.append(f.read())
-
- dataList.append(encode('--'+boundary+'--'))
- dataList.append(encode(''))
-
- body = b'\r\n'.join(dataList)
-
- headers = {
- 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
- 'Accept': '*/*',
- 'Host': 'jtjai.device.wenhq.top:8583',
- 'Connection': 'keep-alive',
- 'Content-Type': 'multipart/form-data; boundary={}'.format(boundary)
- }
-
- conn.request("POST", "/api/resource/oss/upload", body, headers)
- res = conn.getresponse()
- data = res.read()
-
- if res.status == 200:
- result = json.loads(data.decode("utf-8"))
- if result.get('code') == 200:
- return result.get('data', {}).get('purl')
- print(f"上传图片失败: {data.decode('utf-8')}")
- except Exception as e:
- print(f"上传图片异常: {e}")
- return None
- def create_event(addr, purl):
- try:
- url = "https://jtjai.device.wenhq.top:8583/api/system/event"
- create_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
- data = {
- "createTime": create_time,
- "addr": addr,
- "ext1": json.dumps([purl]),
- "ext2": json.dumps({"lx":"工地安全"})
- }
- response = requests.post(url, json=data, verify=False)
- if response.status_code == 200:
- result = response.json()
- if result.get('code') == 200:
- print(f"事件创建成功: {addr}")
- return True
- print(f"创建事件失败: {response.text}")
- except Exception as e:
- print(f"创建事件异常: {e}")
- return False
- def check_safety_equipment(detections):
- person_detections = []
- helmet_detections = []
- safety_clothes_detections = []
-
- for det in detections:
- x1, y1, x2, y2 = det.bbox
- if det.class_id == 3:
- person_detections.append((x1, y1, x2, y2, det.confidence))
- elif det.class_id == 0:
- helmet_detections.append((x1, y1, x2, y2, det.confidence))
- elif det.class_id == 4:
- safety_clothes_detections.append((x1, y1, x2, y2, det.confidence))
-
- need_alert = False
- alert_addr = None
-
- for person_x1, person_y1, person_x2, person_y2, person_conf in person_detections:
- has_helmet = False
- for helmet_x1, helmet_y1, helmet_x2, helmet_y2, helmet_conf in helmet_detections:
- helmet_center_x = (helmet_x1 + helmet_x2) / 2
- helmet_center_y = (helmet_y1 + helmet_y2) / 2
- if (helmet_center_x >= person_x1 and helmet_center_x <= person_x2 and
- helmet_center_y >= person_y1 and helmet_center_y <= person_y2):
- has_helmet = True
- break
-
- has_safety_clothes = False
- for clothes_x1, clothes_y1, clothes_x2, clothes_y2, clothes_conf in safety_clothes_detections:
- overlap_x1 = max(person_x1, clothes_x1)
- overlap_y1 = max(person_y1, clothes_y1)
- overlap_x2 = min(person_x2, clothes_x2)
- overlap_y2 = min(person_y2, clothes_y2)
-
- if overlap_x1 < overlap_x2 and overlap_y1 < overlap_y2:
- has_safety_clothes = True
- break
-
- if not has_helmet or not has_safety_clothes:
- need_alert = True
- if not has_helmet and not has_safety_clothes:
- alert_addr = "反光衣和安全帽都没戴"
- elif not has_helmet:
- alert_addr = "未戴安全帽"
- else:
- alert_addr = "未穿反光衣"
-
- print(f"警告: {alert_addr},置信度: {person_conf:.2f}")
-
- return need_alert, alert_addr, person_detections
- class RTSPCapture:
- def __init__(self, rtsp_url, model_path, rtmp_url, fps=2):
- self.rtsp_url = rtsp_url
- self.rtmp_url = rtmp_url
- self.det = create_detector(model_path)
- self.cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
- self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
- self.rtmp_pipe = None
- self.process_fps = fps
- self.conf_threshold_map = {3: 0.8, 0: 0.5, 4: 0.5}
- self.last_upload_time = 0
- self.upload_interval = 2
- def start_rtmp(self):
- w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- fps = self.cap.get(cv2.CAP_PROP_FPS) or 25
- command = [
- 'ffmpeg',
- '-y',
- '-f', 'rawvideo',
- '-pix_fmt', 'bgr24',
- '-s', f'{w}x{h}',
- '-r', str(fps),
- '-i', '-',
- '-c:v', 'libx264',
- '-preset', 'ultrafast',
- '-tune', 'zerolatency',
- '-f', 'flv',
- self.rtmp_url
- ]
- self.rtmp_pipe = subprocess.Popen(
- command,
- stdin=subprocess.PIPE,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL
- )
- def run(self):
- self.start_rtmp()
- frame_count = 0
-
- fps = self.cap.get(cv2.CAP_PROP_FPS) or 25
- frame_interval = int(round(fps / self.process_fps)) if fps > 0 else 1
- print(f"帧间隔: {frame_interval} 帧")
-
- last_dets = []
- last_need_alert = False
- last_alert_addr = None
- last_person_detections = []
-
- while True:
- ret, frame = self.cap.read()
- if not ret:
- break
- frame_count += 1
-
- if frame_count % frame_interval == 0:
- try:
- last_dets = self.det.detect(frame, self.conf_threshold_map)
- print(last_dets)
-
- last_need_alert, last_alert_addr, last_person_detections = check_safety_equipment(last_dets)
-
- if last_dets:
- print(f"[Frame {frame_count}] 检测到 {len(last_dets)} 个目标")
- for d in last_dets:
- print(f" {d.class_name}: conf={d.confidence:.2f}, box={d.bbox}")
- except Exception as e:
- print(f"检测过程中出错: {e}")
- for d in last_dets:
- x1, y1, x2, y2 = d.bbox
- cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
- text = f"{d.class_name}: {d.confidence:.2f}"
- text_y = max(15, y1 - 20)
- put_text_chinese(frame, text, (x1, text_y), font_size=20, color=(255, 0, 0))
-
- if last_person_detections and last_need_alert and last_alert_addr:
- current_time = time.time()
- if current_time - self.last_upload_time >= self.upload_interval:
- print(f"检测到人,触发告警上传")
- temp_image_path = f"alert_frame_{frame_count}.jpg"
- cv2.imwrite(temp_image_path, frame)
-
- purl = upload_image(temp_image_path)
- if purl:
- create_event(last_alert_addr, purl)
- self.last_upload_time = current_time
-
- if os.path.exists(temp_image_path):
- os.remove(temp_image_path)
- if self.rtmp_pipe:
- try:
- self.rtmp_pipe.stdin.write(frame.tobytes())
- except:
- pass
- cv2.imshow("RK3588 工地安全检测", frame)
- if cv2.waitKey(1) & 0xFF == ord('q'):
- break
-
- self.cap.release()
- self.det.release()
- cv2.destroyAllWindows()
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--rtsp", required=True)
- parser.add_argument("--model", default="yolo11m_safety.rknn")
- parser.add_argument("--rtmp", required=True)
- parser.add_argument("--fps", type=int, default=2, help="每秒处理的帧数")
- args = parser.parse_args()
- cap = RTSPCapture(args.rtsp, args.model, args.rtmp, args.fps)
- cap.run()
|