llm_service.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. """
  2. 大模型 API 调用模块
  3. 用于安全状态判断和 OCR 编号识别
  4. """
  5. import os
  6. import json
  7. import time
  8. import base64
  9. import http.client
  10. from typing import Optional, Dict, Any, List
  11. from dataclasses import dataclass
  12. import cv2
  13. import numpy as np
  14. @dataclass
  15. class LLMResponse:
  16. """大模型响应"""
  17. content: str # 响应内容
  18. success: bool = True # 是否成功
  19. error: str = "" # 错误信息
  20. latency: float = 0.0 # 响应延迟(秒)
  21. class LLMClient:
  22. """
  23. 大模型 API 客户端
  24. 支持 OpenAI 兼容接口 (千问、llama-server 等)
  25. """
  26. def __init__(self, config: Dict[str, Any] = None):
  27. """
  28. 初始化客户端
  29. Args:
  30. config: 配置字典
  31. """
  32. self.config = config or {}
  33. # API 配置
  34. self.api_host = self.config.get('api_host', 'localhost')
  35. self.api_port = self.config.get('api_port', 8111)
  36. self.api_key = self.config.get('api_key', '')
  37. self.model = self.config.get('model', 'Qwen2.5-VL-7B-Instruct')
  38. # 超时和重试
  39. self.timeout = self.config.get('timeout', 30)
  40. self.max_retries = self.config.get('max_retries', 3)
  41. self.retry_delay = self.config.get('retry_delay', 1.0)
  42. # 是否使用 HTTPS
  43. self.use_https = self.config.get('use_https',
  44. self.api_host not in ['localhost', '127.0.0.1'])
  45. def _image_to_base64(self, image: np.ndarray) -> str:
  46. """将图像转换为 base64 编码"""
  47. if image is None:
  48. return ""
  49. # 确保图像是连续的
  50. image = np.ascontiguousarray(image)
  51. # 编码为 JPEG
  52. success, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 85])
  53. if not success:
  54. return ""
  55. base64_str = base64.b64encode(buffer).decode('utf-8')
  56. return f"data:image/jpeg;base64,{base64_str}"
  57. def chat(self, messages: List[Dict], temperature: float = 0.3,
  58. max_tokens: int = 1024, stream: bool = False) -> LLMResponse:
  59. """
  60. 发送聊天请求
  61. Args:
  62. messages: 消息列表
  63. temperature: 温度参数
  64. max_tokens: 最大生成 token 数
  65. stream: 是否流式输出
  66. Returns:
  67. LLMResponse 响应对象
  68. """
  69. payload = {
  70. "model": self.model,
  71. "messages": messages,
  72. "temperature": temperature,
  73. "max_tokens": max_tokens,
  74. "stream": stream
  75. }
  76. headers = {
  77. 'Content-Type': 'application/json',
  78. 'Accept': 'application/json',
  79. }
  80. if self.api_key:
  81. headers['Authorization'] = f'Bearer {self.api_key}'
  82. last_error = None
  83. for attempt in range(self.max_retries):
  84. conn = None
  85. try:
  86. start_time = time.time()
  87. # 创建连接
  88. conn_class = http.client.HTTPSConnection if self.use_https else http.client.HTTPConnection
  89. conn = conn_class(self.api_host, self.api_port, timeout=self.timeout)
  90. conn.request("POST", "/v1/chat/completions",
  91. json.dumps(payload), headers)
  92. res = conn.getresponse()
  93. data = res.read()
  94. latency = time.time() - start_time
  95. if res.status != 200:
  96. error_msg = f"HTTP {res.status}: {data.decode('utf-8', errors='ignore')}"
  97. return LLMResponse(content="", success=False, error=error_msg, latency=latency)
  98. response = json.loads(data.decode('utf-8'))
  99. if 'choices' in response and len(response['choices']) > 0:
  100. content = response['choices'][0]['message']['content']
  101. return LLMResponse(content=content, success=True, latency=latency)
  102. elif 'error' in response:
  103. return LLMResponse(content="", success=False,
  104. error=response['error'].get('message', 'Unknown error'),
  105. latency=latency)
  106. else:
  107. return LLMResponse(content="", success=False,
  108. error="Invalid response format", latency=latency)
  109. except json.JSONDecodeError as e:
  110. last_error = f"JSON 解析错误: {e}"
  111. except http.client.HTTPException as e:
  112. last_error = f"HTTP 错误: {e}"
  113. except Exception as e:
  114. last_error = str(e)
  115. finally:
  116. if conn:
  117. conn.close()
  118. # 重试
  119. if attempt < self.max_retries - 1:
  120. time.sleep(self.retry_delay * (attempt + 1))
  121. return LLMResponse(content="", success=False, error=last_error or "Unknown error")
  122. def vision_chat(self, image: np.ndarray, prompt: str,
  123. temperature: float = 0.3) -> LLMResponse:
  124. """
  125. 视觉语言模型对话
  126. Args:
  127. image: 图像
  128. prompt: 提示词
  129. temperature: 温度参数
  130. Returns:
  131. LLMResponse 响应对象
  132. """
  133. image_base64 = self._image_to_base64(image)
  134. messages = [
  135. {
  136. "role": "user",
  137. "content": [
  138. {"type": "text", "text": prompt},
  139. {"type": "image_url", "image_url": {"url": image_base64}}
  140. ]
  141. }
  142. ]
  143. return self.chat(messages, temperature=temperature)
  144. def check_connection(self) -> bool:
  145. """检查 API 连接"""
  146. conn = None
  147. try:
  148. conn_class = http.client.HTTPSConnection if self.use_https else http.client.HTTPConnection
  149. conn = conn_class(self.api_host, self.api_port, timeout=5)
  150. conn.request("GET", "/v1/models")
  151. res = conn.getresponse()
  152. return res.status in [200, 404] # 404 也表示服务在运行
  153. except:
  154. return False
  155. finally:
  156. if conn:
  157. conn.close()
  158. class SafetyAnalyzer:
  159. """
  160. 安全状态分析器
  161. 使用大模型判断安全状态
  162. """
  163. # 安全分析提示词
  164. SAFETY_PROMPT = """你是一个施工现场安全管理助手。请分析这张图片中的安全情况。
  165. 请检查以下几点:
  166. 1. 图片中是否有人员?
  167. 2. 人员是否佩戴了安全帽?
  168. 3. 人员是否穿着反光衣/安全背心?
  169. 请以 JSON 格式回复,格式如下:
  170. {
  171. "has_person": true/false,
  172. "person_count": 数字,
  173. "safety_status": "safe" 或 "violation",
  174. "violations": ["违规项1", "违规项2"],
  175. "description": "简要描述",
  176. "confidence": 0.0-1.0
  177. }
  178. 只返回 JSON,不要其他内容。"""
  179. def __init__(self, llm_config: Dict[str, Any] = None):
  180. """
  181. 初始化分析器
  182. Args:
  183. llm_config: LLM 配置
  184. """
  185. self.llm = LLMClient(llm_config)
  186. self.enabled = True
  187. def analyze(self, image: np.ndarray) -> Dict[str, Any]:
  188. """
  189. 分析图像中的安全状态
  190. Args:
  191. image: 输入图像
  192. Returns:
  193. 分析结果字典
  194. """
  195. if not self.enabled or image is None:
  196. return self._default_result()
  197. # 调用大模型
  198. response = self.llm.vision_chat(image, self.SAFETY_PROMPT, temperature=0.1)
  199. if not response.success:
  200. print(f"安全分析失败: {response.error}")
  201. return self._default_result()
  202. # 解析结果
  203. try:
  204. # 尝试提取 JSON
  205. content = response.content.strip()
  206. # 处理 markdown 代码块
  207. if '```json' in content:
  208. content = content.split('```json')[1].split('```')[0]
  209. elif '```' in content:
  210. content = content.split('```')[1].split('```')[0]
  211. result = json.loads(content.strip())
  212. # 验证必要字段
  213. if 'has_person' not in result:
  214. result['has_person'] = False
  215. if 'safety_status' not in result:
  216. result['safety_status'] = 'unknown'
  217. if 'violations' not in result:
  218. result['violations'] = []
  219. result['success'] = True
  220. result['latency'] = response.latency
  221. return result
  222. except json.JSONDecodeError as e:
  223. print(f"解析安全分析结果失败: {e}")
  224. print(f"原始响应: {response.content[:200]}")
  225. return self._default_result()
  226. def _default_result(self) -> Dict[str, Any]:
  227. """返回默认结果"""
  228. return {
  229. 'has_person': False,
  230. 'person_count': 0,
  231. 'safety_status': 'unknown',
  232. 'violations': [],
  233. 'description': '',
  234. 'confidence': 0.0,
  235. 'success': False
  236. }
  237. def check_person_safety(self, person_image: np.ndarray) -> Dict[str, Any]:
  238. """
  239. 检查单个人员的安全状态
  240. Args:
  241. person_image: 人员图像(裁剪后的人体区域)
  242. Returns:
  243. 安全状态字典
  244. """
  245. prompt = """分析这张图片中人员的安全装备佩戴情况。
  246. 请检查:
  247. 1. 是否佩戴安全帽?
  248. 2. 是否穿着反光衣/安全背心?
  249. 以 JSON 格式回复:
  250. {
  251. "has_helmet": true/false,
  252. "has_vest": true/false,
  253. "is_violation": true/false,
  254. "violation_desc": "违规描述,如果没有违规则为空",
  255. "confidence": 0.0-1.0
  256. }
  257. 只返回 JSON。"""
  258. if person_image is None:
  259. return {'has_helmet': False, 'has_vest': False, 'is_violation': True,
  260. 'violation_desc': '无法识别', 'confidence': 0.0}
  261. response = self.llm.vision_chat(person_image, prompt, temperature=0.1)
  262. if not response.success:
  263. return {'has_helmet': False, 'has_vest': False, 'is_violation': True,
  264. 'violation_desc': '识别失败', 'confidence': 0.0}
  265. try:
  266. content = response.content.strip()
  267. if '```json' in content:
  268. content = content.split('```json')[1].split('```')[0]
  269. elif '```' in content:
  270. content = content.split('```')[1].split('```')[0]
  271. result = json.loads(content.strip())
  272. result['success'] = True
  273. return result
  274. except:
  275. return {'has_helmet': False, 'has_vest': False, 'is_violation': True,
  276. 'violation_desc': '解析失败', 'confidence': 0.0}
  277. def create_safety_analyzer(config: Dict[str, Any] = None) -> SafetyAnalyzer:
  278. """创建安全分析器"""
  279. return SafetyAnalyzer(config)