test_yolo11.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. import cv2
  2. import numpy as np
  3. import argparse
  4. import requests
  5. import json
  6. import time
  7. import os
  8. from PIL import Image, ImageDraw, ImageFont
  9. from ultralytics import YOLO
  10. def put_text_chinese(img, text, position, font_size=20, color=(255, 0, 0)):
  11. img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
  12. draw = ImageDraw.Draw(img_pil)
  13. font_path = "Alibaba_PuHuiTi_2.0_35_Thin_35_Thin.ttf"
  14. try:
  15. font = ImageFont.truetype(font_path, font_size)
  16. except:
  17. try:
  18. font = ImageFont.truetype("MiSans-Thin.ttf", font_size)
  19. except:
  20. font = ImageFont.load_default()
  21. color_rgb = (color[2], color[1], color[0])
  22. draw.text(position, text, font=font, fill=color_rgb)
  23. img_cv2 = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
  24. x, y = position
  25. text_width = draw.textlength(text, font=font)
  26. text_height = font_size
  27. img[y:y+text_height, x:x+int(text_width)] = img_cv2[y:y+text_height, x:x+int(text_width)]
  28. # 上传图片到OSS
  29. def upload_image(image_path):
  30. try:
  31. import http.client
  32. import mimetypes
  33. from codecs import encode
  34. # 获取文件名
  35. filename = os.path.basename(image_path)
  36. # 创建连接
  37. conn = http.client.HTTPSConnection("jtjai.device.wenhq.top", 8583)
  38. # 准备multipart/form-data
  39. boundary = 'wL36Yn8afVp8Ag7AmP8qZ0SA4n1v9T'
  40. dataList = []
  41. dataList.append(encode('--' + boundary))
  42. dataList.append(encode('Content-Disposition: form-data; name=file; filename={0}'.format(filename)))
  43. # 猜测文件类型
  44. fileType = mimetypes.guess_type(image_path)[0] or 'application/octet-stream'
  45. dataList.append(encode('Content-Type: {}'.format(fileType)))
  46. dataList.append(encode(''))
  47. # 读取文件内容
  48. with open(image_path, 'rb') as f:
  49. dataList.append(f.read())
  50. dataList.append(encode('--'+boundary+'--'))
  51. dataList.append(encode(''))
  52. # 构建请求体
  53. body = b'\r\n'.join(dataList)
  54. # 构建请求头
  55. headers = {
  56. 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
  57. 'Accept': '*/*',
  58. 'Host': 'jtjai.device.wenhq.top:8583',
  59. 'Connection': 'keep-alive',
  60. 'Content-Type': 'multipart/form-data; boundary={}'.format(boundary)
  61. }
  62. # 发送请求
  63. conn.request("POST", "/api/resource/oss/upload", body, headers)
  64. res = conn.getresponse()
  65. data = res.read()
  66. # 解析响应
  67. if res.status == 200:
  68. result = json.loads(data.decode("utf-8"))
  69. if result.get('code') == 200:
  70. return result.get('data', {}).get('purl')
  71. print(f"上传图片失败: {data.decode('utf-8')}")
  72. except Exception as e:
  73. print(f"上传图片异常: {e}")
  74. return None
  75. # 创建事件
  76. def create_event(addr, purl):
  77. try:
  78. url = "https://jtjai.device.wenhq.top:8583/api/system/event"
  79. create_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
  80. data = {
  81. "createTime": create_time,
  82. "addr": addr,
  83. "ext1": json.dumps([purl]),
  84. "ext2": json.dumps({"lx":"工地安全"})
  85. }
  86. response = requests.post(url, json=data, verify=False)
  87. if response.status_code == 200:
  88. result = response.json()
  89. if result.get('code') == 200:
  90. print(f"事件创建成功: {addr}")
  91. return True
  92. print(f"创建事件失败: {response.text}")
  93. except Exception as e:
  94. print(f"创建事件异常: {e}")
  95. return False
  96. # 解析命令行参数
  97. parser = argparse.ArgumentParser(description='YOLO11 目标检测')
  98. parser.add_argument('--input', type=str, default='b30090c8d3e9bf75f97b2f51a4b3cdd2.jpg', help='输入图片或视频路径')
  99. parser.add_argument('--output', type=str, default='', help='输出结果路径(不提供则不保存)')
  100. parser.add_argument('--type', type=str, default='image', choices=['image', 'video'], help='输入类型')
  101. parser.add_argument('--conf', type=float, default=0.5, help='置信度阈值')
  102. parser.add_argument('--fps', type=int, default=2, help='每秒处理的帧数')
  103. args = parser.parse_args()
  104. # 加载官方 YOLO11n 模型
  105. model = YOLO('yolo11m_safety.pt')
  106. # 处理输入
  107. input_path = args.input
  108. output_path = args.output
  109. input_type = args.type
  110. conf_threshold = args.conf
  111. process_fps = args.fps
  112. # 标签映射
  113. label_map = {0: '安全帽', 4: '安全衣', 3: '人'}
  114. print(f"使用置信度阈值: {conf_threshold}")
  115. print(f"每秒处理帧数: {process_fps}")
  116. # 事件上传时间控制
  117. last_upload_time = 0
  118. upload_interval = 2 # 2秒
  119. if input_type == 'image':
  120. print("正在进行检测...")
  121. results = model(input_path)
  122. print(results)
  123. # 读取图片
  124. image = cv2.imread(input_path)
  125. if image is None:
  126. print(f"无法读取图片: {input_path}")
  127. exit(1)
  128. # 处理检测结果并绘制边界框
  129. print("正在绘制检测结果...")
  130. for result in results:
  131. boxes = result.boxes
  132. for box in boxes:
  133. cls = int(box.cls[0])
  134. conf = float(box.conf[0])
  135. # 根据类别设置不同的置信度阈值
  136. # 人: > 0.8, 安全帽/反光衣: > 0.5
  137. if cls == 3: # 人
  138. conf_threshold_item = 0.8
  139. else: # 安全帽或反光衣
  140. conf_threshold_item = 0.5
  141. if conf > conf_threshold_item and cls in label_map:
  142. x1, y1, x2, y2 = box.xyxy[0].tolist()
  143. x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
  144. print(f" {cls} detected: confidence={conf:.2f}, box=[{x1:.0f}, {y1:.0f}, {x2:.0f}, {y2:.0f}]")
  145. # 绘制边界框
  146. cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 2)
  147. # 绘制标签和置信度
  148. text = f"{label_map[cls]}: {conf:.2f}"
  149. text_y = max(15, y1 - 20)
  150. put_text_chinese(image, text, (x1, text_y), font_size=20, color=(255, 0, 0))
  151. # 保存结果
  152. if output_path:
  153. cv2.imwrite(output_path, image)
  154. print(f"检测结果已保存到: {output_path}")
  155. else:
  156. print("未指定输出路径,跳过保存")
  157. # 显示结果
  158. cv2.imshow('YOLO11 Detection Result', image)
  159. cv2.waitKey(0)
  160. cv2.destroyAllWindows()
  161. elif input_type == 'video':
  162. # 打开视频文件或RTSP流
  163. def open_stream():
  164. cap = cv2.VideoCapture(input_path)
  165. if not cap.isOpened():
  166. print(f"无法打开视频或RTSP流: {input_path}")
  167. return None
  168. print(f"成功打开: {input_path}")
  169. return cap
  170. cap = open_stream()
  171. if not cap:
  172. exit(1)
  173. # 获取视频信息
  174. fps = cap.get(cv2.CAP_PROP_FPS)
  175. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  176. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  177. # 检查是否为RTSP流
  178. is_rtsp = input_path.startswith('rtsp://')
  179. if is_rtsp:
  180. print(f"RTSP流信息: {width}x{height}")
  181. # 对于RTSP流,使用固定帧率
  182. if fps <= 0:
  183. fps = 30.0
  184. print(f"使用帧率: {fps:.2f} FPS")
  185. else:
  186. print(f"视频信息: {width}x{height}, {fps:.2f} FPS")
  187. # 处理视频帧
  188. frame_count = 0
  189. out = None
  190. # 计算帧间隔
  191. frame_interval = int(round(fps / process_fps)) if fps > 0 else 1
  192. print(f"帧间隔: {frame_interval} 帧")
  193. # 只有当指定了输出路径时才创建视频写入对象
  194. if output_path:
  195. # 创建视频写入对象
  196. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  197. out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
  198. print(f"将保存结果到: {output_path}")
  199. else:
  200. print("未指定输出路径,跳过保存")
  201. # 存储检测结果
  202. last_detections = []
  203. # 重连计数器
  204. reconnect_count = 0
  205. max_reconnects = 10
  206. try:
  207. while True:
  208. # 检查连接状态
  209. if not cap or not cap.isOpened():
  210. print(f"连接已断开,尝试重新连接... ({reconnect_count}/{max_reconnects})")
  211. # 清理旧连接
  212. if cap:
  213. cap.release()
  214. # 尝试重新连接
  215. cap = open_stream()
  216. if not cap:
  217. reconnect_count += 1
  218. if reconnect_count > max_reconnects:
  219. print("达到最大重连次数,退出")
  220. break
  221. # 等待一段时间后重试
  222. import time
  223. time.sleep(2)
  224. continue
  225. # 重置重连计数器
  226. reconnect_count = 0
  227. # 重新获取视频信息
  228. fps = cap.get(cv2.CAP_PROP_FPS)
  229. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  230. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  231. frame_interval = int(round(fps / process_fps)) if fps > 0 else 1
  232. print(f"重连成功,新的帧间隔: {frame_interval} 帧")
  233. # 如果需要保存,重新创建视频写入对象
  234. if output_path and not out:
  235. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  236. out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
  237. print(f"重新创建视频写入对象")
  238. # 读取帧
  239. ret, frame = cap.read()
  240. if not ret:
  241. print("读取帧失败,可能流已断开")
  242. # 强制重新连接
  243. continue
  244. frame_count += 1
  245. # 只有当帧号是帧间隔的倍数时才进行检测
  246. if frame_count % frame_interval == 0:
  247. print(f"处理第 {frame_count} 帧...")
  248. try:
  249. # 进行检测
  250. results = model(frame)
  251. # 存储检测结果
  252. last_detections = []
  253. # 存储人的检测结果,用于后续检查
  254. person_detections = []
  255. # 存储安全帽和安全衣的检测结果
  256. helmet_detections = []
  257. safety_clothes_detections = []
  258. for result in results:
  259. boxes = result.boxes
  260. for box in boxes:
  261. cls = int(box.cls[0])
  262. conf = float(box.conf[0])
  263. # 根据类别设置不同的置信度阈值
  264. # 人: > 0.8, 安全帽/反光衣: > 0.5
  265. if cls == 3: # 人
  266. conf_threshold_item = 0.8
  267. else: # 安全帽或反光衣
  268. conf_threshold_item = 0.5
  269. if conf > conf_threshold_item and cls in label_map:
  270. x1, y1, x2, y2 = box.xyxy[0].tolist()
  271. x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
  272. last_detections.append((cls, conf, x1, y1, x2, y2))
  273. # 分类存储检测结果
  274. if cls == 3: # 人
  275. person_detections.append((x1, y1, x2, y2, conf))
  276. elif cls == 0: # 安全帽
  277. helmet_detections.append((x1, y1, x2, y2, conf))
  278. elif cls == 4: # 安全衣
  279. safety_clothes_detections.append((x1, y1, x2, y2, conf))
  280. # 标记是否需要告警
  281. need_alert = False
  282. alert_addr = None
  283. # 检查每个人是否戴了安全帽和安全衣
  284. for person_x1, person_y1, person_x2, person_y2, person_conf in person_detections:
  285. # 检查是否戴安全帽
  286. has_helmet = False
  287. for helmet_x1, helmet_y1, helmet_x2, helmet_y2, helmet_conf in helmet_detections:
  288. # 简单的重叠检测:安全帽中心是否在人框内
  289. helmet_center_x = (helmet_x1 + helmet_x2) / 2
  290. helmet_center_y = (helmet_y1 + helmet_y2) / 2
  291. if (helmet_center_x >= person_x1 and helmet_center_x <= person_x2 and
  292. helmet_center_y >= person_y1 and helmet_center_y <= person_y2):
  293. has_helmet = True
  294. break
  295. # 检查是否穿安全衣
  296. has_safety_clothes = False
  297. for clothes_x1, clothes_y1, clothes_x2, clothes_y2, clothes_conf in safety_clothes_detections:
  298. # 简单的重叠检测:安全衣与人体有重叠
  299. overlap_x1 = max(person_x1, clothes_x1)
  300. overlap_y1 = max(person_y1, clothes_y1)
  301. overlap_x2 = min(person_x2, clothes_x2)
  302. overlap_y2 = min(person_y2, clothes_y2)
  303. if overlap_x1 < overlap_x2 and overlap_y1 < overlap_y2:
  304. has_safety_clothes = True
  305. break
  306. # 标记是否需要告警
  307. if not has_helmet or not has_safety_clothes:
  308. need_alert = True
  309. # 准备告警信息
  310. if not has_helmet and not has_safety_clothes:
  311. alert_addr = "反光衣和安全帽都没戴"
  312. elif not has_helmet:
  313. alert_addr = "未戴安全帽"
  314. else:
  315. alert_addr = "未穿反光衣"
  316. print(f"警告: {alert_addr},置信度: {person_conf:.2f}")
  317. except Exception as e:
  318. print(f"检测过程中出错: {e}")
  319. # 继续处理,使用上一次的检测结果
  320. # 绘制上一次的检测结果
  321. for cls, conf, x1, y1, x2, y2 in last_detections:
  322. # 绘制边界框
  323. cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
  324. # 绘制标签和置信度
  325. text = f"{label_map[cls]}: {conf:.2f}"
  326. text_y = max(15, y1 - 20)
  327. put_text_chinese(frame, text, (x1, text_y), font_size=20, color=(255, 0, 0))
  328. # 检查是否需要告警并上传图片
  329. # 上传的必要条件是先识别到人
  330. if 'person_detections' in locals() and person_detections and 'need_alert' in locals() and need_alert and alert_addr:
  331. # 检查是否在2秒内已经上传过
  332. current_time = time.time()
  333. if current_time - last_upload_time >= upload_interval:
  334. print(f"检测到人,触发告警上传")
  335. # 保存带标签的告警帧
  336. temp_image_path = f"alert_frame_{frame_count}.jpg"
  337. cv2.imwrite(temp_image_path, frame)
  338. # 上传图片
  339. purl = upload_image(temp_image_path)
  340. if purl:
  341. # 创建事件
  342. create_event(alert_addr, purl)
  343. # 更新最后上传时间
  344. last_upload_time = current_time
  345. # 清理临时文件
  346. if os.path.exists(temp_image_path):
  347. os.remove(temp_image_path)
  348. else:
  349. print(f"2秒内已上传过事件,跳过本次上传")
  350. # 只有当指定了输出路径时才写入处理后的帧
  351. if out:
  352. try:
  353. out.write(frame)
  354. except Exception as e:
  355. print(f"写入视频失败: {e}")
  356. # 显示处理后的帧
  357. try:
  358. cv2.imshow('YOLO11 Detection Result', frame)
  359. # 按 'q' 键退出
  360. if cv2.waitKey(1) & 0xFF == ord('q'):
  361. break
  362. except Exception as e:
  363. print(f"显示帧失败: {e}")
  364. # 继续处理
  365. finally:
  366. # 清理资源
  367. cap.release()
  368. if out:
  369. out.release()
  370. cv2.destroyAllWindows()
  371. if output_path:
  372. print(f"视频处理完成,结果已保存到: {output_path}")
  373. else:
  374. print("视频处理完成")