Pārlūkot izejas kodu

vfs: make PipeReader an interface

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 1 gadu atpakaļ
vecāks
revīzija
eec9c449d4

+ 1 - 1
internal/ftpd/transfer.go

@@ -32,7 +32,7 @@ type transfer struct {
 	expectedOffset int64
 }
 
-func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader *vfs.PipeReader,
+func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader,
 	expectedOffset int64) *transfer {
 	var writer io.WriteCloser
 	var reader io.ReadCloser

+ 1 - 1
internal/httpd/file.go

@@ -28,7 +28,7 @@ type httpdFile struct {
 	isFinished bool
 }
 
-func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader *vfs.PipeReader) *httpdFile {
+func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader) *httpdFile {
 	var writer io.WriteCloser
 	var reader io.ReadCloser
 	if baseTransfer.File != nil {

+ 2 - 2
internal/sftpd/sftpd_test.go

@@ -8939,8 +8939,8 @@ func TestStatVFSCloudBackend(t *testing.T) {
 	u := getTestUser(usePubKey)
 	u.FsConfig.Provider = sdk.AzureBlobFilesystemProvider
 	u.FsConfig.AzBlobConfig.SASURL = kms.NewPlainSecret("https://myaccount.blob.core.windows.net/sasurl")
-	user, _, err := httpdtest.AddUser(u, http.StatusCreated)
-	assert.NoError(t, err)
+	user, resp, err := httpdtest.AddUser(u, http.StatusCreated)
+	assert.NoError(t, err, string(resp))
 	conn, client, err := getSftpClient(user, usePubKey)
 	if assert.NoError(t, err) {
 		defer conn.Close()

+ 1 - 1
internal/sftpd/transfer.go

@@ -58,7 +58,7 @@ type transfer struct {
 	isFinished bool
 }
 
-func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader *vfs.PipeReader,
+func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader,
 	errForRead error) *transfer {
 	var writer writerAtCloser
 	var reader readerAtCloser

+ 2 - 2
internal/vfs/azblobfs.go

@@ -202,7 +202,7 @@ func (fs *AzureBlobFs) Lstat(name string) (os.FileInfo, error) {
 }
 
 // Open opens the named file for reading
-func (fs *AzureBlobFs) Open(name string, offset int64) (File, *PipeReader, func(), error) {
+func (fs *AzureBlobFs) Open(name string, offset int64) (File, PipeReader, func(), error) {
 	r, w, err := pipeat.PipeInDir(fs.localTempDir)
 	if err != nil {
 		return nil, nil, nil, err
@@ -991,7 +991,7 @@ func (fs *AzureBlobFs) downloadPart(ctx context.Context, blockBlob *blockblob.Cl
 }
 
 func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *blockblob.Client,
-	offset int64, writer io.WriterAt, pipeReader *PipeReader,
+	offset int64, writer io.WriterAt, pipeReader PipeReader,
 ) error {
 	props, err := blockBlob.GetProperties(ctx, &blob.GetPropertiesOptions{})
 	metric.AZHeadObjectCompleted(err)

+ 1 - 1
internal/vfs/cryptfs.go

@@ -79,7 +79,7 @@ func (fs *CryptFs) Name() string {
 }
 
 // Open opens the named file for reading
-func (fs *CryptFs) Open(name string, offset int64) (File, *PipeReader, func(), error) {
+func (fs *CryptFs) Open(name string, offset int64) (File, PipeReader, func(), error) {
 	f, key, err := fs.getFileAndEncryptionKey(name)
 	if err != nil {
 		return nil, nil, nil, err

+ 1 - 1
internal/vfs/gcsfs.go

@@ -126,7 +126,7 @@ func (fs *GCSFs) Lstat(name string) (os.FileInfo, error) {
 }
 
 // Open opens the named file for reading
-func (fs *GCSFs) Open(name string, offset int64) (File, *PipeReader, func(), error) {
+func (fs *GCSFs) Open(name string, offset int64) (File, PipeReader, func(), error) {
 	r, w, err := pipeat.PipeInDir(fs.localTempDir)
 	if err != nil {
 		return nil, nil, nil, err

+ 1 - 1
internal/vfs/httpfs.go

@@ -316,7 +316,7 @@ func (fs *HTTPFs) Lstat(name string) (os.FileInfo, error) {
 }
 
 // Open opens the named file for reading
-func (fs *HTTPFs) Open(name string, offset int64) (File, *PipeReader, func(), error) {
+func (fs *HTTPFs) Open(name string, offset int64) (File, PipeReader, func(), error) {
 	r, w, err := pipeat.PipeInDir(fs.localTempDir)
 	if err != nil {
 		return nil, nil, nil, err

+ 1 - 1
internal/vfs/osfs.go

@@ -101,7 +101,7 @@ func (fs *OsFs) Lstat(name string) (os.FileInfo, error) {
 }
 
 // Open opens the named file for reading
-func (fs *OsFs) Open(name string, offset int64) (File, *PipeReader, func(), error) {
+func (fs *OsFs) Open(name string, offset int64) (File, PipeReader, func(), error) {
 	f, err := os.Open(name)
 	if err != nil {
 		return nil, nil, nil, err

+ 1 - 1
internal/vfs/s3fs.go

@@ -196,7 +196,7 @@ func (fs *S3Fs) Lstat(name string) (os.FileInfo, error) {
 }
 
 // Open opens the named file for reading
-func (fs *S3Fs) Open(name string, offset int64) (File, *PipeReader, func(), error) {
+func (fs *S3Fs) Open(name string, offset int64) (File, PipeReader, func(), error) {
 	r, w, err := pipeat.PipeInDir(fs.localTempDir)
 	if err != nil {
 		return nil, nil, nil, err

+ 1 - 1
internal/vfs/sftpfs.go

@@ -346,7 +346,7 @@ func (fs *SFTPFs) Lstat(name string) (os.FileInfo, error) {
 }
 
 // Open opens the named file for reading
-func (fs *SFTPFs) Open(name string, offset int64) (File, *PipeReader, func(), error) {
+func (fs *SFTPFs) Open(name string, offset int64) (File, PipeReader, func(), error) {
 	client, err := fs.conn.getClient()
 	if err != nil {
 		return nil, nil, nil, err

+ 22 - 9
internal/vfs/vfs.go

@@ -115,7 +115,7 @@ type Fs interface {
 	ConnectionID() string
 	Stat(name string) (os.FileInfo, error)
 	Lstat(name string) (os.FileInfo, error)
-	Open(name string, offset int64) (File, *PipeReader, func(), error)
+	Open(name string, offset int64) (File, PipeReader, func(), error)
 	Create(name string, flag, checks int) (File, PipeWriter, func(), error)
 	Rename(source, target string) (int, int64, error)
 	Remove(name string, isDir bool) error
@@ -189,6 +189,16 @@ type PipeWriter interface {
 	GetWrittenBytes() int64
 }
 
+// PipeReader defines an interface representing a SFTPGo pipe writer
+type PipeReader interface {
+	io.Reader
+	io.ReaderAt
+	io.Closer
+	setMetadata(value map[string]string)
+	setMetadataFromPointerVal(value map[string]*string)
+	Metadata() map[string]string
+}
+
 // Metadater defines an interface to implement to return metadata for a file
 type Metadater interface {
 	Metadata() map[string]string
@@ -628,7 +638,10 @@ func (c *AzBlobFsConfig) ValidateAndEncryptCredentials(additionalData string) er
 func (c *AzBlobFsConfig) checkCredentials() error {
 	if c.SASURL.IsPlain() {
 		_, err := url.Parse(c.SASURL.GetPayload())
-		return util.NewI18nError(err, util.I18nErrorSASURLInvalid)
+		if err != nil {
+			return util.NewI18nError(err, util.I18nErrorSASURLInvalid)
+		}
+		return nil
 	}
 	if c.SASURL.IsEncrypted() && !c.SASURL.IsValid() {
 		return errors.New("invalid encrypted sas_url")
@@ -851,27 +864,27 @@ func (p *pipeWriterAtOffset) Write(buf []byte) (int, error) {
 }
 
 // NewPipeReader initializes a new PipeReader
-func NewPipeReader(r *pipeat.PipeReaderAt) *PipeReader {
-	return &PipeReader{
+func NewPipeReader(r *pipeat.PipeReaderAt) PipeReader {
+	return &pipeReader{
 		PipeReaderAt: r,
 	}
 }
 
-// PipeReader defines a wrapper for pipeat.PipeReaderAt.
-type PipeReader struct {
+// pipeReader defines a wrapper for pipeat.PipeReaderAt.
+type pipeReader struct {
 	*pipeat.PipeReaderAt
 	mu       sync.RWMutex
 	metadata map[string]string
 }
 
-func (p *PipeReader) setMetadata(value map[string]string) {
+func (p *pipeReader) setMetadata(value map[string]string) {
 	p.mu.Lock()
 	defer p.mu.Unlock()
 
 	p.metadata = value
 }
 
-func (p *PipeReader) setMetadataFromPointerVal(value map[string]*string) {
+func (p *pipeReader) setMetadataFromPointerVal(value map[string]*string) {
 	p.mu.Lock()
 	defer p.mu.Unlock()
 
@@ -890,7 +903,7 @@ func (p *PipeReader) setMetadataFromPointerVal(value map[string]*string) {
 }
 
 // Metadata implements the Metadater interface
-func (p *PipeReader) Metadata() map[string]string {
+func (p *pipeReader) Metadata() map[string]string {
 	p.mu.RLock()
 	defer p.mu.RUnlock()
 

+ 1 - 1
internal/webdavd/internal_test.go

@@ -286,7 +286,7 @@ func (fs *MockOsFs) Name() string {
 }
 
 // Open returns nil
-func (fs *MockOsFs) Open(name string, offset int64) (vfs.File, *vfs.PipeReader, func(), error) {
+func (fs *MockOsFs) Open(name string, offset int64) (vfs.File, vfs.PipeReader, func(), error) {
 	if fs.reader != nil {
 		return nil, vfs.NewPipeReader(fs.reader), nil, nil
 	}