static files: refactor neutered http.FileSystem

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2023-11-03 17:22:28 +01:00
parent 75c2bcff8f
commit 87b12af932
No known key found for this signature in database
GPG key ID: 935D2952DEC4EECF
3 changed files with 30 additions and 15 deletions

View file

@ -1059,11 +1059,10 @@ func fileServer(r chi.Router, path string, root http.FileSystem, disableDirector
r.Get(path, func(w http.ResponseWriter, r *http.Request) { r.Get(path, func(w http.ResponseWriter, r *http.Request) {
rctx := chi.RouteContext(r.Context()) rctx := chi.RouteContext(r.Context())
pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*") pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*")
handler := http.FileServer(root)
if disableDirectoryIndex { if disableDirectoryIndex {
handler = neuter(handler) root = neuteredFileSystem{root}
} }
fs := http.StripPrefix(pathPrefix, handler) fs := http.StripPrefix(pathPrefix, http.FileServer(root))
fs.ServeHTTP(w, r) fs.ServeHTTP(w, r)
}) })
} }
@ -1223,3 +1222,30 @@ func resolveInstallationCode() string {
} }
return installationCode return installationCode
} }
type neuteredFileSystem struct {
fs http.FileSystem
}
func (nfs neuteredFileSystem) Open(name string) (http.File, error) {
f, err := nfs.fs.Open(name)
if err != nil {
return nil, err
}
s, err := f.Stat()
if err != nil {
return nil, err
}
if s.IsDir() {
index := path.Join(name, "index.html")
if _, err := nfs.fs.Open(index); err != nil {
defer f.Close()
return nil, err
}
}
return f, nil
}

View file

@ -24426,7 +24426,7 @@ func TestStaticFilesMock(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, location, nil) req, err = http.NewRequest(http.MethodGet, location, nil)
assert.NoError(t, err) assert.NoError(t, err)
rr = executeRequest(req) rr = executeRequest(req)
checkResponseCode(t, http.StatusForbidden, rr) checkResponseCode(t, http.StatusNotFound, rr)
req, err = http.NewRequest(http.MethodGet, "/openapi", nil) req, err = http.NewRequest(http.MethodGet, "/openapi", nil)
assert.NoError(t, err) assert.NoError(t, err)

View file

@ -553,14 +553,3 @@ func checkPartialAuth(w http.ResponseWriter, r *http.Request, audience string, t
} }
return nil return nil
} }
func neuter(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/") {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}