Refact cwhub: simplify enable/disable/download (#2597)
* Extract methods createInstallLink(), removeInstallLink(), simplify - the result of filepath.Join is already Cleaned - no need to log the creation of parentDir - filepath.Abs() only returns error if the current working directory has been removed * Extract method Item.fetch() * Replace Create() + Write() -> WriteFile()
This commit is contained in:
parent
d9b0d440bf
commit
65473d4e05
6 changed files with 139 additions and 143 deletions
|
@ -48,14 +48,9 @@ func testHub(t *testing.T, update bool) *Hub {
|
|||
err = os.MkdirAll(local.InstallDataDir, 0o700)
|
||||
require.NoError(t, err)
|
||||
|
||||
index, err := os.Create(local.HubIndexFile)
|
||||
err = os.WriteFile(local.HubIndexFile, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = index.WriteString(`{}`)
|
||||
require.NoError(t, err)
|
||||
|
||||
index.Close()
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
|
|
@ -10,12 +10,42 @@ import (
|
|||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// enable creates a symlink between actual config file at hub.HubDir and hub.ConfigDir
|
||||
// Handles collections recursively
|
||||
func (i *Item) enable() error {
|
||||
parentDir := filepath.Clean(i.hub.local.InstallDir + "/" + i.Type + "/" + i.Stage + "/")
|
||||
// installLink returns the location of the symlink to the actual config file (eg. /etc/crowdsec/collections/xyz.yaml)
|
||||
func (i *Item) installLink() string {
|
||||
return filepath.Join(i.hub.local.InstallDir, i.Type, i.Stage, i.FileName)
|
||||
}
|
||||
|
||||
// create directories if needed
|
||||
// makeLink creates a symlink between the actual config file at hub.HubDir and hub.ConfigDir
|
||||
func (i *Item) createInstallLink() error {
|
||||
dest, err := filepath.Abs(i.installLink())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destDir := filepath.Dir(dest)
|
||||
if err = os.MkdirAll(destDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("while creating %s: %w", destDir, err)
|
||||
}
|
||||
|
||||
if _, err = os.Lstat(dest); !os.IsNotExist(err) {
|
||||
log.Infof("%s already exists.", dest)
|
||||
return nil
|
||||
}
|
||||
|
||||
src, err := filepath.Abs(filepath.Join(i.hub.local.HubDir, i.RemotePath))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = os.Symlink(src, dest); err != nil {
|
||||
return fmt.Errorf("while creating symlink from %s to %s: %w", src, dest, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// enable enables the item by creating a symlink to the downloaded content, and also enables sub-items
|
||||
func (i *Item) enable() error {
|
||||
if i.Installed {
|
||||
if i.Tainted {
|
||||
return fmt.Errorf("%s is tainted, won't enable unless --force", i.Name)
|
||||
|
@ -32,40 +62,14 @@ func (i *Item) enable() error {
|
|||
}
|
||||
}
|
||||
|
||||
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
|
||||
log.Infof("%s doesn't exist, create", parentDir)
|
||||
|
||||
if err = os.MkdirAll(parentDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("while creating directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// install sub-items if any
|
||||
for _, sub := range i.SubItems() {
|
||||
if err := sub.enable(); err != nil {
|
||||
return fmt.Errorf("while installing %s: %w", sub.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// check if file already exists where it should in configdir (eg /etc/crowdsec/collections/)
|
||||
if _, err := os.Lstat(parentDir + "/" + i.FileName); !os.IsNotExist(err) {
|
||||
log.Infof("%s already exists.", parentDir+"/"+i.FileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// hub.ConfigDir + target.RemotePath
|
||||
srcPath, err := filepath.Abs(i.hub.local.HubDir + "/" + i.RemotePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("while getting source path: %w", err)
|
||||
}
|
||||
|
||||
dstPath, err := filepath.Abs(parentDir + "/" + i.FileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("while getting destination path: %w", err)
|
||||
}
|
||||
|
||||
if err = os.Symlink(srcPath, dstPath); err != nil {
|
||||
return fmt.Errorf("while creating symlink from %s to %s: %w", srcPath, dstPath, err)
|
||||
if err := i.createInstallLink(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Enabled %s: %s", i.Type, i.Name)
|
||||
|
@ -81,12 +85,11 @@ func (i *Item) purge() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
itempath := i.hub.local.HubDir + "/" + i.RemotePath
|
||||
src := filepath.Join(i.hub.local.HubDir, i.RemotePath)
|
||||
|
||||
// disable hub file
|
||||
if err := os.Remove(itempath); err != nil {
|
||||
if err := os.Remove(src); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Debugf("%s doesn't exist, no need to remove", itempath)
|
||||
log.Debugf("%s doesn't exist, no need to remove", src)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -94,7 +97,48 @@ func (i *Item) purge() error {
|
|||
}
|
||||
|
||||
i.Downloaded = false
|
||||
log.Infof("Removed source file [%s]: %s", i.Name, itempath)
|
||||
log.Infof("Removed source file [%s]: %s", i.Name, src)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Item) removeInstallLink() error {
|
||||
syml, err := filepath.Abs(i.installLink())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stat, err := os.Lstat(syml)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if it's managed by hub, it's a symlink to csconfig.GConfig.hub.HubDir / ...
|
||||
if stat.Mode()&os.ModeSymlink == 0 {
|
||||
log.Warningf("%s (%s) isn't a symlink, can't disable", i.Name, syml)
|
||||
return fmt.Errorf("%s isn't managed by hub", i.Name)
|
||||
}
|
||||
|
||||
hubpath, err := os.Readlink(syml)
|
||||
if err != nil {
|
||||
return fmt.Errorf("while reading symlink: %w", err)
|
||||
}
|
||||
|
||||
src, err := filepath.Abs(i.hub.local.HubDir + "/" + i.RemotePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hubpath != src {
|
||||
log.Warningf("%s (%s) isn't a symlink to %s", i.Name, syml, src)
|
||||
return fmt.Errorf("%s isn't managed by hub", i.Name)
|
||||
}
|
||||
|
||||
if err := os.Remove(syml); err != nil {
|
||||
return fmt.Errorf("while removing symlink: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Removed symlink [%s]: %s", i.Name, syml)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -111,6 +155,7 @@ func (i *Item) disable(purge bool, force bool) error {
|
|||
}
|
||||
|
||||
for _, sub := range i.SubItems() {
|
||||
// TODO XXX: if the other collection(s) are direct or indirect dependencies of the current one, it's good to go
|
||||
if len(sub.BelongsToCollections) > 1 {
|
||||
log.Infof("%s was not removed because it belongs to another collection", sub.Name)
|
||||
continue
|
||||
|
@ -126,48 +171,13 @@ func (i *Item) disable(purge bool, force bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
syml, err := filepath.Abs(i.hub.local.InstallDir + "/" + i.Type + "/" + i.Stage + "/" + i.FileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stat, err := os.Lstat(syml)
|
||||
err := i.removeInstallLink()
|
||||
if os.IsNotExist(err) {
|
||||
// we only accept to "delete" non existing items if it's a forced purge
|
||||
if !purge && !force {
|
||||
return fmt.Errorf("can't delete %s: %s doesn't exist", i.Name, syml)
|
||||
return fmt.Errorf("can't disable %s: %s doesn't exist", i.Name, i.installLink())
|
||||
}
|
||||
} else {
|
||||
// if it's managed by hub, it's a symlink to csconfig.GConfig.hub.HubDir / ...
|
||||
if stat.Mode()&os.ModeSymlink == 0 {
|
||||
log.Warningf("%s (%s) isn't a symlink, can't disable", i.Name, syml)
|
||||
return fmt.Errorf("%s isn't managed by hub", i.Name)
|
||||
}
|
||||
|
||||
hubpath, err := os.Readlink(syml)
|
||||
if err != nil {
|
||||
return fmt.Errorf("while reading symlink: %w", err)
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(i.hub.local.HubDir + "/" + i.RemotePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("while abs path: %w", err)
|
||||
}
|
||||
|
||||
if hubpath != absPath {
|
||||
log.Warningf("%s (%s) isn't a symlink to %s", i.Name, syml, absPath)
|
||||
return fmt.Errorf("%s isn't managed by hub", i.Name)
|
||||
}
|
||||
|
||||
if err := os.Remove(syml); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Debugf("%s doesn't exist, no need to remove", syml)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("while removing symlink: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Removed symlink [%s]: %s", i.Name, syml)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.Installed = false
|
||||
|
|
|
@ -161,14 +161,46 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (i *Item) download(overwrite bool) error {
|
||||
// fetch downloads the item from the hub, verifies the hash and returns the body
|
||||
func (i *Item) fetch() ([]byte, error) {
|
||||
url, err := i.hub.remote.urlTo(i.RemotePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build hub item request: %w", err)
|
||||
return nil, fmt.Errorf("failed to build hub item request: %w", err)
|
||||
}
|
||||
|
||||
tdir := i.hub.local.HubDir
|
||||
resp, err := hubClient.Get(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("while downloading %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("bad http code %d for %s", resp.StatusCode, url)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("while downloading %s: %w", url, err)
|
||||
}
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err = hash.Write(body); err != nil {
|
||||
return nil, fmt.Errorf("while hashing %s: %w", i.Name, err)
|
||||
}
|
||||
|
||||
meow := hex.EncodeToString(hash.Sum(nil))
|
||||
if meow != i.Versions[i.Version].Digest {
|
||||
log.Errorf("Downloaded version doesn't match index, please 'hub update'")
|
||||
log.Debugf("got %s, expected %s", meow, i.Versions[i.Version].Digest)
|
||||
|
||||
return nil, fmt.Errorf("invalid download hash for %s", i.Name)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// download downloads the item from the hub and writes it to the hub directory
|
||||
func (i *Item) download(overwrite bool) error {
|
||||
// if user didn't --force, don't overwrite local, tainted, up-to-date files
|
||||
if !overwrite {
|
||||
if i.Tainted {
|
||||
|
@ -182,56 +214,29 @@ func (i *Item) download(overwrite bool) error {
|
|||
}
|
||||
}
|
||||
|
||||
resp, err := hubClient.Get(url)
|
||||
body, err := i.fetch()
|
||||
if err != nil {
|
||||
return fmt.Errorf("while downloading %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bad http code %d for %s", resp.StatusCode, url)
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("while downloading %s: %w", url, err)
|
||||
}
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err = hash.Write(body); err != nil {
|
||||
return fmt.Errorf("while hashing %s: %w", i.Name, err)
|
||||
}
|
||||
|
||||
meow := hex.EncodeToString(hash.Sum(nil))
|
||||
if meow != i.Versions[i.Version].Digest {
|
||||
log.Errorf("Downloaded version doesn't match index, please 'hub update'")
|
||||
log.Debugf("got %s, expected %s", meow, i.Versions[i.Version].Digest)
|
||||
|
||||
return fmt.Errorf("invalid download hash for %s", i.Name)
|
||||
}
|
||||
tdir := i.hub.local.HubDir
|
||||
|
||||
//all good, install
|
||||
//check if parent dir exists
|
||||
tmpdirs := strings.Split(tdir+"/"+i.RemotePath, "/")
|
||||
parentDir := strings.Join(tmpdirs[:len(tmpdirs)-1], "/")
|
||||
|
||||
// ensure that target file is within target dir
|
||||
finalPath, err := filepath.Abs(tdir + "/" + i.RemotePath)
|
||||
finalPath, err := filepath.Abs(filepath.Join(tdir, i.RemotePath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("filepath.Abs error on %s: %w", tdir+"/"+i.RemotePath, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure that target file is within target dir
|
||||
if !strings.HasPrefix(finalPath, tdir) {
|
||||
return fmt.Errorf("path %s escapes %s, abort", i.RemotePath, tdir)
|
||||
}
|
||||
|
||||
// check dir
|
||||
if _, err = os.Stat(parentDir); os.IsNotExist(err) {
|
||||
log.Debugf("%s doesn't exist, create", parentDir)
|
||||
parentDir := filepath.Dir(finalPath)
|
||||
|
||||
if err = os.MkdirAll(parentDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("while creating parent directories: %w", err)
|
||||
}
|
||||
if err = os.MkdirAll(parentDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("while creating %s: %w", parentDir, err)
|
||||
}
|
||||
|
||||
// check actual file
|
||||
|
@ -242,15 +247,8 @@ func (i *Item) download(overwrite bool) error {
|
|||
log.Infof("%s: OK", i.Name)
|
||||
}
|
||||
|
||||
f, err := os.Create(tdir + "/" + i.RemotePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("while opening file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = f.Write(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("while writing file: %w", err)
|
||||
if err = os.WriteFile(finalPath, body, 0o644); err != nil {
|
||||
return fmt.Errorf("while writing %s: %w", finalPath, err)
|
||||
}
|
||||
|
||||
i.Downloaded = true
|
||||
|
|
|
@ -69,5 +69,5 @@ func TestDownloadIndex(t *testing.T) {
|
|||
}
|
||||
|
||||
err = hub.remote.downloadIndex("/does/not/exist/index.json")
|
||||
cstest.RequireErrorContains(t, err, "while opening hub index file: open /does/not/exist/index.json:")
|
||||
cstest.RequireErrorContains(t, err, "failed to write hub index: open /does/not/exist/index.json:")
|
||||
}
|
||||
|
|
|
@ -250,7 +250,7 @@ func (i *Item) versionStatus() int {
|
|||
}
|
||||
|
||||
// validPath returns true if the (relative) path is allowed for the item
|
||||
// dirNmae: the directory name (ie. crowdsecurity)
|
||||
// dirNname: the directory name (ie. crowdsecurity)
|
||||
// fileName: the filename (ie. apache2-logs.yaml)
|
||||
func (i *Item) validPath(dirName, fileName string) bool {
|
||||
return (dirName+"/"+fileName == i.Name+".yaml") || (dirName+"/"+fileName == i.Name+".yml")
|
||||
|
|
|
@ -70,18 +70,11 @@ func (r *RemoteHubCfg) downloadIndex(localPath string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
file, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("while opening hub index file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
wsize, err := file.Write(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("while writing hub index file: %w", err)
|
||||
if err = os.WriteFile(localPath, body, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write hub index: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Wrote index to %s, %d bytes", localPath, wsize)
|
||||
log.Infof("Wrote index to %s, %d bytes", localPath, len(body))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue