httpd.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package httpd
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "fmt"
  7. "io/ioutil"
  8. "net/http"
  9. "path/filepath"
  10. "time"
  11. "github.com/drakkan/sftpgo/ldapauthserver/config"
  12. "github.com/drakkan/sftpgo/ldapauthserver/logger"
  13. "github.com/drakkan/sftpgo/ldapauthserver/utils"
  14. "github.com/go-chi/chi"
  15. "github.com/go-chi/chi/middleware"
  16. "github.com/go-chi/render"
  17. )
  18. const (
  19. logSender = "httpd"
  20. versionPath = "/api/v1/version"
  21. checkAuthPath = "/api/v1/check_auth"
  22. maxRequestSize = 1 << 18 // 256KB
  23. )
  24. var (
  25. ldapConfig config.LDAPConfig
  26. httpAuth httpAuthProvider
  27. certMgr *certManager
  28. rootCAs *x509.CertPool
  29. )
  30. // StartHTTPServer initializes and starts the HTTP Server
  31. func StartHTTPServer(configDir string, httpConfig config.HTTPDConfig) error {
  32. var err error
  33. authUserFile := getConfigPath(httpConfig.AuthUserFile, configDir)
  34. httpAuth, err = newBasicAuthProvider(authUserFile)
  35. if err != nil {
  36. return err
  37. }
  38. router := chi.NewRouter()
  39. router.Use(middleware.RequestID)
  40. router.Use(middleware.RealIP)
  41. router.Use(logger.NewStructuredLogger(logger.GetLogger()))
  42. router.Use(middleware.Recoverer)
  43. router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  44. sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
  45. }))
  46. router.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  47. sendAPIResponse(w, r, nil, "Method not allowed", http.StatusMethodNotAllowed)
  48. }))
  49. router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) {
  50. render.JSON(w, r, utils.GetAppVersion())
  51. })
  52. router.Group(func(router chi.Router) {
  53. router.Use(checkAuth)
  54. router.Post(checkAuthPath, checkSFTPGoUserAuth)
  55. })
  56. ldapConfig = config.GetLDAPConfig()
  57. loadCACerts(configDir)
  58. certificateFile := getConfigPath(httpConfig.CertificateFile, configDir)
  59. certificateKeyFile := getConfigPath(httpConfig.CertificateKeyFile, configDir)
  60. httpServer := &http.Server{
  61. Addr: fmt.Sprintf("%s:%d", httpConfig.BindAddress, httpConfig.BindPort),
  62. Handler: router,
  63. ReadTimeout: 70 * time.Second,
  64. WriteTimeout: 70 * time.Second,
  65. IdleTimeout: 120 * time.Second,
  66. MaxHeaderBytes: 1 << 16, // 64KB
  67. }
  68. if len(certificateFile) > 0 && len(certificateKeyFile) > 0 {
  69. certMgr, err = newCertManager(certificateFile, certificateKeyFile)
  70. if err != nil {
  71. return err
  72. }
  73. config := &tls.Config{
  74. GetCertificate: certMgr.GetCertificateFunc(),
  75. }
  76. httpServer.TLSConfig = config
  77. return httpServer.ListenAndServeTLS("", "")
  78. }
  79. return httpServer.ListenAndServe()
  80. }
  81. func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
  82. var errorString string
  83. if err != nil {
  84. errorString = err.Error()
  85. }
  86. resp := apiResponse{
  87. Error: errorString,
  88. Message: message,
  89. HTTPStatus: code,
  90. }
  91. ctx := context.WithValue(r.Context(), render.StatusCtxKey, code)
  92. render.JSON(w, r.WithContext(ctx), resp)
  93. }
  94. func loadCACerts(configDir string) error {
  95. var err error
  96. rootCAs, err = x509.SystemCertPool()
  97. if err != nil {
  98. rootCAs = x509.NewCertPool()
  99. }
  100. for _, ca := range ldapConfig.CACertificates {
  101. caPath := getConfigPath(ca, configDir)
  102. certs, err := ioutil.ReadFile(caPath)
  103. if err != nil {
  104. logger.Warn(logSender, "", "error loading ca cert %#v: %v", caPath, err)
  105. return err
  106. }
  107. if !rootCAs.AppendCertsFromPEM(certs) {
  108. logger.Warn(logSender, "", "unable to add ca cert %#v", caPath)
  109. } else {
  110. logger.Debug(logSender, "", "ca cert %#v added to the trusted certificates", caPath)
  111. }
  112. }
  113. return nil
  114. }
  115. // ReloadTLSCertificate reloads the TLS certificate and key from the configured paths
  116. func ReloadTLSCertificate() {
  117. if certMgr != nil {
  118. certMgr.loadCertificate()
  119. }
  120. }
  121. func getConfigPath(name, configDir string) string {
  122. if !utils.IsFileInputValid(name) {
  123. return ""
  124. }
  125. if len(name) > 0 && !filepath.IsAbs(name) {
  126. return filepath.Join(configDir, name)
  127. }
  128. return name
  129. }