Explorar o código

:art: Kernel API `/api/query/sql` support `||` operator https://github.com/siyuan-note/siyuan/issues/9662

Daniel hai 1 ano
pai
achega
c65c32caf9
Modificáronse 3 ficheiros con 44 adicións e 16 borrados
  1. 1 0
      kernel/go.mod
  2. 2 0
      kernel/go.sum
  3. 41 16
      kernel/sql/block_query.go

+ 1 - 0
kernel/go.mod

@@ -139,6 +139,7 @@ require (
 	github.com/richardlehane/mscfb v1.0.4 // indirect
 	github.com/richardlehane/msoleps v1.0.3 // indirect
 	github.com/rivo/uniseg v0.4.4 // indirect
+	github.com/rqlite/sql v0.0.0-20221103124402-8f9ff0ceb8f0 // indirect
 	github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 // indirect
 	github.com/shoenig/go-m1cpu v0.1.6 // indirect
 	github.com/shopspring/decimal v1.3.1 // indirect

+ 2 - 0
kernel/go.sum

@@ -335,6 +335,8 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE
 github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
 github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
 github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
+github.com/rqlite/sql v0.0.0-20221103124402-8f9ff0ceb8f0 h1:C8DZB5okjhCSd7zvkOM+zxGz7S6ulUFIL34bpkqFk+0=
+github.com/rqlite/sql v0.0.0-20221103124402-8f9ff0ceb8f0/go.mod h1:ib9zVtNgRKiGuoMyUqqL5aNpk+r+++YlyiVIkclVqPg=
 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI=
 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06/go.mod h1:+ePHsJ1keEjQtpvf9HHw0f4ZeJ0TLRsxhunSI2hYJSs=
 github.com/sashabaranov/go-openai v1.17.6 h1:hYXRPM1xO6QLOJhWEOMlSg/l3jERiKDKd1qIoK22lvs=

+ 41 - 16
kernel/sql/block_query.go

@@ -27,6 +27,7 @@ import (
 	"github.com/88250/lute/ast"
 	"github.com/88250/vitess-sqlparser/sqlparser"
 	"github.com/emirpasic/gods/sets/hashset"
+	sqlparser2 "github.com/rqlite/sql"
 	"github.com/siyuan-note/logging"
 	"github.com/siyuan-note/siyuan/kernel/treenode"
 	"github.com/siyuan-note/siyuan/kernel/util"
@@ -384,24 +385,48 @@ func QueryNoLimit(stmt string) (ret []map[string]interface{}, err error) {
 }
 
 func Query(stmt string, limit int) (ret []map[string]interface{}, err error) {
-	parsedStmt, err := sqlparser.Parse(stmt)
+	// Kernel API `/api/query/sql` support `||` operator https://github.com/siyuan-note/siyuan/issues/9662
+	// 这里为了支持 || 操作符,使用了另一个 sql 解析器,但是这个解析器无法处理 UNION https://github.com/siyuan-note/siyuan/issues/8226
+	// 考虑到 UNION 的使用场景不多,这里还是以支持 || 操作符为主
+	p := sqlparser2.NewParser(strings.NewReader(stmt))
+	parsedStmt2, err := p.ParseStatement()
 	if nil != err {
-		return queryRawStmt(stmt, limit)
-	}
+		if !strings.Contains(stmt, "||") {
+			// 这个解析器无法处理 || 连接字符串操作符
+			parsedStmt, err2 := sqlparser.Parse(stmt)
+			if nil != err2 {
+				return queryRawStmt(stmt, limit)
+			}
 
-	switch parsedStmt.(type) {
-	case *sqlparser.Select:
-		limitClause := getLimitClause(parsedStmt, limit)
-		slct := parsedStmt.(*sqlparser.Select)
-		slct.Limit = limitClause
-		stmt = sqlparser.String(slct)
-	case *sqlparser.Union:
-		limitClause := getLimitClause(parsedStmt, limit)
-		union := parsedStmt.(*sqlparser.Union)
-		union.Limit = limitClause
-		stmt = sqlparser.String(union)
-	default:
-		return queryRawStmt(stmt, limit)
+			switch parsedStmt.(type) {
+			case *sqlparser.Select:
+				limitClause := getLimitClause(parsedStmt, limit)
+				slct := parsedStmt.(*sqlparser.Select)
+				slct.Limit = limitClause
+				stmt = sqlparser.String(slct)
+			case *sqlparser.Union:
+				// Kernel API `/api/query/sql` support `UNION` statement https://github.com/siyuan-note/siyuan/issues/8226
+				limitClause := getLimitClause(parsedStmt, limit)
+				union := parsedStmt.(*sqlparser.Union)
+				union.Limit = limitClause
+				stmt = sqlparser.String(union)
+			default:
+				return queryRawStmt(stmt, limit)
+			}
+		} else {
+			return queryRawStmt(stmt, limit)
+		}
+	} else {
+		switch parsedStmt2.(type) {
+		case *sqlparser2.SelectStatement:
+			slct := parsedStmt2.(*sqlparser2.SelectStatement)
+			if nil == slct.LimitExpr {
+				slct.LimitExpr = &sqlparser2.NumberLit{Value: strconv.Itoa(limit)}
+			}
+			stmt = slct.String()
+		default:
+			return queryRawStmt(stmt, limit)
+		}
 	}
 
 	ret = []map[string]interface{}{}