|
@@ -18,26 +18,15 @@ package util
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
+ gogpt "github.com/sashabaranov/go-gpt3"
|
|
|
+ "github.com/siyuan-note/logging"
|
|
|
"net/http"
|
|
|
"net/url"
|
|
|
- "os"
|
|
|
- "strconv"
|
|
|
"strings"
|
|
|
"time"
|
|
|
-
|
|
|
- gogpt "github.com/sashabaranov/go-gpt3"
|
|
|
- "github.com/siyuan-note/logging"
|
|
|
)
|
|
|
|
|
|
-var (
|
|
|
- OpenAIAPIKey = ""
|
|
|
- OpenAIAPITimeout = 30 * time.Second
|
|
|
- OpenAIAPIProxy = ""
|
|
|
- OpenAIAPIMaxTokens = 0
|
|
|
- OpenAIAPIBaseURL = "https://api.openai.com/v1"
|
|
|
-)
|
|
|
-
|
|
|
-func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool, err error) {
|
|
|
+func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client, maxTokens, timeout int) (ret string, stop bool, err error) {
|
|
|
var reqMsgs []gogpt.ChatCompletionMessage
|
|
|
|
|
|
for _, ctxMsg := range contextMsgs {
|
|
@@ -53,10 +42,10 @@ func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, sto
|
|
|
|
|
|
req := gogpt.ChatCompletionRequest{
|
|
|
Model: gogpt.GPT3Dot5Turbo,
|
|
|
- MaxTokens: OpenAIAPIMaxTokens,
|
|
|
+ MaxTokens: maxTokens,
|
|
|
Messages: reqMsgs,
|
|
|
}
|
|
|
- ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout)
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
|
|
defer cancel()
|
|
|
resp, err := c.CreateChatCompletion(ctx, req)
|
|
|
if nil != err {
|
|
@@ -85,10 +74,10 @@ func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, sto
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func NewOpenAIClient() *gogpt.Client {
|
|
|
- config := gogpt.DefaultConfig(OpenAIAPIKey)
|
|
|
- if "" != OpenAIAPIProxy {
|
|
|
- proxyUrl, err := url.Parse(OpenAIAPIProxy)
|
|
|
+func NewOpenAIClient(apiKey, apiProxy, apiBaseURL string) *gogpt.Client {
|
|
|
+ config := gogpt.DefaultConfig(apiKey)
|
|
|
+ if "" != apiProxy {
|
|
|
+ proxyUrl, err := url.Parse(apiProxy)
|
|
|
if nil != err {
|
|
|
logging.LogErrorf("OpenAI API proxy failed: %v", err)
|
|
|
} else {
|
|
@@ -96,46 +85,6 @@ func NewOpenAIClient() *gogpt.Client {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- config.BaseURL = OpenAIAPIBaseURL
|
|
|
+ config.BaseURL = apiBaseURL
|
|
|
return gogpt.NewClientWithConfig(config)
|
|
|
}
|
|
|
-
|
|
|
-func initOpenAI() {
|
|
|
- OpenAIAPIKey = os.Getenv("SIYUAN_OPENAI_API_KEY")
|
|
|
- if "" == OpenAIAPIKey {
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- timeout := os.Getenv("SIYUAN_OPENAI_API_TIMEOUT")
|
|
|
- if "" != timeout {
|
|
|
- timeoutInt, err := strconv.Atoi(timeout)
|
|
|
- if nil == err {
|
|
|
- OpenAIAPITimeout = time.Duration(timeoutInt) * time.Second
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- proxy := os.Getenv("SIYUAN_OPENAI_API_PROXY")
|
|
|
- if "" != proxy {
|
|
|
- OpenAIAPIProxy = proxy
|
|
|
- }
|
|
|
-
|
|
|
- maxTokens := os.Getenv("SIYUAN_OPENAI_API_MAX_TOKENS")
|
|
|
- if "" != maxTokens {
|
|
|
- maxTokensInt, err := strconv.Atoi(maxTokens)
|
|
|
- if nil == err {
|
|
|
- OpenAIAPIMaxTokens = maxTokensInt
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- baseURL := os.Getenv("SIYUAN_OPENAI_API_BASE_URL")
|
|
|
- if "" != baseURL {
|
|
|
- OpenAIAPIBaseURL = baseURL
|
|
|
- }
|
|
|
-
|
|
|
- logging.LogInfof("OpenAI API enabled\n"+
|
|
|
- " baseURL=%s\n"+
|
|
|
- " timeout=%ds\n"+
|
|
|
- " proxy=%s\n"+
|
|
|
- " maxTokens=%d",
|
|
|
- OpenAIAPIBaseURL, int(OpenAIAPITimeout.Seconds()), OpenAIAPIProxy, OpenAIAPIMaxTokens)
|
|
|
-}
|