safety_detector.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. """
  2. 施工现场安全行为检测模块
  3. 使用 YOLO11 模型检测人员、安全帽、反光衣
  4. 判断是否存在违规行为(未戴安全帽、未穿反光衣)
  5. """
  6. import cv2
  7. import numpy as np
  8. from typing import Optional, List, Tuple, Dict, Any
  9. from dataclasses import dataclass
  10. from enum import Enum
  11. class SafetyViolationType(Enum):
  12. """安全违规类型"""
  13. NO_HELMET = "未戴安全帽" # 未戴安全帽
  14. NO_SAFETY_VEST = "未穿反光衣" # 未穿反光衣
  15. NO_BOTH = "反光衣和安全帽都没戴" # 都没有
  16. @dataclass
  17. class SafetyDetection:
  18. """安全检测结果"""
  19. # 基础信息
  20. class_id: int # 类别ID
  21. class_name: str # 类别名称
  22. confidence: float # 置信度
  23. bbox: Tuple[int, int, int, int] # 边界框 (x1, y1, x2, y2)
  24. center: Tuple[int, int] # 中心点坐标
  25. track_id: Optional[int] = None # 跟踪ID
  26. @dataclass
  27. class PersonSafetyStatus:
  28. """人员安全状态"""
  29. track_id: int # 跟踪ID
  30. person_bbox: Tuple[int, int, int, int] # 人体边界框
  31. person_conf: float # 人体置信度
  32. has_helmet: bool = False # 是否戴安全帽
  33. helmet_conf: float = 0.0 # 安全帽置信度
  34. has_safety_vest: bool = False # 是否穿反光衣
  35. vest_conf: float = 0.0 # 反光衣置信度
  36. is_violation: bool = False # 是否违规
  37. violation_types: List[SafetyViolationType] = None # 违规类型列表
  38. def __post_init__(self):
  39. if self.violation_types is None:
  40. self.violation_types = []
  41. def check_violation(self) -> bool:
  42. """检查是否违规"""
  43. self.violation_types = []
  44. if not self.has_helmet and not self.has_safety_vest:
  45. self.violation_types.append(SafetyViolationType.NO_BOTH)
  46. elif not self.has_helmet:
  47. self.violation_types.append(SafetyViolationType.NO_HELMET)
  48. elif not self.has_safety_vest:
  49. self.violation_types.append(SafetyViolationType.NO_SAFETY_VEST)
  50. self.is_violation = len(self.violation_types) > 0
  51. return self.is_violation
  52. def get_violation_desc(self) -> str:
  53. """获取违规描述"""
  54. if not self.is_violation:
  55. return ""
  56. if SafetyViolationType.NO_BOTH in self.violation_types:
  57. return "反光衣和安全帽都没戴"
  58. elif SafetyViolationType.NO_HELMET in self.violation_types:
  59. return "未戴安全帽"
  60. elif SafetyViolationType.NO_SAFETY_VEST in self.violation_types:
  61. return "未穿反光衣"
  62. return ""
  63. class SafetyDetector:
  64. """
  65. 施工现场安全检测器
  66. 使用 YOLO11 检测人员、安全帽、反光衣
  67. """
  68. # 类别映射 (根据 yolo11m_safety.pt 模型的训练标签)
  69. # 0: 安全帽, 3: 人, 4: 安全衣/反光衣
  70. CLASS_MAP = {
  71. 0: '安全帽',
  72. 3: '人',
  73. 4: '反光衣'
  74. }
  75. # 反向映射
  76. CLASS_ID_MAP = {
  77. 'helmet': 0,
  78. 'person': 3,
  79. 'safety_vest': 4
  80. }
  81. def __init__(self, model_path: str = None, use_gpu: bool = True,
  82. conf_threshold: float = 0.5, person_threshold: float = 0.8):
  83. """
  84. 初始化安全检测器
  85. Args:
  86. model_path: 模型路径,默认使用 yolo11m_safety.pt
  87. use_gpu: 是否使用 GPU
  88. conf_threshold: 一般物品置信度阈值 (安全帽、反光衣)
  89. person_threshold: 人员检测置信度阈值
  90. """
  91. self.model = None
  92. self.model_path = model_path or '/home/wen/dsh/yolo/yolo11m_safety.pt'
  93. self.use_gpu = use_gpu
  94. self.device = 'cuda:0' if use_gpu else 'cpu'
  95. # 置信度阈值
  96. self.conf_threshold = conf_threshold
  97. self.person_threshold = person_threshold
  98. # 加载模型
  99. self._load_model()
  100. def _load_model(self):
  101. """加载 YOLO11 安全检测模型"""
  102. try:
  103. from ultralytics import YOLO
  104. self.model = YOLO(self.model_path)
  105. # 预热模型
  106. dummy = np.zeros((640, 640, 3), dtype=np.uint8)
  107. self.model(dummy, device=self.device, verbose=False)
  108. print(f"安全检测模型加载成功: {self.model_path} (device={self.device})")
  109. except ImportError:
  110. raise ImportError("未安装 ultralytics,请运行: pip install ultralytics")
  111. except Exception as e:
  112. raise RuntimeError(f"加载模型失败: {e}")
  113. def detect(self, frame: np.ndarray) -> List[SafetyDetection]:
  114. """
  115. 检测画面中的安全相关对象
  116. Args:
  117. frame: 输入图像
  118. Returns:
  119. 检测结果列表
  120. """
  121. if self.model is None or frame is None:
  122. return []
  123. results = []
  124. try:
  125. detections = self.model(frame, device=self.device, verbose=False)
  126. for det in detections:
  127. boxes = det.boxes
  128. if boxes is None:
  129. continue
  130. for i in range(len(boxes)):
  131. # 获取类别
  132. cls_id = int(boxes.cls[i])
  133. # 只处理我们关心的类别
  134. if cls_id not in self.CLASS_MAP:
  135. continue
  136. cls_name = self.CLASS_MAP[cls_id]
  137. conf = float(boxes.conf[i])
  138. # 根据类别设置不同的置信度阈值
  139. threshold = self.person_threshold if cls_id == 3 else self.conf_threshold
  140. if conf < threshold:
  141. continue
  142. # 获取边界框
  143. xyxy = boxes.xyxy[i].cpu().numpy()
  144. x1, y1, x2, y2 = map(int, xyxy)
  145. # 过滤过小的检测框
  146. width = x2 - x1
  147. height = y2 - y1
  148. if width < 10 or height < 10:
  149. continue
  150. # 计算中心点
  151. center_x = (x1 + x2) // 2
  152. center_y = (y1 + y2) // 2
  153. detection = SafetyDetection(
  154. class_id=cls_id,
  155. class_name=cls_name,
  156. confidence=conf,
  157. bbox=(x1, y1, x2, y2),
  158. center=(center_x, center_y)
  159. )
  160. results.append(detection)
  161. except Exception as e:
  162. print(f"检测错误: {e}")
  163. return results
  164. def check_safety(self, frame: np.ndarray,
  165. detections: List[SafetyDetection] = None) -> List[PersonSafetyStatus]:
  166. """
  167. 检查人员安全状态
  168. Args:
  169. frame: 输入图像
  170. detections: 检测结果,如果为 None 则自动检测
  171. Returns:
  172. 人员安全状态列表
  173. """
  174. if detections is None:
  175. detections = self.detect(frame)
  176. # 分类检测结果
  177. persons = []
  178. helmets = []
  179. vests = []
  180. for det in detections:
  181. if det.class_id == 3: # 人
  182. persons.append(det)
  183. elif det.class_id == 0: # 安全帽
  184. helmets.append(det)
  185. elif det.class_id == 4: # 反光衣
  186. vests.append(det)
  187. # 检查每个人员的安全状态
  188. results = []
  189. for person in persons:
  190. status = PersonSafetyStatus(
  191. track_id=person.track_id or 0,
  192. person_bbox=person.bbox,
  193. person_conf=person.confidence
  194. )
  195. px1, py1, px2, py2 = person.bbox
  196. # 检查是否戴安全帽
  197. # 安全帽应该在人体上方区域(头部附近)
  198. for helmet in helmets:
  199. hx1, hy1, hx2, hy2 = helmet.bbox
  200. # 检查安全帽是否在人体框内
  201. helmet_center_x = (hx1 + hx2) / 2
  202. helmet_center_y = (hy1 + hy2) / 2
  203. # 安全帽中心在人体框内,且在人体上半部分
  204. if (hx1 >= px1 and hx2 <= px2 and
  205. helmet_center_y >= py1 and
  206. helmet_center_y <= py1 + (py2 - py1) * 0.5):
  207. status.has_helmet = True
  208. status.helmet_conf = helmet.confidence
  209. break
  210. # 检查是否穿反光衣
  211. # 反光衣应该与人体有重叠
  212. for vest in vests:
  213. vx1, vy1, vx2, vy2 = vest.bbox
  214. # 计算重叠区域
  215. overlap_x1 = max(px1, vx1)
  216. overlap_y1 = max(py1, vy1)
  217. overlap_x2 = min(px2, vx2)
  218. overlap_y2 = min(py2, vy2)
  219. # 如果有重叠
  220. if overlap_x1 < overlap_x2 and overlap_y1 < overlap_y2:
  221. # 计算重叠面积占比
  222. overlap_area = (overlap_x2 - overlap_x1) * (overlap_y2 - overlap_y1)
  223. vest_area = (vx2 - vx1) * (vy2 - vy1)
  224. overlap_ratio = overlap_area / vest_area if vest_area > 0 else 0
  225. # 重叠比例超过30%认为穿了反光衣
  226. if overlap_ratio > 0.3:
  227. status.has_safety_vest = True
  228. status.vest_conf = vest.confidence
  229. break
  230. # 检查是否违规
  231. status.check_violation()
  232. results.append(status)
  233. return results
  234. def detect_with_tracking(self, frame: np.ndarray,
  235. prev_tracks: Dict[int, Tuple[int, int]] = None,
  236. max_disappeared: int = 30) -> Tuple[List[SafetyDetection], Dict[int, Tuple[int, int]]]:
  237. """
  238. 带跟踪的检测
  239. Args:
  240. frame: 输入图像
  241. prev_tracks: 上一帧的跟踪状态 {track_id: center}
  242. max_disappeared: 最大消失帧数
  243. Returns:
  244. (检测结果列表, 当前跟踪状态)
  245. """
  246. detections = self.detect(frame)
  247. if prev_tracks is None:
  248. prev_tracks = {}
  249. # 简单的质心跟踪
  250. # 这里只对人体进行跟踪
  251. persons = [d for d in detections if d.class_id == 3]
  252. # 分配跟踪ID
  253. if len(persons) > 0:
  254. if len(prev_tracks) == 0:
  255. # 初始化
  256. for i, person in enumerate(persons):
  257. person.track_id = i + 1
  258. else:
  259. # 匹配
  260. used_ids = set()
  261. for person in persons:
  262. # 找最近的已跟踪对象
  263. min_dist = float('inf')
  264. best_id = None
  265. for track_id, center in prev_tracks.items():
  266. if track_id in used_ids:
  267. continue
  268. dist = np.sqrt((person.center[0] - center[0])**2 +
  269. (person.center[1] - center[1])**2)
  270. if dist < min_dist:
  271. min_dist = dist
  272. best_id = track_id
  273. if best_id is not None and min_dist < 100: # 距离阈值
  274. person.track_id = best_id
  275. used_ids.add(best_id)
  276. else:
  277. # 新ID
  278. new_id = max(prev_tracks.keys(), default=0) + 1
  279. person.track_id = new_id
  280. # 更新跟踪状态
  281. new_tracks = {}
  282. for person in persons:
  283. if person.track_id is not None:
  284. new_tracks[person.track_id] = person.center
  285. return detections, new_tracks
  286. def draw_safety_result(frame: np.ndarray,
  287. detections: List[SafetyDetection],
  288. status_list: List[PersonSafetyStatus]) -> np.ndarray:
  289. """
  290. 在图像上绘制安全检测结果
  291. Args:
  292. frame: 输入图像
  293. detections: 检测结果
  294. status_list: 人员安全状态
  295. Returns:
  296. 绘制后的图像
  297. """
  298. result = frame.copy()
  299. # 绘制检测框
  300. for det in detections:
  301. x1, y1, x2, y2 = det.bbox
  302. # 根据类别选择颜色
  303. if det.class_id == 3: # 人
  304. color = (0, 255, 0) # 绿色
  305. elif det.class_id == 0: # 安全帽
  306. color = (255, 165, 0) # 橙色
  307. elif det.class_id == 4: # 反光衣
  308. color = (0, 165, 255) # 黄色
  309. else:
  310. color = (255, 255, 255)
  311. cv2.rectangle(result, (x1, y1), (x2, y2), color, 2)
  312. # 绘制标签
  313. label = f"{det.class_name}: {det.conf:.2f}"
  314. cv2.putText(result, label, (x1, y1 - 5),
  315. cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
  316. # 绘制安全状态
  317. for status in status_list:
  318. x1, y1, x2, y2 = status.person_bbox
  319. if status.is_violation:
  320. # 违规 - 红色警告
  321. color = (0, 0, 255)
  322. text = status.get_violation_desc()
  323. cv2.rectangle(result, (x1, y1), (x2, y2), color, 3)
  324. cv2.putText(result, text, (x1, y2 + 20),
  325. cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
  326. else:
  327. # 正常 - 显示安全标识
  328. color = (0, 255, 0)
  329. text = "安全装备齐全"
  330. cv2.putText(result, text, (x1, y2 + 20),
  331. cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
  332. return result
  333. class LLMSafetyDetector:
  334. """
  335. 基于大模型的安全检测器
  336. 结合 YOLO 检测和大模型判断
  337. """
  338. def __init__(self, yolo_model_path: str = None,
  339. llm_config: Dict[str, Any] = None,
  340. use_gpu: bool = True,
  341. use_llm: bool = True):
  342. """
  343. 初始化检测器
  344. Args:
  345. yolo_model_path: YOLO 模型路径
  346. llm_config: 大模型配置
  347. use_gpu: 是否使用 GPU
  348. use_llm: 是否使用大模型判断
  349. """
  350. # YOLO 检测器
  351. self.yolo_detector = SafetyDetector(
  352. model_path=yolo_model_path,
  353. use_gpu=use_gpu
  354. )
  355. # 大模型分析器
  356. self.use_llm = use_llm
  357. self.llm_analyzer = None
  358. if use_llm:
  359. try:
  360. from llm_service import SafetyAnalyzer, NumberRecognizer
  361. self.llm_analyzer = SafetyAnalyzer(llm_config)
  362. self.number_recognizer = NumberRecognizer(llm_config)
  363. print("大模型安全分析器初始化成功")
  364. except ImportError:
  365. print("未找到 llm_service 模块,将使用规则判断")
  366. self.use_llm = False
  367. except Exception as e:
  368. print(f"大模型初始化失败: {e},将使用规则判断")
  369. self.use_llm = False
  370. def detect(self, frame: np.ndarray) -> List[SafetyDetection]:
  371. """
  372. YOLO 检测
  373. Args:
  374. frame: 输入图像
  375. Returns:
  376. 检测结果列表
  377. """
  378. return self.yolo_detector.detect(frame)
  379. def check_safety(self, frame: np.ndarray,
  380. detections: List[SafetyDetection] = None,
  381. use_llm: bool = None) -> List[PersonSafetyStatus]:
  382. """
  383. 检查人员安全状态
  384. Args:
  385. frame: 输入图像
  386. detections: YOLO 检测结果
  387. use_llm: 是否使用大模型(覆盖默认设置)
  388. Returns:
  389. 人员安全状态列表
  390. """
  391. # 先用 YOLO 检测
  392. if detections is None:
  393. detections = self.yolo_detector.detect(frame)
  394. # 规则判断
  395. rule_status_list = self.yolo_detector.check_safety(frame, detections)
  396. # 如果不使用大模型,直接返回规则判断结果
  397. should_use_llm = use_llm if use_llm is not None else self.use_llm
  398. if not should_use_llm or self.llm_analyzer is None:
  399. return rule_status_list
  400. # 使用大模型对每个人员进行判断
  401. llm_status_list = []
  402. for status in rule_status_list:
  403. # 裁剪人员区域
  404. x1, y1, x2, y2 = status.person_bbox
  405. margin = 10
  406. x1 = max(0, x1 - margin)
  407. y1 = max(0, y1 - margin)
  408. x2 = min(frame.shape[1], x2 + margin)
  409. y2 = min(frame.shape[0], y2 + margin)
  410. person_image = frame[y1:y2, x1:x2]
  411. # 调用大模型分析
  412. try:
  413. llm_result = self.llm_analyzer.check_person_safety(person_image)
  414. # 更新状态
  415. if llm_result.get('success', False):
  416. status.has_helmet = llm_result.get('has_helmet', False)
  417. status.has_safety_vest = llm_result.get('has_vest', False)
  418. # 重新检查违规
  419. status.check_violation()
  420. # 如果大模型判断有违规,使用大模型的描述
  421. if status.is_violation and llm_result.get('violation_desc'):
  422. # 更新违规类型
  423. desc = llm_result.get('violation_desc', '')
  424. if '安全帽' in desc and '反光' in desc:
  425. status.violation_types = [SafetyViolationType.NO_BOTH]
  426. elif '安全帽' in desc:
  427. status.violation_types = [SafetyViolationType.NO_HELMET]
  428. elif '反光' in desc:
  429. status.violation_types = [SafetyViolationType.NO_SAFETY_VEST]
  430. except Exception as e:
  431. print(f"大模型分析失败: {e}")
  432. llm_status_list.append(status)
  433. return llm_status_list
  434. def recognize_number(self, frame: np.ndarray,
  435. person_bbox: Tuple[int, int, int, int]) -> Dict[str, Any]:
  436. """
  437. 识别人员编号
  438. Args:
  439. frame: 输入图像
  440. person_bbox: 人员边界框
  441. Returns:
  442. 编号识别结果
  443. """
  444. if self.number_recognizer is None:
  445. return {'number': None, 'success': False}
  446. # 裁剪人员区域
  447. x1, y1, x2, y2 = person_bbox
  448. person_image = frame[y1:y2, x1:x2]
  449. return self.number_recognizer.recognize_person_number(person_image)
  450. def detect_with_tracking(self, frame: np.ndarray,
  451. prev_tracks: Dict[int, Tuple[int, int]] = None,
  452. max_disappeared: int = 30) -> Tuple[List[SafetyDetection], Dict[int, Tuple[int, int]]]:
  453. """
  454. 带跟踪的检测
  455. Args:
  456. frame: 输入图像
  457. prev_tracks: 上一帧的跟踪状态
  458. max_disappeared: 最大消失帧数
  459. Returns:
  460. (检测结果列表, 当前跟踪状态)
  461. """
  462. return self.yolo_detector.detect_with_tracking(frame, prev_tracks, max_disappeared)