소스 검색

feat: add support for custom action hooks

Signed-off-by: Mark Sagi-Kazar <mark.sagikazar@gmail.com>
Mark Sagi-Kazar 4 년 전
부모
커밋
75750e3a79
4개의 변경된 파일119개의 추가작업 그리고 78개의 파일을 삭제
  1. 105 64
      common/actions.go
  2. 9 9
      common/actions_test.go
  3. 3 3
      common/connection.go
  4. 2 2
      common/transfer.go

+ 105 - 64
common/actions.go

@@ -34,8 +34,29 @@ type ProtocolActions struct {
 	Hook string `json:"hook" mapstructure:"hook"`
 }
 
-// actionNotification defines a notification for a Protocol Action
-type actionNotification struct {
+var actionHandler ActionHandler = defaultActionHandler{}
+
+// InitializeActionHandler lets the user choose an action handler implementation.
+//
+// Do NOT call this function after application initialization.
+func InitializeActionHandler(handler ActionHandler) {
+	actionHandler = handler
+}
+
+// SSHCommandActionNotification executes the defined action for the specified SSH command.
+func SSHCommandActionNotification(user *dataprovider.User, filePath, target, sshCmd string, err error) {
+	notification := newActionNotification(user, operationSSHCmd, filePath, target, sshCmd, ProtocolSSH, 0, err)
+
+	go actionHandler.Handle(notification) // nolint:errcheck
+}
+
+// ActionHandler handles a notification for a Protocol Action.
+type ActionHandler interface {
+	Handle(notification ActionNotification) error
+}
+
+// ActionNotification defines a notification for a Protocol Action.
+type ActionNotification struct {
 	Action     string `json:"action"`
 	Username   string `json:"username"`
 	Path       string `json:"path"`
@@ -49,29 +70,29 @@ type actionNotification struct {
 	Protocol   string `json:"protocol"`
 }
 
-// SSHCommandActionNotification executes the defined action for the specified SSH command
-func SSHCommandActionNotification(user *dataprovider.User, filePath, target, sshCmd string, err error) {
-	action := newActionNotification(user, operationSSHCmd, filePath, target, sshCmd, ProtocolSSH, 0, err)
-	go action.execute() //nolint:errcheck
-}
-
-func newActionNotification(user *dataprovider.User, operation, filePath, target, sshCmd, protocol string, fileSize int64,
-	err error) actionNotification {
-	bucket := ""
-	endpoint := ""
+func newActionNotification(
+	user *dataprovider.User,
+	operation, filePath, target, sshCmd, protocol string,
+	fileSize int64,
+	err error,
+) ActionNotification {
+	var bucket, endpoint string
 	status := 1
+
 	if user.FsConfig.Provider == dataprovider.S3FilesystemProvider {
 		bucket = user.FsConfig.S3Config.Bucket
 		endpoint = user.FsConfig.S3Config.Endpoint
 	} else if user.FsConfig.Provider == dataprovider.GCSFilesystemProvider {
 		bucket = user.FsConfig.GCSConfig.Bucket
 	}
+
 	if err == ErrQuotaExceeded {
 		status = 2
 	} else if err != nil {
 		status = 0
 	}
-	return actionNotification{
+
+	return ActionNotification{
 		Action:     operation,
 		Username:   user.Username,
 		Path:       filePath,
@@ -86,72 +107,92 @@ func newActionNotification(user *dataprovider.User, operation, filePath, target,
 	}
 }
 
-func (a *actionNotification) asJSON() []byte {
-	res, _ := json.Marshal(a)
-	return res
+type defaultActionHandler struct{}
+
+func (h defaultActionHandler) Handle(notification ActionNotification) error {
+	if !utils.IsStringInSlice(notification.Action, Config.Actions.ExecuteOn) {
+		return errUnconfiguredAction
+	}
+
+	if Config.Actions.Hook == "" {
+		logger.Warn(notification.Protocol, "", "Unable to send notification, no hook is defined")
+
+		return errNoHook
+	}
+
+	if strings.HasPrefix(Config.Actions.Hook, "http") {
+		return h.handleHTTP(notification)
+	}
+
+	return h.handleCommand(notification)
 }
 
-func (a *actionNotification) asEnvVars() []string {
-	return []string{fmt.Sprintf("SFTPGO_ACTION=%v", a.Action),
-		fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", a.Username),
-		fmt.Sprintf("SFTPGO_ACTION_PATH=%v", a.Path),
-		fmt.Sprintf("SFTPGO_ACTION_TARGET=%v", a.TargetPath),
-		fmt.Sprintf("SFTPGO_ACTION_SSH_CMD=%v", a.SSHCmd),
-		fmt.Sprintf("SFTPGO_ACTION_FILE_SIZE=%v", a.FileSize),
-		fmt.Sprintf("SFTPGO_ACTION_FS_PROVIDER=%v", a.FsProvider),
-		fmt.Sprintf("SFTPGO_ACTION_BUCKET=%v", a.Bucket),
-		fmt.Sprintf("SFTPGO_ACTION_ENDPOINT=%v", a.Endpoint),
-		fmt.Sprintf("SFTPGO_ACTION_STATUS=%v", a.Status),
-		fmt.Sprintf("SFTPGO_ACTION_PROTOCOL=%v", a.Protocol),
+func (h defaultActionHandler) handleHTTP(notification ActionNotification) error {
+	u, err := url.Parse(Config.Actions.Hook)
+	if err != nil {
+		logger.Warn(notification.Protocol, "", "Invalid hook %#v for operation %#v: %v", Config.Actions.Hook, notification.Action, err)
+
+		return err
 	}
+
+	startTime := time.Now()
+	respCode := 0
+
+	httpClient := httpclient.GetHTTPClient()
+
+	var b bytes.Buffer
+	_ = json.NewEncoder(&b).Encode(notification)
+
+	resp, err := httpClient.Post(u.String(), "application/json", &b)
+	if err == nil {
+		respCode = resp.StatusCode
+		resp.Body.Close()
+
+		if respCode != http.StatusOK {
+			err = errUnexpectedHTTResponse
+		}
+	}
+
+	logger.Debug(notification.Protocol, "", "notified operation %#v to URL: %v status code: %v, elapsed: %v err: %v", notification.Action, u.String(), respCode, time.Since(startTime), err)
+
+	return err
 }
 
-func (a *actionNotification) executeNotificationCommand() error {
+func (h defaultActionHandler) handleCommand(notification ActionNotification) error {
 	if !filepath.IsAbs(Config.Actions.Hook) {
 		err := fmt.Errorf("invalid notification command %#v", Config.Actions.Hook)
-		logger.Warn(a.Protocol, "", "unable to execute notification command: %v", err)
+		logger.Warn(notification.Protocol, "", "unable to execute notification command: %v", err)
+
 		return err
 	}
+
 	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
 	defer cancel()
-	cmd := exec.CommandContext(ctx, Config.Actions.Hook, a.Action, a.Username, a.Path, a.TargetPath, a.SSHCmd)
-	cmd.Env = append(os.Environ(), a.asEnvVars()...)
+
+	cmd := exec.CommandContext(ctx, Config.Actions.Hook, notification.Action, notification.Username, notification.Path, notification.TargetPath, notification.SSHCmd)
+	cmd.Env = append(os.Environ(), notificationAsEnvVars(notification)...)
+
 	startTime := time.Now()
 	err := cmd.Run()
-	logger.Debug(a.Protocol, "", "executed command %#v with arguments: %#v, %#v, %#v, %#v, %#v, elapsed: %v, error: %v",
-		Config.Actions.Hook, a.Action, a.Username, a.Path, a.TargetPath, a.SSHCmd, time.Since(startTime), err)
+
+	logger.Debug(notification.Protocol, "", "executed command %#v with arguments: %#v, %#v, %#v, %#v, %#v, elapsed: %v, error: %v",
+		Config.Actions.Hook, notification.Action, notification.Username, notification.Path, notification.TargetPath, notification.SSHCmd, time.Since(startTime), err)
+
 	return err
 }
 
-func (a *actionNotification) execute() error {
-	if !utils.IsStringInSlice(a.Action, Config.Actions.ExecuteOn) {
-		return errUnconfiguredAction
-	}
-	if len(Config.Actions.Hook) == 0 {
-		logger.Warn(a.Protocol, "", "Unable to send notification, no hook is defined")
-		return errNoHook
-	}
-	if strings.HasPrefix(Config.Actions.Hook, "http") {
-		var url *url.URL
-		url, err := url.Parse(Config.Actions.Hook)
-		if err != nil {
-			logger.Warn(a.Protocol, "", "Invalid hook %#v for operation %#v: %v", Config.Actions.Hook, a.Action, err)
-			return err
-		}
-		startTime := time.Now()
-		httpClient := httpclient.GetHTTPClient()
-		resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(a.asJSON()))
-		respCode := 0
-		if err == nil {
-			respCode = resp.StatusCode
-			resp.Body.Close()
-			if respCode != http.StatusOK {
-				err = errUnexpectedHTTResponse
-			}
-		}
-		logger.Debug(a.Protocol, "", "notified operation %#v to URL: %v status code: %v, elapsed: %v err: %v",
-			a.Action, url.String(), respCode, time.Since(startTime), err)
-		return err
+func notificationAsEnvVars(notification ActionNotification) []string {
+	return []string{
+		fmt.Sprintf("SFTPGO_ACTION=%v", notification.Action),
+		fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", notification.Username),
+		fmt.Sprintf("SFTPGO_ACTION_PATH=%v", notification.Path),
+		fmt.Sprintf("SFTPGO_ACTION_TARGET=%v", notification.TargetPath),
+		fmt.Sprintf("SFTPGO_ACTION_SSH_CMD=%v", notification.SSHCmd),
+		fmt.Sprintf("SFTPGO_ACTION_FILE_SIZE=%v", notification.FileSize),
+		fmt.Sprintf("SFTPGO_ACTION_FS_PROVIDER=%v", notification.FsProvider),
+		fmt.Sprintf("SFTPGO_ACTION_BUCKET=%v", notification.Bucket),
+		fmt.Sprintf("SFTPGO_ACTION_ENDPOINT=%v", notification.Endpoint),
+		fmt.Sprintf("SFTPGO_ACTION_STATUS=%v", notification.Status),
+		fmt.Sprintf("SFTPGO_ACTION_PROTOCOL=%v", notification.Protocol),
 	}
-	return a.executeNotificationCommand()
 }

+ 9 - 9
common/actions_test.go

@@ -58,15 +58,15 @@ func TestActionHTTP(t *testing.T) {
 		Username: "username",
 	}
 	a := newActionNotification(user, operationDownload, "path", "target", "", ProtocolSFTP, 123, nil)
-	err := a.execute()
+	err := actionHandler.Handle(a)
 	assert.NoError(t, err)
 
 	Config.Actions.Hook = "http://invalid:1234"
-	err = a.execute()
+	err = actionHandler.Handle(a)
 	assert.Error(t, err)
 
 	Config.Actions.Hook = fmt.Sprintf("http://%v/404", httpAddr)
-	err = a.execute()
+	err = actionHandler.Handle(a)
 	if assert.Error(t, err) {
 		assert.EqualError(t, err, errUnexpectedHTTResponse.Error())
 	}
@@ -91,7 +91,7 @@ func TestActionCMD(t *testing.T) {
 		Username: "username",
 	}
 	a := newActionNotification(user, operationDownload, "path", "target", "", ProtocolSFTP, 123, nil)
-	err = a.execute()
+	err = actionHandler.Handle(a)
 	assert.NoError(t, err)
 
 	SSHCommandActionNotification(user, "path", "target", "sha1sum", nil)
@@ -115,26 +115,26 @@ func TestWrongActions(t *testing.T) {
 	}
 
 	a := newActionNotification(user, operationUpload, "", "", "", ProtocolSFTP, 123, nil)
-	err := a.execute()
+	err := actionHandler.Handle(a)
 	assert.Error(t, err, "action with bad command must fail")
 
 	a.Action = operationDelete
-	err = a.execute()
+	err = actionHandler.Handle(a)
 	assert.EqualError(t, err, errUnconfiguredAction.Error())
 
 	Config.Actions.Hook = "http://foo\x7f.com/"
 	a.Action = operationUpload
-	err = a.execute()
+	err = actionHandler.Handle(a)
 	assert.Error(t, err, "action with bad url must fail")
 
 	Config.Actions.Hook = ""
-	err = a.execute()
+	err = actionHandler.Handle(a)
 	if assert.Error(t, err) {
 		assert.EqualError(t, err, errNoHook.Error())
 	}
 
 	Config.Actions.Hook = "relative path"
-	err = a.execute()
+	err = actionHandler.Handle(a)
 	if assert.Error(t, err) {
 		assert.EqualError(t, err, fmt.Sprintf("invalid notification command %#v", Config.Actions.Hook))
 	}

+ 3 - 3
common/connection.go

@@ -249,7 +249,7 @@ func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo
 	}
 	size := info.Size()
 	action := newActionNotification(&c.User, operationPreDelete, fsPath, "", "", c.protocol, size, nil)
-	actionErr := action.execute()
+	actionErr := actionHandler.Handle(action)
 	if actionErr == nil {
 		c.Log(logger.LevelDebug, "remove for path %#v handled by pre-delete action", fsPath)
 	} else {
@@ -273,7 +273,7 @@ func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo
 	}
 	if actionErr != nil {
 		action := newActionNotification(&c.User, operationDelete, fsPath, "", "", c.protocol, size, nil)
-		go action.execute() //nolint:errcheck
+		go actionHandler.Handle(action) // nolint:errcheck
 	}
 	return nil
 }
@@ -392,7 +392,7 @@ func (c *BaseConnection) Rename(fsSourcePath, fsTargetPath, virtualSourcePath, v
 		"", "", "", -1)
 	action := newActionNotification(&c.User, operationRename, fsSourcePath, fsTargetPath, "", c.protocol, 0, nil)
 	// the returned error is used in test cases only, we already log the error inside action.execute
-	go action.execute() //nolint:errcheck
+	go actionHandler.Handle(action) // nolint:errcheck
 
 	return nil
 }

+ 2 - 2
common/transfer.go

@@ -220,7 +220,7 @@ func (t *BaseTransfer) Close() error {
 			t.Connection.ID, t.Connection.protocol)
 		action := newActionNotification(&t.Connection.User, operationDownload, t.fsPath, "", "", t.Connection.protocol,
 			atomic.LoadInt64(&t.BytesSent), t.ErrTransfer)
-		go action.execute() //nolint:errcheck
+		go actionHandler.Handle(action) //nolint:errcheck
 	} else {
 		fileSize := atomic.LoadInt64(&t.BytesReceived) + t.MinWriteOffset
 		info, err := t.Fs.Stat(t.fsPath)
@@ -233,7 +233,7 @@ func (t *BaseTransfer) Close() error {
 			t.Connection.ID, t.Connection.protocol)
 		action := newActionNotification(&t.Connection.User, operationUpload, t.fsPath, "", "", t.Connection.protocol,
 			fileSize, t.ErrTransfer)
-		go action.execute() //nolint:errcheck
+		go actionHandler.Handle(action) //nolint:errcheck
 	}
 	if t.ErrTransfer != nil {
 		t.Connection.Log(logger.LevelWarn, "transfer error: %v, path: %#v", t.ErrTransfer, t.fsPath)