Переглянути джерело

:art: 接入云端人工智能接口公测 https://github.com/siyuan-note/siyuan/issues/7601

Liang Ding 2 роки тому
батько
коміт
1c9e0356cc
3 змінених файлів з 122 додано та 53 видалено
  1. 60 6
      kernel/model/ai.go
  2. 60 0
      kernel/model/liandi.go
  3. 2 47
      kernel/util/openai.go

+ 60 - 6
kernel/model/ai.go

@@ -18,6 +18,7 @@ package model
 
 import (
 	"bytes"
+	"strings"
 
 	"github.com/88250/lute/ast"
 	"github.com/88250/lute/parse"
@@ -25,22 +26,75 @@ import (
 	"github.com/siyuan-note/siyuan/kernel/util"
 )
 
+func ChatGPT(msg string) (ret string) {
+	cloud := IsSubscriber()
+	if !cloud && !isOpenAIAPIEnabled() {
+		return
+	}
+
+	cloud = false
+
+	return chatGPT(msg, cloud)
+}
+
 func ChatGPTWithAction(ids []string, action string) (ret string) {
-	if !isOpenAIAPIEnabled() {
+	cloud := IsSubscriber()
+	if !cloud && !isOpenAIAPIEnabled() {
 		return
 	}
 
+	cloud = false
+
 	msg := getBlocksContent(ids)
-	ret = util.ChatGPTWithAction(msg, action)
+	ret = chatGPTWithAction(msg, action, cloud)
 	return
 }
 
-func ChatGPT(msg string) (ret string) {
-	if !isOpenAIAPIEnabled() {
-		return
+var cachedContextMsg []string
+
+func chatGPT(msg string, cloud bool) (ret string) {
+	ret, retCtxMsgs := chatGPTContinueWrite(msg, cachedContextMsg, cloud)
+	cachedContextMsg = append(cachedContextMsg, retCtxMsgs...)
+	return
+}
+
+func chatGPTWithAction(msg string, action string, cloud bool) (ret string) {
+	msg = action + ":\n\n" + msg
+	ret, _ = chatGPTContinueWrite(msg, nil, cloud)
+	return
+}
+
+func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string) {
+	util.PushEndlessProgress("Requesting...")
+	defer util.ClearPushProgress(100)
+
+	if 7 < len(contextMsgs) {
+		contextMsgs = contextMsgs[len(contextMsgs)-7:]
+	}
+
+	c := util.NewOpenAIClient()
+	buf := &bytes.Buffer{}
+	for i := 0; i < 7; i++ {
+		var part string
+		var stop bool
+		if cloud {
+			part, stop = CloudChatGPT(msg, contextMsgs)
+		} else {
+			part, stop = util.ChatGPT(msg, contextMsgs, c)
+		}
+		buf.WriteString(part)
+
+		if stop {
+			break
+		}
+
+		util.PushEndlessProgress("Continue requesting...")
 	}
 
-	return util.ChatGPT(msg)
+	ret = buf.String()
+	ret = strings.TrimSpace(ret)
+	retContextMsgs = append(retContextMsgs, msg, ret)
+	return
 }
 
 func isOpenAIAPIEnabled() bool {

+ 60 - 0
kernel/model/liandi.go

@@ -36,6 +36,66 @@ import (
 
 var ErrFailedToConnectCloudServer = errors.New("failed to connect cloud server")
 
+func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) {
+	if nil == Conf.User {
+		return
+	}
+
+	payload := map[string]interface{}{}
+	var messages []map[string]interface{}
+	for _, contextMsg := range contextMsgs {
+		messages = append(messages, map[string]interface{}{
+			"role":    "user",
+			"content": contextMsg,
+		})
+	}
+	messages = append(messages, map[string]interface{}{
+		"role":    "user",
+		"content": msg,
+	})
+	payload["messages"] = messages
+
+	requestResult := gulu.Ret.NewResult()
+	request := httpclient.NewCloudRequest30s()
+	_, err := request.
+		SetSuccessResult(requestResult).
+		SetCookies(&http.Cookie{Name: "symphony", Value: Conf.User.UserToken}).
+		SetBody(payload).
+		Post(util.AliyunServer + "/apis/siyuan/ai/chatGPT")
+	if nil != err {
+		logging.LogErrorf("chat gpt failed: %s", err)
+		err = ErrFailedToConnectCloudServer
+		return
+	}
+	if 0 != requestResult.Code {
+		err = errors.New(requestResult.Msg)
+		stop = true
+		return
+	}
+
+	data := requestResult.Data.(map[string]interface{})
+	choices := data["choices"].([]interface{})
+	if 1 > len(choices) {
+		stop = true
+		return
+	}
+	choice := choices[0].(map[string]interface{})
+	message := choice["message"].(map[string]interface{})
+	ret = message["content"].(string)
+
+	if nil != choice["finish_reason"] {
+		finishReason := choice["finish_reason"].(string)
+		if "length" == finishReason {
+			stop = false
+		} else {
+			stop = true
+		}
+	} else {
+		stop = true
+	}
+	return
+}
+
 func StartFreeTrial() (err error) {
 	if nil == Conf.User {
 		return errors.New(Conf.Language(31))

+ 2 - 47
kernel/util/openai.go

@@ -17,7 +17,6 @@
 package util
 
 import (
-	"bytes"
 	"context"
 	"net/http"
 	"net/url"
@@ -37,52 +36,8 @@ var (
 	OpenAIAPIMaxTokens = 0
 )
 
-var cachedContextMsg []string
-
-func ChatGPT(msg string) (ret string) {
-	ret, retCtxMsgs := ChatGPTContinueWrite(msg, cachedContextMsg)
-	cachedContextMsg = append(cachedContextMsg, retCtxMsgs...)
-	return
-}
-
-func ChatGPTWithAction(msg string, action string) (ret string) {
-	msg = action + ":\n\n" + msg
-	ret, _ = ChatGPTContinueWrite(msg, nil)
-	return
-}
-
-func ChatGPTContinueWrite(msg string, contextMsgs []string) (ret string, retContextMsgs []string) {
-	if "" == OpenAIAPIKey {
-		return
-	}
-
-	PushEndlessProgress("Requesting...")
-	defer ClearPushProgress(100)
-
-	c := newOpenAIClient()
-	buf := &bytes.Buffer{}
-	for i := 0; i < 7; i++ {
-		part, stop := chatGPT(msg, contextMsgs, c)
-		buf.WriteString(part)
-
-		if stop {
-			break
-		}
-
-		PushEndlessProgress("Continue requesting...")
-	}
-
-	ret = buf.String()
-	ret = strings.TrimSpace(ret)
-	retContextMsgs = append(retContextMsgs, msg, ret)
-	return
-}
-
-func chatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) {
+func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) {
 	var reqMsgs []gogpt.ChatCompletionMessage
-	if 7 < len(contextMsgs) {
-		contextMsgs = contextMsgs[len(contextMsgs)-7:]
-	}
 
 	for _, ctxMsg := range contextMsgs {
 		reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
@@ -129,7 +84,7 @@ func chatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, sto
 	return
 }
 
-func newOpenAIClient() *gogpt.Client {
+func NewOpenAIClient() *gogpt.Client {
 	config := gogpt.DefaultConfig(OpenAIAPIKey)
 	if "" != OpenAIAPIProxy {
 		proxyUrl, err := url.Parse(OpenAIAPIProxy)