diff --git a/pkg/cwhub/cwhub_test.go b/pkg/cwhub/cwhub_test.go index fd95dd2eb..3a260cb9a 100644 --- a/pkg/cwhub/cwhub_test.go +++ b/pkg/cwhub/cwhub_test.go @@ -48,14 +48,9 @@ func testHub(t *testing.T, update bool) *Hub { err = os.MkdirAll(local.InstallDataDir, 0o700) require.NoError(t, err) - index, err := os.Create(local.HubIndexFile) + err = os.WriteFile(local.HubIndexFile, []byte("{}"), 0o644) require.NoError(t, err) - _, err = index.WriteString(`{}`) - require.NoError(t, err) - - index.Close() - t.Cleanup(func() { os.RemoveAll(tmpDir) }) diff --git a/pkg/cwhub/enable.go b/pkg/cwhub/enable.go index 2f3bdb00b..9887558f6 100644 --- a/pkg/cwhub/enable.go +++ b/pkg/cwhub/enable.go @@ -10,12 +10,42 @@ import ( log "github.com/sirupsen/logrus" ) -// enable creates a symlink between actual config file at hub.HubDir and hub.ConfigDir -// Handles collections recursively -func (i *Item) enable() error { - parentDir := filepath.Clean(i.hub.local.InstallDir + "/" + i.Type + "/" + i.Stage + "/") +// installLink returns the location of the symlink to the actual config file (eg. /etc/crowdsec/collections/xyz.yaml) +func (i *Item) installLink() string { + return filepath.Join(i.hub.local.InstallDir, i.Type, i.Stage, i.FileName) +} - // create directories if needed +// makeLink creates a symlink between the actual config file at hub.HubDir and hub.ConfigDir +func (i *Item) createInstallLink() error { + dest, err := filepath.Abs(i.installLink()) + if err != nil { + return err + } + + destDir := filepath.Dir(dest) + if err = os.MkdirAll(destDir, os.ModePerm); err != nil { + return fmt.Errorf("while creating %s: %w", destDir, err) + } + + if _, err = os.Lstat(dest); !os.IsNotExist(err) { + log.Infof("%s already exists.", dest) + return nil + } + + src, err := filepath.Abs(filepath.Join(i.hub.local.HubDir, i.RemotePath)) + if err != nil { + return err + } + + if err = os.Symlink(src, dest); err != nil { + return fmt.Errorf("while creating symlink from %s to %s: %w", src, dest, err) + } + + return nil +} + +// enable enables the item by creating a symlink to the downloaded content, and also enables sub-items +func (i *Item) enable() error { if i.Installed { if i.Tainted { return fmt.Errorf("%s is tainted, won't enable unless --force", i.Name) @@ -32,40 +62,14 @@ func (i *Item) enable() error { } } - if _, err := os.Stat(parentDir); os.IsNotExist(err) { - log.Infof("%s doesn't exist, create", parentDir) - - if err = os.MkdirAll(parentDir, os.ModePerm); err != nil { - return fmt.Errorf("while creating directory: %w", err) - } - } - - // install sub-items if any for _, sub := range i.SubItems() { if err := sub.enable(); err != nil { return fmt.Errorf("while installing %s: %w", sub.Name, err) } } - // check if file already exists where it should in configdir (eg /etc/crowdsec/collections/) - if _, err := os.Lstat(parentDir + "/" + i.FileName); !os.IsNotExist(err) { - log.Infof("%s already exists.", parentDir+"/"+i.FileName) - return nil - } - - // hub.ConfigDir + target.RemotePath - srcPath, err := filepath.Abs(i.hub.local.HubDir + "/" + i.RemotePath) - if err != nil { - return fmt.Errorf("while getting source path: %w", err) - } - - dstPath, err := filepath.Abs(parentDir + "/" + i.FileName) - if err != nil { - return fmt.Errorf("while getting destination path: %w", err) - } - - if err = os.Symlink(srcPath, dstPath); err != nil { - return fmt.Errorf("while creating symlink from %s to %s: %w", srcPath, dstPath, err) + if err := i.createInstallLink(); err != nil { + return err } log.Infof("Enabled %s: %s", i.Type, i.Name) @@ -81,12 +85,11 @@ func (i *Item) purge() error { return nil } - itempath := i.hub.local.HubDir + "/" + i.RemotePath + src := filepath.Join(i.hub.local.HubDir, i.RemotePath) - // disable hub file - if err := os.Remove(itempath); err != nil { + if err := os.Remove(src); err != nil { if os.IsNotExist(err) { - log.Debugf("%s doesn't exist, no need to remove", itempath) + log.Debugf("%s doesn't exist, no need to remove", src) return nil } @@ -94,7 +97,48 @@ func (i *Item) purge() error { } i.Downloaded = false - log.Infof("Removed source file [%s]: %s", i.Name, itempath) + log.Infof("Removed source file [%s]: %s", i.Name, src) + + return nil +} + +func (i *Item) removeInstallLink() error { + syml, err := filepath.Abs(i.installLink()) + if err != nil { + return err + } + + stat, err := os.Lstat(syml) + if err != nil { + return err + } + + // if it's managed by hub, it's a symlink to csconfig.GConfig.hub.HubDir / ... + if stat.Mode()&os.ModeSymlink == 0 { + log.Warningf("%s (%s) isn't a symlink, can't disable", i.Name, syml) + return fmt.Errorf("%s isn't managed by hub", i.Name) + } + + hubpath, err := os.Readlink(syml) + if err != nil { + return fmt.Errorf("while reading symlink: %w", err) + } + + src, err := filepath.Abs(i.hub.local.HubDir + "/" + i.RemotePath) + if err != nil { + return err + } + + if hubpath != src { + log.Warningf("%s (%s) isn't a symlink to %s", i.Name, syml, src) + return fmt.Errorf("%s isn't managed by hub", i.Name) + } + + if err := os.Remove(syml); err != nil { + return fmt.Errorf("while removing symlink: %w", err) + } + + log.Infof("Removed symlink [%s]: %s", i.Name, syml) return nil } @@ -111,6 +155,7 @@ func (i *Item) disable(purge bool, force bool) error { } for _, sub := range i.SubItems() { + // TODO XXX: if the other collection(s) are direct or indirect dependencies of the current one, it's good to go if len(sub.BelongsToCollections) > 1 { log.Infof("%s was not removed because it belongs to another collection", sub.Name) continue @@ -126,48 +171,13 @@ func (i *Item) disable(purge bool, force bool) error { return nil } - syml, err := filepath.Abs(i.hub.local.InstallDir + "/" + i.Type + "/" + i.Stage + "/" + i.FileName) - if err != nil { - return err - } - - stat, err := os.Lstat(syml) + err := i.removeInstallLink() if os.IsNotExist(err) { - // we only accept to "delete" non existing items if it's a forced purge if !purge && !force { - return fmt.Errorf("can't delete %s: %s doesn't exist", i.Name, syml) + return fmt.Errorf("can't disable %s: %s doesn't exist", i.Name, i.installLink()) } - } else { - // if it's managed by hub, it's a symlink to csconfig.GConfig.hub.HubDir / ... - if stat.Mode()&os.ModeSymlink == 0 { - log.Warningf("%s (%s) isn't a symlink, can't disable", i.Name, syml) - return fmt.Errorf("%s isn't managed by hub", i.Name) - } - - hubpath, err := os.Readlink(syml) - if err != nil { - return fmt.Errorf("while reading symlink: %w", err) - } - - absPath, err := filepath.Abs(i.hub.local.HubDir + "/" + i.RemotePath) - if err != nil { - return fmt.Errorf("while abs path: %w", err) - } - - if hubpath != absPath { - log.Warningf("%s (%s) isn't a symlink to %s", i.Name, syml, absPath) - return fmt.Errorf("%s isn't managed by hub", i.Name) - } - - if err := os.Remove(syml); err != nil { - if os.IsNotExist(err) { - log.Debugf("%s doesn't exist, no need to remove", syml) - return nil - } - return fmt.Errorf("while removing symlink: %w", err) - } - - log.Infof("Removed symlink [%s]: %s", i.Name, syml) + } else if err != nil { + return err } i.Installed = false diff --git a/pkg/cwhub/helpers.go b/pkg/cwhub/helpers.go index 2427a8fcb..ba14c07fb 100644 --- a/pkg/cwhub/helpers.go +++ b/pkg/cwhub/helpers.go @@ -161,14 +161,46 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) error { return nil } -func (i *Item) download(overwrite bool) error { +// fetch downloads the item from the hub, verifies the hash and returns the body +func (i *Item) fetch() ([]byte, error) { url, err := i.hub.remote.urlTo(i.RemotePath) if err != nil { - return fmt.Errorf("failed to build hub item request: %w", err) + return nil, fmt.Errorf("failed to build hub item request: %w", err) } - tdir := i.hub.local.HubDir + resp, err := hubClient.Get(url) + if err != nil { + return nil, fmt.Errorf("while downloading %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bad http code %d for %s", resp.StatusCode, url) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("while downloading %s: %w", url, err) + } + + hash := sha256.New() + if _, err = hash.Write(body); err != nil { + return nil, fmt.Errorf("while hashing %s: %w", i.Name, err) + } + + meow := hex.EncodeToString(hash.Sum(nil)) + if meow != i.Versions[i.Version].Digest { + log.Errorf("Downloaded version doesn't match index, please 'hub update'") + log.Debugf("got %s, expected %s", meow, i.Versions[i.Version].Digest) + + return nil, fmt.Errorf("invalid download hash for %s", i.Name) + } + + return body, nil +} + +// download downloads the item from the hub and writes it to the hub directory +func (i *Item) download(overwrite bool) error { // if user didn't --force, don't overwrite local, tainted, up-to-date files if !overwrite { if i.Tainted { @@ -182,56 +214,29 @@ func (i *Item) download(overwrite bool) error { } } - resp, err := hubClient.Get(url) + body, err := i.fetch() if err != nil { - return fmt.Errorf("while downloading %s: %w", url, err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad http code %d for %s", resp.StatusCode, url) + return err } - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("while downloading %s: %w", url, err) - } - - hash := sha256.New() - if _, err = hash.Write(body); err != nil { - return fmt.Errorf("while hashing %s: %w", i.Name, err) - } - - meow := hex.EncodeToString(hash.Sum(nil)) - if meow != i.Versions[i.Version].Digest { - log.Errorf("Downloaded version doesn't match index, please 'hub update'") - log.Debugf("got %s, expected %s", meow, i.Versions[i.Version].Digest) - - return fmt.Errorf("invalid download hash for %s", i.Name) - } + tdir := i.hub.local.HubDir //all good, install - //check if parent dir exists - tmpdirs := strings.Split(tdir+"/"+i.RemotePath, "/") - parentDir := strings.Join(tmpdirs[:len(tmpdirs)-1], "/") - // ensure that target file is within target dir - finalPath, err := filepath.Abs(tdir + "/" + i.RemotePath) + finalPath, err := filepath.Abs(filepath.Join(tdir, i.RemotePath)) if err != nil { - return fmt.Errorf("filepath.Abs error on %s: %w", tdir+"/"+i.RemotePath, err) + return err } + // ensure that target file is within target dir if !strings.HasPrefix(finalPath, tdir) { return fmt.Errorf("path %s escapes %s, abort", i.RemotePath, tdir) } - // check dir - if _, err = os.Stat(parentDir); os.IsNotExist(err) { - log.Debugf("%s doesn't exist, create", parentDir) + parentDir := filepath.Dir(finalPath) - if err = os.MkdirAll(parentDir, os.ModePerm); err != nil { - return fmt.Errorf("while creating parent directories: %w", err) - } + if err = os.MkdirAll(parentDir, os.ModePerm); err != nil { + return fmt.Errorf("while creating %s: %w", parentDir, err) } // check actual file @@ -242,15 +247,8 @@ func (i *Item) download(overwrite bool) error { log.Infof("%s: OK", i.Name) } - f, err := os.Create(tdir + "/" + i.RemotePath) - if err != nil { - return fmt.Errorf("while opening file: %w", err) - } - defer f.Close() - - _, err = f.Write(body) - if err != nil { - return fmt.Errorf("while writing file: %w", err) + if err = os.WriteFile(finalPath, body, 0o644); err != nil { + return fmt.Errorf("while writing %s: %w", finalPath, err) } i.Downloaded = true diff --git a/pkg/cwhub/hub_test.go b/pkg/cwhub/hub_test.go index 1bba5e90a..56e2bf376 100644 --- a/pkg/cwhub/hub_test.go +++ b/pkg/cwhub/hub_test.go @@ -69,5 +69,5 @@ func TestDownloadIndex(t *testing.T) { } err = hub.remote.downloadIndex("/does/not/exist/index.json") - cstest.RequireErrorContains(t, err, "while opening hub index file: open /does/not/exist/index.json:") + cstest.RequireErrorContains(t, err, "failed to write hub index: open /does/not/exist/index.json:") } diff --git a/pkg/cwhub/items.go b/pkg/cwhub/items.go index 8e3063a9d..37e853d3a 100644 --- a/pkg/cwhub/items.go +++ b/pkg/cwhub/items.go @@ -250,7 +250,7 @@ func (i *Item) versionStatus() int { } // validPath returns true if the (relative) path is allowed for the item -// dirNmae: the directory name (ie. crowdsecurity) +// dirNname: the directory name (ie. crowdsecurity) // fileName: the filename (ie. apache2-logs.yaml) func (i *Item) validPath(dirName, fileName string) bool { return (dirName+"/"+fileName == i.Name+".yaml") || (dirName+"/"+fileName == i.Name+".yml") diff --git a/pkg/cwhub/remote.go b/pkg/cwhub/remote.go index ad97a4efb..c3855d5e0 100644 --- a/pkg/cwhub/remote.go +++ b/pkg/cwhub/remote.go @@ -70,18 +70,11 @@ func (r *RemoteHubCfg) downloadIndex(localPath string) error { return nil } - file, err := os.Create(localPath) - if err != nil { - return fmt.Errorf("while opening hub index file: %w", err) - } - defer file.Close() - - wsize, err := file.Write(body) - if err != nil { - return fmt.Errorf("while writing hub index file: %w", err) + if err = os.WriteFile(localPath, body, 0o644); err != nil { + return fmt.Errorf("failed to write hub index: %w", err) } - log.Infof("Wrote index to %s, %d bytes", localPath, wsize) + log.Infof("Wrote index to %s, %d bytes", localPath, len(body)) return nil }