Переглянути джерело

:art: API `/api/query/sql` add `LIMIT` clause https://github.com/siyuan-note/siyuan/issues/8167

Liang Ding 2 роки тому
батько
коміт
a48154ba84
3 змінених файлів з 91 додано та 8 видалено
  1. 2 1
      kernel/api/sql.go
  2. 4 4
      kernel/model/search.go
  3. 85 3
      kernel/sql/block_query.go

+ 2 - 1
kernel/api/sql.go

@@ -21,6 +21,7 @@ import (
 
 	"github.com/88250/gulu"
 	"github.com/gin-gonic/gin"
+	"github.com/siyuan-note/siyuan/kernel/model"
 	"github.com/siyuan-note/siyuan/kernel/sql"
 	"github.com/siyuan-note/siyuan/kernel/util"
 )
@@ -35,7 +36,7 @@ func SQL(c *gin.Context) {
 	}
 
 	stmt := arg["stmt"].(string)
-	result, err := sql.Query(stmt)
+	result, err := sql.Query(stmt, model.Conf.Search.Limit)
 	if nil != err {
 		ret.Code = 1
 		ret.Msg = err.Error()

+ 4 - 4
kernel/model/search.go

@@ -624,7 +624,7 @@ func searchBySQL(stmt string, beforeLen, page int) (ret []*Block, matchedBlockCo
 		stmt = strings.ReplaceAll(stmt, "select * ", "select COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` ")
 	}
 	stmt = removeLimitClause(stmt)
-	result, _ := sql.Query(stmt)
+	result, _ := sql.QueryNoLimit(stmt)
 	if 1 > len(ret) {
 		return
 	}
@@ -745,7 +745,7 @@ func fullTextSearchCountByRegexp(exp, boxFilter, pathFilter, typeFilter string)
 	fieldFilter := fieldRegexp(exp)
 	stmt := "SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `blocks` WHERE " + fieldFilter + " AND type IN " + typeFilter
 	stmt += boxFilter + pathFilter
-	result, _ := sql.Query(stmt)
+	result, _ := sql.QueryNoLimit(stmt)
 	if 1 > len(result) {
 		return
 	}
@@ -785,7 +785,7 @@ func fullTextSearchByFTS(query, boxFilter, pathFilter, typeFilter, orderBy strin
 func fullTextSearchCount(query, boxFilter, pathFilter, typeFilter string) (matchedBlockCount, matchedRootCount int) {
 	query = gulu.Str.RemoveInvisible(query)
 	if ast.IsNodeIDPattern(query) {
-		ret, _ := sql.Query("SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `blocks` WHERE `id` = '" + query + "'")
+		ret, _ := sql.QueryNoLimit("SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `blocks` WHERE `id` = '" + query + "'")
 		if 1 > len(ret) {
 			return
 		}
@@ -802,7 +802,7 @@ func fullTextSearchCount(query, boxFilter, pathFilter, typeFilter string) (match
 	stmt := "SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `" + table + "` WHERE (`" + table + "` MATCH '" + columnFilter() + ":(" + query + ")'"
 	stmt += ") AND type IN " + typeFilter
 	stmt += boxFilter + pathFilter
-	result, _ := sql.Query(stmt)
+	result, _ := sql.QueryNoLimit(stmt)
 	if 1 > len(result) {
 		return
 	}

+ 85 - 3
kernel/sql/block_query.go

@@ -19,6 +19,7 @@ package sql
 import (
 	"bytes"
 	"database/sql"
+	"math"
 	"sort"
 	"strconv"
 	"strings"
@@ -378,7 +379,45 @@ func QueryBookmarkLabels() (ret []string) {
 	return
 }
 
-func Query(stmt string) (ret []map[string]interface{}, err error) {
+func QueryNoLimit(stmt string) (ret []map[string]interface{}, err error) {
+	return queryRawStmt(stmt, math.MaxInt)
+}
+
+func Query(stmt string, limit int) (ret []map[string]interface{}, err error) {
+	parsedStmt, err := sqlparser.Parse(stmt)
+	if nil != err {
+		return queryRawStmt(stmt, limit)
+	}
+
+	switch parsedStmt.(type) {
+	case *sqlparser.Select:
+		slct := parsedStmt.(*sqlparser.Select)
+		if nil == slct.Limit {
+			slct.Limit = &sqlparser.Limit{
+				Rowcount: &sqlparser.SQLVal{
+					Type: sqlparser.IntVal,
+					Val:  []byte(strconv.Itoa(limit)),
+				},
+			}
+		} else {
+			if nil != slct.Limit.Rowcount && 0 < len(slct.Limit.Rowcount.(*sqlparser.SQLVal).Val) {
+				limit, _ = strconv.Atoi(string(slct.Limit.Rowcount.(*sqlparser.SQLVal).Val))
+				if 0 >= limit {
+					limit = 32
+				}
+			}
+
+			slct.Limit.Rowcount = &sqlparser.SQLVal{
+				Type: sqlparser.IntVal,
+				Val:  []byte(strconv.Itoa(limit)),
+			}
+		}
+
+		stmt = sqlparser.String(slct)
+	default:
+		return
+	}
+
 	ret = []map[string]interface{}{}
 	rows, err := query(stmt)
 	if nil != err {
@@ -413,6 +452,49 @@ func Query(stmt string) (ret []map[string]interface{}, err error) {
 	return
 }
 
+func queryRawStmt(stmt string, limit int) (ret []map[string]interface{}, err error) {
+	rows, err := query(stmt)
+	if nil != err {
+		if strings.Contains(err.Error(), "syntax error") {
+			return
+		}
+		return
+	}
+	defer rows.Close()
+
+	cols, err := rows.Columns()
+	if nil != err || nil == cols {
+		return
+	}
+
+	noLimit := !strings.Contains(strings.ToLower(stmt), " limit ")
+	var count, errCount int
+	for rows.Next() {
+		columns := make([]interface{}, len(cols))
+		columnPointers := make([]interface{}, len(cols))
+		for i := range columns {
+			columnPointers[i] = &columns[i]
+		}
+
+		if err = rows.Scan(columnPointers...); nil != err {
+			return
+		}
+
+		m := make(map[string]interface{})
+		for i, colName := range cols {
+			val := columnPointers[i].(*interface{})
+			m[colName] = *val
+		}
+
+		ret = append(ret, m)
+		count++
+		if (noLimit && limit < count) || 0 < errCount {
+			break
+		}
+	}
+	return
+}
+
 func SelectBlocksRawStmtNoParse(stmt string, limit int) (ret []*Block) {
 	return selectBlocksRawStmt(stmt, limit)
 }
@@ -491,7 +573,7 @@ func selectBlocksRawStmt(stmt string, limit int) (ret []*Block) {
 	}
 	defer rows.Close()
 
-	confLimit := !strings.Contains(strings.ToLower(stmt), " limit ")
+	noLimit := !strings.Contains(strings.ToLower(stmt), " limit ")
 	var count, errCount int
 	for rows.Next() {
 		count++
@@ -502,7 +584,7 @@ func selectBlocksRawStmt(stmt string, limit int) (ret []*Block) {
 			errCount++
 		}
 
-		if (confLimit && limit < count) || 0 < errCount {
+		if (noLimit && limit < count) || 0 < errCount {
 			break
 		}
 	}