瀏覽代碼

:art: API `/api/query/sql` support `UNION` statement https://github.com/siyuan-note/siyuan/issues/8226

Liang Ding 2 年之前
父節點
當前提交
3fdf00bbbd
共有 1 個文件被更改,包括 33 次插入22 次删除
  1. 33 22
      kernel/sql/block_query.go

+ 33 - 22
kernel/sql/block_query.go

@@ -391,31 +391,17 @@ func Query(stmt string, limit int) (ret []map[string]interface{}, err error) {
 
 	switch parsedStmt.(type) {
 	case *sqlparser.Select:
+		limitClause := getLimitClause(parsedStmt, limit)
 		slct := parsedStmt.(*sqlparser.Select)
-		if nil == slct.Limit {
-			slct.Limit = &sqlparser.Limit{
-				Rowcount: &sqlparser.SQLVal{
-					Type: sqlparser.IntVal,
-					Val:  []byte(strconv.Itoa(limit)),
-				},
-			}
-		} else {
-			if nil != slct.Limit.Rowcount && 0 < len(slct.Limit.Rowcount.(*sqlparser.SQLVal).Val) {
-				limit, _ = strconv.Atoi(string(slct.Limit.Rowcount.(*sqlparser.SQLVal).Val))
-				if 0 >= limit {
-					limit = 32
-				}
-			}
-
-			slct.Limit.Rowcount = &sqlparser.SQLVal{
-				Type: sqlparser.IntVal,
-				Val:  []byte(strconv.Itoa(limit)),
-			}
-		}
-
+		slct.Limit = limitClause
 		stmt = sqlparser.String(slct)
+	case *sqlparser.Union:
+		limitClause := getLimitClause(parsedStmt, limit)
+		union := parsedStmt.(*sqlparser.Union)
+		union.Limit = limitClause
+		stmt = sqlparser.String(union)
 	default:
-		return
+		return queryRawStmt(stmt, limit)
 	}
 
 	ret = []map[string]interface{}{}
@@ -452,6 +438,31 @@ func Query(stmt string, limit int) (ret []map[string]interface{}, err error) {
 	return
 }
 
+func getLimitClause(parsedStmt sqlparser.Statement, limit int) (ret *sqlparser.Limit) {
+	switch parsedStmt.(type) {
+	case *sqlparser.Select:
+		slct := parsedStmt.(*sqlparser.Select)
+		if nil != slct.Limit {
+			ret = slct.Limit
+		}
+	case *sqlparser.Union:
+		union := parsedStmt.(*sqlparser.Union)
+		if nil != union.Limit {
+			ret = union.Limit
+		}
+	}
+
+	if nil == ret || nil == ret.Rowcount {
+		ret = &sqlparser.Limit{
+			Rowcount: &sqlparser.SQLVal{
+				Type: sqlparser.IntVal,
+				Val:  []byte(strconv.Itoa(limit)),
+			},
+		}
+	}
+	return
+}
+
 func queryRawStmt(stmt string, limit int) (ret []map[string]interface{}, err error) {
 	rows, err := query(stmt)
 	if nil != err {