Просмотр исходного кода

:art: AI 加入上下文信息

Liang Ding 2 лет назад
Родитель
Сommit
5eae66a35e
1 измененных файлов с 41 добавлено и 18 удалено
  1. 41 18
      kernel/util/openai.go

+ 41 - 18
kernel/util/openai.go

@@ -40,7 +40,15 @@ var (
 	OpenAIAPIMaxTokens = 0
 )
 
+var cachedContextMsg []string
+
 func ChatGPT(msg string) (ret string) {
+	ret, retCtxMsgs := ChatGPTContinueWrite(msg, cachedContextMsg)
+	cachedContextMsg = append(cachedContextMsg, retCtxMsgs...)
+	return
+}
+
+func ChatGPTContinueWrite(msg string, contextMsgs []string) (ret string, retContextMsgs []string) {
 	if "" == OpenAIAPIKey {
 		return
 	}
@@ -48,30 +56,31 @@ func ChatGPT(msg string) (ret string) {
 	PushEndlessProgress("Requesting...")
 	defer ClearPushProgress(100)
 
-	config := gogpt.DefaultConfig(OpenAIAPIKey)
-	if "" != OpenAIAPIProxy {
-		proxyUrl, err := url.Parse(OpenAIAPIProxy)
-		if nil != err {
-			logging.LogErrorf("OpenAI API proxy error: %v", err)
-		} else {
-			config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
-		}
+	c := newOpenAIClient()
+
+	var reqMsgs []gogpt.ChatCompletionMessage
+	if 7 < len(contextMsgs) {
+		contextMsgs = contextMsgs[len(contextMsgs)-7:]
 	}
 
-	c := gogpt.NewClientWithConfig(config)
-	ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout)
-	defer cancel()
+	for _, ctxMsg := range contextMsgs {
+		reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
+			Role:    "user",
+			Content: ctxMsg,
+		})
+	}
+	reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
+		Role:    "user",
+		Content: msg,
+	})
+
 	req := gogpt.ChatCompletionRequest{
 		Model:     gogpt.GPT3Dot5Turbo,
 		MaxTokens: OpenAIAPIMaxTokens,
-		Messages: []gogpt.ChatCompletionMessage{
-			{
-				Role:    "user",
-				Content: msg,
-			},
-		},
+		Messages:  reqMsgs,
 	}
-
+	ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout)
+	defer cancel()
 	stream, err := c.CreateChatCompletionStream(ctx, req)
 	if nil != err {
 		logging.LogErrorf("create chat completion stream failed: %s", err)
@@ -100,9 +109,23 @@ func ChatGPT(msg string) (ret string) {
 
 	ret = buf.String()
 	ret = strings.TrimSpace(ret)
+	retContextMsgs = append(retContextMsgs, msg, ret)
 	return
 }
 
+func newOpenAIClient() *gogpt.Client {
+	config := gogpt.DefaultConfig(OpenAIAPIKey)
+	if "" != OpenAIAPIProxy {
+		proxyUrl, err := url.Parse(OpenAIAPIProxy)
+		if nil != err {
+			logging.LogErrorf("OpenAI API proxy failed: %v", err)
+		} else {
+			config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
+		}
+	}
+	return gogpt.NewClientWithConfig(config)
+}
+
 func initOpenAI() {
 	OpenAIAPIKey = os.Getenv("SIYUAN_OPENAI_API_KEY")
 	if "" == OpenAIAPIKey {