openai.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. // SiYuan - Build Your Eternal Digital Garden
  2. // Copyright (c) 2020-present, b3log.org
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package util
  17. import (
  18. "context"
  19. "net/http"
  20. "net/url"
  21. "os"
  22. "strconv"
  23. "strings"
  24. "time"
  25. gogpt "github.com/sashabaranov/go-gpt3"
  26. "github.com/siyuan-note/logging"
  27. )
  28. var (
  29. OpenAIAPIKey = ""
  30. OpenAIAPITimeout = 30 * time.Second
  31. OpenAIAPIProxy = ""
  32. OpenAIAPIMaxTokens = 0
  33. OpenAIAPIBaseURL = "https://api.openai.com/v1"
  34. )
  35. func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool, err error) {
  36. var reqMsgs []gogpt.ChatCompletionMessage
  37. for _, ctxMsg := range contextMsgs {
  38. reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
  39. Role: "user",
  40. Content: ctxMsg,
  41. })
  42. }
  43. reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
  44. Role: "user",
  45. Content: msg,
  46. })
  47. req := gogpt.ChatCompletionRequest{
  48. Model: gogpt.GPT3Dot5Turbo,
  49. MaxTokens: OpenAIAPIMaxTokens,
  50. Messages: reqMsgs,
  51. }
  52. ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout)
  53. defer cancel()
  54. resp, err := c.CreateChatCompletion(ctx, req)
  55. if nil != err {
  56. PushErrMsg("Requesting failed, please check kernel log for more details", 3000)
  57. logging.LogErrorf("create chat completion failed: %s", err)
  58. stop = true
  59. return
  60. }
  61. if 1 > len(resp.Choices) {
  62. stop = true
  63. return
  64. }
  65. buf := &strings.Builder{}
  66. choice := resp.Choices[0]
  67. buf.WriteString(choice.Message.Content)
  68. if "length" == choice.FinishReason {
  69. stop = false
  70. } else {
  71. stop = true
  72. }
  73. ret = buf.String()
  74. ret = strings.TrimSpace(ret)
  75. return
  76. }
  77. func NewOpenAIClient() *gogpt.Client {
  78. config := gogpt.DefaultConfig(OpenAIAPIKey)
  79. if "" != OpenAIAPIProxy {
  80. proxyUrl, err := url.Parse(OpenAIAPIProxy)
  81. if nil != err {
  82. logging.LogErrorf("OpenAI API proxy failed: %v", err)
  83. } else {
  84. config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
  85. }
  86. }
  87. config.BaseURL = OpenAIAPIBaseURL
  88. return gogpt.NewClientWithConfig(config)
  89. }
  90. func initOpenAI() {
  91. OpenAIAPIKey = os.Getenv("SIYUAN_OPENAI_API_KEY")
  92. if "" == OpenAIAPIKey {
  93. return
  94. }
  95. timeout := os.Getenv("SIYUAN_OPENAI_API_TIMEOUT")
  96. if "" != timeout {
  97. timeoutInt, err := strconv.Atoi(timeout)
  98. if nil == err {
  99. OpenAIAPITimeout = time.Duration(timeoutInt) * time.Second
  100. }
  101. }
  102. proxy := os.Getenv("SIYUAN_OPENAI_API_PROXY")
  103. if "" != proxy {
  104. OpenAIAPIProxy = proxy
  105. }
  106. maxTokens := os.Getenv("SIYUAN_OPENAI_API_MAX_TOKENS")
  107. if "" != maxTokens {
  108. maxTokensInt, err := strconv.Atoi(maxTokens)
  109. if nil == err {
  110. OpenAIAPIMaxTokens = maxTokensInt
  111. }
  112. }
  113. baseURL := os.Getenv("SIYUAN_OPENAI_API_BASE_URL")
  114. if "" != baseURL {
  115. OpenAIAPIBaseURL = baseURL
  116. }
  117. logging.LogInfof("OpenAI API enabled\n"+
  118. " baseURL=%s\n"+
  119. " timeout=%ds\n"+
  120. " proxy=%s\n"+
  121. " maxTokens=%d",
  122. OpenAIAPIBaseURL, int(OpenAIAPITimeout.Seconds()), OpenAIAPIProxy, OpenAIAPIMaxTokens)
  123. }