openai.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. // SiYuan - Build Your Eternal Digital Garden
  2. // Copyright (c) 2020-present, b3log.org
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package util
  17. import (
  18. "context"
  19. "net/http"
  20. "net/url"
  21. "os"
  22. "strconv"
  23. "strings"
  24. "time"
  25. gogpt "github.com/sashabaranov/go-gpt3"
  26. "github.com/siyuan-note/logging"
  27. )
  28. var (
  29. OpenAIAPIKey = ""
  30. OpenAIAPITimeout = 15 * time.Second
  31. OpenAIAPIProxy = ""
  32. OpenAIAPIMaxTokens = 4096
  33. )
  34. func ChatGPT(msg string) (ret string) {
  35. if "" == OpenAIAPIKey {
  36. return
  37. }
  38. config := gogpt.DefaultConfig(OpenAIAPIKey)
  39. if "" != OpenAIAPIProxy {
  40. proxyUrl, err := url.Parse(OpenAIAPIProxy)
  41. if nil != err {
  42. logging.LogErrorf("OpenAI API proxy error: %v", err)
  43. } else {
  44. config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
  45. }
  46. }
  47. c := gogpt.NewClientWithConfig(config)
  48. ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout)
  49. defer cancel()
  50. req := gogpt.ChatCompletionRequest{
  51. Model: gogpt.GPT3Dot5Turbo,
  52. MaxTokens: OpenAIAPIMaxTokens,
  53. Messages: []gogpt.ChatCompletionMessage{
  54. {
  55. Role: "user",
  56. Content: msg,
  57. },
  58. },
  59. }
  60. resp, err := c.CreateChatCompletion(ctx, req)
  61. if nil != err {
  62. logging.LogErrorf("create chat completion failed: %s", err)
  63. return
  64. }
  65. if 0 < len(resp.Choices) {
  66. ret = resp.Choices[0].Message.Content
  67. ret = strings.TrimSpace(ret)
  68. }
  69. return
  70. }
  71. func initOpenAI() {
  72. OpenAIAPIKey = os.Getenv("SIYUAN_OPENAI_API_KEY")
  73. if "" == OpenAIAPIKey {
  74. return
  75. }
  76. timeout := os.Getenv("SIYUAN_OPENAI_API_TIMEOUT")
  77. if "" != timeout {
  78. timeoutInt, err := strconv.Atoi(timeout)
  79. if nil == err {
  80. OpenAIAPITimeout = time.Duration(timeoutInt) * time.Second
  81. }
  82. }
  83. proxy := os.Getenv("SIYUAN_OPENAI_API_PROXY")
  84. if "" != proxy {
  85. OpenAIAPIProxy = proxy
  86. }
  87. maxTokens := os.Getenv("SIYUAN_OPENAI_API_MAX_TOKENS")
  88. if "" != maxTokens {
  89. maxTokensInt, err := strconv.Atoi(maxTokens)
  90. if nil == err {
  91. OpenAIAPIMaxTokens = maxTokensInt
  92. }
  93. }
  94. if 1 > OpenAIAPIMaxTokens {
  95. OpenAIAPIMaxTokens = 4096
  96. }
  97. logging.LogInfof("OpenAI API enabled [maxTokens=%d, timeout=%ds, proxy=%s]", OpenAIAPIMaxTokens, int(OpenAIAPITimeout.Seconds()), OpenAIAPIProxy)
  98. }