Преглед изворни кода

feat(ota): add firmware detect feature and optimize ota workflow

1. add detectFirmware api and frontend detection button
2. add ota detect endpoint to parse firmware info automatically
3. refactor ota upgrade logic to reuse background task function
4. add dtu config persistence support
wenhongquan пре 12 часа
родитељ
комит
f121f5dcb6
5 измењених фајлова са 261 додато и 83 уклоњено
  1. 225 81
      backend/app.py
  2. BIN
      firmware_v1.0.0.tar.gz
  3. BIN
      firmware_v1.0.1.tar.gz
  4. 3 0
      frontend/src/api/apiService.js
  5. 33 2
      frontend/src/views/OTAUpgrade.vue

+ 225 - 81
backend/app.py

@@ -99,9 +99,28 @@ dtu_config = {
     'firmware_version': 'v1.0.0',
     'hardware_version': 'v1.0',
     'heartbeat_interval': DTU_HEARTBEAT_INTERVAL,
-    'enabled': True  # 是否启用DTU MQTT协议
+    'enabled': True
 }
 
+DTU_CONFIG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dtu_config.json')
+
+def load_dtu_config():
+    global dtu_config
+    try:
+        if os.path.exists(DTU_CONFIG_FILE):
+            with open(DTU_CONFIG_FILE, 'r') as f:
+                saved = json.load(f)
+                dtu_config.update(saved)
+    except Exception as e:
+        logger.warning(f"加载DTU配置失败: {e}")
+
+def save_dtu_config():
+    try:
+        with open(DTU_CONFIG_FILE, 'w') as f:
+            json.dump(dtu_config, f, indent=2)
+    except Exception as e:
+        logger.warning(f"保存DTU配置失败: {e}")
+
 # 端口状态追踪(用于事件检测)
 port_state = {}  # {panel_id: {port_id: {'last_uid': str, 'expected_uid': str, 'alarm_count': int}}}
 
@@ -714,6 +733,18 @@ def dtu_handle_control(topic, payload):
             response_payload['payload']['success'] = 'error' not in sensor_result
             response_payload['payload']['result'] = env_sensor
 
+        elif command == 'OTA_UPGRADE':
+            params = payload.get('payload', {}).get('params', {})
+            if params.get('firmware_url'):
+                import threading
+                threading.Thread(target=lambda: _run_ota(params), daemon=True).start()
+            response_payload['payload']['success'] = True
+            response_payload['payload']['result'] = {
+                'firmware_version': params.get('firmware_version'),
+                'ota_status': 'DOWNLOADING',
+                'restart_required': False
+            }
+
         elif command == 'OTA_CANCEL':
             # 取消OTA升级
             # 向OTA控制寄存器写入取消命令 (0x0000=取消)
@@ -2288,7 +2319,13 @@ def dtu_control():
         data = request.json
         command = data.get('command')
         if command == 'REBOOT':
-            return jsonify({'success': True, 'message': '重启命令已发送'})
+            import subprocess
+            logger.warning("执行系统重启命令")
+            threading.Thread(target=lambda: (
+                time.sleep(1),
+                subprocess.run(['reboot'], capture_output=True)
+            ), daemon=True).start()
+            return jsonify({'success': True, 'message': '系统正在重启...'})
         return jsonify({'success': False, 'message': f'未知命令: {command}'}), 400
     except Exception as e:
         logger.error(f"DTU控制命令失败: {str(e)}")
@@ -2307,6 +2344,53 @@ ota_status = {
 }
 
 
+@app.route('/api/dtu/ota_detect', methods=['POST'])
+def ota_detect():
+    """检测固件包信息(自动下载并解析 manifest)"""
+    try:
+        url = request.json.get('url', '')
+        if not url:
+            return jsonify({'success': False, 'message': '请提供固件 URL'}), 400
+
+        import urllib.request, tempfile, tarfile, json, hashlib
+        tmp = tempfile.mktemp(suffix='.tar.gz')
+        try:
+            urllib.request.urlretrieve(url, tmp)
+        except Exception as e:
+            return jsonify({'success': False, 'message': f'下载失败: {str(e)}'}), 400
+
+        file_size = os.path.getsize(tmp)
+        h = hashlib.md5()
+        with open(tmp, 'rb') as f:
+            for chunk in iter(lambda: f.read(65536), b''):
+                h.update(chunk)
+        md5sum = h.hexdigest()
+
+        version = ''
+        try:
+            with tarfile.open(tmp, 'r:gz') as tar:
+                m = tar.extractfile('firmware/firmware.json')
+                if m:
+                    manifest = json.loads(m.read())
+                    version = manifest.get('version', '')
+        except Exception as e:
+            logger.warning(f"解析 firmware.json 失败: {e}")
+
+        os.remove(tmp)
+        return jsonify({
+            'success': True,
+            'data': {
+                'file_size': file_size,
+                'checksum': md5sum,
+                'checksum_type': 'MD5',
+                'firmware_version': version
+            }
+        })
+    except Exception as e:
+        logger.error(f"检测固件失败: {str(e)}")
+        return jsonify({'success': False, 'message': str(e)}), 500
+
+
 @app.route('/api/dtu/ota_status', methods=['GET'])
 def get_ota_status():
     """获取OTA升级状态"""
@@ -2328,68 +2412,159 @@ def get_ota_status():
         return jsonify({'success': False, 'message': str(e)}), 500
 
 
+def handle_ota_status(dtu_id, payload):
+    """处理MQTT上报的OTA状态"""
+    ota_status['status'] = payload.get('ota_status', 'IDLE')
+    ota_status['progress'] = payload.get('ota_progress', 0)
+    ota_status['firmware_version'] = payload.get('firmware_version')
+    ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+    error_code = payload.get('error_code')
+    if error_code and error_code != 0:
+        ota_status['error_code'] = error_code
+        ota_status['error_message'] = payload.get('error_message', get_ota_error_message(error_code))
+        ota_status['status'] = 'FAILED'
+    elif ota_status['status'] == 'SUCCESS':
+        ota_status['error_code'] = None
+        ota_status['error_message'] = None
+        if ota_status.get('firmware_version'):
+            dtu_config['firmware_version'] = ota_status['firmware_version']
+            save_dtu_config()
+    logger.info(f"MQTT OTA状态更新: {ota_status['status']}, 进度: {ota_status['progress']}%")
+
+
+def get_ota_error_message(error_code):
+    error_messages = {1021: '已是目标版本', 1022: '校验失败', 1023: '下载失败', 1024: '写入失败', 1025: '存储空间不足'}
+    return error_messages.get(error_code, f'未知错误码: {error_code}')
+
+
+def _run_ota(params):
+    """执行OTA升级(后台线程,供HTTP和MQTT共用)"""
+    import subprocess, os, hashlib, shutil, tarfile, tempfile, urllib.request
+    url = params['firmware_url']
+    version = params['firmware_version']
+    file_size = params['file_size']
+    checksum = params['checksum']
+    checksum_type = params.get('checksum_type', 'MD5').upper()
+    force = params.get('force_upgrade', False)
+    ota_status['status'] = 'DOWNLOADING'
+    ota_status['progress'] = 0
+    ota_status['target_version'] = version
+    ota_status['error_code'] = None
+    ota_status['error_message'] = None
+    ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+    try:
+        tmp_dir = tempfile.mkdtemp(prefix='ota_')
+        fw_path = os.path.join(tmp_dir, 'firmware.tar.gz')
+        ota_status['progress'] = 5
+        ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+        logger.info(f"OTA: 开始下载固件 {url}")
+        req = urllib.request.Request(url, headers={'User-Agent': 'OTA-Updater'})
+        with urllib.request.urlopen(req, timeout=120) as resp:
+            with open(fw_path, 'wb') as f:
+                total = int(resp.headers.get('Content-Length', 0))
+                downloaded = 0
+                while True:
+                    chunk = resp.read(65536)
+                    if not chunk: break
+                    f.write(chunk)
+                    downloaded += len(chunk)
+                    if total:
+                        pct = 5 + int(downloaded / total * 30)
+                        ota_status['progress'] = min(pct, 35)
+        dl_size = os.path.getsize(fw_path)
+        ota_status['progress'] = 40
+        ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+        if abs(dl_size - file_size) > 1024:
+            raise Exception(f"文件大小不匹配: 预期{file_size}, 实际{dl_size}")
+        ota_status['status'] = 'VERIFYING'
+        ota_status['progress'] = 50
+        ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+        h = hashlib.new(checksum_type)
+        with open(fw_path, 'rb') as f:
+            for chunk in iter(lambda: f.read(65536), b''): h.update(chunk)
+        actual_checksum = h.hexdigest().lower()
+        if actual_checksum != checksum.lower():
+            raise Exception(f"校验和不匹配: 预期{checksum}, 实际{actual_checksum}")
+        ota_status['progress'] = 60
+        ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+        extract_dir = os.path.join(tmp_dir, 'firmware')
+        os.makedirs(extract_dir, exist_ok=True)
+        with tarfile.open(fw_path, 'r:gz') as tar: tar.extractall(extract_dir)
+        manifest_path = os.path.join(extract_dir, 'firmware', 'firmware.json')
+        if not os.path.exists(manifest_path):
+            raise Exception("固件包缺少 firmware.json")
+        with open(manifest_path, 'r') as f: manifest = json.load(f)
+        fw_version = manifest.get('version', version)
+        if not force:
+            current_ver = dtu_config.get('firmware_version', 'v0.0.0')
+            if fw_version == current_ver:
+                raise Exception(f'已是目标版本 {current_ver}')
+        ota_status['status'] = 'FLASHING'
+        ota_status['progress'] = 70
+        ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+        fw_root = os.path.join(extract_dir, 'firmware')
+        project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        excludes = {'__pycache__', '.git', 'log', 'venv'}
+        for root, dirs, files in os.walk(fw_root):
+            rel = os.path.relpath(root, fw_root)
+            if rel == '.': rel = ''
+            parts = rel.split(os.sep) if rel else []
+            if parts and parts[0] in excludes: continue
+            for fname in files:
+                if fname == 'firmware.json': continue
+                src = os.path.join(root, fname)
+                dst = os.path.join(project_dir, rel, fname)
+                os.makedirs(os.path.dirname(dst), exist_ok=True)
+                shutil.copy2(src, dst)
+        ota_status['progress'] = 90
+        ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+        dtu_config['firmware_version'] = fw_version
+        save_dtu_config()
+        ota_status['status'] = 'SUCCESS'
+        ota_status['progress'] = 100
+        ota_status['firmware_version'] = fw_version
+        ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+        logger.info(f"OTA: 升级成功 {fw_version}")
+        shutil.rmtree(tmp_dir, ignore_errors=True)
+        for i in range(10, 0, -1):
+            ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+            time.sleep(1)
+        logger.info("OTA: 重启服务...")
+        subprocess.Popen(
+            [subprocess.sys.executable, '-m', 'flask', 'run', '--host=0.0.0.0', '--port=5001'],
+            cwd=os.path.dirname(os.path.abspath(__file__)),
+            stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
+        )
+        os._exit(0)
+    except Exception as e:
+        ota_status['status'] = 'FAILED'
+        ota_status['error_message'] = str(e)
+        ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+        logger.error(f"OTA: 升级失败 - {str(e)}")
+        if 'tmp_dir' in dir() and tmp_dir and os.path.exists(tmp_dir):
+            shutil.rmtree(tmp_dir, ignore_errors=True)
+
+
 @app.route('/api/dtu/ota_upgrade', methods=['POST'])
 def trigger_ota_upgrade():
     """触发OTA升级"""
     try:
-        if not mqtt_status:
-            return jsonify({'success': False, 'message': 'MQTT未连接'}), 400
-
         data = request.json
         if not data:
             return jsonify({'success': False, 'message': '请求体不能为空'}), 400
 
-        # 验证必填字段
         required_fields = ['firmware_url', 'firmware_version', 'file_size', 'checksum', 'checksum_type']
         missing = [f for f in required_fields if not data.get(f)]
         if missing:
             return jsonify({'success': False, 'message': f'缺少必填字段: {", ".join(missing)}'}), 400
 
-        # 构建OTA升级命令
-        msg_id = f"ota_{int(time.time() * 1000)}"
-        control_topic = build_dtu_topic(
-            dtu_config['customer_id'],
-            'dtu',
-            dtu_config['dtu_id'],
-            'control'
-        )
-
-        payload = {
-            'msg_id': msg_id,
-            'timestamp': int(time.time() * 1000),
-            'dtu_id': dtu_config['dtu_id'],
-            'type': 'CONTROL',
-            'payload': {
-                'command': 'OTA_UPGRADE',
-                'target': 'dtu',
-                'params': {
-                    'firmware_url': data['firmware_url'],
-                    'firmware_version': data['firmware_version'],
-                    'file_size': data['file_size'],
-                    'checksum': data['checksum'],
-                    'checksum_type': data['checksum_type'],
-                    'force_upgrade': data.get('force_upgrade', False)
-                }
-            }
-        }
-
-        # 发布OTA升级命令
-        mqtt_client.publish(control_topic, json.dumps(payload))
-        logger.info(f"OTA升级命令已发送: {data['firmware_version']}")
-
-        # 更新OTA状态
-        ota_status['status'] = 'DOWNLOADING'
-        ota_status['progress'] = 0
-        ota_status['target_version'] = data['firmware_version']
-        ota_status['error_code'] = None
-        ota_status['error_message'] = None
-        ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
+        t = threading.Thread(target=_run_ota, args=(data,), daemon=True)
+        t.start()
 
         return jsonify({
             'success': True,
-            'message': 'OTA升级命令已发送',
+            'message': 'OTA升级已启动',
             'data': {
-                'msg_id': msg_id,
                 'target_version': data['firmware_version'],
                 'ota_status': 'DOWNLOADING'
             }
@@ -2399,40 +2574,6 @@ def trigger_ota_upgrade():
         return jsonify({'success': False, 'message': str(e)}), 500
 
 
-# 处理OTA状态上报
-def handle_ota_status(dtu_id, payload):
-    """处理OTA状态上报"""
-    ota_status['status'] = payload.get('ota_status', 'IDLE')
-    ota_status['progress'] = payload.get('ota_progress', 0)
-    ota_status['firmware_version'] = payload.get('firmware_version')
-    ota_status['last_update'] = time.strftime('%Y-%m-%d %H:%M:%S')
-
-    error_code = payload.get('error_code')
-    if error_code and error_code != 0:
-        ota_status['error_code'] = error_code
-        ota_status['error_message'] = payload.get('error_message', get_ota_error_message(error_code))
-        ota_status['status'] = 'FAILED'
-    elif ota_status['status'] == 'SUCCESS':
-        ota_status['error_code'] = None
-        ota_status['error_message'] = None
-        # 更新DTU配置中的固件版本
-        dtu_config['firmware_version'] = ota_status['firmware_version']
-
-    logger.info(f"OTA状态更新: {ota_status['status']}, 进度: {ota_status['progress']}%")
-
-
-def get_ota_error_message(error_code):
-    """获取OTA错误码对应的错误信息"""
-    error_messages = {
-        1021: '已是目标版本',
-        1022: '校验失败',
-        1023: '下载失败',
-        1024: '写入失败',
-        1025: '存储空间不足'
-    }
-    return error_messages.get(error_code, f'未知错误码: {error_code}')
-
-
 # ==================== 环境传感器 API ====================
 
 # 环境传感器数据存储
@@ -2626,6 +2767,9 @@ if __name__ == '__main__':
         logger.info('启动串口-MQTT网关服务...')
         logger.info(f"配置信息: 主机={FLASK_HOST}, 端口={FLASK_PORT}, 调试模式={FLASK_DEBUG}")
 
+        # 加载DTU配置
+        load_dtu_config()
+
         # 加载设备配置
         loaded = load_device_config()
         logger.info(f"已加载 {len(loaded)} 个设备配置")

BIN
firmware_v1.0.0.tar.gz


BIN
firmware_v1.0.1.tar.gz


+ 3 - 0
frontend/src/api/apiService.js

@@ -107,6 +107,9 @@ const apiService = {
     // 获取OTA升级状态
     getOtaStatus: () => apiClient.get('/dtu/ota_status'),
 
+    // 检测固件信息
+    detectFirmware: (url) => apiClient.post('/dtu/ota_detect', { url }),
+
     // 触发OTA升级
     triggerOtaUpgrade: (data) => apiClient.post('/dtu/ota_upgrade', data)
   },

+ 33 - 2
frontend/src/views/OTAUpgrade.vue

@@ -60,9 +60,13 @@
             <a-form-item label="固件URL" name="firmware_url">
               <a-input
                 v-model:value="form.firmware_url"
-                placeholder="http://example.com/firmware.bin"
+                placeholder="http://example.com/firmware.tar.gz"
                 :disabled="upgrading"
-              />
+              >
+                <template #suffix>
+                  <a-button size="small" type="link" @click="detectFirmware" :loading="detecting">检测</a-button>
+                </template>
+              </a-input>
             </a-form-item>
           </a-col>
           <a-col :span="6">
@@ -181,6 +185,7 @@ import apiService from '../api/apiService'
 // 状态
 const loading = ref(false)
 const upgrading = ref(false)
+const detecting = ref(false)
 const pollTimer = ref(null)
 const currentFirmware = ref('')
 
@@ -272,6 +277,32 @@ const progressColor = computed(() => {
   return '#1890ff'
 })
 
+// 检测固件信息
+async function detectFirmware() {
+  if (!form.value.firmware_url) {
+    message.warning('请先输入固件URL')
+    return
+  }
+  detecting.value = true
+  try {
+    const res = await apiService.dtu.detectFirmware(form.value.firmware_url)
+    if (res.data.success) {
+      const d = res.data.data
+      form.value.firmware_version = d.firmware_version
+      form.value.file_size = d.file_size
+      form.value.checksum = d.checksum
+      form.value.checksum_type = d.checksum_type
+      message.success('已自动获取固件信息')
+    } else {
+      message.error(res.data.message || '检测失败')
+    }
+  } catch (e) {
+    message.error('检测失败: ' + (e.response?.data?.message || e.message))
+  } finally {
+    detecting.value = false
+  }
+}
+
 // 获取OTA状态
 async function fetchOtaStatus() {
   try {