Browse Source

:art: Allow GPT Model selection in UI https://github.com/siyuan-note/siyuan/issues/8142

Liang Ding 2 years ago
parent
commit
7daafcb3af

+ 2 - 0
app/appearance/langs/en_US.json

@@ -35,6 +35,8 @@
   "apiTimeoutTip": "The timeout period for initiating a request, unit: second",
   "apiTimeoutTip": "The timeout period for initiating a request, unit: second",
   "apiProxy": "Network Proxy",
   "apiProxy": "Network Proxy",
   "apiProxyTip": "The network proxy that initiates the request, such as <code class='fn__code'>socks://127.0.0.1:1080</code>",
   "apiProxyTip": "The network proxy that initiates the request, such as <code class='fn__code'>socks://127.0.0.1:1080</code>",
+  "apiModel": "Model",
+  "apiModelTip": "The <code class='fn__code'>model</code> parameter passed in when requesting the API is used to control the style of the generated text",
   "apiMaxTokens": "Maximum number of Tokens",
   "apiMaxTokens": "Maximum number of Tokens",
   "apiMaxTokensTip": "The <code class='fn__code'>max_tokens</code> parameter passed in when requesting the API is used to control the length of the generated text",
   "apiMaxTokensTip": "The <code class='fn__code'>max_tokens</code> parameter passed in when requesting the API is used to control the length of the generated text",
   "apiBaseURL": "API Base URL",
   "apiBaseURL": "API Base URL",

+ 2 - 0
app/appearance/langs/es_ES.json

@@ -35,6 +35,8 @@
   "apiTimeoutTip": "El tiempo de espera para iniciar una solicitud, unidad: segundo",
   "apiTimeoutTip": "El tiempo de espera para iniciar una solicitud, unidad: segundo",
   "apiProxy": "Proxy web",
   "apiProxy": "Proxy web",
   "apiProxyTip": "El proxy de red que inicia la solicitud, como <code class='fn__code'>socks://127.0.0.1:1080</code>",
   "apiProxyTip": "El proxy de red que inicia la solicitud, como <code class='fn__code'>socks://127.0.0.1:1080</code>",
+  "apiModel": "Modelo",
+  "apiModelTip": "El parámetro <code class='fn__code'>model</code> pasado al solicitar la API se usa para controlar el estilo del texto generado",
   "apiMaxTokens": "Número máximo de tokens",
   "apiMaxTokens": "Número máximo de tokens",
   "apiMaxTokensTip": "El parámetro <code class='fn__code'>max_tokens</code> que se pasa al solicitar la API se usa para controlar la longitud del texto generado",
   "apiMaxTokensTip": "El parámetro <code class='fn__code'>max_tokens</code> que se pasa al solicitar la API se usa para controlar la longitud del texto generado",
   "apiBaseURL": "URL base de la API",
   "apiBaseURL": "URL base de la API",

+ 2 - 0
app/appearance/langs/fr_FR.json

@@ -35,6 +35,8 @@
   "apiTimeoutTip": "Le délai d'attente pour lancer une requête, unité : seconde",
   "apiTimeoutTip": "Le délai d'attente pour lancer une requête, unité : seconde",
   "apiProxy": "Proxy Web",
   "apiProxy": "Proxy Web",
   "apiProxyTip": "Le proxy réseau qui lance la requête, tel que <code class='fn__code'>socks://127.0.0.1:1080</code>",
   "apiProxyTip": "Le proxy réseau qui lance la requête, tel que <code class='fn__code'>socks://127.0.0.1:1080</code>",
+  "apiModel": "Modelo",
+  "apiModelTip": "El parámetro <code class='fn__code'>model</code> pasado al solicitar la API se usa para controlar el estilo del texto generado",
   "apiMaxTokens": "Nombre maximum de jetons",
   "apiMaxTokens": "Nombre maximum de jetons",
   "apiMaxTokensTip": "Le paramètre <code class='fn__code'>max_tokens</code> transmis lors de la demande de l'API est utilisé pour contrôler la longueur du texte généré",
   "apiMaxTokensTip": "Le paramètre <code class='fn__code'>max_tokens</code> transmis lors de la demande de l'API est utilisé pour contrôler la longueur du texte généré",
   "apiBaseURL": "URL de base de l'API",
   "apiBaseURL": "URL de base de l'API",

+ 2 - 0
app/appearance/langs/zh_CHT.json

@@ -35,6 +35,8 @@
   "apiTimeoutTip": "發起請求的超時時間,單位:秒",
   "apiTimeoutTip": "發起請求的超時時間,單位:秒",
   "apiProxy": "網絡代理",
   "apiProxy": "網絡代理",
   "apiProxyTip": "發起請求的網絡代理,如 <code class='fn__code'>socks://127.0.0.1:1080</code>",
   "apiProxyTip": "發起請求的網絡代理,如 <code class='fn__code'>socks://127.0.0.1:1080</code>",
+  "apiModel": "模型",
+  "apiModelTip": "請求 API 時傳入的 <code class='fn__code'>model</code> 參數,用於控制生成的文本風格",
   "apiMaxTokens": "最大 Token 數",
   "apiMaxTokens": "最大 Token 數",
   "apiMaxTokensTip": "請求 API 時傳入的 <code class='fn__code'>max_tokens</code> 參數,用於控制生成的文本長度",
   "apiMaxTokensTip": "請求 API 時傳入的 <code class='fn__code'>max_tokens</code> 參數,用於控制生成的文本長度",
   "apiBaseURL": "API 基礎地址",
   "apiBaseURL": "API 基礎地址",

+ 2 - 0
app/appearance/langs/zh_CN.json

@@ -35,6 +35,8 @@
   "apiTimeoutTip": "发起请求的超时时间,单位:秒",
   "apiTimeoutTip": "发起请求的超时时间,单位:秒",
   "apiProxy": "网络代理",
   "apiProxy": "网络代理",
   "apiProxyTip": "发起请求的网络代理,如 <code class='fn__code'>socks://127.0.0.1:1080</code>",
   "apiProxyTip": "发起请求的网络代理,如 <code class='fn__code'>socks://127.0.0.1:1080</code>",
+  "apiModel": "模型",
+  "apiModelTip": "请求 API 时传入的 <code class='fn__code'>model</code> 参数,用于控制生成的文本风格",
   "apiMaxTokens": "最大 Token 数",
   "apiMaxTokens": "最大 Token 数",
   "apiMaxTokensTip": "请求 API 时传入的 <code class='fn__code'>max_tokens</code> 参数,用于控制生成的文本长度",
   "apiMaxTokensTip": "请求 API 时传入的 <code class='fn__code'>max_tokens</code> 参数,用于控制生成的文本长度",
   "apiBaseURL": "API 基础地址",
   "apiBaseURL": "API 基础地址",

+ 27 - 1
app/src/config/ai.ts

@@ -11,6 +11,19 @@ export const ai = {
     <input class="b3-text-field fn__flex-center fn__block" type="number" step="1" min="5" max="600" id="apiTimeout" value="${window.siyuan.config.ai.openAI.apiTimeout}"/>     
     <input class="b3-text-field fn__flex-center fn__block" type="number" step="1" min="5" max="600" id="apiTimeout" value="${window.siyuan.config.ai.openAI.apiTimeout}"/>     
     <div class="b3-label__text">${window.siyuan.languages.apiTimeoutTip}</div>
     <div class="b3-label__text">${window.siyuan.languages.apiTimeoutTip}</div>
 </div>
 </div>
+<div class="b3-label">
+    ${window.siyuan.languages.apiModel}
+    <div class="b3-label__text">
+        ${window.siyuan.languages.apiModelTip}
+    </div>
+    <div class="b3-label__text fn__flex config__item" style="padding: 4px 0 4px 4px;">
+        <select id="apiModel" class="b3-select">
+            <option value="gpt-4" ${window.siyuan.config.ai.openAI.apiModel === "gpt-4" ? "selected" : ""}>gpt-4</option>
+            <option value="gpt-4-32k" ${window.siyuan.config.ai.openAI.apiModel === "gpt-4-32k" ? "selected" : ""}>gpt-4-32k</option>
+            <option value="gpt-3.5-turbo" ${window.siyuan.config.ai.openAI.apiModel === "gpt-3.5-turbo" ? "selected" : ""}>gpt-3.5-turbo</option>
+        </select>
+    </div>
+</div>
 <div class="b3-label">
 <div class="b3-label">
     ${window.siyuan.languages.apiMaxTokens}
     ${window.siyuan.languages.apiMaxTokens}
     <div class="fn__hr"></div>
     <div class="fn__hr"></div>
@@ -44,6 +57,18 @@ export const ai = {
     <span class="fn__space"></span>
     <span class="fn__space"></span>
     <input class="b3-text-field fn__flex-center fn__size200" type="number" step="1" min="5" max="600" id="apiTimeout" value="${window.siyuan.config.ai.openAI.apiTimeout}"/>
     <input class="b3-text-field fn__flex-center fn__size200" type="number" step="1" min="5" max="600" id="apiTimeout" value="${window.siyuan.config.ai.openAI.apiTimeout}"/>
 </label>
 </label>
+<label class="fn__flex b3-label config__item">
+    <div class="fn__flex-1">
+        ${window.siyuan.languages.apiModel}
+        <div class="b3-label__text">${window.siyuan.languages.apiModelTip}</div>
+    </div>
+    <span class="fn__space"></span>
+    <select id="apiModel" class="b3-select fn__flex-center fn__size200">
+        <option value="gpt-4" ${window.siyuan.config.ai.openAI.apiModel === "gpt-4" ? "selected" : ""}>gpt-4</option>
+        <option value="gpt-4-32k" ${window.siyuan.config.ai.openAI.apiModel === "gpt-4-32k" ? "selected" : ""}>gpt-4-32k</option>
+        <option value="gpt-3.5-turbo" ${window.siyuan.config.ai.openAI.apiModel === "gpt-3.5-turbo" ? "selected" : ""}>gpt-3.5-turbo</option>
+    </select>
+</label>
 <label class="fn__flex b3-label">
 <label class="fn__flex b3-label">
     <div class="fn__flex-1">
     <div class="fn__flex-1">
         ${window.siyuan.languages.apiMaxTokens}
         ${window.siyuan.languages.apiMaxTokens}
@@ -89,12 +114,13 @@ export const ai = {
 </div>`;
 </div>`;
     },
     },
     bindEvent: () => {
     bindEvent: () => {
-        ai.element.querySelectorAll("input").forEach((item) => {
+        ai.element.querySelectorAll("input,select").forEach((item) => {
             item.addEventListener("change", () => {
             item.addEventListener("change", () => {
                 fetchPost("/api/setting/setAI", {
                 fetchPost("/api/setting/setAI", {
                     openAI: {
                     openAI: {
                         apiBaseURL: (ai.element.querySelector("#apiBaseURL") as HTMLInputElement).value,
                         apiBaseURL: (ai.element.querySelector("#apiBaseURL") as HTMLInputElement).value,
                         apiKey: (ai.element.querySelector("#apiKey") as HTMLInputElement).value,
                         apiKey: (ai.element.querySelector("#apiKey") as HTMLInputElement).value,
+                        apiModel: (ai.element.querySelector("#apiModel") as HTMLSelectElement).value,
                         apiMaxTokens: parseInt((ai.element.querySelector("#apiMaxTokens") as HTMLInputElement).value),
                         apiMaxTokens: parseInt((ai.element.querySelector("#apiMaxTokens") as HTMLInputElement).value),
                         apiProxy: (ai.element.querySelector("#apiProxy") as HTMLInputElement).value,
                         apiProxy: (ai.element.querySelector("#apiProxy") as HTMLInputElement).value,
                         apiTimeout: parseInt((ai.element.querySelector("#apiTimeout") as HTMLInputElement).value),
                         apiTimeout: parseInt((ai.element.querySelector("#apiTimeout") as HTMLInputElement).value),

+ 1 - 0
app/src/types/index.d.ts

@@ -424,6 +424,7 @@ declare interface IConfig {
         openAI: {
         openAI: {
             apiBaseURL: string
             apiBaseURL: string
             apiKey: string
             apiKey: string
+            apiModel: string
             apiMaxTokens: number
             apiMaxTokens: number
             apiProxy: string
             apiProxy: string
             apiTimeout: number
             apiTimeout: number

+ 1 - 0
kernel/conf/ai.go

@@ -29,6 +29,7 @@ type OpenAI struct {
 	APIKey       string `json:"apiKey"`
 	APIKey       string `json:"apiKey"`
 	APITimeout   int    `json:"apiTimeout"`
 	APITimeout   int    `json:"apiTimeout"`
 	APIProxy     string `json:"apiProxy"`
 	APIProxy     string `json:"apiProxy"`
+	APIModel     string `json:"apiModel"`
 	APIMaxTokens int    `json:"apiMaxTokens"`
 	APIMaxTokens int    `json:"apiMaxTokens"`
 	APIBaseURL   string `json:"apiBaseURL"`
 	APIBaseURL   string `json:"apiBaseURL"`
 }
 }

+ 1 - 1
kernel/go.mod

@@ -42,7 +42,7 @@ require (
 	github.com/panjf2000/ants/v2 v2.7.3
 	github.com/panjf2000/ants/v2 v2.7.3
 	github.com/patrickmn/go-cache v2.1.0+incompatible
 	github.com/patrickmn/go-cache v2.1.0+incompatible
 	github.com/radovskyb/watcher v1.0.7
 	github.com/radovskyb/watcher v1.0.7
-	github.com/sashabaranov/go-gpt3 v1.4.0
+	github.com/sashabaranov/go-openai v1.9.0
 	github.com/shirou/gopsutil/v3 v3.23.2
 	github.com/shirou/gopsutil/v3 v3.23.2
 	github.com/siyuan-note/dejavu v0.0.0-20230425070132-9eeaf90cb5ba
 	github.com/siyuan-note/dejavu v0.0.0-20230425070132-9eeaf90cb5ba
 	github.com/siyuan-note/encryption v0.0.0-20220713091850-5ecd92177b75
 	github.com/siyuan-note/encryption v0.0.0-20220713091850-5ecd92177b75

+ 2 - 2
kernel/go.sum

@@ -267,8 +267,8 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po
 github.com/rwtodd/Go.Sed v0.0.0-20210816025313-55464686f9ef/go.mod h1:8AEUvGVi2uQ5b24BIhcr0GCcpd/RNAFWaN2CJFrWIIQ=
 github.com/rwtodd/Go.Sed v0.0.0-20210816025313-55464686f9ef/go.mod h1:8AEUvGVi2uQ5b24BIhcr0GCcpd/RNAFWaN2CJFrWIIQ=
 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI=
 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI=
 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06/go.mod h1:+ePHsJ1keEjQtpvf9HHw0f4ZeJ0TLRsxhunSI2hYJSs=
 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06/go.mod h1:+ePHsJ1keEjQtpvf9HHw0f4ZeJ0TLRsxhunSI2hYJSs=
-github.com/sashabaranov/go-gpt3 v1.4.0 h1:UqHYdXgJNtNvTtbzDnnQgkQ9TgTnHtCXx966uFTYXvU=
-github.com/sashabaranov/go-gpt3 v1.4.0/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
+github.com/sashabaranov/go-openai v1.9.0 h1:NoiO++IISxxJ1pRc0n7uZvMGMake0G+FJ1XPwXtprsA=
+github.com/sashabaranov/go-openai v1.9.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
 github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg=
 github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg=
 github.com/shirou/gopsutil/v3 v3.23.2 h1:PAWSuiAszn7IhPMBtXsbSCafej7PqUOvY6YywlQUExU=
 github.com/shirou/gopsutil/v3 v3.23.2 h1:PAWSuiAszn7IhPMBtXsbSCafej7PqUOvY6YywlQUExU=
 github.com/shirou/gopsutil/v3 v3.23.2/go.mod h1:gv0aQw33GLo3pG8SiWKiQrbDzbRY1K80RyZJ7V4Th1M=
 github.com/shirou/gopsutil/v3 v3.23.2/go.mod h1:gv0aQw33GLo3pG8SiWKiQrbDzbRY1K80RyZJ7V4Th1M=

+ 3 - 3
kernel/model/ai.go

@@ -22,7 +22,7 @@ import (
 
 
 	"github.com/88250/lute/ast"
 	"github.com/88250/lute/ast"
 	"github.com/88250/lute/parse"
 	"github.com/88250/lute/parse"
-	gogpt "github.com/sashabaranov/go-gpt3"
+	"github.com/sashabaranov/go-openai"
 	"github.com/siyuan-note/siyuan/kernel/treenode"
 	"github.com/siyuan-note/siyuan/kernel/treenode"
 	"github.com/siyuan-note/siyuan/kernel/util"
 	"github.com/siyuan-note/siyuan/kernel/util"
 )
 )
@@ -154,11 +154,11 @@ type GPT interface {
 }
 }
 
 
 type OpenAIGPT struct {
 type OpenAIGPT struct {
-	c *gogpt.Client
+	c *openai.Client
 }
 }
 
 
 func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
 func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
-	return util.ChatGPT(msg, contextMsgs, gpt.c, Conf.AI.OpenAI.APIMaxTokens, Conf.AI.OpenAI.APITimeout)
+	return util.ChatGPT(msg, contextMsgs, gpt.c, Conf.AI.OpenAI.APIModel, Conf.AI.OpenAI.APIMaxTokens, Conf.AI.OpenAI.APITimeout)
 }
 }
 
 
 type CloudGPT struct {
 type CloudGPT struct {

+ 6 - 1
kernel/model/conf.go

@@ -19,6 +19,7 @@ package model
 import (
 import (
 	"bytes"
 	"bytes"
 	"fmt"
 	"fmt"
+	"github.com/sashabaranov/go-openai"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
@@ -337,14 +338,18 @@ func InitConf() {
 	if nil == Conf.AI {
 	if nil == Conf.AI {
 		Conf.AI = conf.NewAI()
 		Conf.AI = conf.NewAI()
 	}
 	}
+	if "" == Conf.AI.OpenAI.APIModel {
+		Conf.AI.OpenAI.APIModel = openai.GPT4
+	}
 
 
 	if "" != Conf.AI.OpenAI.APIKey {
 	if "" != Conf.AI.OpenAI.APIKey {
 		logging.LogInfof("OpenAI API enabled\n"+
 		logging.LogInfof("OpenAI API enabled\n"+
 			"    baseURL=%s\n"+
 			"    baseURL=%s\n"+
 			"    timeout=%ds\n"+
 			"    timeout=%ds\n"+
 			"    proxy=%s\n"+
 			"    proxy=%s\n"+
+			"    model=%s\n"+
 			"    maxTokens=%d",
 			"    maxTokens=%d",
-			Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APITimeout, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIMaxTokens)
+			Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APITimeout, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIModel, Conf.AI.OpenAI.APIMaxTokens)
 	}
 	}
 
 
 	Conf.ReadOnly = util.ReadOnly
 	Conf.ReadOnly = util.ReadOnly

+ 12 - 11
kernel/util/openai.go

@@ -18,30 +18,31 @@ package util
 
 
 import (
 import (
 	"context"
 	"context"
-	gogpt "github.com/sashabaranov/go-gpt3"
-	"github.com/siyuan-note/logging"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 	"strings"
 	"strings"
 	"time"
 	"time"
+
+	"github.com/sashabaranov/go-openai"
+	"github.com/siyuan-note/logging"
 )
 )
 
 
-func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client, maxTokens, timeout int) (ret string, stop bool, err error) {
-	var reqMsgs []gogpt.ChatCompletionMessage
+func ChatGPT(msg string, contextMsgs []string, c *openai.Client, model string, maxTokens, timeout int) (ret string, stop bool, err error) {
+	var reqMsgs []openai.ChatCompletionMessage
 
 
 	for _, ctxMsg := range contextMsgs {
 	for _, ctxMsg := range contextMsgs {
-		reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
+		reqMsgs = append(reqMsgs, openai.ChatCompletionMessage{
 			Role:    "user",
 			Role:    "user",
 			Content: ctxMsg,
 			Content: ctxMsg,
 		})
 		})
 	}
 	}
-	reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
+	reqMsgs = append(reqMsgs, openai.ChatCompletionMessage{
 		Role:    "user",
 		Role:    "user",
 		Content: msg,
 		Content: msg,
 	})
 	})
 
 
-	req := gogpt.ChatCompletionRequest{
-		Model:     gogpt.GPT3Dot5Turbo,
+	req := openai.ChatCompletionRequest{
+		Model:     model,
 		MaxTokens: maxTokens,
 		MaxTokens: maxTokens,
 		Messages:  reqMsgs,
 		Messages:  reqMsgs,
 	}
 	}
@@ -74,8 +75,8 @@ func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client, maxTokens, timeo
 	return
 	return
 }
 }
 
 
-func NewOpenAIClient(apiKey, apiProxy, apiBaseURL string) *gogpt.Client {
-	config := gogpt.DefaultConfig(apiKey)
+func NewOpenAIClient(apiKey, apiProxy, apiBaseURL string) *openai.Client {
+	config := openai.DefaultConfig(apiKey)
 	if "" != apiProxy {
 	if "" != apiProxy {
 		proxyUrl, err := url.Parse(apiProxy)
 		proxyUrl, err := url.Parse(apiProxy)
 		if nil != err {
 		if nil != err {
@@ -86,5 +87,5 @@ func NewOpenAIClient(apiKey, apiProxy, apiBaseURL string) *gogpt.Client {
 	}
 	}
 
 
 	config.BaseURL = apiBaseURL
 	config.BaseURL = apiBaseURL
-	return gogpt.NewClientWithConfig(config)
+	return openai.NewClientWithConfig(config)
 }
 }