Browse Source

:art: 加入针对内容块的人工智能辅助支持 https://github.com/siyuan-note/siyuan/issues/7566

Liang Ding 2 years ago
parent
commit
20af6b0b7a
3 changed files with 41 additions and 15 deletions
  1. 38 12
      kernel/model/ai.go
  2. 2 2
      kernel/model/liandi.go
  3. 1 1
      kernel/util/openai.go

+ 38 - 12
kernel/model/ai.go

@@ -22,6 +22,7 @@ import (
 
 	"github.com/88250/lute/ast"
 	"github.com/88250/lute/parse"
+	gogpt "github.com/sashabaranov/go-gpt3"
 	"github.com/siyuan-note/siyuan/kernel/treenode"
 	"github.com/siyuan-note/siyuan/kernel/util"
 )
@@ -47,18 +48,24 @@ func ChatGPTWithAction(ids []string, action string) (ret string) {
 var cachedContextMsg []string
 
 func chatGPT(msg string, cloud bool) (ret string) {
-	ret, retCtxMsgs := chatGPTContinueWrite(msg, cachedContextMsg, cloud)
+	ret, retCtxMsgs, err := chatGPTContinueWrite(msg, cachedContextMsg, cloud)
+	if nil != err {
+		return
+	}
 	cachedContextMsg = append(cachedContextMsg, retCtxMsgs...)
 	return
 }
 
 func chatGPTWithAction(msg string, action string, cloud bool) (ret string) {
 	msg = action + ":\n\n" + msg
-	ret, _ = chatGPTContinueWrite(msg, nil, cloud)
+	ret, _, err := chatGPTContinueWrite(msg, nil, cloud)
+	if nil != err {
+		return
+	}
 	return
 }
 
-func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string) {
+func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string, err error) {
 	util.PushEndlessProgress("Requesting...")
 	defer util.ClearPushProgress(100)
 
@@ -66,19 +73,19 @@ func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret str
 		contextMsgs = contextMsgs[len(contextMsgs)-7:]
 	}
 
-	c := util.NewOpenAIClient()
+	var gpt GPT
+	if cloud {
+		gpt = &CloudGPT{}
+	} else {
+		gpt = &OpenAIGPT{c: util.NewOpenAIClient()}
+	}
+
 	buf := &bytes.Buffer{}
 	for i := 0; i < 7; i++ {
-		var part string
-		var stop bool
-		if cloud {
-			part, stop = CloudChatGPT(msg, contextMsgs)
-		} else {
-			part, stop = util.ChatGPT(msg, contextMsgs, c)
-		}
+		part, stop, chatErr := gpt.chat(msg, contextMsgs)
 		buf.WriteString(part)
 
-		if stop {
+		if stop || nil != chatErr {
 			break
 		}
 
@@ -138,3 +145,22 @@ func getBlocksContent(ids []string) string {
 	}
 	return buf.String()
 }
+
+type GPT interface {
+	chat(msg string, contextMsgs []string) (partRet string, stop bool, err error)
+}
+
+type OpenAIGPT struct {
+	c *gogpt.Client
+}
+
+func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
+	return util.ChatGPT(msg, contextMsgs, gpt.c)
+}
+
+type CloudGPT struct {
+}
+
+func (gpt *CloudGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
+	return CloudChatGPT(msg, contextMsgs)
+}

+ 2 - 2
kernel/model/liandi.go

@@ -36,7 +36,7 @@ import (
 
 var ErrFailedToConnectCloudServer = errors.New("failed to connect cloud server")
 
-func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) {
+func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool, err error) {
 	if nil == Conf.User {
 		return
 	}
@@ -57,7 +57,7 @@ func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) {
 
 	requestResult := gulu.Ret.NewResult()
 	request := httpclient.NewCloudRequest30s()
-	_, err := request.
+	_, err = request.
 		SetSuccessResult(requestResult).
 		SetCookies(&http.Cookie{Name: "symphony", Value: Conf.User.UserToken}).
 		SetBody(payload).

+ 1 - 1
kernel/util/openai.go

@@ -37,7 +37,7 @@ var (
 	OpenAIAPIBaseURL   = "https://api.openai.com/v1"
 )
 
-func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) {
+func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool, err error) {
 	var reqMsgs []gogpt.ChatCompletionMessage
 
 	for _, ctxMsg := range contextMsgs {