123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- package httpd
- import (
- "context"
- "crypto/tls"
- "crypto/x509"
- "fmt"
- "net/http"
- "os"
- "path/filepath"
- "time"
- "github.com/drakkan/sftpgo/ldapauthserver/config"
- "github.com/drakkan/sftpgo/ldapauthserver/logger"
- "github.com/drakkan/sftpgo/ldapauthserver/utils"
- "github.com/go-chi/chi/v5"
- "github.com/go-chi/chi/v5/middleware"
- "github.com/go-chi/render"
- )
- const (
- logSender = "httpd"
- versionPath = "/api/v1/version"
- checkAuthPath = "/api/v1/check_auth"
- maxRequestSize = 1 << 18 // 256KB
- )
- var (
- ldapConfig config.LDAPConfig
- httpAuth httpAuthProvider
- certMgr *certManager
- rootCAs *x509.CertPool
- )
- // StartHTTPServer initializes and starts the HTTP Server
- func StartHTTPServer(configDir string, httpConfig config.HTTPDConfig) error {
- var err error
- authUserFile := getConfigPath(httpConfig.AuthUserFile, configDir)
- httpAuth, err = newBasicAuthProvider(authUserFile)
- if err != nil {
- return err
- }
- router := chi.NewRouter()
- router.Use(middleware.RequestID)
- router.Use(middleware.RealIP)
- router.Use(logger.NewStructuredLogger(logger.GetLogger()))
- router.Use(middleware.Recoverer)
- router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
- }))
- router.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- sendAPIResponse(w, r, nil, "Method not allowed", http.StatusMethodNotAllowed)
- }))
- router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) {
- render.JSON(w, r, utils.GetAppVersion())
- })
- router.Group(func(router chi.Router) {
- router.Use(checkAuth)
- router.Post(checkAuthPath, checkSFTPGoUserAuth)
- })
- ldapConfig = config.GetLDAPConfig()
- loadCACerts(configDir)
- certificateFile := getConfigPath(httpConfig.CertificateFile, configDir)
- certificateKeyFile := getConfigPath(httpConfig.CertificateKeyFile, configDir)
- httpServer := &http.Server{
- Addr: fmt.Sprintf("%s:%d", httpConfig.BindAddress, httpConfig.BindPort),
- Handler: router,
- ReadTimeout: 70 * time.Second,
- WriteTimeout: 70 * time.Second,
- IdleTimeout: 120 * time.Second,
- MaxHeaderBytes: 1 << 16, // 64KB
- }
- if len(certificateFile) > 0 && len(certificateKeyFile) > 0 {
- certMgr, err = newCertManager(certificateFile, certificateKeyFile)
- if err != nil {
- return err
- }
- config := &tls.Config{
- GetCertificate: certMgr.GetCertificateFunc(),
- MinVersion: tls.VersionTLS12,
- }
- httpServer.TLSConfig = config
- return httpServer.ListenAndServeTLS("", "")
- }
- return httpServer.ListenAndServe()
- }
- func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
- var errorString string
- if err != nil {
- errorString = err.Error()
- }
- resp := apiResponse{
- Error: errorString,
- Message: message,
- HTTPStatus: code,
- }
- ctx := context.WithValue(r.Context(), render.StatusCtxKey, code)
- render.JSON(w, r.WithContext(ctx), resp)
- }
- func loadCACerts(configDir string) error {
- var err error
- rootCAs, err = x509.SystemCertPool()
- if err != nil {
- rootCAs = x509.NewCertPool()
- }
- for _, ca := range ldapConfig.CACertificates {
- caPath := getConfigPath(ca, configDir)
- certs, err := os.ReadFile(caPath)
- if err != nil {
- logger.Warn(logSender, "", "error loading ca cert %q: %v", caPath, err)
- return err
- }
- if !rootCAs.AppendCertsFromPEM(certs) {
- logger.Warn(logSender, "", "unable to add ca cert %q", caPath)
- } else {
- logger.Debug(logSender, "", "ca cert %q added to the trusted certificates", caPath)
- }
- }
- return nil
- }
- // ReloadTLSCertificate reloads the TLS certificate and key from the configured paths
- func ReloadTLSCertificate() {
- if certMgr != nil {
- certMgr.loadCertificate()
- }
- }
- func getConfigPath(name, configDir string) string {
- if !utils.IsFileInputValid(name) {
- return ""
- }
- if len(name) > 0 && !filepath.IsAbs(name) {
- return filepath.Join(configDir, name)
- }
- return name
- }
|