httpd.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package httpd
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "fmt"
  7. "net/http"
  8. "os"
  9. "path/filepath"
  10. "time"
  11. "github.com/drakkan/sftpgo/ldapauthserver/config"
  12. "github.com/drakkan/sftpgo/ldapauthserver/logger"
  13. "github.com/drakkan/sftpgo/ldapauthserver/utils"
  14. "github.com/go-chi/chi/v5"
  15. "github.com/go-chi/chi/v5/middleware"
  16. "github.com/go-chi/render"
  17. )
  18. const (
  19. logSender = "httpd"
  20. versionPath = "/api/v1/version"
  21. checkAuthPath = "/api/v1/check_auth"
  22. maxRequestSize = 1 << 18 // 256KB
  23. )
  24. var (
  25. ldapConfig config.LDAPConfig
  26. httpAuth httpAuthProvider
  27. certMgr *certManager
  28. rootCAs *x509.CertPool
  29. )
  30. // StartHTTPServer initializes and starts the HTTP Server
  31. func StartHTTPServer(configDir string, httpConfig config.HTTPDConfig) error {
  32. var err error
  33. authUserFile := getConfigPath(httpConfig.AuthUserFile, configDir)
  34. httpAuth, err = newBasicAuthProvider(authUserFile)
  35. if err != nil {
  36. return err
  37. }
  38. router := chi.NewRouter()
  39. router.Use(middleware.RequestID)
  40. router.Use(middleware.RealIP)
  41. router.Use(logger.NewStructuredLogger(logger.GetLogger()))
  42. router.Use(middleware.Recoverer)
  43. router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  44. sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
  45. }))
  46. router.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  47. sendAPIResponse(w, r, nil, "Method not allowed", http.StatusMethodNotAllowed)
  48. }))
  49. router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) {
  50. render.JSON(w, r, utils.GetAppVersion())
  51. })
  52. router.Group(func(router chi.Router) {
  53. router.Use(checkAuth)
  54. router.Post(checkAuthPath, checkSFTPGoUserAuth)
  55. })
  56. ldapConfig = config.GetLDAPConfig()
  57. loadCACerts(configDir)
  58. certificateFile := getConfigPath(httpConfig.CertificateFile, configDir)
  59. certificateKeyFile := getConfigPath(httpConfig.CertificateKeyFile, configDir)
  60. httpServer := &http.Server{
  61. Addr: fmt.Sprintf("%s:%d", httpConfig.BindAddress, httpConfig.BindPort),
  62. Handler: router,
  63. ReadTimeout: 70 * time.Second,
  64. WriteTimeout: 70 * time.Second,
  65. IdleTimeout: 120 * time.Second,
  66. MaxHeaderBytes: 1 << 16, // 64KB
  67. }
  68. if len(certificateFile) > 0 && len(certificateKeyFile) > 0 {
  69. certMgr, err = newCertManager(certificateFile, certificateKeyFile)
  70. if err != nil {
  71. return err
  72. }
  73. config := &tls.Config{
  74. GetCertificate: certMgr.GetCertificateFunc(),
  75. MinVersion: tls.VersionTLS12,
  76. }
  77. httpServer.TLSConfig = config
  78. return httpServer.ListenAndServeTLS("", "")
  79. }
  80. return httpServer.ListenAndServe()
  81. }
  82. func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
  83. var errorString string
  84. if err != nil {
  85. errorString = err.Error()
  86. }
  87. resp := apiResponse{
  88. Error: errorString,
  89. Message: message,
  90. HTTPStatus: code,
  91. }
  92. ctx := context.WithValue(r.Context(), render.StatusCtxKey, code)
  93. render.JSON(w, r.WithContext(ctx), resp)
  94. }
  95. func loadCACerts(configDir string) error {
  96. var err error
  97. rootCAs, err = x509.SystemCertPool()
  98. if err != nil {
  99. rootCAs = x509.NewCertPool()
  100. }
  101. for _, ca := range ldapConfig.CACertificates {
  102. caPath := getConfigPath(ca, configDir)
  103. certs, err := os.ReadFile(caPath)
  104. if err != nil {
  105. logger.Warn(logSender, "", "error loading ca cert %q: %v", caPath, err)
  106. return err
  107. }
  108. if !rootCAs.AppendCertsFromPEM(certs) {
  109. logger.Warn(logSender, "", "unable to add ca cert %q", caPath)
  110. } else {
  111. logger.Debug(logSender, "", "ca cert %q added to the trusted certificates", caPath)
  112. }
  113. }
  114. return nil
  115. }
  116. // ReloadTLSCertificate reloads the TLS certificate and key from the configured paths
  117. func ReloadTLSCertificate() {
  118. if certMgr != nil {
  119. certMgr.loadCertificate()
  120. }
  121. }
  122. func getConfigPath(name, configDir string) string {
  123. if !utils.IsFileInputValid(name) {
  124. return ""
  125. }
  126. if len(name) > 0 && !filepath.IsAbs(name) {
  127. return filepath.Join(configDir, name)
  128. }
  129. return name
  130. }