瀏覽代碼

:art: Support custom AI request User-Agent header https://github.com/siyuan-note/siyuan/issues/10351

Daniel 1 年之前
父節點
當前提交
7e87c8c8ad

+ 1 - 0
app/appearance/langs/en_US.json

@@ -273,6 +273,7 @@
   "apiMaxTokensTip": "The <code class='fn__code'>max_tokens</code> parameter passed in when requesting the API is used to control the length of the generated text",
   "apiBaseURL": "API Base URL",
   "apiBaseURLTip": "The base address of the request, such as <code class='fn__code'>https://api.openai.com/v1</code>",
+  "apiUserAgentTip": "The user agent that initiated the request, that is, the HTTP header <code class='fn__code'>User-Agent</code>",
   "skip": "Skip",
   "nextRound": "Next round",
   "save": "Save",

+ 1 - 0
app/appearance/langs/es_ES.json

@@ -273,6 +273,7 @@
   "apiMaxTokensTip": "El parámetro <code class='fn__code'>max_tokens</code> que se pasa al solicitar la API se usa para controlar la longitud del texto generado",
   "apiBaseURL": "URL base de la API",
   "apiBaseURLTip": "La dirección base de la solicitud, como <code class='fn__code'>https://api.openai.com/v1</code>",
+  "apiUserAgentTip": "El agente de usuario que inició la solicitud, es decir, el encabezado HTTP <code class='fn__code'>User-Agent</code>",
   "skip": "barco",
   "nextRound": "Siguiente ronda",
   "save": "Ahorrar",

+ 1 - 0
app/appearance/langs/fr_FR.json

@@ -273,6 +273,7 @@
   "apiMaxTokensTip": "Le paramètre <code class='fn__code'>max_tokens</code> transmis lors de la demande de l'API est utilisé pour contrôler la longueur du texte généré",
   "apiBaseURL": "URL de base de l'API",
   "apiBaseURLTip": "L'adresse de base de la requête, telle que <code class='fn__code'>https://api.openai.com/v1</code>",
+  "apiUserAgentTip": "L'agent utilisateur qui a initié la requête, c'est-à-dire l'en-tête HTTP <code class='fn__code'>User-Agent</code>",
   "skip": "Navire",
   "nextRound": "Prochain tour",
   "save": "Sauvegarder",

+ 1 - 0
app/appearance/langs/zh_CHT.json

@@ -273,6 +273,7 @@
   "apiMaxTokensTip": "請求 API 時傳入的 <code class='fn__code'>max_tokens</code> 參數,用於控制生成的文字長度",
   "apiBaseURL": "API 基礎地址",
   "apiBaseURLTip": "發起請求的基礎地址,如 <code class='fn__code'>https://api.openai.com/v1</code>",
+  "apiUserAgentTip": "發起請求的使用者代理,即 HTTP 標頭 <code class='fn__code'>User-Agent</code>",
   "skip": "跳過",
   "nextRound": "下一輪",
   "save": "保存",

+ 1 - 0
app/appearance/langs/zh_CN.json

@@ -273,6 +273,7 @@
   "apiMaxTokensTip": "请求 API 时传入的 <code class='fn__code'>max_tokens</code> 参数,用于控制生成的文本长度",
   "apiBaseURL": "API 基础地址",
   "apiBaseURLTip": "发起请求的基础地址,如 <code class='fn__code'>https://api.openai.com/v1</code>",
+  "apiUserAgentTip": "发起请求的用户代理,即 HTTP 标头 <code class='fn__code'>User-Agent</code>",
   "skip": "跳过",
   "nextRound": "下一轮",
   "save": "保存",

+ 15 - 0
app/src/config/ai.ts

@@ -50,6 +50,12 @@ export const ai = {
     <div class="fn__hr"></div>
     <input class="b3-text-field fn__block" id="apiBaseURL" value="${window.siyuan.config.ai.openAI.apiBaseURL}"/>
     <div class="b3-label__text">${window.siyuan.languages.apiBaseURLTip}</div>
+</div>
+<div class="b3-label">
+    User-Agent
+    <div class="fn__hr"></div>
+    <input class="b3-text-field fn__block" id="apiUserAgent" value="${window.siyuan.config.ai.openAI.apiUserAgent}"/>
+    <div class="b3-label__text">${window.siyuan.languages.apiUserAgentTip}</div>
 </div>`;
         /// #else
         responsiveHTML = `<div class="fn__flex b3-label">
@@ -106,6 +112,14 @@ export const ai = {
         <span class="fn__hr"></span>
         <input class="b3-text-field fn__block" id="apiBaseURL" value="${window.siyuan.config.ai.openAI.apiBaseURL}"/>
     </div>
+</div>
+<div class="fn__flex b3-label">
+    <div class="fn__block">
+        User-Agent
+        <div class="b3-label__text">${window.siyuan.languages.apiUserAgentTip}</div>
+        <span class="fn__hr"></span>
+        <input class="b3-text-field fn__block" id="apiUserAgent" value="${window.siyuan.config.ai.openAI.apiUserAgent}"/>
+    </div>
 </div>`;
         /// #endif
         return `<div class="fn__flex-column" style="height: 100%">
@@ -124,6 +138,7 @@ export const ai = {
             item.addEventListener("change", () => {
                 fetchPost("/api/setting/setAI", {
                     openAI: {
+                        apiUserAgent: (ai.element.querySelector("#apiUserAgent") as HTMLInputElement).value,
                         apiBaseURL: (ai.element.querySelector("#apiBaseURL") as HTMLInputElement).value,
                         apiKey: (ai.element.querySelector("#apiKey") as HTMLInputElement).value,
                         apiModel: (ai.element.querySelector("#apiModel") as HTMLSelectElement).value,

+ 1 - 0
app/src/types/index.d.ts

@@ -722,6 +722,7 @@ interface IConfig {
     }
     ai: {
         openAI: {
+            apiUserAgent: string
             apiBaseURL: string
             apiKey: string
             apiModel: string

+ 9 - 3
kernel/conf/ai.go

@@ -17,6 +17,7 @@
 package conf
 
 import (
+	"github.com/siyuan-note/siyuan/kernel/util"
 	"os"
 	"strconv"
 
@@ -34,13 +35,15 @@ type OpenAI struct {
 	APIModel     string `json:"apiModel"`
 	APIMaxTokens int    `json:"apiMaxTokens"`
 	APIBaseURL   string `json:"apiBaseURL"`
+	APIUserAgent string `json:"apiUserAgent"`
 }
 
 func NewAI() *AI {
 	openAI := &OpenAI{
-		APITimeout: 30,
-		APIModel:   openai.GPT3Dot5Turbo,
-		APIBaseURL: "https://api.openai.com/v1",
+		APITimeout:   30,
+		APIModel:     openai.GPT3Dot5Turbo,
+		APIBaseURL:   "https://api.openai.com/v1",
+		APIUserAgent: util.UserAgent,
 	}
 
 	openAI.APIKey = os.Getenv("SIYUAN_OPENAI_API_KEY")
@@ -67,5 +70,8 @@ func NewAI() *AI {
 		openAI.APIBaseURL = baseURL
 	}
 
+	if userAgent := os.Getenv("SIYUAN_OPENAI_API_USER_AGENT"); "" != userAgent {
+		openAI.APIUserAgent = userAgent
+	}
 	return &AI{OpenAI: openAI}
 }

+ 1 - 1
kernel/model/ai.go

@@ -92,7 +92,7 @@ func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret str
 	if cloud {
 		gpt = &CloudGPT{}
 	} else {
-		gpt = &OpenAIGPT{c: util.NewOpenAIClient(Conf.AI.OpenAI.APIKey, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIBaseURL)}
+		gpt = &OpenAIGPT{c: util.NewOpenAIClient(Conf.AI.OpenAI.APIKey, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APIUserAgent)}
 	}
 
 	buf := &bytes.Buffer{}

+ 10 - 1
kernel/model/conf.go

@@ -405,15 +405,24 @@ func InitConf() {
 	if "" == Conf.AI.OpenAI.APIModel {
 		Conf.AI.OpenAI.APIModel = openai.GPT3Dot5Turbo
 	}
+	if "" == Conf.AI.OpenAI.APIUserAgent {
+		Conf.AI.OpenAI.APIUserAgent = util.UserAgent
+	}
 
 	if "" != Conf.AI.OpenAI.APIKey {
 		logging.LogInfof("OpenAI API enabled\n"+
+			"    userAgent=%s\n"+
 			"    baseURL=%s\n"+
 			"    timeout=%ds\n"+
 			"    proxy=%s\n"+
 			"    model=%s\n"+
 			"    maxTokens=%d",
-			Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APITimeout, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIModel, Conf.AI.OpenAI.APIMaxTokens)
+			Conf.AI.OpenAI.APIUserAgent,
+			Conf.AI.OpenAI.APIBaseURL,
+			Conf.AI.OpenAI.APITimeout,
+			Conf.AI.OpenAI.APIProxy,
+			Conf.AI.OpenAI.APIModel,
+			Conf.AI.OpenAI.APIMaxTokens)
 	}
 
 	Conf.ReadOnly = util.ReadOnly

+ 18 - 3
kernel/util/openai.go

@@ -75,17 +75,32 @@ func ChatGPT(msg string, contextMsgs []string, c *openai.Client, model string, m
 	return
 }
 
-func NewOpenAIClient(apiKey, apiProxy, apiBaseURL string) *openai.Client {
+func NewOpenAIClient(apiKey, apiProxy, apiBaseURL, apiUserAgent string) *openai.Client {
 	config := openai.DefaultConfig(apiKey)
+	transport := &http.Transport{}
 	if "" != apiProxy {
 		proxyUrl, err := url.Parse(apiProxy)
 		if nil != err {
 			logging.LogErrorf("OpenAI API proxy failed: %v", err)
 		} else {
-			config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
+			transport.Proxy = http.ProxyURL(proxyUrl)
 		}
 	}
-
+	config.HTTPClient = &http.Client{Transport: newAddHeaderTransport(transport, apiUserAgent)}
 	config.BaseURL = apiBaseURL
 	return openai.NewClientWithConfig(config)
 }
+
+type AddHeaderTransport struct {
+	RoundTripper http.RoundTripper
+	UserAgent    string
+}
+
+func (adt *AddHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+	req.Header.Add("User-Agent", adt.UserAgent)
+	return adt.RoundTripper.RoundTrip(req)
+}
+
+func newAddHeaderTransport(transport *http.Transport, userAgent string) *AddHeaderTransport {
+	return &AddHeaderTransport{RoundTripper: transport, UserAgent: userAgent}
+}