server.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. package webdavd
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "path"
  9. "path/filepath"
  10. "runtime/debug"
  11. "strings"
  12. "time"
  13. "github.com/rs/cors"
  14. "github.com/rs/xid"
  15. "golang.org/x/net/webdav"
  16. "github.com/drakkan/sftpgo/common"
  17. "github.com/drakkan/sftpgo/dataprovider"
  18. "github.com/drakkan/sftpgo/logger"
  19. "github.com/drakkan/sftpgo/metrics"
  20. "github.com/drakkan/sftpgo/utils"
  21. )
  22. var (
  23. err401 = errors.New("Unauthorized")
  24. err403 = errors.New("Forbidden")
  25. xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
  26. xRealIP = http.CanonicalHeaderKey("X-Real-IP")
  27. )
  28. type webDavServer struct {
  29. config *Configuration
  30. certMgr *common.CertManager
  31. status ServiceStatus
  32. }
  33. func newServer(config *Configuration, configDir string) (*webDavServer, error) {
  34. var err error
  35. server := &webDavServer{
  36. config: config,
  37. certMgr: nil,
  38. }
  39. certificateFile := getConfigPath(config.CertificateFile, configDir)
  40. certificateKeyFile := getConfigPath(config.CertificateKeyFile, configDir)
  41. if len(certificateFile) > 0 && len(certificateKeyFile) > 0 {
  42. server.certMgr, err = common.NewCertManager(certificateFile, certificateKeyFile, logSender)
  43. if err != nil {
  44. return server, err
  45. }
  46. }
  47. return server, nil
  48. }
  49. func (s *webDavServer) listenAndServe() error {
  50. addr := fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.BindPort)
  51. s.status.IsActive = true
  52. s.status.Address = addr
  53. s.status.Protocol = "HTTP"
  54. httpServer := &http.Server{
  55. Addr: addr,
  56. Handler: server,
  57. ReadHeaderTimeout: 30 * time.Second,
  58. IdleTimeout: 120 * time.Second,
  59. MaxHeaderBytes: 1 << 16, // 64KB
  60. }
  61. if s.config.Cors.Enabled {
  62. c := cors.New(cors.Options{
  63. AllowedOrigins: s.config.Cors.AllowedOrigins,
  64. AllowedMethods: s.config.Cors.AllowedMethods,
  65. AllowedHeaders: s.config.Cors.AllowedHeaders,
  66. ExposedHeaders: s.config.Cors.ExposedHeaders,
  67. MaxAge: s.config.Cors.MaxAge,
  68. AllowCredentials: s.config.Cors.AllowCredentials,
  69. OptionsPassthrough: true,
  70. })
  71. httpServer.Handler = c.Handler(server)
  72. } else {
  73. httpServer.Handler = server
  74. }
  75. if s.certMgr != nil {
  76. s.status.Protocol = "HTTPS"
  77. httpServer.TLSConfig = &tls.Config{
  78. GetCertificate: s.certMgr.GetCertificateFunc(),
  79. MinVersion: tls.VersionTLS12,
  80. }
  81. return httpServer.ListenAndServeTLS("", "")
  82. }
  83. return httpServer.ListenAndServe()
  84. }
  85. func (s *webDavServer) checkRequestMethod(ctx context.Context, r *http.Request, connection *Connection, prefix string) {
  86. // see RFC4918, section 9.4
  87. if r.Method == http.MethodGet {
  88. p := strings.TrimPrefix(path.Clean(r.URL.Path), prefix)
  89. info, err := connection.Stat(ctx, p)
  90. if err == nil && info.IsDir() {
  91. r.Method = "PROPFIND"
  92. if r.Header.Get("Depth") == "" {
  93. r.Header.Add("Depth", "1")
  94. }
  95. }
  96. }
  97. }
  98. // ServeHTTP implements the http.Handler interface
  99. func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  100. defer func() {
  101. if r := recover(); r != nil {
  102. logger.Error(logSender, "", "panic in ServeHTTP: %#v stack strace: %v", r, string(debug.Stack()))
  103. http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
  104. }
  105. }()
  106. checkRemoteAddress(r)
  107. if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil {
  108. http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
  109. return
  110. }
  111. user, _, lockSystem, err := s.authenticate(r)
  112. if err != nil {
  113. w.Header().Set("WWW-Authenticate", "Basic realm=\"SFTPGo WebDAV\"")
  114. http.Error(w, err401.Error(), http.StatusUnauthorized)
  115. return
  116. }
  117. if path.Clean(r.URL.Path) == "/" && (r.Method == http.MethodGet || r.Method == "PROPFIND" || r.Method == http.MethodOptions) {
  118. http.Redirect(w, r, path.Join("/", user.Username), http.StatusMovedPermanently)
  119. return
  120. }
  121. connectionID, err := s.validateUser(user, r)
  122. if err != nil {
  123. updateLoginMetrics(user.Username, r.RemoteAddr, err)
  124. http.Error(w, err.Error(), http.StatusForbidden)
  125. return
  126. }
  127. fs, err := user.GetFilesystem(connectionID)
  128. if err != nil {
  129. updateLoginMetrics(user.Username, r.RemoteAddr, err)
  130. http.Error(w, err.Error(), http.StatusInternalServerError)
  131. return
  132. }
  133. updateLoginMetrics(user.Username, r.RemoteAddr, err)
  134. ctx := context.WithValue(r.Context(), requestIDKey, connectionID)
  135. ctx = context.WithValue(ctx, requestStartKey, time.Now())
  136. connection := &Connection{
  137. BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, user, fs),
  138. request: r,
  139. }
  140. common.Connections.Add(connection)
  141. defer common.Connections.Remove(connection.GetID())
  142. dataprovider.UpdateLastLogin(user) //nolint:errcheck
  143. prefix := path.Join("/", user.Username)
  144. s.checkRequestMethod(ctx, r, connection, prefix)
  145. handler := webdav.Handler{
  146. Prefix: prefix,
  147. FileSystem: connection,
  148. LockSystem: lockSystem,
  149. Logger: writeLog,
  150. }
  151. handler.ServeHTTP(w, r.WithContext(ctx))
  152. }
  153. func (s *webDavServer) authenticate(r *http.Request) (dataprovider.User, bool, webdav.LockSystem, error) {
  154. var user dataprovider.User
  155. var err error
  156. username, password, ok := r.BasicAuth()
  157. if !ok {
  158. return user, false, nil, err401
  159. }
  160. result, ok := dataprovider.GetCachedWebDAVUser(username)
  161. if ok {
  162. cachedUser := result.(*dataprovider.CachedUser)
  163. if cachedUser.IsExpired() {
  164. dataprovider.RemoveCachedWebDAVUser(username)
  165. } else {
  166. if len(password) > 0 && cachedUser.Password == password {
  167. return cachedUser.User, true, cachedUser.LockSystem, nil
  168. }
  169. updateLoginMetrics(username, r.RemoteAddr, dataprovider.ErrInvalidCredentials)
  170. return user, false, nil, dataprovider.ErrInvalidCredentials
  171. }
  172. }
  173. user, err = dataprovider.CheckUserAndPass(username, password, utils.GetIPFromRemoteAddress(r.RemoteAddr), common.ProtocolWebDAV)
  174. if err != nil {
  175. updateLoginMetrics(username, r.RemoteAddr, err)
  176. return user, false, nil, err
  177. }
  178. lockSystem := webdav.NewMemLS()
  179. if password != "" {
  180. cachedUser := &dataprovider.CachedUser{
  181. User: user,
  182. Password: password,
  183. LockSystem: lockSystem,
  184. }
  185. if s.config.Cache.Users.ExpirationTime > 0 {
  186. cachedUser.Expiration = time.Now().Add(time.Duration(s.config.Cache.Users.ExpirationTime) * time.Minute)
  187. }
  188. dataprovider.CacheWebDAVUser(cachedUser, s.config.Cache.Users.MaxSize)
  189. if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider {
  190. // for sftp fs check root path does nothing so don't open a useless SFTP connection
  191. tempFs, err := user.GetFilesystem("temp")
  192. if err == nil {
  193. tempFs.CheckRootPath(user.Username, user.UID, user.GID)
  194. tempFs.Close()
  195. }
  196. }
  197. }
  198. return user, false, lockSystem, nil
  199. }
  200. func (s *webDavServer) validateUser(user dataprovider.User, r *http.Request) (string, error) {
  201. connID := xid.New().String()
  202. connectionID := fmt.Sprintf("%v_%v", common.ProtocolWebDAV, connID)
  203. uriSegments := strings.Split(path.Clean(r.URL.Path), "/")
  204. if len(uriSegments) < 2 || uriSegments[1] != user.Username {
  205. logger.Debug(logSender, connectionID, "URI %#v not allowed for user %#v", r.URL.Path, user.Username)
  206. return connID, err403
  207. }
  208. if !filepath.IsAbs(user.HomeDir) {
  209. logger.Warn(logSender, connectionID, "user %#v has an invalid home dir: %#v. Home dir must be an absolute path, login not allowed",
  210. user.Username, user.HomeDir)
  211. return connID, fmt.Errorf("cannot login user with invalid home dir: %#v", user.HomeDir)
  212. }
  213. if utils.IsStringInSlice(common.ProtocolWebDAV, user.Filters.DeniedProtocols) {
  214. logger.Debug(logSender, connectionID, "cannot login user %#v, protocol DAV is not allowed", user.Username)
  215. return connID, fmt.Errorf("Protocol DAV is not allowed for user %#v", user.Username)
  216. }
  217. if !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, nil) {
  218. logger.Debug(logSender, connectionID, "cannot login user %#v, password login method is not allowed", user.Username)
  219. return connID, fmt.Errorf("Password login method is not allowed for user %#v", user.Username)
  220. }
  221. if user.MaxSessions > 0 {
  222. activeSessions := common.Connections.GetActiveSessions(user.Username)
  223. if activeSessions >= user.MaxSessions {
  224. logger.Debug(logSender, connID, "authentication refused for user: %#v, too many open sessions: %v/%v", user.Username,
  225. activeSessions, user.MaxSessions)
  226. return connID, fmt.Errorf("too many open sessions: %v", activeSessions)
  227. }
  228. }
  229. if dataprovider.GetQuotaTracking() > 0 && user.HasOverlappedMappedPaths() {
  230. logger.Debug(logSender, connectionID, "cannot login user %#v, overlapping mapped folders are allowed only with quota tracking disabled",
  231. user.Username)
  232. return connID, errors.New("overlapping mapped folders are allowed only with quota tracking disabled")
  233. }
  234. if !user.IsLoginFromAddrAllowed(r.RemoteAddr) {
  235. logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, r.RemoteAddr)
  236. return connID, fmt.Errorf("Login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr)
  237. }
  238. return connID, nil
  239. }
  240. func writeLog(r *http.Request, err error) {
  241. scheme := "http"
  242. if r.TLS != nil {
  243. scheme = "https"
  244. }
  245. fields := map[string]interface{}{
  246. "remote_addr": r.RemoteAddr,
  247. "proto": r.Proto,
  248. "method": r.Method,
  249. "user_agent": r.UserAgent(),
  250. "uri": fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)}
  251. if reqID, ok := r.Context().Value(requestIDKey).(string); ok {
  252. fields["request_id"] = reqID
  253. }
  254. if reqStart, ok := r.Context().Value(requestStartKey).(time.Time); ok {
  255. fields["elapsed_ms"] = time.Since(reqStart).Nanoseconds() / 1000000
  256. }
  257. logger.GetLogger().Info().
  258. Timestamp().
  259. Str("sender", logSender).
  260. Fields(fields).
  261. Err(err).
  262. Send()
  263. }
  264. func checkRemoteAddress(r *http.Request) {
  265. if common.Config.ProxyProtocol != 0 {
  266. return
  267. }
  268. var ip string
  269. if xrip := r.Header.Get(xRealIP); xrip != "" {
  270. ip = xrip
  271. } else if xff := r.Header.Get(xForwardedFor); xff != "" {
  272. i := strings.Index(xff, ", ")
  273. if i == -1 {
  274. i = len(xff)
  275. }
  276. ip = strings.TrimSpace(xff[:i])
  277. }
  278. if len(ip) > 0 {
  279. r.RemoteAddr = ip
  280. }
  281. }
  282. func updateLoginMetrics(username, remoteAddress string, err error) {
  283. metrics.AddLoginAttempt(dataprovider.LoginMethodPassword)
  284. ip := utils.GetIPFromRemoteAddress(remoteAddress)
  285. if err != nil {
  286. logger.ConnectionFailedLog(username, ip, dataprovider.LoginMethodPassword, common.ProtocolWebDAV, err.Error())
  287. }
  288. metrics.AddLoginResult(dataprovider.LoginMethodPassword, err)
  289. dataprovider.ExecutePostLoginHook(username, dataprovider.LoginMethodPassword, ip, common.ProtocolWebDAV, err)
  290. }