safety_detector.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. """
  2. 施工现场安全行为检测模块
  3. 使用 YOLO11 模型检测人员、安全帽、反光衣
  4. 判断是否存在违规行为(未戴安全帽、未穿反光衣)
  5. 支持两种模型格式:
  6. - YOLO (.pt/.onnx): 使用 ultralytics 库
  7. - RKNN (.rknn): 使用 rknnlite 库 (RK3588 平台)
  8. """
  9. import cv2
  10. import numpy as np
  11. from typing import Optional, List, Tuple, Dict, Any
  12. from dataclasses import dataclass
  13. from enum import Enum
  14. import os
  15. # ============================================
  16. # RKNN 模型支持
  17. # ============================================
  18. @dataclass
  19. class Detection:
  20. """检测结果 (用于 RKNN 模型)"""
  21. class_id: int
  22. class_name: str
  23. confidence: float
  24. bbox: Tuple[int, int, int, int]
  25. def nms(dets, iou_threshold=0.45):
  26. """非极大值抑制"""
  27. if len(dets) == 0:
  28. return []
  29. boxes = np.array([[d.bbox[0], d.bbox[1], d.bbox[2], d.bbox[3], d.confidence] for d in dets])
  30. x1 = boxes[:, 0]
  31. y1 = boxes[:, 1]
  32. x2 = boxes[:, 2]
  33. y2 = boxes[:, 3]
  34. scores = boxes[:, 4]
  35. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  36. order = scores.argsort()[::-1]
  37. keep = []
  38. while order.size > 0:
  39. i = order[0]
  40. keep.append(i)
  41. xx1 = np.maximum(x1[i], x1[order[1:]])
  42. yy1 = np.maximum(y1[i], y1[order[1:]])
  43. xx2 = np.minimum(x2[i], x2[order[1:]])
  44. yy2 = np.minimum(y2[i], y2[order[1:]])
  45. w = np.maximum(0.0, xx2 - xx1 + 1)
  46. h = np.maximum(0.0, yy2 - yy1 + 1)
  47. inter = w * h
  48. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  49. inds = np.where(ovr <= iou_threshold)[0]
  50. order = order[inds + 1]
  51. return [dets[i] for i in keep]
  52. class BaseDetector:
  53. """检测器基类 (用于 RKNN/ONNX 模型)"""
  54. # 类别映射: 0: 安全帽, 3: 人, 4: 反光衣
  55. LABEL_MAP = {0: '安全帽', 4: '安全衣', 3: '人'}
  56. def __init__(self):
  57. self.input_size = (640, 640)
  58. self.num_classes = 5
  59. def letterbox(self, image):
  60. """Letterbox 预处理,保持宽高比"""
  61. h0, w0 = image.shape[:2]
  62. ih, iw = self.input_size
  63. scale = min(iw / w0, ih / h0)
  64. new_w, new_h = int(w0 * scale), int(h0 * scale)
  65. pad_w = (iw - new_w) // 2
  66. pad_h = (ih - new_h) // 2
  67. resized = cv2.resize(image, (new_w, new_h))
  68. canvas = np.full((ih, iw, 3), 114, dtype=np.uint8)
  69. canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
  70. return canvas, scale, pad_w, pad_h, h0, w0
  71. def postprocess(self, outputs, scale, pad_w, pad_h, h0, w0, conf_threshold_map):
  72. """后处理"""
  73. dets = []
  74. if not outputs:
  75. return dets
  76. output = outputs[0]
  77. if len(output.shape) == 3:
  78. output = output[0]
  79. num_boxes = output.shape[1]
  80. for i in range(num_boxes):
  81. x_center = float(output[0, i])
  82. y_center = float(output[1, i])
  83. width = float(output[2, i])
  84. height = float(output[3, i])
  85. class_probs = output[4:4+self.num_classes, i]
  86. best_class = int(np.argmax(class_probs))
  87. confidence = float(class_probs[best_class])
  88. if best_class not in self.LABEL_MAP:
  89. continue
  90. conf_threshold = conf_threshold_map.get(best_class, 0.5)
  91. if confidence < conf_threshold:
  92. continue
  93. # 移除 padding 并缩放到原始图像尺寸
  94. x1 = int(((x_center - width / 2) - pad_w) / scale)
  95. y1 = int(((y_center - height / 2) - pad_h) / scale)
  96. x2 = int(((x_center + width / 2) - pad_w) / scale)
  97. y2 = int(((y_center + height / 2) - pad_h) / scale)
  98. x1 = max(0, min(w0, x1))
  99. y1 = max(0, min(h0, y1))
  100. x2 = max(0, min(w0, x2))
  101. y2 = max(0, min(h0, y2))
  102. det = Detection(
  103. class_id=best_class,
  104. class_name=self.LABEL_MAP[best_class],
  105. confidence=confidence,
  106. bbox=(x1, y1, x2, y2)
  107. )
  108. dets.append(det)
  109. dets = nms(dets, iou_threshold=0.45)
  110. return dets
  111. def detect(self, image, conf_threshold_map):
  112. raise NotImplementedError
  113. def release(self):
  114. pass
  115. class RKNNDetector(BaseDetector):
  116. """RKNN 检测器 - 使用 NHWC 输入格式 (1, H, W, C)"""
  117. def __init__(self, model_path: str):
  118. super().__init__()
  119. self.model_path = model_path
  120. self.rknn = None
  121. try:
  122. from rknnlite.api import RKNNLite
  123. self.rknn = RKNNLite()
  124. except ImportError:
  125. raise ImportError("未安装 rknnlite,请运行: pip install rknnlite2 或参考 testrk3588/setup_rknn.sh")
  126. ret = self.rknn.load_rknn(model_path)
  127. if ret != 0:
  128. raise RuntimeError(f"加载 RKNN 模型失败: {model_path}")
  129. ret = self.rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0_1_2)
  130. if ret != 0:
  131. raise RuntimeError(f"初始化 RKNN 运行时失败")
  132. print(f"RKNN 模型加载成功: {model_path}")
  133. def detect(self, image, conf_threshold_map):
  134. canvas, scale, pad_w, pad_h, h0, w0 = self.letterbox(image)
  135. # RKNN 期望 NHWC (1, H, W, C), RGB, 归一化 0-1
  136. img = canvas[..., ::-1].astype(np.float32) / 255.0
  137. blob = img[None, ...] # (1, 640, 640, 3)
  138. outs = self.rknn.inference(inputs=[blob])
  139. return self.postprocess(outs, scale, pad_w, pad_h, h0, w0, conf_threshold_map)
  140. def release(self):
  141. if self.rknn:
  142. self.rknn.release()
  143. self.rknn = None
  144. class ONNXDetector(BaseDetector):
  145. """ONNX 检测器 - 使用 NCHW 输入格式 (1, C, H, W)"""
  146. def __init__(self, model_path: str):
  147. super().__init__()
  148. self.model_path = model_path
  149. try:
  150. import onnxruntime as ort
  151. self.session = ort.InferenceSession(model_path)
  152. self.input_name = self.session.get_inputs()[0].name
  153. self.output_name = self.session.get_outputs()[0].name
  154. print(f"ONNX 模型加载成功: {model_path}")
  155. except ImportError:
  156. raise ImportError("未安装 onnxruntime,请运行: pip install onnxruntime")
  157. except Exception as e:
  158. raise RuntimeError(f"加载 ONNX 模型失败: {e}")
  159. def detect(self, image, conf_threshold_map):
  160. canvas, scale, pad_w, pad_h, h0, w0 = self.letterbox(image)
  161. # ONNX 期望 NCHW (1, C, H, W), RGB, 归一化 0-1
  162. img = canvas[..., ::-1].astype(np.float32) / 255.0
  163. img = img.transpose(2, 0, 1)
  164. blob = img[None, ...] # (1, 3, 640, 640)
  165. outs = self.session.run([self.output_name], {self.input_name: blob})
  166. return self.postprocess(outs, scale, pad_w, pad_h, h0, w0, conf_threshold_map)
  167. def release(self):
  168. self.session = None
  169. def create_detector(model_path: str):
  170. """
  171. 创建检测器工厂函数
  172. Args:
  173. model_path: 模型路径 (.rknn, .onnx, .pt)
  174. Returns:
  175. 检测器实例
  176. """
  177. ext = os.path.splitext(model_path)[1].lower()
  178. if ext == '.rknn':
  179. print(f"使用 RKNN 模型: {model_path}")
  180. return RKNNDetector(model_path)
  181. elif ext == '.onnx':
  182. print(f"使用 ONNX 模型: {model_path}")
  183. return ONNXDetector(model_path)
  184. elif ext == '.pt':
  185. print(f"使用 YOLO 模型: {model_path}")
  186. return None # YOLO 使用原来的 SafetyDetector
  187. else:
  188. raise ValueError(f"不支持的模型格式: {ext}")
  189. # ============================================
  190. # 原有 YOLO 安全检测器
  191. # ============================================
  192. class SafetyViolationType(Enum):
  193. """安全违规类型"""
  194. NO_HELMET = "未戴安全帽" # 未戴安全帽
  195. NO_SAFETY_VEST = "未穿反光衣" # 未穿反光衣
  196. NO_BOTH = "反光衣和安全帽都没戴" # 都没有
  197. @dataclass
  198. class SafetyDetection:
  199. """安全检测结果"""
  200. # 基础信息
  201. class_id: int # 类别ID
  202. class_name: str # 类别名称
  203. confidence: float # 置信度
  204. bbox: Tuple[int, int, int, int] # 边界框 (x1, y1, x2, y2)
  205. center: Tuple[int, int] # 中心点坐标
  206. track_id: Optional[int] = None # 跟踪ID
  207. @dataclass
  208. class PersonSafetyStatus:
  209. """人员安全状态"""
  210. track_id: int # 跟踪ID
  211. person_bbox: Tuple[int, int, int, int] # 人体边界框
  212. person_conf: float # 人体置信度
  213. has_helmet: bool = False # 是否戴安全帽
  214. helmet_conf: float = 0.0 # 安全帽置信度
  215. has_safety_vest: bool = False # 是否穿反光衣
  216. vest_conf: float = 0.0 # 反光衣置信度
  217. is_violation: bool = False # 是否违规
  218. violation_types: List[SafetyViolationType] = None # 违规类型列表
  219. def __post_init__(self):
  220. if self.violation_types is None:
  221. self.violation_types = []
  222. def check_violation(self) -> bool:
  223. """检查是否违规"""
  224. self.violation_types = []
  225. if not self.has_helmet and not self.has_safety_vest:
  226. self.violation_types.append(SafetyViolationType.NO_BOTH)
  227. elif not self.has_helmet:
  228. self.violation_types.append(SafetyViolationType.NO_HELMET)
  229. elif not self.has_safety_vest:
  230. self.violation_types.append(SafetyViolationType.NO_SAFETY_VEST)
  231. self.is_violation = len(self.violation_types) > 0
  232. return self.is_violation
  233. def get_violation_desc(self) -> str:
  234. """获取违规描述"""
  235. if not self.is_violation:
  236. return ""
  237. if SafetyViolationType.NO_BOTH in self.violation_types:
  238. return "反光衣和安全帽都没戴"
  239. elif SafetyViolationType.NO_HELMET in self.violation_types:
  240. return "未戴安全帽"
  241. elif SafetyViolationType.NO_SAFETY_VEST in self.violation_types:
  242. return "未穿反光衣"
  243. return ""
  244. class SafetyDetector:
  245. """
  246. 施工现场安全检测器
  247. 使用 YOLO11 检测人员、安全帽、反光衣
  248. """
  249. CLASS_MAP = {
  250. 0: '安全帽',
  251. 3: '人',
  252. 4: '反光衣'
  253. }
  254. CLASS_ID_MAP = {
  255. 'helmet': 0,
  256. 'person': 3,
  257. 'safety_vest': 4
  258. }
  259. def __init__(self, model_path: str = None, use_gpu: bool = True,
  260. conf_threshold: float = 0.5, person_threshold: float = 0.8,
  261. model_type: str = 'auto'):
  262. """
  263. 初始化安全检测器
  264. Args:
  265. model_path: 模型路径,默认使用 yolo11m_safety.pt 或 .rknn
  266. use_gpu: 是否使用 GPU (仅 YOLO 模型有效)
  267. conf_threshold: 一般物品置信度阈值 (安全帽、反光衣)
  268. person_threshold: 人员检测置信度阈值
  269. model_type: 模型类型 ('auto', 'yolo', 'rknn', 'onnx')
  270. """
  271. self.model = None
  272. self.rknn_detector = None
  273. self.model_type = model_type
  274. # 根据扩展名自动判断模型类型
  275. if model_path:
  276. ext = os.path.splitext(model_path)[1].lower()
  277. if ext == '.rknn':
  278. self.model_type = 'rknn'
  279. elif ext == '.onnx':
  280. self.model_type = 'onnx'
  281. elif ext == '.pt':
  282. self.model_type = 'yolo'
  283. self.model_path = model_path
  284. self.use_gpu = use_gpu
  285. self.device = 'cuda:0' if use_gpu else 'cpu'
  286. self.conf_threshold = conf_threshold
  287. self.person_threshold = person_threshold
  288. self._load_model()
  289. def _load_model(self):
  290. """加载检测模型"""
  291. if self.model_type == 'rknn':
  292. self._load_rknn_model()
  293. elif self.model_type == 'onnx':
  294. self._load_onnx_model()
  295. else:
  296. self._load_yolo_model()
  297. def _load_rknn_model(self):
  298. """加载 RKNN 模型"""
  299. if not self.model_path:
  300. raise ValueError("RKNN 模型需要指定 model_path")
  301. try:
  302. self.rknn_detector = RKNNDetector(self.model_path)
  303. print(f"RKNN 安全检测模型加载成功: {self.model_path}")
  304. except ImportError as e:
  305. raise ImportError(f"rknnlite 未安装: {e}")
  306. except Exception as e:
  307. raise RuntimeError(f"加载 RKNN 模型失败: {e}")
  308. def _load_onnx_model(self):
  309. """加载 ONNX 模型"""
  310. if not self.model_path:
  311. raise ValueError("ONNX 模型需要指定 model_path")
  312. try:
  313. self.rknn_detector = ONNXDetector(self.model_path)
  314. print(f"ONNX 安全检测模型加载成功: {self.model_path}")
  315. except ImportError as e:
  316. raise ImportError(f"onnxruntime 未安装: {e}")
  317. except Exception as e:
  318. raise RuntimeError(f"加载 ONNX 模型失败: {e}")
  319. def _load_yolo_model(self):
  320. """加载 YOLO11 安全检测模型"""
  321. try:
  322. from ultralytics import YOLO
  323. if not self.model_path:
  324. self.model_path = '/home/wen/dsh/yolo/yolo11m_safety.pt'
  325. self.model = YOLO(self.model_path)
  326. dummy = np.zeros((640, 640, 3), dtype=np.uint8)
  327. self.model(dummy, device=self.device, verbose=False)
  328. print(f"YOLO 安全检测模型加载成功: {self.model_path} (device={self.device})")
  329. except ImportError:
  330. raise ImportError("未安装 ultralytics,请运行: pip install ultralytics")
  331. except Exception as e:
  332. raise RuntimeError(f"加载 YOLO 模型失败: {e}")
  333. def detect(self, frame: np.ndarray) -> List[SafetyDetection]:
  334. """
  335. 检测画面中的安全相关对象
  336. Args:
  337. frame: 输入图像
  338. Returns:
  339. 检测结果列表
  340. """
  341. if frame is None:
  342. return []
  343. if self.rknn_detector is not None:
  344. return self._detect_rknn(frame)
  345. else:
  346. return self._detect_yolo(frame)
  347. def _detect_rknn(self, frame: np.ndarray) -> List[SafetyDetection]:
  348. """使用 RKNN/ONNX 模型检测"""
  349. results = []
  350. try:
  351. conf_threshold_map = {
  352. 3: self.person_threshold,
  353. 0: self.conf_threshold,
  354. 4: self.conf_threshold
  355. }
  356. detections = self.rknn_detector.detect(frame, conf_threshold_map)
  357. for det in detections:
  358. x1, y1, x2, y2 = det.bbox
  359. center_x = (x1 + x2) // 2
  360. center_y = (y1 + y2) // 2
  361. safety_det = SafetyDetection(
  362. class_id=det.class_id,
  363. class_name=det.class_name,
  364. confidence=det.confidence,
  365. bbox=det.bbox,
  366. center=(center_x, center_y)
  367. )
  368. results.append(safety_det)
  369. except Exception as e:
  370. print(f"RKNN 检测错误: {e}")
  371. return results
  372. def _detect_yolo(self, frame: np.ndarray) -> List[SafetyDetection]:
  373. """使用 YOLO 模型检测"""
  374. results = []
  375. try:
  376. detections = self.model(frame, device=self.device, verbose=False)
  377. for det in detections:
  378. boxes = det.boxes
  379. if boxes is None:
  380. continue
  381. for i in range(len(boxes)):
  382. cls_id = int(boxes.cls[i])
  383. if cls_id not in self.CLASS_MAP:
  384. continue
  385. cls_name = self.CLASS_MAP[cls_id]
  386. conf = float(boxes.conf[i])
  387. threshold = self.person_threshold if cls_id == 3 else self.conf_threshold
  388. if conf < threshold:
  389. continue
  390. xyxy = boxes.xyxy[i].cpu().numpy()
  391. x1, y1, x2, y2 = map(int, xyxy)
  392. width = x2 - x1
  393. height = y2 - y1
  394. if width < 10 or height < 10:
  395. continue
  396. center_x = (x1 + x2) // 2
  397. center_y = (y1 + y2) // 2
  398. detection = SafetyDetection(
  399. class_id=cls_id,
  400. class_name=cls_name,
  401. confidence=conf,
  402. bbox=(x1, y1, x2, y2),
  403. center=(center_x, center_y)
  404. )
  405. results.append(detection)
  406. except Exception as e:
  407. print(f"YOLO 检测错误: {e}")
  408. return results
  409. def release(self):
  410. """释放模型资源"""
  411. if self.rknn_detector:
  412. self.rknn_detector.release()
  413. self.rknn_detector = None
  414. self.model = None
  415. def check_safety(self, frame: np.ndarray,
  416. detections: List[SafetyDetection] = None) -> List[PersonSafetyStatus]:
  417. """
  418. 检查人员安全状态
  419. Args:
  420. frame: 输入图像
  421. detections: 检测结果,如果为 None 则自动检测
  422. Returns:
  423. 人员安全状态列表
  424. """
  425. if detections is None:
  426. detections = self.detect(frame)
  427. # 分类检测结果
  428. persons = []
  429. helmets = []
  430. vests = []
  431. for det in detections:
  432. if det.class_id == 3: # 人
  433. persons.append(det)
  434. elif det.class_id == 0: # 安全帽
  435. helmets.append(det)
  436. elif det.class_id == 4: # 反光衣
  437. vests.append(det)
  438. # 检查每个人员的安全状态
  439. results = []
  440. for person in persons:
  441. status = PersonSafetyStatus(
  442. track_id=person.track_id or 0,
  443. person_bbox=person.bbox,
  444. person_conf=person.confidence
  445. )
  446. px1, py1, px2, py2 = person.bbox
  447. # 检查是否戴安全帽
  448. # 安全帽应该在人体上方区域(头部附近)
  449. for helmet in helmets:
  450. hx1, hy1, hx2, hy2 = helmet.bbox
  451. # 检查安全帽是否在人体框内
  452. helmet_center_x = (hx1 + hx2) / 2
  453. helmet_center_y = (hy1 + hy2) / 2
  454. # 安全帽中心在人体框内,且在人体上半部分
  455. if (hx1 >= px1 and hx2 <= px2 and
  456. helmet_center_y >= py1 and
  457. helmet_center_y <= py1 + (py2 - py1) * 0.5):
  458. status.has_helmet = True
  459. status.helmet_conf = helmet.confidence
  460. break
  461. # 检查是否穿反光衣
  462. # 反光衣应该与人体有重叠
  463. for vest in vests:
  464. vx1, vy1, vx2, vy2 = vest.bbox
  465. # 计算重叠区域
  466. overlap_x1 = max(px1, vx1)
  467. overlap_y1 = max(py1, vy1)
  468. overlap_x2 = min(px2, vx2)
  469. overlap_y2 = min(py2, vy2)
  470. # 如果有重叠
  471. if overlap_x1 < overlap_x2 and overlap_y1 < overlap_y2:
  472. # 计算重叠面积占比
  473. overlap_area = (overlap_x2 - overlap_x1) * (overlap_y2 - overlap_y1)
  474. vest_area = (vx2 - vx1) * (vy2 - vy1)
  475. overlap_ratio = overlap_area / vest_area if vest_area > 0 else 0
  476. # 重叠比例超过30%认为穿了反光衣
  477. if overlap_ratio > 0.3:
  478. status.has_safety_vest = True
  479. status.vest_conf = vest.confidence
  480. break
  481. # 检查是否违规
  482. status.check_violation()
  483. results.append(status)
  484. return results
  485. # 轨迹追踪已禁用 - detect_with_tracking 方法已移除
  486. def draw_safety_result(frame: np.ndarray,
  487. detections: List[SafetyDetection],
  488. status_list: List[PersonSafetyStatus]) -> np.ndarray:
  489. """
  490. 在图像上绘制安全检测结果
  491. Args:
  492. frame: 输入图像
  493. detections: 检测结果
  494. status_list: 人员安全状态
  495. Returns:
  496. 绘制后的图像
  497. """
  498. result = frame.copy()
  499. # 绘制检测框
  500. for det in detections:
  501. x1, y1, x2, y2 = det.bbox
  502. # 根据类别选择颜色
  503. if det.class_id == 3: # 人
  504. color = (0, 255, 0) # 绿色
  505. elif det.class_id == 0: # 安全帽
  506. color = (255, 165, 0) # 橙色
  507. elif det.class_id == 4: # 反光衣
  508. color = (0, 165, 255) # 黄色
  509. else:
  510. color = (255, 255, 255)
  511. cv2.rectangle(result, (x1, y1), (x2, y2), color, 2)
  512. # 绘制标签
  513. label = f"{det.class_name}: {det.conf:.2f}"
  514. cv2.putText(result, label, (x1, y1 - 5),
  515. cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
  516. # 绘制安全状态
  517. for status in status_list:
  518. x1, y1, x2, y2 = status.person_bbox
  519. if status.is_violation:
  520. # 违规 - 红色警告
  521. color = (0, 0, 255)
  522. text = status.get_violation_desc()
  523. cv2.rectangle(result, (x1, y1), (x2, y2), color, 3)
  524. cv2.putText(result, text, (x1, y2 + 20),
  525. cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
  526. else:
  527. # 正常 - 显示安全标识
  528. color = (0, 255, 0)
  529. text = "安全装备齐全"
  530. cv2.putText(result, text, (x1, y2 + 20),
  531. cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
  532. return result
  533. class LLMSafetyDetector:
  534. """
  535. 基于大模型的安全检测器
  536. 结合 YOLO 检测和大模型判断
  537. """
  538. def __init__(self, yolo_model_path: str = None,
  539. llm_config: Dict[str, Any] = None,
  540. use_gpu: bool = True,
  541. use_llm: bool = True,
  542. model_type: str = 'auto'):
  543. """
  544. 初始化检测器
  545. Args:
  546. yolo_model_path: 模型路径 (.pt, .rknn, .onnx)
  547. llm_config: 大模型配置
  548. use_gpu: 是否使用 GPU (仅 YOLO 模型有效)
  549. use_llm: 是否使用大模型判断
  550. model_type: 模型类型 ('auto', 'yolo', 'rknn', 'onnx')
  551. """
  552. # 安全检测器 (支持 YOLO/RKNN/ONNX)
  553. self.yolo_detector = SafetyDetector(
  554. model_path=yolo_model_path,
  555. use_gpu=use_gpu,
  556. model_type=model_type
  557. )
  558. # 大模型分析器
  559. self.use_llm = use_llm
  560. self.llm_analyzer = None
  561. if use_llm:
  562. try:
  563. from llm_service import SafetyAnalyzer, NumberRecognizer
  564. self.llm_analyzer = SafetyAnalyzer(llm_config)
  565. self.number_recognizer = NumberRecognizer(llm_config)
  566. print("大模型安全分析器初始化成功")
  567. except ImportError:
  568. print("未找到 llm_service 模块,将使用规则判断")
  569. self.use_llm = False
  570. except Exception as e:
  571. print(f"大模型初始化失败: {e},将使用规则判断")
  572. self.use_llm = False
  573. def detect(self, frame: np.ndarray) -> List[SafetyDetection]:
  574. """
  575. YOLO 检测
  576. Args:
  577. frame: 输入图像
  578. Returns:
  579. 检测结果列表
  580. """
  581. return self.yolo_detector.detect(frame)
  582. def check_safety(self, frame: np.ndarray,
  583. detections: List[SafetyDetection] = None,
  584. use_llm: bool = None) -> List[PersonSafetyStatus]:
  585. """
  586. 检查人员安全状态
  587. Args:
  588. frame: 输入图像
  589. detections: YOLO 检测结果
  590. use_llm: 是否使用大模型(覆盖默认设置)
  591. Returns:
  592. 人员安全状态列表
  593. """
  594. # 先用 YOLO 检测
  595. if detections is None:
  596. detections = self.yolo_detector.detect(frame)
  597. # 规则判断
  598. rule_status_list = self.yolo_detector.check_safety(frame, detections)
  599. # 如果不使用大模型,直接返回规则判断结果
  600. should_use_llm = use_llm if use_llm is not None else self.use_llm
  601. if not should_use_llm or self.llm_analyzer is None:
  602. return rule_status_list
  603. # 使用大模型对每个人员进行判断
  604. llm_status_list = []
  605. for status in rule_status_list:
  606. # 裁剪人员区域
  607. x1, y1, x2, y2 = status.person_bbox
  608. margin = 10
  609. x1 = max(0, x1 - margin)
  610. y1 = max(0, y1 - margin)
  611. x2 = min(frame.shape[1], x2 + margin)
  612. y2 = min(frame.shape[0], y2 + margin)
  613. person_image = frame[y1:y2, x1:x2]
  614. # 调用大模型分析
  615. try:
  616. llm_result = self.llm_analyzer.check_person_safety(person_image)
  617. # 更新状态
  618. if llm_result.get('success', False):
  619. status.has_helmet = llm_result.get('has_helmet', False)
  620. status.has_safety_vest = llm_result.get('has_vest', False)
  621. # 重新检查违规
  622. status.check_violation()
  623. # 如果大模型判断有违规,使用大模型的描述
  624. if status.is_violation and llm_result.get('violation_desc'):
  625. # 更新违规类型
  626. desc = llm_result.get('violation_desc', '')
  627. if '安全帽' in desc and '反光' in desc:
  628. status.violation_types = [SafetyViolationType.NO_BOTH]
  629. elif '安全帽' in desc:
  630. status.violation_types = [SafetyViolationType.NO_HELMET]
  631. elif '反光' in desc:
  632. status.violation_types = [SafetyViolationType.NO_SAFETY_VEST]
  633. except Exception as e:
  634. print(f"大模型分析失败: {e}")
  635. llm_status_list.append(status)
  636. return llm_status_list
  637. def recognize_number(self, frame: np.ndarray,
  638. person_bbox: Tuple[int, int, int, int]) -> Dict[str, Any]:
  639. """
  640. 识别人员编号
  641. Args:
  642. frame: 输入图像
  643. person_bbox: 人员边界框
  644. Returns:
  645. 编号识别结果
  646. """
  647. if self.number_recognizer is None:
  648. return {'number': None, 'success': False}
  649. # 裁剪人员区域
  650. x1, y1, x2, y2 = person_bbox
  651. person_image = frame[y1:y2, x1:x2]
  652. return self.number_recognizer.recognize_person_number(person_image)
  653. # 轨迹追踪已禁用 - detect_with_tracking 方法已移除