diff --git a/internal/common/connection.go b/internal/common/connection.go index 1dad8c5d..46140e44 100644 --- a/internal/common/connection.go +++ b/internal/common/connection.go @@ -1610,7 +1610,7 @@ func (c *BaseConnection) GetOpUnsupportedError() error { func getQuotaExceededError(protocol string) error { switch protocol { case ProtocolSFTP: - return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrQuotaExceeded.Error()) + return fmt.Errorf("%w: %w", sftp.ErrSSHFxFailure, ErrQuotaExceeded) case ProtocolFTP: return ftpserver.ErrStorageExceeded default: @@ -1621,7 +1621,7 @@ func getQuotaExceededError(protocol string) error { func getReadQuotaExceededError(protocol string) error { switch protocol { case ProtocolSFTP: - return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrReadQuotaExceeded.Error()) + return fmt.Errorf("%w: %w", sftp.ErrSSHFxFailure, ErrReadQuotaExceeded) default: return ErrReadQuotaExceeded } @@ -1655,15 +1655,21 @@ func (c *BaseConnection) IsQuotaExceededError(err error) bool { } } +func isSFTPGoError(err error) bool { + return errors.Is(err, ErrPermissionDenied) || errors.Is(err, ErrNotExist) || errors.Is(err, ErrOpUnsupported) || + errors.Is(err, ErrQuotaExceeded) || errors.Is(err, ErrReadQuotaExceeded) || + errors.Is(err, vfs.ErrStorageSizeUnavailable) || errors.Is(err, ErrShuttingDown) +} + // GetGenericError returns an appropriate generic error for the connection protocol func (c *BaseConnection) GetGenericError(err error) error { switch c.protocol { case ProtocolSFTP: - if err == vfs.ErrStorageSizeUnavailable { - return fmt.Errorf("%w: %v", sftp.ErrSSHFxOpUnsupported, err.Error()) + if errors.Is(err, vfs.ErrStorageSizeUnavailable) || errors.Is(err, ErrOpUnsupported) || errors.Is(err, sftp.ErrSSHFxOpUnsupported) { + return fmt.Errorf("%w: %w", sftp.ErrSSHFxOpUnsupported, err) } - if err == ErrShuttingDown { - return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, err.Error()) + if isSFTPGoError(err) { + return fmt.Errorf("%w: %w", sftp.ErrSSHFxFailure, err) } if err != nil { var pathError *fs.PathError @@ -1672,13 +1678,10 @@ func (c *BaseConnection) GetGenericError(err error) error { return fmt.Errorf("%w: %v %v", sftp.ErrSSHFxFailure, pathError.Op, pathError.Err.Error()) } c.Log(logger.LevelError, "generic error: %+v", err) - return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrGenericFailure.Error()) } return sftp.ErrSSHFxFailure default: - if err == ErrPermissionDenied || err == ErrNotExist || err == ErrOpUnsupported || - err == ErrQuotaExceeded || err == ErrReadQuotaExceeded || err == vfs.ErrStorageSizeUnavailable || - err == ErrShuttingDown { + if isSFTPGoError(err) { return err } c.Log(logger.LevelError, "generic error: %+v", err) diff --git a/internal/common/transfer.go b/internal/common/transfer.go index aaedbf62..4b1272bd 100644 --- a/internal/common/transfer.go +++ b/internal/common/transfer.go @@ -217,16 +217,11 @@ func (t *BaseTransfer) SetCancelFn(cancelFn func()) { // converts it into a more understandable form for the client if it is a // well-known type of error func (t *BaseTransfer) ConvertError(err error) error { - if t.Fs.IsNotExist(err) { - return t.Connection.GetNotExistError() - } else if t.Fs.IsPermission(err) { - return t.Connection.GetPermissionDeniedError() - } var pathError *fs.PathError if errors.As(err, &pathError) { return fmt.Errorf("%s %s: %s", pathError.Op, t.GetVirtualPath(), pathError.Err.Error()) } - return err + return t.Connection.GetFsError(t.Fs, err) } // CheckRead returns an error if read if not allowed diff --git a/internal/sftpd/internal_test.go b/internal/sftpd/internal_test.go index b7b18e05..7f9464a3 100644 --- a/internal/sftpd/internal_test.go +++ b/internal/sftpd/internal_test.go @@ -1794,7 +1794,7 @@ func TestTransferFailingReader(t *testing.T) { require.NoError(t, err) buf := make([]byte, 32) _, err = transfer.ReadAt(buf, 0) - assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error()) + assert.ErrorIs(t, err, sftp.ErrSSHFxOpUnsupported) if c, ok := transfer.(io.Closer); ok { err = c.Close() assert.NoError(t, err) @@ -1809,14 +1809,14 @@ func TestTransferFailingReader(t *testing.T) { errRead := errors.New("read is not allowed") tr := newTransfer(baseTransfer, nil, r, errRead) _, err = tr.ReadAt(buf, 0) - assert.EqualError(t, err, errRead.Error()) + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) err = tr.Close() assert.NoError(t, err) tr = newTransfer(baseTransfer, nil, nil, errRead) _, err = tr.ReadAt(buf, 0) - assert.EqualError(t, err, errRead.Error()) + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) err = tr.Close() assert.NoError(t, err) diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index 73e13593..6f128cc6 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -4756,7 +4756,9 @@ func TestQuotaLimits(t *testing.T) { err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_FAILURE") - assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + if user.Username == localUser.Username { + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } } _, err = client.Stat(testFileName1) assert.Error(t, err) diff --git a/internal/version/version.go b/internal/version/version.go index 6b57a65c..53f008dd 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -17,7 +17,7 @@ package version import "strings" -const version = "2.5.6" +const version = "2.5.7-dev" var ( commit = ""