static files: refactor neutered http.FileSystem
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
75c2bcff8f
commit
87b12af932
3 changed files with 30 additions and 15 deletions
|
@ -1059,11 +1059,10 @@ func fileServer(r chi.Router, path string, root http.FileSystem, disableDirector
|
||||||
r.Get(path, func(w http.ResponseWriter, r *http.Request) {
|
r.Get(path, func(w http.ResponseWriter, r *http.Request) {
|
||||||
rctx := chi.RouteContext(r.Context())
|
rctx := chi.RouteContext(r.Context())
|
||||||
pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*")
|
pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*")
|
||||||
handler := http.FileServer(root)
|
|
||||||
if disableDirectoryIndex {
|
if disableDirectoryIndex {
|
||||||
handler = neuter(handler)
|
root = neuteredFileSystem{root}
|
||||||
}
|
}
|
||||||
fs := http.StripPrefix(pathPrefix, handler)
|
fs := http.StripPrefix(pathPrefix, http.FileServer(root))
|
||||||
fs.ServeHTTP(w, r)
|
fs.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1223,3 +1222,30 @@ func resolveInstallationCode() string {
|
||||||
}
|
}
|
||||||
return installationCode
|
return installationCode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type neuteredFileSystem struct {
|
||||||
|
fs http.FileSystem
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nfs neuteredFileSystem) Open(name string) (http.File, error) {
|
||||||
|
f, err := nfs.fs.Open(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.IsDir() {
|
||||||
|
index := path.Join(name, "index.html")
|
||||||
|
if _, err := nfs.fs.Open(index); err != nil {
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
|
@ -24426,7 +24426,7 @@ func TestStaticFilesMock(t *testing.T) {
|
||||||
req, err = http.NewRequest(http.MethodGet, location, nil)
|
req, err = http.NewRequest(http.MethodGet, location, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
rr = executeRequest(req)
|
rr = executeRequest(req)
|
||||||
checkResponseCode(t, http.StatusForbidden, rr)
|
checkResponseCode(t, http.StatusNotFound, rr)
|
||||||
|
|
||||||
req, err = http.NewRequest(http.MethodGet, "/openapi", nil)
|
req, err = http.NewRequest(http.MethodGet, "/openapi", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -553,14 +553,3 @@ func checkPartialAuth(w http.ResponseWriter, r *http.Request, audience string, t
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func neuter(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if strings.HasSuffix(r.URL.Path, "/") {
|
|
||||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in a new issue