Просмотр исходного кода

:art: 支持应用内配置人工智能 https://github.com/siyuan-note/siyuan/issues/7714

Liang Ding 2 лет назад
Родитель
Сommit
59004a664f
4 измененных файлов с 22 добавлено и 66 удалено
  1. 3 3
      kernel/model/ai.go
  2. 9 0
      kernel/model/conf.go
  3. 10 61
      kernel/util/openai.go
  4. 0 2
      kernel/util/working.go

+ 3 - 3
kernel/model/ai.go

@@ -77,7 +77,7 @@ func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret str
 	if cloud {
 		gpt = &CloudGPT{}
 	} else {
-		gpt = &OpenAIGPT{c: util.NewOpenAIClient()}
+		gpt = &OpenAIGPT{c: util.NewOpenAIClient(Conf.AI.OpenAI.APIKey, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIBaseURL)}
 	}
 
 	buf := &bytes.Buffer{}
@@ -99,7 +99,7 @@ func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret str
 }
 
 func isOpenAIAPIEnabled() bool {
-	if "" == util.OpenAIAPIKey {
+	if "" == Conf.AI.OpenAI.APIKey {
 		util.PushMsg(Conf.Language(193), 5000)
 		return false
 	}
@@ -155,7 +155,7 @@ type OpenAIGPT struct {
 }
 
 func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
-	return util.ChatGPT(msg, contextMsgs, gpt.c)
+	return util.ChatGPT(msg, contextMsgs, gpt.c, Conf.AI.OpenAI.APIMaxTokens, Conf.AI.OpenAI.APITimeout)
 }
 
 type CloudGPT struct {

+ 9 - 0
kernel/model/conf.go

@@ -323,6 +323,15 @@ func InitConf() {
 		Conf.AI = conf.NewAI()
 	}
 
+	if "" != Conf.AI.OpenAI.APIKey {
+		logging.LogInfof("OpenAI API enabled\n"+
+			"    baseURL=%s\n"+
+			"    timeout=%ds\n"+
+			"    proxy=%s\n"+
+			"    maxTokens=%d",
+			Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APITimeout, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIMaxTokens)
+	}
+
 	Conf.ReadOnly = util.ReadOnly
 
 	if "" != util.AccessAuthCode {

+ 10 - 61
kernel/util/openai.go

@@ -18,26 +18,15 @@ package util
 
 import (
 	"context"
+	gogpt "github.com/sashabaranov/go-gpt3"
+	"github.com/siyuan-note/logging"
 	"net/http"
 	"net/url"
-	"os"
-	"strconv"
 	"strings"
 	"time"
-
-	gogpt "github.com/sashabaranov/go-gpt3"
-	"github.com/siyuan-note/logging"
 )
 
-var (
-	OpenAIAPIKey       = ""
-	OpenAIAPITimeout   = 30 * time.Second
-	OpenAIAPIProxy     = ""
-	OpenAIAPIMaxTokens = 0
-	OpenAIAPIBaseURL   = "https://api.openai.com/v1"
-)
-
-func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool, err error) {
+func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client, maxTokens, timeout int) (ret string, stop bool, err error) {
 	var reqMsgs []gogpt.ChatCompletionMessage
 
 	for _, ctxMsg := range contextMsgs {
@@ -53,10 +42,10 @@ func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, sto
 
 	req := gogpt.ChatCompletionRequest{
 		Model:     gogpt.GPT3Dot5Turbo,
-		MaxTokens: OpenAIAPIMaxTokens,
+		MaxTokens: maxTokens,
 		Messages:  reqMsgs,
 	}
-	ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
 	defer cancel()
 	resp, err := c.CreateChatCompletion(ctx, req)
 	if nil != err {
@@ -85,10 +74,10 @@ func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, sto
 	return
 }
 
-func NewOpenAIClient() *gogpt.Client {
-	config := gogpt.DefaultConfig(OpenAIAPIKey)
-	if "" != OpenAIAPIProxy {
-		proxyUrl, err := url.Parse(OpenAIAPIProxy)
+func NewOpenAIClient(apiKey, apiProxy, apiBaseURL string) *gogpt.Client {
+	config := gogpt.DefaultConfig(apiKey)
+	if "" != apiProxy {
+		proxyUrl, err := url.Parse(apiProxy)
 		if nil != err {
 			logging.LogErrorf("OpenAI API proxy failed: %v", err)
 		} else {
@@ -96,46 +85,6 @@ func NewOpenAIClient() *gogpt.Client {
 		}
 	}
 
-	config.BaseURL = OpenAIAPIBaseURL
+	config.BaseURL = apiBaseURL
 	return gogpt.NewClientWithConfig(config)
 }
-
-func initOpenAI() {
-	OpenAIAPIKey = os.Getenv("SIYUAN_OPENAI_API_KEY")
-	if "" == OpenAIAPIKey {
-		return
-	}
-
-	timeout := os.Getenv("SIYUAN_OPENAI_API_TIMEOUT")
-	if "" != timeout {
-		timeoutInt, err := strconv.Atoi(timeout)
-		if nil == err {
-			OpenAIAPITimeout = time.Duration(timeoutInt) * time.Second
-		}
-	}
-
-	proxy := os.Getenv("SIYUAN_OPENAI_API_PROXY")
-	if "" != proxy {
-		OpenAIAPIProxy = proxy
-	}
-
-	maxTokens := os.Getenv("SIYUAN_OPENAI_API_MAX_TOKENS")
-	if "" != maxTokens {
-		maxTokensInt, err := strconv.Atoi(maxTokens)
-		if nil == err {
-			OpenAIAPIMaxTokens = maxTokensInt
-		}
-	}
-
-	baseURL := os.Getenv("SIYUAN_OPENAI_API_BASE_URL")
-	if "" != baseURL {
-		OpenAIAPIBaseURL = baseURL
-	}
-
-	logging.LogInfof("OpenAI API enabled\n"+
-		"    baseURL=%s\n"+
-		"    timeout=%ds\n"+
-		"    proxy=%s\n"+
-		"    maxTokens=%d",
-		OpenAIAPIBaseURL, int(OpenAIAPITimeout.Seconds()), OpenAIAPIProxy, OpenAIAPIMaxTokens)
-}

+ 0 - 2
kernel/util/working.go

@@ -118,8 +118,6 @@ func Boot() {
 	bootBanner := figure.NewColorFigure("SiYuan", "isometric3", "green", true)
 	logging.LogInfof("\n" + bootBanner.String())
 	logBootInfo()
-
-	initOpenAI()
 }
 
 func setBootDetails(details string) {