🎨 加入针对内容块的人工智能辅助支持 https://github.com/siyuan-note/siyuan/issues/7566

This commit is contained in:
Liang Ding 2023-03-08 21:28:29 +08:00
parent 5c3055d601
commit 20af6b0b7a
No known key found for this signature in database
GPG key ID: 136F30F901A2231D
3 changed files with 41 additions and 15 deletions

View file

@ -22,6 +22,7 @@ import (
"github.com/88250/lute/ast"
"github.com/88250/lute/parse"
gogpt "github.com/sashabaranov/go-gpt3"
"github.com/siyuan-note/siyuan/kernel/treenode"
"github.com/siyuan-note/siyuan/kernel/util"
)
@ -47,18 +48,24 @@ func ChatGPTWithAction(ids []string, action string) (ret string) {
var cachedContextMsg []string
func chatGPT(msg string, cloud bool) (ret string) {
ret, retCtxMsgs := chatGPTContinueWrite(msg, cachedContextMsg, cloud)
ret, retCtxMsgs, err := chatGPTContinueWrite(msg, cachedContextMsg, cloud)
if nil != err {
return
}
cachedContextMsg = append(cachedContextMsg, retCtxMsgs...)
return
}
func chatGPTWithAction(msg string, action string, cloud bool) (ret string) {
msg = action + ":\n\n" + msg
ret, _ = chatGPTContinueWrite(msg, nil, cloud)
ret, _, err := chatGPTContinueWrite(msg, nil, cloud)
if nil != err {
return
}
return
}
func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string) {
func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string, err error) {
util.PushEndlessProgress("Requesting...")
defer util.ClearPushProgress(100)
@ -66,19 +73,19 @@ func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret str
contextMsgs = contextMsgs[len(contextMsgs)-7:]
}
c := util.NewOpenAIClient()
var gpt GPT
if cloud {
gpt = &CloudGPT{}
} else {
gpt = &OpenAIGPT{c: util.NewOpenAIClient()}
}
buf := &bytes.Buffer{}
for i := 0; i < 7; i++ {
var part string
var stop bool
if cloud {
part, stop = CloudChatGPT(msg, contextMsgs)
} else {
part, stop = util.ChatGPT(msg, contextMsgs, c)
}
part, stop, chatErr := gpt.chat(msg, contextMsgs)
buf.WriteString(part)
if stop {
if stop || nil != chatErr {
break
}
@ -138,3 +145,22 @@ func getBlocksContent(ids []string) string {
}
return buf.String()
}
type GPT interface {
chat(msg string, contextMsgs []string) (partRet string, stop bool, err error)
}
type OpenAIGPT struct {
c *gogpt.Client
}
func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
return util.ChatGPT(msg, contextMsgs, gpt.c)
}
type CloudGPT struct {
}
func (gpt *CloudGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
return CloudChatGPT(msg, contextMsgs)
}

View file

@ -36,7 +36,7 @@ import (
var ErrFailedToConnectCloudServer = errors.New("failed to connect cloud server")
func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) {
func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool, err error) {
if nil == Conf.User {
return
}
@ -57,7 +57,7 @@ func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) {
requestResult := gulu.Ret.NewResult()
request := httpclient.NewCloudRequest30s()
_, err := request.
_, err = request.
SetSuccessResult(requestResult).
SetCookies(&http.Cookie{Name: "symphony", Value: Conf.User.UserToken}).
SetBody(payload).

View file

@ -37,7 +37,7 @@ var (
OpenAIAPIBaseURL = "https://api.openai.com/v1"
)
func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) {
func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool, err error) {
var reqMsgs []gogpt.ChatCompletionMessage
for _, ctxMsg := range contextMsgs {