openai.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. // SiYuan - Refactor your thinking
  2. // Copyright (c) 2020-present, b3log.org
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package util
  17. import (
  18. "context"
  19. "net/http"
  20. "net/url"
  21. "strings"
  22. "time"
  23. "github.com/sashabaranov/go-openai"
  24. "github.com/siyuan-note/logging"
  25. )
  26. func ChatGPT(msg string, contextMsgs []string, c *openai.Client, model string, maxTokens, timeout int) (ret string, stop bool, err error) {
  27. var reqMsgs []openai.ChatCompletionMessage
  28. for _, ctxMsg := range contextMsgs {
  29. reqMsgs = append(reqMsgs, openai.ChatCompletionMessage{
  30. Role: "user",
  31. Content: ctxMsg,
  32. })
  33. }
  34. reqMsgs = append(reqMsgs, openai.ChatCompletionMessage{
  35. Role: "user",
  36. Content: msg,
  37. })
  38. req := openai.ChatCompletionRequest{
  39. Model: model,
  40. MaxTokens: maxTokens,
  41. Messages: reqMsgs,
  42. }
  43. ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
  44. defer cancel()
  45. resp, err := c.CreateChatCompletion(ctx, req)
  46. if nil != err {
  47. PushErrMsg("Requesting failed, please check kernel log for more details", 3000)
  48. logging.LogErrorf("create chat completion failed: %s", err)
  49. stop = true
  50. return
  51. }
  52. if 1 > len(resp.Choices) {
  53. stop = true
  54. return
  55. }
  56. buf := &strings.Builder{}
  57. choice := resp.Choices[0]
  58. buf.WriteString(choice.Message.Content)
  59. if "length" == choice.FinishReason {
  60. stop = false
  61. } else {
  62. stop = true
  63. }
  64. ret = buf.String()
  65. ret = strings.TrimSpace(ret)
  66. return
  67. }
  68. func NewOpenAIClient(apiKey, apiProxy, apiBaseURL, apiUserAgent string) *openai.Client {
  69. config := openai.DefaultConfig(apiKey)
  70. transport := &http.Transport{}
  71. if "" != apiProxy {
  72. proxyUrl, err := url.Parse(apiProxy)
  73. if nil != err {
  74. logging.LogErrorf("OpenAI API proxy failed: %v", err)
  75. } else {
  76. transport.Proxy = http.ProxyURL(proxyUrl)
  77. }
  78. }
  79. config.HTTPClient = &http.Client{Transport: newAddHeaderTransport(transport, apiUserAgent)}
  80. config.BaseURL = apiBaseURL
  81. return openai.NewClientWithConfig(config)
  82. }
  83. type AddHeaderTransport struct {
  84. RoundTripper http.RoundTripper
  85. UserAgent string
  86. }
  87. func (adt *AddHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  88. req.Header.Add("User-Agent", adt.UserAgent)
  89. return adt.RoundTripper.RoundTrip(req)
  90. }
  91. func newAddHeaderTransport(transport *http.Transport, userAgent string) *AddHeaderTransport {
  92. return &AddHeaderTransport{RoundTripper: transport, UserAgent: userAgent}
  93. }