ai.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. // SiYuan - Refactor your thinking
  2. // Copyright (c) 2020-present, b3log.org
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package model
  17. import (
  18. "bytes"
  19. "strings"
  20. "github.com/88250/lute/ast"
  21. "github.com/88250/lute/parse"
  22. "github.com/sashabaranov/go-openai"
  23. "github.com/siyuan-note/siyuan/kernel/treenode"
  24. "github.com/siyuan-note/siyuan/kernel/util"
  25. )
  26. func ChatGPT(msg string) (ret string) {
  27. if !isOpenAIAPIEnabled() {
  28. return
  29. }
  30. return chatGPT(msg, false)
  31. }
  32. func ChatGPTWithAction(ids []string, action string) (ret string) {
  33. if !isOpenAIAPIEnabled() {
  34. return
  35. }
  36. msg := getBlocksContent(ids)
  37. ret = chatGPTWithAction(msg, action, false)
  38. return
  39. }
  40. var cachedContextMsg []string
  41. func chatGPT(msg string, cloud bool) (ret string) {
  42. ret, retCtxMsgs, err := chatGPTContinueWrite(msg, cachedContextMsg, cloud)
  43. if nil != err {
  44. return
  45. }
  46. cachedContextMsg = append(cachedContextMsg, retCtxMsgs...)
  47. return
  48. }
  49. func chatGPTWithAction(msg string, action string, cloud bool) (ret string) {
  50. action = strings.TrimSpace(action)
  51. if "" != action {
  52. msg = action + ":\n\n" + msg
  53. }
  54. ret, _, err := chatGPTContinueWrite(msg, nil, cloud)
  55. if nil != err {
  56. return
  57. }
  58. return
  59. }
  60. func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string, err error) {
  61. util.PushEndlessProgress("Requesting...")
  62. defer util.ClearPushProgress(100)
  63. if 7 < len(contextMsgs) {
  64. contextMsgs = contextMsgs[len(contextMsgs)-7:]
  65. }
  66. var gpt GPT
  67. if cloud {
  68. gpt = &CloudGPT{}
  69. } else {
  70. gpt = &OpenAIGPT{c: util.NewOpenAIClient(Conf.AI.OpenAI.APIKey, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIBaseURL)}
  71. }
  72. buf := &bytes.Buffer{}
  73. for i := 0; i < 7; i++ {
  74. part, stop, chatErr := gpt.chat(msg, contextMsgs)
  75. buf.WriteString(part)
  76. if stop || nil != chatErr {
  77. break
  78. }
  79. util.PushEndlessProgress("Continue requesting...")
  80. }
  81. ret = buf.String()
  82. ret = strings.TrimSpace(ret)
  83. retContextMsgs = append(retContextMsgs, msg, ret)
  84. return
  85. }
  86. func isOpenAIAPIEnabled() bool {
  87. if "" == Conf.AI.OpenAI.APIKey {
  88. util.PushMsg(Conf.Language(193), 5000)
  89. return false
  90. }
  91. return true
  92. }
  93. func getBlocksContent(ids []string) string {
  94. var nodes []*ast.Node
  95. trees := map[string]*parse.Tree{}
  96. for _, id := range ids {
  97. bt := treenode.GetBlockTree(id)
  98. if nil == bt {
  99. continue
  100. }
  101. var tree *parse.Tree
  102. if tree = trees[bt.RootID]; nil == tree {
  103. tree, _ = loadTreeByBlockID(bt.RootID)
  104. if nil == tree {
  105. continue
  106. }
  107. trees[bt.RootID] = tree
  108. }
  109. if node := treenode.GetNodeInTree(tree, id); nil != node {
  110. if ast.NodeDocument == node.Type {
  111. for child := node.FirstChild; nil != child; child = child.Next {
  112. nodes = append(nodes, child)
  113. }
  114. } else {
  115. nodes = append(nodes, node)
  116. }
  117. }
  118. }
  119. luteEngine := util.NewLute()
  120. buf := bytes.Buffer{}
  121. for _, node := range nodes {
  122. md := treenode.ExportNodeStdMd(node, luteEngine)
  123. buf.WriteString(md)
  124. buf.WriteString("\n\n")
  125. }
  126. return buf.String()
  127. }
  128. type GPT interface {
  129. chat(msg string, contextMsgs []string) (partRet string, stop bool, err error)
  130. }
  131. type OpenAIGPT struct {
  132. c *openai.Client
  133. }
  134. func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
  135. return util.ChatGPT(msg, contextMsgs, gpt.c, Conf.AI.OpenAI.APIModel, Conf.AI.OpenAI.APIMaxTokens, Conf.AI.OpenAI.APITimeout)
  136. }
  137. type CloudGPT struct {
  138. }
  139. func (gpt *CloudGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) {
  140. return CloudChatGPT(msg, contextMsgs)
  141. }