ocr_recognizer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. """
  2. OCR识别模块
  3. 负责人体分割和编号OCR识别
  4. """
  5. import cv2
  6. import numpy as np
  7. from typing import List, Optional, Tuple, Dict
  8. from dataclasses import dataclass
  9. from config import OCR_CONFIG
  10. @dataclass
  11. class OCRResult:
  12. """OCR识别结果"""
  13. text: str # 识别文本
  14. confidence: float # 置信度
  15. bbox: Tuple[int, int, int, int] # 边界框
  16. location: str = "" # 位置描述 (如: "胸部", "背部")
  17. @dataclass
  18. class PersonInfo:
  19. """人员信息"""
  20. person_id: int # 人员ID
  21. person_bbox: Tuple[int, int, int, int] # 人体边界框
  22. number_text: Optional[str] = None # 编号文本
  23. number_confidence: float = 0.0 # 编号置信度
  24. number_location: str = "" # 编号位置
  25. ocr_results: List[OCRResult] = None # 所有OCR结果
  26. class PersonSegmenter:
  27. """
  28. 人体分割器
  29. 将人体从背景中分割出来
  30. """
  31. def __init__(self, use_gpu: bool = True):
  32. """
  33. 初始化分割器
  34. Args:
  35. use_gpu: 是否使用GPU
  36. """
  37. self.use_gpu = use_gpu
  38. self.segmentor = None
  39. self._load_model()
  40. def _load_model(self):
  41. """加载分割模型"""
  42. try:
  43. # 使用YOLO11分割模型
  44. from ultralytics import YOLO
  45. self.segmentor = YOLO('yolo11n-seg.pt') # YOLO11分割模型
  46. print("成功加载YOLO11人体分割模型")
  47. except Exception as e:
  48. print(f"加载分割模型失败: {e}")
  49. self.segmentor = None
  50. def segment_person(self, frame: np.ndarray,
  51. person_bbox: Tuple[int, int, int, int]) -> Optional[np.ndarray]:
  52. """
  53. 分割人体
  54. Args:
  55. frame: 输入图像
  56. person_bbox: 人体边界框 (x, y, w, h)
  57. Returns:
  58. 人体分割掩码 (或分割后的人体图像)
  59. """
  60. if self.segmentor is None:
  61. return None
  62. x, y, w, h = person_bbox
  63. # 裁剪人体区域
  64. person_roi = frame[y:y+h, x:x+w]
  65. try:
  66. # 使用分割模型
  67. results = self.segmentor(person_roi, classes=[0], verbose=False) # class 0 = person
  68. if results and len(results) > 0 and results[0].masks is not None:
  69. masks = results[0].masks.data
  70. if len(masks) > 0:
  71. # 获取第一个掩码
  72. mask = masks[0].cpu().numpy()
  73. mask = cv2.resize(mask, (w, h))
  74. mask = (mask > 0.5).astype(np.uint8) * 255
  75. return mask
  76. except Exception as e:
  77. print(f"分割错误: {e}")
  78. return None
  79. def extract_person_region(self, frame: np.ndarray,
  80. person_bbox: Tuple[int, int, int, int],
  81. padding: float = 0.1) -> Tuple[np.ndarray, Tuple[int, int]]:
  82. """
  83. 提取人体区域
  84. Args:
  85. frame: 输入图像
  86. person_bbox: 人体边界框
  87. padding: 边界填充比例
  88. Returns:
  89. (人体区域图像, 原始位置偏移)
  90. """
  91. x, y, w, h = person_bbox
  92. # 添加填充
  93. pad_w = int(w * padding)
  94. pad_h = int(h * padding)
  95. x1 = max(0, x - pad_w)
  96. y1 = max(0, y - pad_h)
  97. x2 = min(frame.shape[1], x + w + pad_w)
  98. y2 = min(frame.shape[0], y + h + pad_h)
  99. person_region = frame[y1:y2, x1:x2]
  100. offset = (x1, y1)
  101. return person_region, offset
  102. class OCRRecognizer:
  103. """
  104. OCR识别器
  105. 使用llama-server API接口进行OCR识别
  106. """
  107. def __init__(self, config: Dict = None):
  108. """
  109. 初始化OCR
  110. Args:
  111. config: API配置
  112. """
  113. self.config = config or OCR_CONFIG
  114. self.api_host = self.config.get('api_host', 'localhost')
  115. self.api_port = self.config.get('api_port', 8111)
  116. self.model = self.config.get('model', 'PaddleOCR-VL-1.5-GGUF.gguf')
  117. self.prompt = self.config.get('prompt', '请识别图片中的数字编号,只返回数字,不要其他内容')
  118. self.temperature = self.config.get('temperature', 0.3)
  119. self.timeout = self.config.get('timeout', 30)
  120. # 检查API是否可用
  121. self._check_api()
  122. def _check_api(self):
  123. """检查API是否可用"""
  124. try:
  125. import http.client
  126. # localhost通常使用HTTP而非HTTPS
  127. use_https = self.api_host not in ['localhost', '127.0.0.1']
  128. conn_class = http.client.HTTPSConnection if use_https else http.client.HTTPConnection
  129. conn = conn_class(self.api_host, self.api_port, timeout=5)
  130. conn.request("GET", "/")
  131. res = conn.getresponse()
  132. conn.close()
  133. print(f"llama-server API已连接: {self.api_host}:{self.api_port}")
  134. except Exception as e:
  135. print(f"连接llama-server失败: {e}")
  136. print(f"请确保llama-server运行在 {self.api_host}:{self.api_port}")
  137. def _image_to_base64(self, image: np.ndarray) -> str:
  138. """
  139. 将图像转换为base64编码
  140. Args:
  141. image: 输入图像
  142. Returns:
  143. base64编码字符串
  144. """
  145. import base64
  146. _, buffer = cv2.imencode('.jpg', image)
  147. base64_str = base64.b64encode(buffer).decode('utf-8')
  148. return f"data:image/jpeg;base64,{base64_str}"
  149. def recognize(self, image: np.ndarray,
  150. prompt: str = None,
  151. detect_only_numbers: bool = True,
  152. max_retries: int = 3) -> List[OCRResult]:
  153. """
  154. 使用llama-server API识别图像中的文字
  155. Args:
  156. image: 输入图像
  157. prompt: 自定义提示词
  158. detect_only_numbers: 是否只检测数字编号
  159. max_retries: 最大重试次数
  160. Returns:
  161. 识别结果列表
  162. """
  163. if image is None:
  164. return []
  165. import http.client
  166. import json
  167. import re
  168. results = []
  169. last_error = None
  170. for attempt in range(max_retries):
  171. try:
  172. # 准备图像数据
  173. image_base64 = self._image_to_base64(image)
  174. # 构建请求
  175. use_prompt = prompt or self.prompt
  176. payload = {
  177. "model": self.model,
  178. "messages": [
  179. {
  180. "role": "user",
  181. "content": [
  182. {
  183. "type": "text",
  184. "text": use_prompt
  185. },
  186. {
  187. "type": "image_url",
  188. "image_url": {
  189. "url": image_base64
  190. }
  191. }
  192. ]
  193. }
  194. ],
  195. "temperature": self.temperature,
  196. "stream": False
  197. }
  198. headers = {
  199. 'Content-Type': 'application/json',
  200. 'Accept': 'application/json',
  201. }
  202. # 发送请求 - localhost使用HTTP
  203. use_https = self.api_host not in ['localhost', '127.0.0.1']
  204. conn_class = http.client.HTTPSConnection if use_https else http.client.HTTPConnection
  205. conn = conn_class(
  206. self.api_host,
  207. self.api_port,
  208. timeout=self.timeout
  209. )
  210. conn.request("POST", "/v1/chat/completions",
  211. json.dumps(payload), headers)
  212. res = conn.getresponse()
  213. data = res.read()
  214. conn.close()
  215. # 解析响应
  216. response = json.loads(data.decode('utf-8'))
  217. if 'choices' in response and len(response['choices']) > 0:
  218. content = response['choices'][0]['message']['content']
  219. # 从响应中提取数字/编号
  220. text = content.strip()
  221. # 如果只检测数字,提取数字部分
  222. if detect_only_numbers:
  223. # 匹配数字、字母数字组合
  224. numbers = re.findall(r'[A-Za-z]*\d+[A-Za-z0-9]*', text)
  225. if numbers:
  226. text = numbers[0]
  227. # 创建结果
  228. if text:
  229. results.append(OCRResult(
  230. text=text,
  231. confidence=1.0, # API不返回置信度,设为1.0
  232. bbox=(0, 0, image.shape[1], image.shape[0])
  233. ))
  234. return results # 成功则直接返回
  235. except Exception as e:
  236. last_error = e
  237. print(f"OCR API识别错误 (尝试 {attempt + 1}/{max_retries}): {e}")
  238. if attempt < max_retries - 1:
  239. import time
  240. time.sleep(0.5 * (attempt + 1)) # 指数退避
  241. # 所有重试都失败
  242. if last_error:
  243. print(f"OCR API识别最终失败: {last_error}")
  244. return results
  245. def recognize_number(self, image: np.ndarray) -> Optional[str]:
  246. """
  247. 识别图像中的编号
  248. Args:
  249. image: 输入图像
  250. Returns:
  251. 编号文本
  252. """
  253. results = self.recognize(image, detect_only_numbers=True)
  254. if results:
  255. return results[0].text
  256. return None
  257. class OCRRecognizerLocal:
  258. """
  259. 本地OCR识别器 (备用)
  260. 使用PaddleOCR或EasyOCR进行识别
  261. """
  262. def __init__(self, use_gpu: bool = True, languages: List[str] = None):
  263. """
  264. 初始化OCR
  265. Args:
  266. use_gpu: 是否使用GPU
  267. languages: 支持的语言列表
  268. """
  269. self.use_gpu = use_gpu
  270. self.languages = languages or ['ch', 'en']
  271. self.ocr = None
  272. self._load_ocr()
  273. def _load_ocr(self):
  274. """加载OCR引擎"""
  275. try:
  276. from paddleocr import PaddleOCR
  277. self.ocr = PaddleOCR(
  278. use_angle_cls=True,
  279. lang='ch' if 'ch' in self.languages else 'en',
  280. use_gpu=self.use_gpu,
  281. show_log=False
  282. )
  283. print("成功加载PaddleOCR")
  284. except ImportError:
  285. print("未安装PaddleOCR")
  286. self.ocr = None
  287. except Exception as e:
  288. print(f"加载OCR失败: {e}")
  289. def recognize(self, image: np.ndarray,
  290. detect_only_numbers: bool = True) -> List[OCRResult]:
  291. """识别图像中的文字"""
  292. if self.ocr is None or image is None:
  293. return []
  294. results = []
  295. try:
  296. ocr_results = self.ocr.ocr(image, cls=True)
  297. if ocr_results and len(ocr_results) > 0:
  298. for line in ocr_results[0]:
  299. if line is None:
  300. continue
  301. bbox_points, (text, conf) = line
  302. if conf < 0.5:
  303. continue
  304. x1 = int(min(p[0] for p in bbox_points))
  305. y1 = int(min(p[1] for p in bbox_points))
  306. x2 = int(max(p[0] for p in bbox_points))
  307. y2 = int(max(p[1] for p in bbox_points))
  308. results.append(OCRResult(
  309. text=text,
  310. confidence=conf,
  311. bbox=(x1, y1, x2-x1, y2-y1)
  312. ))
  313. except Exception as e:
  314. print(f"OCR识别错误: {e}")
  315. return results
  316. class NumberDetector:
  317. """
  318. 编号检测器
  319. 在人体图像中检测编号
  320. 使用llama-server API进行OCR识别
  321. """
  322. def __init__(self, use_api: bool = True, ocr_config: Dict = None):
  323. """
  324. 初始化检测器
  325. Args:
  326. use_api: 是否使用API进行OCR
  327. ocr_config: OCR配置
  328. """
  329. self.segmenter = PersonSegmenter(use_gpu=False)
  330. # 使用API OCR或本地OCR
  331. if use_api:
  332. self.ocr = OCRRecognizer(ocr_config)
  333. print("使用llama-server API进行OCR识别")
  334. else:
  335. self.ocr = OCRRecognizerLocal()
  336. print("使用本地OCR进行识别")
  337. # 编号可能出现的区域 (相对于人体边界框的比例)
  338. self.search_regions = [
  339. {'name': '胸部', 'y_ratio': (0.2, 0.5), 'x_ratio': (0.2, 0.8)},
  340. {'name': '腹部', 'y_ratio': (0.5, 0.8), 'x_ratio': (0.2, 0.8)},
  341. {'name': '背部上方', 'y_ratio': (0.1, 0.4), 'x_ratio': (0.1, 0.9)},
  342. ]
  343. def detect_number(self, frame: np.ndarray,
  344. person_bbox: Tuple[int, int, int, int]) -> PersonInfo:
  345. """
  346. 检测人体编号
  347. Args:
  348. frame: 输入图像
  349. person_bbox: 人体边界框
  350. Returns:
  351. 人员信息
  352. """
  353. x, y, w, h = person_bbox
  354. # 提取人体区域
  355. person_region, offset = self.segmenter.extract_person_region(
  356. frame, person_bbox
  357. )
  358. person_info = PersonInfo(
  359. person_id=-1,
  360. person_bbox=person_bbox,
  361. ocr_results=[]
  362. )
  363. # 在不同区域搜索编号
  364. best_result = None
  365. best_confidence = 0
  366. for region in self.search_regions:
  367. # 计算搜索区域
  368. y1 = int(h * region['y_ratio'][0])
  369. y2 = int(h * region['y_ratio'][1])
  370. x1 = int(w * region['x_ratio'][0])
  371. x2 = int(w * region['x_ratio'][1])
  372. # 确保在图像范围内
  373. y1 = max(0, min(y1, person_region.shape[0]))
  374. y2 = max(0, min(y2, person_region.shape[0]))
  375. x1 = max(0, min(x1, person_region.shape[1]))
  376. x2 = max(0, min(x2, person_region.shape[1]))
  377. if y2 <= y1 or x2 <= x1:
  378. continue
  379. # 裁剪区域
  380. roi = person_region[y1:y2, x1:x2]
  381. # OCR识别
  382. ocr_results = self.ocr.recognize(roi)
  383. for result in ocr_results:
  384. # 调整坐标到原始图像坐标系
  385. adjusted_bbox = (
  386. result.bbox[0] + x1 + offset[0],
  387. result.bbox[1] + y1 + offset[1],
  388. result.bbox[2],
  389. result.bbox[3]
  390. )
  391. result.bbox = adjusted_bbox
  392. result.location = region['name']
  393. person_info.ocr_results.append(result)
  394. # 更新最佳结果
  395. if result.confidence > best_confidence:
  396. best_confidence = result.confidence
  397. best_result = result
  398. # 设置最佳结果作为编号
  399. if best_result:
  400. person_info.number_text = best_result.text
  401. person_info.number_confidence = best_result.confidence
  402. person_info.number_location = best_result.location
  403. return person_info
  404. def detect_numbers_batch(self, frame: np.ndarray,
  405. person_bboxes: List[Tuple[int, int, int, int]]) -> List[PersonInfo]:
  406. """
  407. 批量检测人体编号
  408. Args:
  409. frame: 输入图像
  410. person_bboxes: 人体边界框列表
  411. Returns:
  412. 人员信息列表
  413. """
  414. results = []
  415. for i, bbox in enumerate(person_bboxes):
  416. person_info = self.detect_number(frame, bbox)
  417. person_info.person_id = i
  418. results.append(person_info)
  419. return results
  420. def preprocess_for_ocr(image: np.ndarray) -> np.ndarray:
  421. """
  422. OCR预处理
  423. Args:
  424. image: 输入图像
  425. Returns:
  426. 预处理后的图像
  427. """
  428. if image is None:
  429. return None
  430. # 转换为灰度图
  431. if len(image.shape) == 3:
  432. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  433. else:
  434. gray = image
  435. # 自适应直方图均衡化
  436. clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  437. enhanced = clahe.apply(gray)
  438. # 降噪
  439. denoised = cv2.fastNlMeansDenoising(enhanced, None, 10)
  440. # 二值化
  441. _, binary = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
  442. return binary