sftpgo/vfs/sftpfs.go
Nicola Murino 4a6a4ce28d
sftpfs: map path resolution error to permission denied
we do the same for os fs so that the problematic directory is excluded
from the webdav listing instead of failing the whole directory listing
2021-10-16 10:32:18 +02:00

800 lines
21 KiB
Go

package vfs
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
"github.com/eikenb/pipeat"
"github.com/pkg/sftp"
"github.com/rs/xid"
"golang.org/x/crypto/ssh"
"github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/sdk"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/version"
)
const (
// sftpFsName is the name for the SFTP Fs implementation
sftpFsName = "sftpfs"
)
// ErrSFTPLoop defines the error to return if an SFTP loop is detected
var ErrSFTPLoop = errors.New("SFTP loop or nested local SFTP folders detected")
// SFTPFsConfig defines the configuration for SFTP based filesystem
type SFTPFsConfig struct {
sdk.SFTPFsConfig
forbiddenSelfUsernames []string `json:"-"`
}
// HideConfidentialData hides confidential data
func (c *SFTPFsConfig) HideConfidentialData() {
if c.Password != nil {
c.Password.Hide()
}
if c.PrivateKey != nil {
c.PrivateKey.Hide()
}
}
func (c *SFTPFsConfig) isEqual(other *SFTPFsConfig) bool {
if c.Endpoint != other.Endpoint {
return false
}
if c.Username != other.Username {
return false
}
if c.Prefix != other.Prefix {
return false
}
if c.DisableCouncurrentReads != other.DisableCouncurrentReads {
return false
}
if c.BufferSize != other.BufferSize {
return false
}
if len(c.Fingerprints) != len(other.Fingerprints) {
return false
}
for _, fp := range c.Fingerprints {
if !util.IsStringInSlice(fp, other.Fingerprints) {
return false
}
}
c.setEmptyCredentialsIfNil()
other.setEmptyCredentialsIfNil()
if !c.Password.IsEqual(other.Password) {
return false
}
return c.PrivateKey.IsEqual(other.PrivateKey)
}
func (c *SFTPFsConfig) setEmptyCredentialsIfNil() {
if c.Password == nil {
c.Password = kms.NewEmptySecret()
}
if c.PrivateKey == nil {
c.PrivateKey = kms.NewEmptySecret()
}
}
// Validate returns an error if the configuration is not valid
func (c *SFTPFsConfig) Validate() error {
c.setEmptyCredentialsIfNil()
if c.Endpoint == "" {
return errors.New("endpoint cannot be empty")
}
_, _, err := net.SplitHostPort(c.Endpoint)
if err != nil {
return fmt.Errorf("invalid endpoint: %v", err)
}
if c.Username == "" {
return errors.New("username cannot be empty")
}
if c.BufferSize < 0 || c.BufferSize > 16 {
return errors.New("invalid buffer_size, valid range is 0-16")
}
if err := c.validateCredentials(); err != nil {
return err
}
if c.Prefix != "" {
c.Prefix = util.CleanPath(c.Prefix)
} else {
c.Prefix = "/"
}
return nil
}
func (c *SFTPFsConfig) validateCredentials() error {
if c.Password.IsEmpty() && c.PrivateKey.IsEmpty() {
return errors.New("credentials cannot be empty")
}
if c.Password.IsEncrypted() && !c.Password.IsValid() {
return errors.New("invalid encrypted password")
}
if !c.Password.IsEmpty() && !c.Password.IsValidInput() {
return errors.New("invalid password")
}
if c.PrivateKey.IsEncrypted() && !c.PrivateKey.IsValid() {
return errors.New("invalid encrypted private key")
}
if !c.PrivateKey.IsEmpty() && !c.PrivateKey.IsValidInput() {
return errors.New("invalid private key")
}
return nil
}
// EncryptCredentials encrypts password and/or private key if they are in plain text
func (c *SFTPFsConfig) EncryptCredentials(additionalData string) error {
if c.Password.IsPlain() {
c.Password.SetAdditionalData(additionalData)
if err := c.Password.Encrypt(); err != nil {
return err
}
}
if c.PrivateKey.IsPlain() {
c.PrivateKey.SetAdditionalData(additionalData)
if err := c.PrivateKey.Encrypt(); err != nil {
return err
}
}
return nil
}
// SFTPFs is a Fs implementation for SFTP backends
type SFTPFs struct {
sync.Mutex
connectionID string
// if not empty this fs is mouted as virtual folder in the specified path
mountPath string
localTempDir string
config *SFTPFsConfig
sshClient *ssh.Client
sftpClient *sftp.Client
err chan error
}
// NewSFTPFs returns an SFTPFs object that allows to interact with an SFTP server
func NewSFTPFs(connectionID, mountPath, localTempDir string, forbiddenSelfUsernames []string, config SFTPFsConfig) (Fs, error) {
if localTempDir == "" {
if tempPath != "" {
localTempDir = tempPath
} else {
localTempDir = filepath.Clean(os.TempDir())
}
}
if err := config.Validate(); err != nil {
return nil, err
}
if !config.Password.IsEmpty() {
if err := config.Password.TryDecrypt(); err != nil {
return nil, err
}
}
if !config.PrivateKey.IsEmpty() {
if err := config.PrivateKey.TryDecrypt(); err != nil {
return nil, err
}
}
config.forbiddenSelfUsernames = forbiddenSelfUsernames
sftpFs := &SFTPFs{
connectionID: connectionID,
mountPath: mountPath,
localTempDir: localTempDir,
config: &config,
err: make(chan error, 1),
}
err := sftpFs.createConnection()
return sftpFs, err
}
// Name returns the name for the Fs implementation
func (fs *SFTPFs) Name() string {
return fmt.Sprintf("%v %#v", sftpFsName, fs.config.Endpoint)
}
// ConnectionID returns the connection ID associated to this Fs implementation
func (fs *SFTPFs) ConnectionID() string {
return fs.connectionID
}
// Stat returns a FileInfo describing the named file
func (fs *SFTPFs) Stat(name string) (os.FileInfo, error) {
if err := fs.checkConnection(); err != nil {
return nil, err
}
return fs.sftpClient.Stat(name)
}
// Lstat returns a FileInfo describing the named file
func (fs *SFTPFs) Lstat(name string) (os.FileInfo, error) {
if err := fs.checkConnection(); err != nil {
return nil, err
}
return fs.sftpClient.Lstat(name)
}
// Open opens the named file for reading
func (fs *SFTPFs) Open(name string, offset int64) (File, *pipeat.PipeReaderAt, func(), error) {
if err := fs.checkConnection(); err != nil {
return nil, nil, nil, err
}
f, err := fs.sftpClient.Open(name)
if err != nil {
return nil, nil, nil, err
}
if fs.config.BufferSize == 0 {
return f, nil, nil, err
}
if offset > 0 {
_, err = f.Seek(offset, io.SeekStart)
if err != nil {
f.Close()
return nil, nil, nil, err
}
}
r, w, err := pipeat.PipeInDir(fs.localTempDir)
if err != nil {
f.Close()
return nil, nil, nil, err
}
go func() {
// if we enable buffering the client stalls
//br := bufio.NewReaderSize(f, int(fs.config.BufferSize)*1024*1024)
//n, err := fs.copy(w, br)
n, err := io.Copy(w, f)
w.CloseWithError(err) //nolint:errcheck
f.Close()
fsLog(fs, logger.LevelDebug, "download completed, path: %#v size: %v, err: %v", name, n, err)
}()
return nil, r, nil, nil
}
// Create creates or opens the named file for writing
func (fs *SFTPFs) Create(name string, flag int) (File, *PipeWriter, func(), error) {
err := fs.checkConnection()
if err != nil {
return nil, nil, nil, err
}
if fs.config.BufferSize == 0 {
var f File
if flag == 0 {
f, err = fs.sftpClient.Create(name)
} else {
f, err = fs.sftpClient.OpenFile(name, flag)
}
return f, nil, nil, err
}
// buffering is enabled
f, err := fs.sftpClient.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC)
if err != nil {
return nil, nil, nil, err
}
r, w, err := pipeat.PipeInDir(fs.localTempDir)
if err != nil {
f.Close()
return nil, nil, nil, err
}
p := NewPipeWriter(w)
go func() {
bw := bufio.NewWriterSize(f, int(fs.config.BufferSize)*1024*1024)
// we don't use io.Copy since bufio.Writer implements io.WriterTo and
// so it calls the sftp.File WriteTo method without buffering
n, err := fs.copy(bw, r)
errFlush := bw.Flush()
if err == nil && errFlush != nil {
err = errFlush
}
var errTruncate error
if err != nil {
errTruncate = f.Truncate(n)
}
errClose := f.Close()
if err == nil && errClose != nil {
err = errClose
}
r.CloseWithError(err) //nolint:errcheck
p.Done(err)
fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, readed bytes: %v, err: %v err truncate: %v",
name, n, err, errTruncate)
}()
return nil, p, nil, nil
}
// Rename renames (moves) source to target.
func (fs *SFTPFs) Rename(source, target string) error {
if err := fs.checkConnection(); err != nil {
return err
}
return fs.sftpClient.Rename(source, target)
}
// Remove removes the named file or (empty) directory.
func (fs *SFTPFs) Remove(name string, isDir bool) error {
if err := fs.checkConnection(); err != nil {
return err
}
return fs.sftpClient.Remove(name)
}
// Mkdir creates a new directory with the specified name and default permissions
func (fs *SFTPFs) Mkdir(name string) error {
if err := fs.checkConnection(); err != nil {
return err
}
return fs.sftpClient.Mkdir(name)
}
// MkdirAll creates a directory named path, along with any necessary parents,
// and returns nil, or else returns an error.
// If path is already a directory, MkdirAll does nothing and returns nil.
func (fs *SFTPFs) MkdirAll(name string, uid int, gid int) error {
if err := fs.checkConnection(); err != nil {
return err
}
return fs.sftpClient.MkdirAll(name)
}
// Symlink creates source as a symbolic link to target.
func (fs *SFTPFs) Symlink(source, target string) error {
if err := fs.checkConnection(); err != nil {
return err
}
return fs.sftpClient.Symlink(source, target)
}
// Readlink returns the destination of the named symbolic link
func (fs *SFTPFs) Readlink(name string) (string, error) {
if err := fs.checkConnection(); err != nil {
return "", err
}
return fs.sftpClient.ReadLink(name)
}
// Chown changes the numeric uid and gid of the named file.
func (fs *SFTPFs) Chown(name string, uid int, gid int) error {
if err := fs.checkConnection(); err != nil {
return err
}
return fs.sftpClient.Chown(name, uid, gid)
}
// Chmod changes the mode of the named file to mode.
func (fs *SFTPFs) Chmod(name string, mode os.FileMode) error {
if err := fs.checkConnection(); err != nil {
return err
}
return fs.sftpClient.Chmod(name, mode)
}
// Chtimes changes the access and modification times of the named file.
func (fs *SFTPFs) Chtimes(name string, atime, mtime time.Time) error {
if err := fs.checkConnection(); err != nil {
return err
}
return fs.sftpClient.Chtimes(name, atime, mtime)
}
// Truncate changes the size of the named file.
func (fs *SFTPFs) Truncate(name string, size int64) error {
if err := fs.checkConnection(); err != nil {
return err
}
return fs.sftpClient.Truncate(name, size)
}
// ReadDir reads the directory named by dirname and returns
// a list of directory entries.
func (fs *SFTPFs) ReadDir(dirname string) ([]os.FileInfo, error) {
if err := fs.checkConnection(); err != nil {
return nil, err
}
return fs.sftpClient.ReadDir(dirname)
}
// IsUploadResumeSupported returns true if resuming uploads is supported.
func (fs *SFTPFs) IsUploadResumeSupported() bool {
return fs.config.BufferSize == 0
}
// IsAtomicUploadSupported returns true if atomic upload is supported.
func (fs *SFTPFs) IsAtomicUploadSupported() bool {
return fs.config.BufferSize == 0
}
// IsNotExist returns a boolean indicating whether the error is known to
// report that a file or directory does not exist
func (*SFTPFs) IsNotExist(err error) bool {
return os.IsNotExist(err)
}
// IsPermission returns a boolean indicating whether the error is known to
// report that permission is denied.
func (*SFTPFs) IsPermission(err error) bool {
if _, ok := err.(*pathResolutionError); ok {
return true
}
return os.IsPermission(err)
}
// IsNotSupported returns true if the error indicate an unsupported operation
func (*SFTPFs) IsNotSupported(err error) bool {
if err == nil {
return false
}
return err == ErrVfsUnsupported
}
// CheckRootPath creates the specified local root directory if it does not exists
func (fs *SFTPFs) CheckRootPath(username string, uid int, gid int) bool {
if fs.config.BufferSize > 0 {
// we need a local directory for temporary files
osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "")
osFs.CheckRootPath(username, uid, gid)
}
if fs.config.Prefix == "/" {
return true
}
if err := fs.MkdirAll(fs.config.Prefix, uid, gid); err != nil {
fsLog(fs, logger.LevelDebug, "error creating root directory %#v for user %#v: %v", fs.config.Prefix, username, err)
return false
}
return true
}
// ScanRootDirContents returns the number of files contained in a directory and
// their size
func (fs *SFTPFs) ScanRootDirContents() (int, int64, error) {
return fs.GetDirSize(fs.config.Prefix)
}
// GetAtomicUploadPath returns the path to use for an atomic upload
func (*SFTPFs) GetAtomicUploadPath(name string) string {
dir := path.Dir(name)
guid := xid.New().String()
return path.Join(dir, ".sftpgo-upload."+guid+"."+path.Base(name))
}
// GetRelativePath returns the path for a file relative to the sftp prefix if any.
// This is the path as seen by SFTPGo users
func (fs *SFTPFs) GetRelativePath(name string) string {
rel := path.Clean(name)
if rel == "." {
rel = ""
}
if !path.IsAbs(rel) {
return "/" + rel
}
if fs.config.Prefix != "/" {
if !strings.HasPrefix(rel, fs.config.Prefix) {
rel = "/"
}
rel = path.Clean("/" + strings.TrimPrefix(rel, fs.config.Prefix))
}
if fs.mountPath != "" {
rel = path.Join(fs.mountPath, rel)
}
return rel
}
// Walk walks the file tree rooted at root, calling walkFn for each file or
// directory in the tree, including root
func (fs *SFTPFs) Walk(root string, walkFn filepath.WalkFunc) error {
if err := fs.checkConnection(); err != nil {
return err
}
walker := fs.sftpClient.Walk(root)
for walker.Step() {
err := walker.Err()
if err != nil {
return err
}
err = walkFn(walker.Path(), walker.Stat(), err)
if err != nil {
return err
}
}
return nil
}
// Join joins any number of path elements into a single path
func (*SFTPFs) Join(elem ...string) string {
return path.Join(elem...)
}
// HasVirtualFolders returns true if folders are emulated
func (*SFTPFs) HasVirtualFolders() bool {
return false
}
// ResolvePath returns the matching filesystem path for the specified virtual path
func (fs *SFTPFs) ResolvePath(virtualPath string) (string, error) {
if fs.mountPath != "" {
virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath)
}
if !path.IsAbs(virtualPath) {
virtualPath = path.Clean("/" + virtualPath)
}
fsPath := fs.Join(fs.config.Prefix, virtualPath)
if fs.config.Prefix != "/" && fsPath != "/" {
// we need to check if this path is a symlink outside the given prefix
// or a file/dir inside a dir symlinked outside the prefix
if err := fs.checkConnection(); err != nil {
return "", err
}
var validatedPath string
var err error
validatedPath, err = fs.getRealPath(fsPath)
if err != nil && !os.IsNotExist(err) {
fsLog(fs, logger.LevelWarn, "Invalid path resolution, original path %v resolved %#v err: %v",
virtualPath, fsPath, err)
return "", err
} else if os.IsNotExist(err) {
for os.IsNotExist(err) {
validatedPath = path.Dir(validatedPath)
if validatedPath == "/" {
err = nil
break
}
validatedPath, err = fs.getRealPath(validatedPath)
}
if err != nil {
fsLog(fs, logger.LevelWarn, "Invalid path resolution, dir %#v original path %#v resolved %#v err: %v",
validatedPath, virtualPath, fsPath, err)
return "", err
}
}
if err := fs.isSubDir(validatedPath); err != nil {
fsLog(fs, logger.LevelWarn, "Invalid path resolution, dir %#v original path %#v resolved %#v err: %v",
validatedPath, virtualPath, fsPath, err)
return "", err
}
}
return fsPath, nil
}
// getRealPath returns the real remote path trying to resolve symbolic links if any
func (fs *SFTPFs) getRealPath(name string) (string, error) {
info, err := fs.sftpClient.Lstat(name)
if err != nil {
return name, err
}
if info.Mode()&os.ModeSymlink != 0 {
return fs.sftpClient.ReadLink(name)
}
return name, err
}
func (fs *SFTPFs) isSubDir(name string) error {
if name == fs.config.Prefix {
return nil
}
if len(name) < len(fs.config.Prefix) {
err := fmt.Errorf("path %#v is not inside: %#v", name, fs.config.Prefix)
return &pathResolutionError{err: err.Error()}
}
if !strings.HasPrefix(name, fs.config.Prefix+"/") {
err := fmt.Errorf("path %#v is not inside: %#v", name, fs.config.Prefix)
return &pathResolutionError{err: err.Error()}
}
return nil
}
// GetDirSize returns the number of files and the size for a folder
// including any subfolders
func (fs *SFTPFs) GetDirSize(dirname string) (int, int64, error) {
numFiles := 0
size := int64(0)
if err := fs.checkConnection(); err != nil {
return numFiles, size, err
}
isDir, err := IsDirectory(fs, dirname)
if err == nil && isDir {
walker := fs.sftpClient.Walk(dirname)
for walker.Step() {
err := walker.Err()
if err != nil {
return numFiles, size, err
}
if walker.Stat().Mode().IsRegular() {
size += walker.Stat().Size()
numFiles++
}
}
}
return numFiles, size, err
}
// GetMimeType returns the content type
func (fs *SFTPFs) GetMimeType(name string) (string, error) {
if err := fs.checkConnection(); err != nil {
return "", err
}
f, err := fs.sftpClient.OpenFile(name, os.O_RDONLY)
if err != nil {
return "", err
}
defer f.Close()
var buf [512]byte
n, err := io.ReadFull(f, buf[:])
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return "", err
}
ctype := http.DetectContentType(buf[:n])
// Rewind file.
_, err = f.Seek(0, io.SeekStart)
return ctype, err
}
// GetAvailableDiskSize return the available size for the specified path
func (fs *SFTPFs) GetAvailableDiskSize(dirName string) (*sftp.StatVFS, error) {
if err := fs.checkConnection(); err != nil {
return nil, err
}
if _, ok := fs.sftpClient.HasExtension("statvfs@openssh.com"); !ok {
return nil, ErrStorageSizeUnavailable
}
return fs.sftpClient.StatVFS(dirName)
}
// Close the connection
func (fs *SFTPFs) Close() error {
fs.Lock()
defer fs.Unlock()
var sftpErr, sshErr error
if fs.sftpClient != nil {
sftpErr = fs.sftpClient.Close()
}
if fs.sshClient != nil {
sshErr = fs.sshClient.Close()
}
if sftpErr != nil {
return sftpErr
}
return sshErr
}
func (fs *SFTPFs) copy(dst io.Writer, src io.Reader) (written int64, err error) {
buf := make([]byte, 32768)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = errors.New("invalid write")
}
}
written += int64(nw)
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
func (fs *SFTPFs) checkConnection() error {
err := fs.closed()
if err == nil {
return nil
}
return fs.createConnection()
}
func (fs *SFTPFs) createConnection() error {
fs.Lock()
defer fs.Unlock()
var err error
clientConfig := &ssh.ClientConfig{
User: fs.config.Username,
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
fp := ssh.FingerprintSHA256(key)
if util.IsStringInSlice(fp, sftpFingerprints) {
if util.IsStringInSlice(fs.config.Username, fs.config.forbiddenSelfUsernames) {
fsLog(fs, logger.LevelWarn, "SFTP loop or nested local SFTP folders detected, mount path %#v, username %#v, forbidden usernames: %+v",
fs.mountPath, fs.config.Username, fs.config.forbiddenSelfUsernames)
return ErrSFTPLoop
}
}
if len(fs.config.Fingerprints) > 0 {
for _, provided := range fs.config.Fingerprints {
if provided == fp {
return nil
}
}
return fmt.Errorf("invalid fingerprint %#v", fp)
}
fsLog(fs, logger.LevelWarn, "login without host key validation, please provide at least a fingerprint!")
return nil
},
ClientVersion: fmt.Sprintf("SSH-2.0-SFTPGo_%v", version.Get().Version),
}
if fs.config.PrivateKey.GetPayload() != "" {
signer, err := ssh.ParsePrivateKey([]byte(fs.config.PrivateKey.GetPayload()))
if err != nil {
fs.err <- err
return err
}
clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(signer))
}
if fs.config.Password.GetPayload() != "" {
clientConfig.Auth = append(clientConfig.Auth, ssh.Password(fs.config.Password.GetPayload()))
}
fs.sshClient, err = ssh.Dial("tcp", fs.config.Endpoint, clientConfig)
if err != nil {
fs.err <- err
return err
}
fs.sftpClient, err = sftp.NewClient(fs.sshClient)
if err != nil {
fs.sshClient.Close()
fs.err <- err
return err
}
if fs.config.DisableCouncurrentReads {
fsLog(fs, logger.LevelDebug, "disabling concurrent reads")
opt := sftp.UseConcurrentReads(false)
opt(fs.sftpClient) //nolint:errcheck
}
if fs.config.BufferSize > 0 {
fsLog(fs, logger.LevelDebug, "enabling concurrent writes")
opt := sftp.UseConcurrentWrites(true)
opt(fs.sftpClient) //nolint:errcheck
}
go fs.wait()
return nil
}
func (fs *SFTPFs) wait() {
// we wait on the sftp client otherwise if the channel is closed but not the connection
// we don't detect the event.
fs.err <- fs.sftpClient.Wait()
fsLog(fs, logger.LevelDebug, "sftp channel closed")
fs.Lock()
defer fs.Unlock()
if fs.sshClient != nil {
fs.sshClient.Close()
}
}
func (fs *SFTPFs) closed() error {
select {
case err := <-fs.err:
return err
default:
return nil
}
}