ocr_recognizer.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. """
  2. OCR识别模块
  3. 负责人体分割和编号OCR识别
  4. """
  5. import cv2
  6. import numpy as np
  7. from typing import List, Optional, Tuple, Dict
  8. from dataclasses import dataclass
  9. from config import OCR_CONFIG, SEGMENTATION_CONFIG
  10. @dataclass
  11. class OCRResult:
  12. """OCR识别结果"""
  13. text: str # 识别文本
  14. confidence: float # 置信度
  15. bbox: Tuple[int, int, int, int] # 边界框
  16. location: str = "" # 位置描述 (如: "胸部", "背部")
  17. @dataclass
  18. class PersonInfo:
  19. """人员信息"""
  20. person_id: int # 人员ID
  21. person_bbox: Tuple[int, int, int, int] # 人体边界框
  22. number_text: Optional[str] = None # 编号文本
  23. number_confidence: float = 0.0 # 编号置信度
  24. number_location: str = "" # 编号位置
  25. ocr_results: List[OCRResult] = None # 所有OCR结果
  26. class PersonSegmenter:
  27. """
  28. 人体分割器 - 使用 RKNN YOLOv8 分割模型
  29. 将人体从背景中分割出来
  30. """
  31. def __init__(self, use_gpu: bool = True):
  32. """
  33. 初始化分割器
  34. Args:
  35. use_gpu: 是否使用GPU (RKNN使用NPU,此参数保留用于兼容)
  36. """
  37. self.use_gpu = use_gpu
  38. self.config = SEGMENTATION_CONFIG
  39. self.input_size = self.config.get('input_size', (640, 640))
  40. self.conf_threshold = self.config.get('conf_threshold', 0.5)
  41. self.rknn = None
  42. self._load_model()
  43. def _load_model(self):
  44. """加载 RKNN 分割模型"""
  45. try:
  46. from rknnlite.api import RKNNLite
  47. model_path = self.config.get('model_path', '/home/admin/dsh/testrk3588/yolov8n-seg.rknn')
  48. self.rknn = RKNNLite()
  49. ret = self.rknn.load_rknn(model_path)
  50. if ret != 0:
  51. print(f"[错误] 加载 RKNN 分割模型失败: {model_path}")
  52. self.rknn = None
  53. return
  54. # 初始化运行时,使用所有NPU核心
  55. ret = self.rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0_1_2)
  56. if ret != 0:
  57. print("[错误] 初始化 RKNN 运行时失败")
  58. self.rknn = None
  59. return
  60. print(f"成功加载 RKNN 人体分割模型: {model_path}")
  61. except ImportError:
  62. print("未安装 rknnlite,无法使用 RKNN 分割模型")
  63. self.rknn = None
  64. except Exception as e:
  65. print(f"加载分割模型失败: {e}")
  66. self.rknn = None
  67. def _letterbox(self, image: np.ndarray) -> tuple:
  68. """Letterbox 预处理,保持宽高比"""
  69. h0, w0 = image.shape[:2]
  70. ih, iw = self.input_size
  71. scale = min(iw / w0, ih / h0)
  72. new_w, new_h = int(w0 * scale), int(h0 * scale)
  73. pad_w = (iw - new_w) // 2
  74. pad_h = (ih - new_h) // 2
  75. resized = cv2.resize(image, (new_w, new_h))
  76. canvas = np.full((ih, iw, 3), 114, dtype=np.uint8)
  77. canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
  78. return canvas, scale, pad_w, pad_h, h0, w0
  79. def _postprocess_segmentation(self, outputs, scale, pad_w, pad_h, w0, h0):
  80. """
  81. 处理 YOLOv8 分割模型输出
  82. YOLOv8-seg 输出格式: [检测输出, 分割输出]
  83. - 检测输出: (1, 116, 8400) - 包含边界框、类别、掩码系数
  84. - 分割输出: (1, 32, 160, 160) - 原型掩码
  85. """
  86. if not outputs or len(outputs) < 2:
  87. return None
  88. # 解析输出
  89. det_output = outputs[0] # (1, 116, 8400) - 检测输出
  90. seg_output = outputs[1] # (1, 32, 160, 160) - 分割原型
  91. # 处理检测输出
  92. if len(det_output.shape) == 3:
  93. det_output = det_output[0] # (116, 8400)
  94. # YOLOv8-seg: 前 84 维是检测 (4 box + 80 classes),后 32 维是掩码系数
  95. num_anchors = det_output.shape[1]
  96. best_idx = -1
  97. best_conf = 0
  98. # 寻找最佳人体检测 (class 0 = person)
  99. for i in range(num_anchors):
  100. # 类别概率 (索引 4-84 是80个类别)
  101. class_probs = det_output[4:84, i]
  102. person_conf = float(class_probs[0]) # class 0 = person
  103. if person_conf > self.conf_threshold and person_conf > best_conf:
  104. best_conf = person_conf
  105. best_idx = i
  106. if best_idx < 0:
  107. return None
  108. # 获取掩码系数 (后32维)
  109. mask_coeffs = det_output[84:116, best_idx] # (32,)
  110. # 处理分割原型 (1, 32, 160, 160) -> (32, 160, 160)
  111. if len(seg_output.shape) == 4:
  112. seg_output = seg_output[0]
  113. # 计算最终掩码: mask = coeffs @ prototypes
  114. # seg_output: (32, 160, 160), mask_coeffs: (32,)
  115. mask = np.zeros((160, 160), dtype=np.float32)
  116. for i in range(32):
  117. mask += mask_coeffs[i] * seg_output[i]
  118. # Sigmoid 激活
  119. mask = 1 / (1 + np.exp(-mask))
  120. # 移除 padding 并缩放到原始尺寸
  121. mask = (mask > 0.5).astype(np.uint8) * 255
  122. # 裁剪掉 letterbox 添加的 padding
  123. mask_h, mask_w = mask.shape
  124. pad_h_mask = int(pad_h * mask_h / self.input_size[0]) # 160/640 = 0.25
  125. pad_w_mask = int(pad_w * mask_w / self.input_size[1])
  126. new_h_mask = int((mask_h - 2 * pad_h_mask))
  127. new_w_mask = int((mask_w - 2 * pad_w_mask))
  128. if new_h_mask > 0 and new_w_mask > 0:
  129. mask = mask[pad_h_mask:pad_h_mask+new_h_mask, pad_w_mask:pad_w_mask+new_w_mask]
  130. # 缩放到原始 ROI 尺寸
  131. mask = cv2.resize(mask, (w0, h0))
  132. return mask
  133. def segment_person(self, frame: np.ndarray,
  134. person_bbox: Tuple[int, int, int, int]) -> Optional[np.ndarray]:
  135. """
  136. 分割人体
  137. Args:
  138. frame: 输入图像
  139. person_bbox: 人体边界框 (x, y, w, h)
  140. Returns:
  141. 人体分割掩码
  142. """
  143. if self.rknn is None:
  144. return None
  145. x, y, w, h = person_bbox
  146. # 裁剪人体区域
  147. person_roi = frame[y:y+h, x:x+w]
  148. if person_roi.size == 0:
  149. return None
  150. try:
  151. # 预处理
  152. canvas, scale, pad_w, pad_h, h0, w0 = self._letterbox(person_roi)
  153. # RKNN 输入: NHWC (1, H, W, C), RGB, float32 normalized 0-1
  154. img = canvas[..., ::-1].astype(np.float32) / 255.0
  155. blob = img[None, ...] # (1, 640, 640, 3)
  156. # 推理
  157. outputs = self.rknn.inference(inputs=[blob])
  158. # 后处理
  159. mask = self._postprocess_segmentation(outputs, scale, pad_w, pad_h, w0, h0)
  160. return mask
  161. except Exception as e:
  162. print(f"分割错误: {e}")
  163. return None
  164. def release(self):
  165. """释放 RKNN 资源"""
  166. if self.rknn is not None:
  167. self.rknn.release()
  168. self.rknn = None
  169. def extract_person_region(self, frame: np.ndarray,
  170. person_bbox: Tuple[int, int, int, int],
  171. padding: float = 0.1) -> Tuple[np.ndarray, Tuple[int, int]]:
  172. """
  173. 提取人体区域
  174. Args:
  175. frame: 输入图像
  176. person_bbox: 人体边界框
  177. padding: 边界填充比例
  178. Returns:
  179. (人体区域图像, 原始位置偏移)
  180. """
  181. x, y, w, h = person_bbox
  182. # 添加填充
  183. pad_w = int(w * padding)
  184. pad_h = int(h * padding)
  185. x1 = max(0, x - pad_w)
  186. y1 = max(0, y - pad_h)
  187. x2 = min(frame.shape[1], x + w + pad_w)
  188. y2 = min(frame.shape[0], y + h + pad_h)
  189. person_region = frame[y1:y2, x1:x2]
  190. offset = (x1, y1)
  191. return person_region, offset
  192. class OCRRecognizer:
  193. """
  194. OCR识别器
  195. 使用llama-server API接口进行OCR识别
  196. """
  197. def __init__(self, config: Dict = None):
  198. """
  199. 初始化OCR
  200. Args:
  201. config: API配置
  202. """
  203. self.config = config or OCR_CONFIG
  204. self.api_host = self.config.get('api_host', 'localhost')
  205. self.api_port = self.config.get('api_port', 8111)
  206. self.model = self.config.get('model', 'PaddleOCR-VL-1.5-GGUF.gguf')
  207. self.prompt = self.config.get('prompt', '请识别图片中的数字编号,只返回数字,不要其他内容')
  208. self.temperature = self.config.get('temperature', 0.3)
  209. self.timeout = self.config.get('timeout', 30)
  210. # 检查API是否可用
  211. self._check_api()
  212. def _check_api(self):
  213. """检查API是否可用"""
  214. try:
  215. import http.client
  216. # localhost通常使用HTTP而非HTTPS
  217. use_https = self.api_host not in ['localhost', '127.0.0.1']
  218. conn_class = http.client.HTTPSConnection if use_https else http.client.HTTPConnection
  219. conn = conn_class(self.api_host, self.api_port, timeout=5)
  220. conn.request("GET", "/")
  221. res = conn.getresponse()
  222. conn.close()
  223. print(f"llama-server API已连接: {self.api_host}:{self.api_port}")
  224. except Exception as e:
  225. print(f"连接llama-server失败: {e}")
  226. print(f"请确保llama-server运行在 {self.api_host}:{self.api_port}")
  227. def _image_to_base64(self, image: np.ndarray) -> str:
  228. """
  229. 将图像转换为base64编码
  230. Args:
  231. image: 输入图像
  232. Returns:
  233. base64编码字符串
  234. """
  235. import base64
  236. _, buffer = cv2.imencode('.jpg', image)
  237. base64_str = base64.b64encode(buffer).decode('utf-8')
  238. return f"data:image/jpeg;base64,{base64_str}"
  239. def recognize(self, image: np.ndarray,
  240. prompt: str = None,
  241. detect_only_numbers: bool = True,
  242. max_retries: int = 3) -> List[OCRResult]:
  243. """
  244. 使用llama-server API识别图像中的文字
  245. Args:
  246. image: 输入图像
  247. prompt: 自定义提示词
  248. detect_only_numbers: 是否只检测数字编号
  249. max_retries: 最大重试次数
  250. Returns:
  251. 识别结果列表
  252. """
  253. if image is None:
  254. return []
  255. import http.client
  256. import json
  257. import re
  258. results = []
  259. last_error = None
  260. for attempt in range(max_retries):
  261. try:
  262. # 准备图像数据
  263. image_base64 = self._image_to_base64(image)
  264. # 构建请求
  265. use_prompt = prompt or self.prompt
  266. payload = {
  267. "model": self.model,
  268. "messages": [
  269. {
  270. "role": "user",
  271. "content": [
  272. {
  273. "type": "text",
  274. "text": use_prompt
  275. },
  276. {
  277. "type": "image_url",
  278. "image_url": {
  279. "url": image_base64
  280. }
  281. }
  282. ]
  283. }
  284. ],
  285. "temperature": self.temperature,
  286. "stream": False
  287. }
  288. headers = {
  289. 'Content-Type': 'application/json',
  290. 'Accept': 'application/json',
  291. }
  292. # 发送请求 - localhost使用HTTP
  293. use_https = self.api_host not in ['localhost', '127.0.0.1']
  294. conn_class = http.client.HTTPSConnection if use_https else http.client.HTTPConnection
  295. conn = conn_class(
  296. self.api_host,
  297. self.api_port,
  298. timeout=self.timeout
  299. )
  300. conn.request("POST", "/v1/chat/completions",
  301. json.dumps(payload), headers)
  302. res = conn.getresponse()
  303. data = res.read()
  304. conn.close()
  305. # 解析响应
  306. response = json.loads(data.decode('utf-8'))
  307. if 'choices' in response and len(response['choices']) > 0:
  308. content = response['choices'][0]['message']['content']
  309. # 从响应中提取数字/编号
  310. text = content.strip()
  311. # 如果只检测数字,提取数字部分
  312. if detect_only_numbers:
  313. # 匹配数字、字母数字组合
  314. numbers = re.findall(r'[A-Za-z]*\d+[A-Za-z0-9]*', text)
  315. if numbers:
  316. text = numbers[0]
  317. # 创建结果
  318. if text:
  319. results.append(OCRResult(
  320. text=text,
  321. confidence=1.0, # API不返回置信度,设为1.0
  322. bbox=(0, 0, image.shape[1], image.shape[0])
  323. ))
  324. return results # 成功则直接返回
  325. except Exception as e:
  326. last_error = e
  327. print(f"OCR API识别错误 (尝试 {attempt + 1}/{max_retries}): {e}")
  328. if attempt < max_retries - 1:
  329. import time
  330. time.sleep(0.5 * (attempt + 1)) # 指数退避
  331. # 所有重试都失败
  332. if last_error:
  333. print(f"OCR API识别最终失败: {last_error}")
  334. return results
  335. def recognize_number(self, image: np.ndarray) -> Optional[str]:
  336. """
  337. 识别图像中的编号
  338. Args:
  339. image: 输入图像
  340. Returns:
  341. 编号文本
  342. """
  343. results = self.recognize(image, detect_only_numbers=True)
  344. if results:
  345. return results[0].text
  346. return None
  347. class OCRRecognizerLocal:
  348. """
  349. 本地OCR识别器 (备用)
  350. 使用PaddleOCR或EasyOCR进行识别
  351. """
  352. def __init__(self, use_gpu: bool = True, languages: List[str] = None):
  353. """
  354. 初始化OCR
  355. Args:
  356. use_gpu: 是否使用GPU
  357. languages: 支持的语言列表
  358. """
  359. self.use_gpu = use_gpu
  360. self.languages = languages or ['ch', 'en']
  361. self.ocr = None
  362. self._load_ocr()
  363. def _load_ocr(self):
  364. """加载OCR引擎"""
  365. try:
  366. from paddleocr import PaddleOCR
  367. self.ocr = PaddleOCR(
  368. use_angle_cls=True,
  369. lang='ch' if 'ch' in self.languages else 'en',
  370. use_gpu=self.use_gpu,
  371. show_log=False
  372. )
  373. print("成功加载PaddleOCR")
  374. except ImportError:
  375. print("未安装PaddleOCR")
  376. self.ocr = None
  377. except Exception as e:
  378. print(f"加载OCR失败: {e}")
  379. def recognize(self, image: np.ndarray,
  380. detect_only_numbers: bool = True) -> List[OCRResult]:
  381. """识别图像中的文字"""
  382. if self.ocr is None or image is None:
  383. return []
  384. results = []
  385. try:
  386. ocr_results = self.ocr.ocr(image, cls=True)
  387. if ocr_results and len(ocr_results) > 0:
  388. for line in ocr_results[0]:
  389. if line is None:
  390. continue
  391. bbox_points, (text, conf) = line
  392. if conf < 0.5:
  393. continue
  394. x1 = int(min(p[0] for p in bbox_points))
  395. y1 = int(min(p[1] for p in bbox_points))
  396. x2 = int(max(p[0] for p in bbox_points))
  397. y2 = int(max(p[1] for p in bbox_points))
  398. results.append(OCRResult(
  399. text=text,
  400. confidence=conf,
  401. bbox=(x1, y1, x2-x1, y2-y1)
  402. ))
  403. except Exception as e:
  404. print(f"OCR识别错误: {e}")
  405. return results
  406. class NumberDetector:
  407. """
  408. 编号检测器
  409. 在人体图像中检测编号
  410. 使用llama-server API进行OCR识别
  411. """
  412. def __init__(self, use_api: bool = True, ocr_config: Dict = None):
  413. """
  414. 初始化检测器
  415. Args:
  416. use_api: 是否使用API进行OCR
  417. ocr_config: OCR配置
  418. """
  419. self.segmenter = PersonSegmenter(use_gpu=False)
  420. # 使用API OCR或本地OCR
  421. if use_api:
  422. self.ocr = OCRRecognizer(ocr_config)
  423. print("使用llama-server API进行OCR识别")
  424. else:
  425. self.ocr = OCRRecognizerLocal()
  426. print("使用本地OCR进行识别")
  427. # 编号可能出现的区域 (相对于人体边界框的比例)
  428. self.search_regions = [
  429. {'name': '胸部', 'y_ratio': (0.2, 0.5), 'x_ratio': (0.2, 0.8)},
  430. {'name': '腹部', 'y_ratio': (0.5, 0.8), 'x_ratio': (0.2, 0.8)},
  431. {'name': '背部上方', 'y_ratio': (0.1, 0.4), 'x_ratio': (0.1, 0.9)},
  432. ]
  433. def detect_number(self, frame: np.ndarray,
  434. person_bbox: Tuple[int, int, int, int]) -> PersonInfo:
  435. """
  436. 检测人体编号
  437. Args:
  438. frame: 输入图像
  439. person_bbox: 人体边界框
  440. Returns:
  441. 人员信息
  442. """
  443. x, y, w, h = person_bbox
  444. # 提取人体区域
  445. person_region, offset = self.segmenter.extract_person_region(
  446. frame, person_bbox
  447. )
  448. person_info = PersonInfo(
  449. person_id=-1,
  450. person_bbox=person_bbox,
  451. ocr_results=[]
  452. )
  453. # 在不同区域搜索编号
  454. best_result = None
  455. best_confidence = 0
  456. for region in self.search_regions:
  457. # 计算搜索区域
  458. y1 = int(h * region['y_ratio'][0])
  459. y2 = int(h * region['y_ratio'][1])
  460. x1 = int(w * region['x_ratio'][0])
  461. x2 = int(w * region['x_ratio'][1])
  462. # 确保在图像范围内
  463. y1 = max(0, min(y1, person_region.shape[0]))
  464. y2 = max(0, min(y2, person_region.shape[0]))
  465. x1 = max(0, min(x1, person_region.shape[1]))
  466. x2 = max(0, min(x2, person_region.shape[1]))
  467. if y2 <= y1 or x2 <= x1:
  468. continue
  469. # 裁剪区域
  470. roi = person_region[y1:y2, x1:x2]
  471. # OCR识别
  472. ocr_results = self.ocr.recognize(roi)
  473. for result in ocr_results:
  474. # 调整坐标到原始图像坐标系
  475. adjusted_bbox = (
  476. result.bbox[0] + x1 + offset[0],
  477. result.bbox[1] + y1 + offset[1],
  478. result.bbox[2],
  479. result.bbox[3]
  480. )
  481. result.bbox = adjusted_bbox
  482. result.location = region['name']
  483. person_info.ocr_results.append(result)
  484. # 更新最佳结果
  485. if result.confidence > best_confidence:
  486. best_confidence = result.confidence
  487. best_result = result
  488. # 设置最佳结果作为编号
  489. if best_result:
  490. person_info.number_text = best_result.text
  491. person_info.number_confidence = best_result.confidence
  492. person_info.number_location = best_result.location
  493. return person_info
  494. def detect_numbers_batch(self, frame: np.ndarray,
  495. person_bboxes: List[Tuple[int, int, int, int]]) -> List[PersonInfo]:
  496. """
  497. 批量检测人体编号
  498. Args:
  499. frame: 输入图像
  500. person_bboxes: 人体边界框列表
  501. Returns:
  502. 人员信息列表
  503. """
  504. results = []
  505. for i, bbox in enumerate(person_bboxes):
  506. person_info = self.detect_number(frame, bbox)
  507. person_info.person_id = i
  508. results.append(person_info)
  509. return results
  510. def release(self):
  511. """释放资源"""
  512. if hasattr(self.segmenter, 'release'):
  513. self.segmenter.release()
  514. def preprocess_for_ocr(image: np.ndarray) -> np.ndarray:
  515. """
  516. OCR预处理
  517. Args:
  518. image: 输入图像
  519. Returns:
  520. 预处理后的图像
  521. """
  522. if image is None:
  523. return None
  524. # 转换为灰度图
  525. if len(image.shape) == 3:
  526. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  527. else:
  528. gray = image
  529. # 自适应直方图均衡化
  530. clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  531. enhanced = clahe.apply(gray)
  532. # 降噪
  533. denoised = cv2.fastNlMeansDenoising(enhanced, None, 10)
  534. # 二值化
  535. _, binary = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
  536. return binary