broadcast.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. // SiYuan - Refactor your thinking
  2. // Copyright (c) 2020-present, b3log.org
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package api
  17. import (
  18. "fmt"
  19. "net/http"
  20. "sync"
  21. "github.com/88250/gulu"
  22. "github.com/gin-gonic/gin"
  23. "github.com/olahol/melody"
  24. "github.com/siyuan-note/logging"
  25. "github.com/siyuan-note/siyuan/kernel/util"
  26. )
  27. type Channel struct {
  28. Name string `json:"name"`
  29. Count int `json:"count"`
  30. }
  31. var (
  32. BroadcastChannels = sync.Map{}
  33. )
  34. /*
  35. broadcast create a broadcast channel WebSocket connection
  36. @param
  37. query.channel: channel name
  38. @example
  39. ws://localhost:6806/ws/broadcast?channel=test
  40. */
  41. func broadcast(c *gin.Context) {
  42. var (
  43. channel string = c.Query("channel")
  44. broadcastChannel *melody.Melody
  45. )
  46. if _broadcastChannel, exist := BroadcastChannels.Load(channel); exist {
  47. // channel exists, use it
  48. broadcastChannel = _broadcastChannel.(*melody.Melody)
  49. subscribe(c, broadcastChannel, channel)
  50. } else {
  51. // channel not found, create a new one
  52. broadcastChannel := melody.New()
  53. broadcastChannel.Config.MaxMessageSize = 1024 * 1024 * 128 // 128 MiB
  54. BroadcastChannels.Store(channel, broadcastChannel)
  55. subscribe(c, broadcastChannel, channel)
  56. // broadcast string message to other session
  57. broadcastChannel.HandleMessage(func(s *melody.Session, msg []byte) {
  58. broadcastChannel.BroadcastOthers(msg, s)
  59. })
  60. // broadcast binary message to other session
  61. broadcastChannel.HandleMessageBinary(func(s *melody.Session, msg []byte) {
  62. broadcastChannel.BroadcastBinaryOthers(msg, s)
  63. })
  64. // recycling
  65. broadcastChannel.HandleClose(func(s *melody.Session, status int, reason string) error {
  66. channel := s.Keys["channel"].(string)
  67. logging.LogInfof("close broadcast session in channel [%s] with status code %d: %s", channel, status, reason)
  68. count := broadcastChannel.Len()
  69. if count == 0 {
  70. BroadcastChannels.Delete(channel)
  71. logging.LogInfof("dispose broadcast channel [%s]", channel)
  72. }
  73. return nil
  74. })
  75. }
  76. }
  77. // subscribe creates a new websocket session to a channel
  78. func subscribe(c *gin.Context, broadcastChannel *melody.Melody, channel string) {
  79. if err := broadcastChannel.HandleRequestWithKeys(
  80. c.Writer,
  81. c.Request,
  82. map[string]interface{}{
  83. "channel": channel,
  84. },
  85. ); nil != err {
  86. logging.LogErrorf("create broadcast channel failed: %s", err)
  87. return
  88. }
  89. }
  90. /*
  91. postMessage send string message to a broadcast channel
  92. @param
  93. body.channel: channel name
  94. body.message: message payload
  95. @returns
  96. body.data.channel.name: channel name
  97. body.data.channel.count: indicate how many websocket session received the message
  98. */
  99. func postMessage(c *gin.Context) {
  100. ret := gulu.Ret.NewResult()
  101. defer c.JSON(http.StatusOK, ret)
  102. arg, ok := util.JsonArg(c, ret)
  103. if !ok {
  104. return
  105. }
  106. channel := arg["channel"].(string)
  107. message := arg["message"].(string)
  108. if _broadcastChannel, ok := BroadcastChannels.Load(channel); !ok {
  109. err := fmt.Errorf("broadcast channel [%s] not found", channel)
  110. logging.LogWarnf(err.Error())
  111. ret.Code = -1
  112. ret.Msg = err.Error()
  113. return
  114. } else {
  115. var broadcastChannel = _broadcastChannel.(*melody.Melody)
  116. if err := broadcastChannel.Broadcast([]byte(message)); nil != err {
  117. logging.LogErrorf("broadcast message failed: %s", err)
  118. ret.Code = -2
  119. ret.Msg = err.Error()
  120. return
  121. }
  122. count := broadcastChannel.Len()
  123. ret.Data = map[string]interface{}{
  124. "channel": &Channel{
  125. Name: channel,
  126. Count: count,
  127. },
  128. }
  129. }
  130. }
  131. /*
  132. getChannelInfo gets the information of a broadcast channel
  133. @param
  134. body.name: channel name
  135. @returns
  136. body.data.channel: the channel information
  137. */
  138. func getChannelInfo(c *gin.Context) {
  139. ret := gulu.Ret.NewResult()
  140. defer c.JSON(http.StatusOK, ret)
  141. arg, ok := util.JsonArg(c, ret)
  142. if !ok {
  143. return
  144. }
  145. name := arg["name"].(string)
  146. if _broadcastChannel, ok := BroadcastChannels.Load(name); !ok {
  147. err := fmt.Errorf("broadcast channel [%s] not found", name)
  148. logging.LogWarnf(err.Error())
  149. ret.Code = -1
  150. ret.Msg = err.Error()
  151. return
  152. } else {
  153. var broadcastChannel = _broadcastChannel.(*melody.Melody)
  154. count := broadcastChannel.Len()
  155. ret.Data = map[string]interface{}{
  156. "channel": &Channel{
  157. Name: name,
  158. Count: count,
  159. },
  160. }
  161. }
  162. }
  163. /*
  164. getChannels gets the channel name and lintener number of all broadcast chanel
  165. @returns
  166. body.data.channels: {
  167. name: channel name
  168. count: listener count
  169. }[]
  170. */
  171. func getChannels(c *gin.Context) {
  172. ret := gulu.Ret.NewResult()
  173. defer c.JSON(http.StatusOK, ret)
  174. channels := []*Channel{}
  175. BroadcastChannels.Range(func(key, value any) bool {
  176. broadcastChannel := value.(*melody.Melody)
  177. channels = append(channels, &Channel{
  178. Name: key.(string),
  179. Count: broadcastChannel.Len(),
  180. })
  181. return true
  182. })
  183. ret.Data = map[string]interface{}{
  184. "channels": channels,
  185. }
  186. }