ai.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. // SiYuan - Refactor your thinking
  2. // Copyright (c) 2020-present, b3log.org
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package model
  17. import (
  18. "bytes"
  19. "strings"
  20. "github.com/88250/lute/ast"
  21. "github.com/88250/lute/parse"
  22. "github.com/sashabaranov/go-openai"
  23. "github.com/siyuan-note/siyuan/kernel/treenode"
  24. "github.com/siyuan-note/siyuan/kernel/util"
  25. )
  26. func ChatGPT(msg string) (ret string) {
  27. if !isOpenAIAPIEnabled() {
  28. return
  29. }
  30. return chatGPT(msg, false)
  31. }
  32. func ChatGPTWithAction(ids []string, action string) (ret string) {
  33. if !isOpenAIAPIEnabled() {
  34. return
  35. }
  36. if "Clear context" == action {
  37. // AI clear context action https://github.com/siyuan-note/siyuan/issues/10255
  38. cachedContextMsg = nil
  39. return
  40. }
  41. msg := getBlocksContent(ids)
  42. ret = chatGPTWithAction(msg, action, false)
  43. return
  44. }
  45. var cachedContextMsg []string
  46. func chatGPT(msg string, cloud bool) (ret string) {
  47. if "Clear context" == strings.TrimSpace(msg) {
  48. // AI clear context action https://github.com/siyuan-note/siyuan/issues/10255
  49. cachedContextMsg = nil
  50. return
  51. }
  52. ret, retCtxMsgs, err := chatGPTContinueWrite(msg, cachedContextMsg, cloud)
  53. if nil != err {
  54. return
  55. }
  56. cachedContextMsg = append(cachedContextMsg, retCtxMsgs...)
  57. return
  58. }
  59. func chatGPTWithAction(msg string, action string, cloud bool) (ret string) {
  60. action = strings.TrimSpace(action)
  61. if "" != action {
  62. msg = action + ":\n\n" + msg
  63. }
  64. ret, _, err := chatGPTContinueWrite(msg, nil, cloud)
  65. if nil != err {
  66. return
  67. }
  68. return
  69. }
  70. func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string, err error) {
  71. util.PushEndlessProgress("Requesting...")
  72. defer util.ClearPushProgress(100)
  73. if Conf.AI.OpenAI.APIMaxContexts < len(contextMsgs) {
  74. contextMsgs = contextMsgs[len(contextMsgs)-Conf.AI.OpenAI.APIMaxContexts:]
  75. }
  76. var gpt GPT
  77. if cloud {
  78. gpt = &CloudGPT{}
  79. } else {
  80. gpt = &OpenAIGPT{c: util.NewOpenAIClient(Conf.AI.OpenAI.APIKey, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APIUserAgent, Conf.AI.OpenAI.APIVersion, Conf.AI.OpenAI.APIProvider)}
  81. }
  82. buf := &bytes.Buffer{}
  83. for i := 0; i < Conf.AI.OpenAI.APIMaxContexts; i++ {
  84. part, stop, chatErr := gpt.chat(msg, contextMsgs)
  85. buf.WriteString(part)
  86. if stop || nil != chatErr {
  87. break
  88. }
  89. util.PushEndlessProgress("Continue requesting...")
  90. }
  91. ret = buf.String()
  92. ret = strings.TrimSpace(ret)
  93. retContextMsgs = append(retContextMsgs, msg, ret)
  94. return
  95. }
  96. func isOpenAIAPIEnabled() bool {
  97. if "" == Conf.AI.OpenAI.APIKey {
  98. util.PushMsg(Conf.Language(193), 5000)
  99. return false
  100. }
  101. return true
  102. }
  103. func getBlocksContent(ids []string) string {
  104. var nodes []*ast.Node
  105. trees := map[string]*parse.Tree{}
  106. for _, id := range ids {
  107. bt := treenode.GetBlockTree(id)
  108. if nil == bt {
  109. continue
  110. }
  111. var tree *parse.Tree
  112. if tree = trees[bt.RootID]; nil == tree {
  113. tree, _ = LoadTreeByBlockID(bt.RootID)
  114. if nil == tree {
  115. continue
  116. }
  117. trees[bt.RootID] = tree
  118. }
  119. if node := treenode.GetNodeInTree(tree, id); nil != node {
  120. if ast.NodeDocument == node.Type {
  121. for child := node.FirstChild; nil != child; child = child.Next {
  122. nodes = append(nodes, child)
  123. }
  124. } else {
  125. nodes = append(nodes, node)
  126. }
  127. }
  128. }
  129. luteEngine := util.NewLute()
  130. buf := bytes.Buffer{}
  131. for _, node := range nodes {
  132. md := treenode.ExportNodeStdMd(node, luteEngine)
  133. buf.WriteString(md)
  134. buf.WriteString("\n\n")
  135. }
  136. return buf.String()
  137. }
  138. type GPT interface {
  139. chat(msg string, contextMsgs []string) (partRet string, stop bool, err error)
  140. }
  141. type OpenAIGPT struct {
  142. c *openai.Client
  143. }
  144. func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
  145. return util.ChatGPT(msg, contextMsgs, gpt.c, Conf.AI.OpenAI.APIModel, Conf.AI.OpenAI.APIMaxTokens, Conf.AI.OpenAI.APITemperature, Conf.AI.OpenAI.APITimeout)
  146. }
  147. type CloudGPT struct {
  148. }
  149. func (gpt *CloudGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
  150. return CloudChatGPT(msg, contextMsgs)
  151. }