Browse Source

static files: refactor neutered http.FileSystem

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 1 year ago
parent
commit
87b12af932
3 changed files with 30 additions and 15 deletions
  1. 29 3
      internal/httpd/httpd.go
  2. 1 1
      internal/httpd/httpd_test.go
  3. 0 11
      internal/httpd/middleware.go

+ 29 - 3
internal/httpd/httpd.go

@@ -1059,11 +1059,10 @@ func fileServer(r chi.Router, path string, root http.FileSystem, disableDirector
 	r.Get(path, func(w http.ResponseWriter, r *http.Request) {
 		rctx := chi.RouteContext(r.Context())
 		pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*")
-		handler := http.FileServer(root)
 		if disableDirectoryIndex {
-			handler = neuter(handler)
+			root = neuteredFileSystem{root}
 		}
-		fs := http.StripPrefix(pathPrefix, handler)
+		fs := http.StripPrefix(pathPrefix, http.FileServer(root))
 		fs.ServeHTTP(w, r)
 	})
 }
@@ -1223,3 +1222,30 @@ func resolveInstallationCode() string {
 	}
 	return installationCode
 }
+
+type neuteredFileSystem struct {
+	fs http.FileSystem
+}
+
+func (nfs neuteredFileSystem) Open(name string) (http.File, error) {
+	f, err := nfs.fs.Open(name)
+	if err != nil {
+		return nil, err
+	}
+
+	s, err := f.Stat()
+	if err != nil {
+		return nil, err
+	}
+
+	if s.IsDir() {
+		index := path.Join(name, "index.html")
+		if _, err := nfs.fs.Open(index); err != nil {
+			defer f.Close()
+
+			return nil, err
+		}
+	}
+
+	return f, nil
+}

+ 1 - 1
internal/httpd/httpd_test.go

@@ -24426,7 +24426,7 @@ func TestStaticFilesMock(t *testing.T) {
 	req, err = http.NewRequest(http.MethodGet, location, nil)
 	assert.NoError(t, err)
 	rr = executeRequest(req)
-	checkResponseCode(t, http.StatusForbidden, rr)
+	checkResponseCode(t, http.StatusNotFound, rr)
 
 	req, err = http.NewRequest(http.MethodGet, "/openapi", nil)
 	assert.NoError(t, err)

+ 0 - 11
internal/httpd/middleware.go

@@ -553,14 +553,3 @@ func checkPartialAuth(w http.ResponseWriter, r *http.Request, audience string, t
 	}
 	return nil
 }
-
-func neuter(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		if strings.HasSuffix(r.URL.Path, "/") {
-			http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
-			return
-		}
-
-		next.ServeHTTP(w, r)
-	})
-}