소스 검색

* 修改count 复杂子句的场景优化

chen.cheng 2 달 전
부모
커밋
88a8002af9

+ 29 - 6
bd-park/park-backend/park-infrastructure/src/main/java/com/huashe/park/infrastructure/cfg/mybatis/TenantSqlInterceptor.java

@@ -20,6 +20,7 @@ import org.springframework.stereotype.Component;
 import com.huashe.common.annotation.mybatis.Tenant;
 import com.huashe.park.common.animations.mybatis.MethodMetadataCache;
 
+import cn.hutool.core.lang.Pair;
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.LongValue;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
@@ -30,6 +31,7 @@ import net.sf.jsqlparser.statement.Statement;
 import net.sf.jsqlparser.statement.select.PlainSelect;
 import net.sf.jsqlparser.statement.select.Select;
 import net.sf.jsqlparser.statement.select.SelectBody;
+import net.sf.jsqlparser.statement.select.SubSelect;
 
 @Intercepts({
     @Signature(type = StatementHandler.class, method = "prepare", args = {
@@ -47,9 +49,9 @@ public class TenantSqlInterceptor extends BaseInterceptor implements Interceptor
         // 获取原始 SQL 和 MappedStatement
         String originalSql = (String) metaObject.getValue("delegate.boundSql.sql");
         MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
-        Method mapperMethod = getMapperMethod(mappedStatement);
+        Pair<Method, Boolean> mapperMethod = getMapperMethod(mappedStatement);
         // 判断是否需要添加租户条件
-        if (!ObjectUtils.allNotNull(mapperMethod, getLoginUser().getUser().getTenantId())) {
+        if (!ObjectUtils.allNotNull(mapperMethod.getKey(), getLoginUser().getUser().getTenantId())) {
             return invocation.proceed();
         }
         // 解析 SQL 语法树
@@ -65,8 +67,27 @@ public class TenantSqlInterceptor extends BaseInterceptor implements Interceptor
         PlainSelect plainSelect = (PlainSelect) selectBody;
         // 构建租户条件表达式(tenant_id = ?)
         EqualsTo tenantCondition = new EqualsTo();
-        tenantCondition.setLeftExpression(handleColumn(mapperMethod));
+        tenantCondition.setLeftExpression(handleColumn(mapperMethod.getKey()));
         tenantCondition.setRightExpression(new LongValue(getLoginUser().getUser().getTenantId()));
+        // *_COUNT 结尾的查询方法且from组合查询子句添加租户条件
+        // SELECT count(0) FROM (SELECT * FROM flow_task AS t LEFT JOIN flow_user uu ON uu.associated = t.id)
+        // table_count
+        if (mapperMethod.getValue() && plainSelect.getFromItem() instanceof SubSelect) {
+            SelectBody subSelect = ((SubSelect) plainSelect.getFromItem()).getSelectBody();
+            PlainSelect subPlainSelect = (PlainSelect) subSelect;
+            // 添加到原有WHERE条件
+            if (subPlainSelect.getWhere() != null) {
+                subPlainSelect.setWhere(new AndExpression(subPlainSelect.getWhere(), tenantCondition));
+            }
+            else {
+                subPlainSelect.setWhere(tenantCondition);
+            }
+            // 生成修改后的 SQL
+            String modifiedSql = plainSelect.toString();
+            metaObject.setValue("delegate.boundSql.sql", modifiedSql);
+
+            return invocation.proceed();
+        }
 
         // 修改 WHERE 条件
         Expression where = plainSelect.getWhere();
@@ -87,17 +108,19 @@ public class TenantSqlInterceptor extends BaseInterceptor implements Interceptor
         return invocation.proceed();
     }
 
-    private Method getMapperMethod(MappedStatement mappedStatement) {
+    private Pair<Method, Boolean> getMapperMethod(MappedStatement mappedStatement) {
         String methodId = mappedStatement.getId();
         int lastDotIndex = methodId.lastIndexOf(".");
         String methodName = methodId.substring(lastDotIndex + 1);
         List<Method> annotatedMethods = MethodMetadataCache.getAnnotatedMethods(mappedStatement);
+        boolean endWithCount = false;
         if (methodName.endsWith("_COUNT")) {
             methodName = methodName.replaceFirst("_COUNT", "");
+            endWithCount = true;
         }
         String finalMethodName = methodName;
-        return annotatedMethods.stream().filter(method -> method.getName().equals(finalMethodName)).findFirst()
-            .orElse(null);
+        return Pair.of(annotatedMethods.stream().filter(method -> method.getName().equals(finalMethodName)).findFirst()
+            .orElse(null), endWithCount);
     }
 
     private Column handleColumn(Method mapperMethod) {