llm_service.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. """
  2. 大模型 API 调用模块
  3. 用于安全状态判断和 OCR 编号识别
  4. """
  5. import os
  6. import json
  7. import time
  8. import base64
  9. import http.client
  10. from typing import Optional, Dict, Any, List
  11. from dataclasses import dataclass
  12. import cv2
  13. import numpy as np
  14. @dataclass
  15. class LLMResponse:
  16. """大模型响应"""
  17. content: str # 响应内容
  18. success: bool = True # 是否成功
  19. error: str = "" # 错误信息
  20. latency: float = 0.0 # 响应延迟(秒)
  21. class LLMClient:
  22. """
  23. 大模型 API 客户端
  24. 支持 OpenAI 兼容接口 (千问、llama-server 等)
  25. """
  26. def __init__(self, config: Dict[str, Any] = None):
  27. """
  28. 初始化客户端
  29. Args:
  30. config: 配置字典
  31. """
  32. self.config = config or {}
  33. # API 配置
  34. self.api_host = self.config.get('api_host', 'localhost')
  35. self.api_port = self.config.get('api_port', 8111)
  36. self.api_key = self.config.get('api_key', '')
  37. self.model = self.config.get('model', 'Qwen2.5-VL-7B-Instruct')
  38. # 超时和重试
  39. self.timeout = self.config.get('timeout', 30)
  40. self.max_retries = self.config.get('max_retries', 3)
  41. self.retry_delay = self.config.get('retry_delay', 1.0)
  42. # 是否使用 HTTPS
  43. self.use_https = self.config.get('use_https',
  44. self.api_host not in ['localhost', '127.0.0.1'])
  45. def _image_to_base64(self, image: np.ndarray) -> str:
  46. """将图像转换为 base64 编码"""
  47. if image is None:
  48. return ""
  49. # 确保图像是连续的
  50. image = np.ascontiguousarray(image)
  51. # 编码为 JPEG
  52. success, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 85])
  53. if not success:
  54. return ""
  55. base64_str = base64.b64encode(buffer).decode('utf-8')
  56. return f"data:image/jpeg;base64,{base64_str}"
  57. def chat(self, messages: List[Dict], temperature: float = 0.3,
  58. max_tokens: int = 1024, stream: bool = False) -> LLMResponse:
  59. """
  60. 发送聊天请求
  61. Args:
  62. messages: 消息列表
  63. temperature: 温度参数
  64. max_tokens: 最大生成 token 数
  65. stream: 是否流式输出
  66. Returns:
  67. LLMResponse 响应对象
  68. """
  69. payload = {
  70. "model": self.model,
  71. "messages": messages,
  72. "temperature": temperature,
  73. "max_tokens": max_tokens,
  74. "stream": stream
  75. }
  76. headers = {
  77. 'Content-Type': 'application/json',
  78. 'Accept': 'application/json',
  79. }
  80. if self.api_key:
  81. headers['Authorization'] = f'Bearer {self.api_key}'
  82. last_error = None
  83. for attempt in range(self.max_retries):
  84. try:
  85. start_time = time.time()
  86. # 创建连接
  87. conn_class = http.client.HTTPSConnection if self.use_https else http.client.HTTPConnection
  88. conn = conn_class(self.api_host, self.api_port, timeout=self.timeout)
  89. conn.request("POST", "/v1/chat/completions",
  90. json.dumps(payload), headers)
  91. res = conn.getresponse()
  92. data = res.read()
  93. conn.close()
  94. latency = time.time() - start_time
  95. if res.status != 200:
  96. error_msg = f"HTTP {res.status}: {data.decode('utf-8', errors='ignore')}"
  97. return LLMResponse(content="", success=False, error=error_msg, latency=latency)
  98. response = json.loads(data.decode('utf-8'))
  99. if 'choices' in response and len(response['choices']) > 0:
  100. content = response['choices'][0]['message']['content']
  101. return LLMResponse(content=content, success=True, latency=latency)
  102. elif 'error' in response:
  103. return LLMResponse(content="", success=False,
  104. error=response['error'].get('message', 'Unknown error'),
  105. latency=latency)
  106. else:
  107. return LLMResponse(content="", success=False,
  108. error="Invalid response format", latency=latency)
  109. except json.JSONDecodeError as e:
  110. last_error = f"JSON 解析错误: {e}"
  111. except http.client.HTTPException as e:
  112. last_error = f"HTTP 错误: {e}"
  113. except Exception as e:
  114. last_error = str(e)
  115. # 重试
  116. if attempt < self.max_retries - 1:
  117. time.sleep(self.retry_delay * (attempt + 1))
  118. return LLMResponse(content="", success=False, error=last_error or "Unknown error")
  119. def vision_chat(self, image: np.ndarray, prompt: str,
  120. temperature: float = 0.3) -> LLMResponse:
  121. """
  122. 视觉语言模型对话
  123. Args:
  124. image: 图像
  125. prompt: 提示词
  126. temperature: 温度参数
  127. Returns:
  128. LLMResponse 响应对象
  129. """
  130. image_base64 = self._image_to_base64(image)
  131. messages = [
  132. {
  133. "role": "user",
  134. "content": [
  135. {"type": "text", "text": prompt},
  136. {"type": "image_url", "image_url": {"url": image_base64}}
  137. ]
  138. }
  139. ]
  140. return self.chat(messages, temperature=temperature)
  141. def check_connection(self) -> bool:
  142. """检查 API 连接"""
  143. try:
  144. conn_class = http.client.HTTPSConnection if self.use_https else http.client.HTTPConnection
  145. conn = conn_class(self.api_host, self.api_port, timeout=5)
  146. conn.request("GET", "/v1/models")
  147. res = conn.getresponse()
  148. conn.close()
  149. return res.status in [200, 404] # 404 也表示服务在运行
  150. except:
  151. return False
  152. class SafetyAnalyzer:
  153. """
  154. 安全状态分析器
  155. 使用大模型判断安全状态
  156. """
  157. # 安全分析提示词
  158. SAFETY_PROMPT = """你是一个施工现场安全管理助手。请分析这张图片中的安全情况。
  159. 请检查以下几点:
  160. 1. 图片中是否有人员?
  161. 2. 人员是否佩戴了安全帽?
  162. 3. 人员是否穿着反光衣/安全背心?
  163. 请以 JSON 格式回复,格式如下:
  164. {
  165. "has_person": true/false,
  166. "person_count": 数字,
  167. "safety_status": "safe" 或 "violation",
  168. "violations": ["违规项1", "违规项2"],
  169. "description": "简要描述",
  170. "confidence": 0.0-1.0
  171. }
  172. 只返回 JSON,不要其他内容。"""
  173. def __init__(self, llm_config: Dict[str, Any] = None):
  174. """
  175. 初始化分析器
  176. Args:
  177. llm_config: LLM 配置
  178. """
  179. self.llm = LLMClient(llm_config)
  180. self.enabled = True
  181. def analyze(self, image: np.ndarray) -> Dict[str, Any]:
  182. """
  183. 分析图像中的安全状态
  184. Args:
  185. image: 输入图像
  186. Returns:
  187. 分析结果字典
  188. """
  189. if not self.enabled or image is None:
  190. return self._default_result()
  191. # 调用大模型
  192. response = self.llm.vision_chat(image, self.SAFETY_PROMPT, temperature=0.1)
  193. if not response.success:
  194. print(f"安全分析失败: {response.error}")
  195. return self._default_result()
  196. # 解析结果
  197. try:
  198. # 尝试提取 JSON
  199. content = response.content.strip()
  200. # 处理 markdown 代码块
  201. if '```json' in content:
  202. content = content.split('```json')[1].split('```')[0]
  203. elif '```' in content:
  204. content = content.split('```')[1].split('```')[0]
  205. result = json.loads(content.strip())
  206. # 验证必要字段
  207. if 'has_person' not in result:
  208. result['has_person'] = False
  209. if 'safety_status' not in result:
  210. result['safety_status'] = 'unknown'
  211. if 'violations' not in result:
  212. result['violations'] = []
  213. result['success'] = True
  214. result['latency'] = response.latency
  215. return result
  216. except json.JSONDecodeError as e:
  217. print(f"解析安全分析结果失败: {e}")
  218. print(f"原始响应: {response.content[:200]}")
  219. return self._default_result()
  220. def _default_result(self) -> Dict[str, Any]:
  221. """返回默认结果"""
  222. return {
  223. 'has_person': False,
  224. 'person_count': 0,
  225. 'safety_status': 'unknown',
  226. 'violations': [],
  227. 'description': '',
  228. 'confidence': 0.0,
  229. 'success': False
  230. }
  231. def check_person_safety(self, person_image: np.ndarray) -> Dict[str, Any]:
  232. """
  233. 检查单个人员的安全状态
  234. Args:
  235. person_image: 人员图像(裁剪后的人体区域)
  236. Returns:
  237. 安全状态字典
  238. """
  239. prompt = """分析这张图片中人员的安全装备佩戴情况。
  240. 请检查:
  241. 1. 是否佩戴安全帽?
  242. 2. 是否穿着反光衣/安全背心?
  243. 以 JSON 格式回复:
  244. {
  245. "has_helmet": true/false,
  246. "has_vest": true/false,
  247. "is_violation": true/false,
  248. "violation_desc": "违规描述,如果没有违规则为空",
  249. "confidence": 0.0-1.0
  250. }
  251. 只返回 JSON。"""
  252. if person_image is None:
  253. return {'has_helmet': False, 'has_vest': False, 'is_violation': True,
  254. 'violation_desc': '无法识别', 'confidence': 0.0}
  255. response = self.llm.vision_chat(person_image, prompt, temperature=0.1)
  256. if not response.success:
  257. return {'has_helmet': False, 'has_vest': False, 'is_violation': True,
  258. 'violation_desc': '识别失败', 'confidence': 0.0}
  259. try:
  260. content = response.content.strip()
  261. if '```json' in content:
  262. content = content.split('```json')[1].split('```')[0]
  263. elif '```' in content:
  264. content = content.split('```')[1].split('```')[0]
  265. result = json.loads(content.strip())
  266. result['success'] = True
  267. return result
  268. except:
  269. return {'has_helmet': False, 'has_vest': False, 'is_violation': True,
  270. 'violation_desc': '解析失败', 'confidence': 0.0}
  271. class NumberRecognizer:
  272. """
  273. 编号识别器
  274. 使用大模型进行 OCR 编号识别
  275. """
  276. NUMBER_PROMPT = """请识别这张图片中工作人员衣服上的编号或工号。
  277. 只返回识别到的编号数字,如果没有看到编号则返回 "无"。
  278. 不要返回其他内容。"""
  279. def __init__(self, llm_config: Dict[str, Any] = None):
  280. """
  281. 初始化识别器
  282. Args:
  283. llm_config: LLM 配置
  284. """
  285. self.llm = LLMClient(llm_config)
  286. def recognize(self, image: np.ndarray) -> Dict[str, Any]:
  287. """
  288. 识别图像中的编号
  289. Args:
  290. image: 输入图像
  291. Returns:
  292. 识别结果 {'number': str, 'confidence': float, 'success': bool}
  293. """
  294. if image is None:
  295. return {'number': None, 'confidence': 0.0, 'success': False}
  296. response = self.llm.vision_chat(image, self.NUMBER_PROMPT, temperature=0.1)
  297. if not response.success:
  298. return {'number': None, 'confidence': 0.0, 'success': False,
  299. 'error': response.error}
  300. content = response.content.strip()
  301. # 处理结果
  302. if content == '无' or '无' in content or not content:
  303. return {'number': None, 'confidence': 0.0, 'success': True}
  304. # 提取数字/字母数字组合
  305. import re
  306. matches = re.findall(r'[A-Za-z]*\d+[A-Za-z0-9]*', content)
  307. if matches:
  308. number = matches[0]
  309. return {'number': number, 'confidence': 0.9, 'success': True}
  310. # 如果没有匹配到,返回原始内容
  311. return {'number': content, 'confidence': 0.5, 'success': True}
  312. def recognize_person_number(self, person_image: np.ndarray,
  313. search_chest: bool = True) -> Dict[str, Any]:
  314. """
  315. 识别人员编号(在胸部/背部区域搜索)
  316. Args:
  317. person_image: 人员图像
  318. search_chest: 是否搜索胸部区域
  319. Returns:
  320. 识别结果
  321. """
  322. if person_image is None:
  323. return {'number': None, 'confidence': 0.0, 'success': False}
  324. h, w = person_image.shape[:2]
  325. # 如果图像较大,先尝试裁剪胸部区域
  326. if search_chest and h > 100 and w > 100:
  327. # 胸部区域:上半身中间部分
  328. y1 = int(h * 0.15)
  329. y2 = int(h * 0.55)
  330. x1 = int(w * 0.15)
  331. x2 = int(w * 0.85)
  332. chest_region = person_image[y1:y2, x1:x2]
  333. # 先在胸部区域搜索
  334. result = self.recognize(chest_region)
  335. if result.get('number'):
  336. result['location'] = '胸部'
  337. return result
  338. # 整图识别
  339. result = self.recognize(person_image)
  340. result['location'] = '全身'
  341. return result
  342. def create_safety_analyzer(config: Dict[str, Any] = None) -> SafetyAnalyzer:
  343. """创建安全分析器"""
  344. return SafetyAnalyzer(config)
  345. def create_number_recognizer(config: Dict[str, Any] = None) -> NumberRecognizer:
  346. """创建编号识别器"""
  347. return NumberRecognizer(config)