publish.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. // SiYuan - Refactor your thinking
  2. // Copyright (c) 2020-present, b3log.org
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package proxy
  17. import (
  18. "fmt"
  19. "net"
  20. "net/http"
  21. "net/http/httputil"
  22. "strconv"
  23. "github.com/siyuan-note/logging"
  24. "github.com/siyuan-note/siyuan/kernel/model"
  25. "github.com/siyuan-note/siyuan/kernel/util"
  26. )
  27. type PublishServiceTransport struct{}
  28. var (
  29. Host = "0.0.0.0"
  30. Port = "0"
  31. listener net.Listener
  32. transport = PublishServiceTransport{}
  33. proxy = &httputil.ReverseProxy{
  34. Rewrite: rewrite,
  35. Transport: transport,
  36. }
  37. )
  38. func InitPublishService() (uint16, error) {
  39. model.InitAccounts()
  40. if listener != nil {
  41. if !model.Conf.Publish.Enable {
  42. // 关闭发布服务
  43. closePublishListener()
  44. return 0, nil
  45. }
  46. if port, err := util.ParsePort(Port); err != nil {
  47. return 0, err
  48. } else if port != model.Conf.Publish.Port {
  49. // 关闭原端口的发布服务
  50. if err = closePublishListener(); err != nil {
  51. return 0, err
  52. }
  53. // 重新启动新端口的发布服务
  54. initPublishService()
  55. }
  56. } else {
  57. if !model.Conf.Publish.Enable {
  58. return 0, nil
  59. }
  60. // 启动新端口的发布服务
  61. initPublishService()
  62. }
  63. return util.ParsePort(Port)
  64. }
  65. func initPublishService() (err error) {
  66. if err = initPublishListener(); err == nil {
  67. go startPublishReverseProxyService()
  68. }
  69. return
  70. }
  71. func initPublishListener() (err error) {
  72. // Start new listener
  73. listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", Host, model.Conf.Publish.Port))
  74. if err != nil {
  75. logging.LogErrorf("start listener failed: %s", err)
  76. return
  77. }
  78. _, Port, err = net.SplitHostPort(listener.Addr().String())
  79. if nil != err {
  80. logging.LogErrorf("split host and port failed: %s", err)
  81. return
  82. }
  83. return
  84. }
  85. func closePublishListener() (err error) {
  86. listener_ := listener
  87. listener = nil
  88. if err = listener_.Close(); err != nil {
  89. logging.LogErrorf("close listener %s failed: %s", listener_.Addr().String(), err)
  90. listener = listener_
  91. }
  92. return
  93. }
  94. func startPublishReverseProxyService() {
  95. logging.LogInfof("publish service [%s:%s] is running", Host, Port)
  96. // 服务进行时一直阻塞
  97. if err := http.Serve(listener, proxy); nil != err {
  98. if listener != nil {
  99. logging.LogErrorf("boot publish service failed: %s", err)
  100. }
  101. }
  102. logging.LogInfof("publish service [%s:%s] is stopped", Host, Port)
  103. }
  104. func rewrite(r *httputil.ProxyRequest) {
  105. r.SetURL(util.ServerURL)
  106. r.SetXForwarded()
  107. // r.Out.Host = r.In.Host // if desired
  108. }
  109. func (PublishServiceTransport) RoundTrip(request *http.Request) (response *http.Response, err error) {
  110. if model.Conf.Publish.Auth.Enable {
  111. // Basic Auth
  112. username, password, ok := request.BasicAuth()
  113. account := model.GetBasicAuthAccount(username)
  114. if !ok ||
  115. account == nil ||
  116. account.Username == "" || // 匿名用户
  117. account.Password != password {
  118. return &http.Response{
  119. StatusCode: http.StatusUnauthorized,
  120. Status: http.StatusText(http.StatusUnauthorized),
  121. Proto: request.Proto,
  122. ProtoMajor: request.ProtoMajor,
  123. ProtoMinor: request.ProtoMinor,
  124. Request: request,
  125. Header: http.Header{
  126. "WWW-Authenticate": {"Basic realm=" + strconv.Quote("Authorization Required")},
  127. },
  128. Close: false,
  129. ContentLength: -1,
  130. }, nil
  131. } else {
  132. // set JWT
  133. request.Header.Set(model.XAuthTokenKey, account.Token)
  134. }
  135. } else {
  136. request.Header.Set(model.XAuthTokenKey, model.GetBasicAuthAccount("").Token)
  137. }
  138. response, err = http.DefaultTransport.RoundTrip(request)
  139. return
  140. }