diff --git a/docs/httpfs.md b/docs/httpfs.md index 78e06b41..28237f96 100644 --- a/docs/httpfs.md +++ b/docs/httpfs.md @@ -16,3 +16,12 @@ Here is a mapping between HTTP response codes and protocol errors: - `501`, means not supported error - `200`, `201`, mean no error - any other response code means a generic error + +HTTPFs can also connect to UNIX domain sockets. To use UNIX domain sockets you need to set an endpoint with the following conventions: + +- the URL schema can be `http` or `https` as usual. +- The URL host must be `unix`. +- The socket path is mandatory and is set using the `socket_path` query parameter. The path must be query escaped. +- The optional API prefix can be set using the `api_prefix` query parameter. The prefix must be query escaped. + +Here is an example endpoint for UNIX domain socket connections: `http://unix?socket_path=%2Ftmp%2Fsftpgofs.sock&api_prefix=%2Fapi%2Fv1`. In this case we are connecting using the `HTTP` protocol to the socket `/tmp/sftpgofs.sock` and we use the `/api/v1` prefix for API URLs. diff --git a/go.mod b/go.mod index 8825638d..a7a89274 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,7 @@ require ( github.com/minio/sio v0.3.0 github.com/otiai10/copy v1.7.0 github.com/pires/go-proxyproto v0.6.2 - github.com/pkg/sftp v1.13.5-0.20220303113417-dcfc1d5e4162 + github.com/pkg/sftp v1.13.5 github.com/pquerna/otp v1.3.0 github.com/prometheus/client_golang v1.12.2 github.com/robfig/cron/v3 v3.0.1 diff --git a/go.sum b/go.sum index 933869e1..d0651e78 100644 --- a/go.sum +++ b/go.sum @@ -648,8 +648,8 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= -github.com/pkg/sftp v1.13.5-0.20220303113417-dcfc1d5e4162 h1:uJSlAAzEUQq5tpfK+SWIIx/3UJ4EpjAYuMqZpKYrmw4= -github.com/pkg/sftp v1.13.5-0.20220303113417-dcfc1d5e4162/go.mod h1:wHDZ0IZX6JcBYRK1TH9bcVq8G7TLpVHYIGJRFnmPfxg= +github.com/pkg/sftp v1.13.5 h1:a3RLUqkyjYRtBTZJZ1VRrKbN3zhuPLlUc3sphVz81go= +github.com/pkg/sftp v1.13.5/go.mod h1:wHDZ0IZX6JcBYRK1TH9bcVq8G7TLpVHYIGJRFnmPfxg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index 1b67d8a9..afc5a166 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -2435,6 +2435,16 @@ func TestAddUserInvalidFsConfig(t *testing.T) { if assert.NoError(t, err) { assert.Contains(t, string(resp), "invalid endpoint schema") } + u.FsConfig.HTTPConfig.Endpoint = "http://unix?api_prefix=v1" + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid unix domain socket path") + } + u.FsConfig.HTTPConfig.Endpoint = "http://unix?socket_path=test.sock" + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid unix domain socket path") + } } func TestUserRedactedPassword(t *testing.T) { diff --git a/httpdtest/httpfsimpl.go b/httpdtest/httpfsimpl.go index 5c1ab008..2e697a71 100644 --- a/httpdtest/httpfsimpl.go +++ b/httpdtest/httpfsimpl.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "mime" + "net" "net/http" "net/url" "os" @@ -38,6 +39,7 @@ const ( ) // StartTestHTTPFs starts a test HTTP service that implements httpfs +// and listens on the specified port func StartTestHTTPFs(port int) error { fs := httpFsImpl{ port: port, @@ -45,10 +47,20 @@ func StartTestHTTPFs(port int) error { return fs.Run() } +// StartTestHTTPFsOverUnixSocket starts a test HTTP service that implements httpfs +// and listens on the specified UNIX domain socket path +func StartTestHTTPFsOverUnixSocket(socketPath string) error { + fs := httpFsImpl{ + unixSocketPath: socketPath, + } + return fs.Run() +} + type httpFsImpl struct { - router *chi.Mux - basePath string - port int + router *chi.Mux + basePath string + port int + unixSocketPath string } type apiResponse struct { @@ -473,7 +485,18 @@ func (fs *httpFsImpl) Run() error { MaxHeaderBytes: 1 << 16, // 64KB } - return httpServer.ListenAndServe() + if fs.unixSocketPath == "" { + return httpServer.ListenAndServe() + } + err := os.Remove(fs.unixSocketPath) + if err != nil && !os.IsNotExist(err) { + return err + } + listener, err := net.Listen("unix", fs.unixSocketPath) + if err != nil { + return err + } + return httpServer.Serve(listener) } func getStatFromInfo(info os.FileInfo) map[string]any { diff --git a/sftpd/httpfs_test.go b/sftpd/httpfs_test.go index 76524ab1..8e85044c 100644 --- a/sftpd/httpfs_test.go +++ b/sftpd/httpfs_test.go @@ -5,9 +5,11 @@ import ( "io/fs" "math" "net/http" + "net/url" "os" "path" "path/filepath" + "runtime" "testing" "time" @@ -26,6 +28,10 @@ const ( defaultHTTPFsUsername = "httpfs_user" ) +var ( + httpFsSocketPath = filepath.Join(os.TempDir(), "httpfs.sock") +) + func TestBasicHTTPFsHandling(t *testing.T) { usePubKey := true u := getTestUserWithHTTPFs(usePubKey) @@ -271,6 +277,49 @@ func TestHTTPFsWalk(t *testing.T) { assert.NoError(t, err) } +func TestHTTPFsOverUNIXSocket(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("UNIX domain sockets are not supported on Windows") + } + assert.Eventually(t, func() bool { + _, err := os.Stat(httpFsSocketPath) + return err == nil + }, 1*time.Second, 50*time.Millisecond) + usePubKey := true + u := getTestUserWithHTTPFs(usePubKey) + u.FsConfig.HTTPConfig.Endpoint = fmt.Sprintf("http://unix?socket_path=%s&api_prefix=%s", + url.QueryEscape(httpFsSocketPath), url.QueryEscape("/api/v1")) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = checkBasicSFTP(client) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = client.Mkdir(testFileName) + assert.NoError(t, err) + err = client.RemoveDirectory(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func getTestUserWithHTTPFs(usePubKey bool) dataprovider.User { u := getTestUser(usePubKey) u.FsConfig.Provider = sdk.HTTPFilesystemProvider @@ -284,6 +333,14 @@ func getTestUserWithHTTPFs(usePubKey bool) dataprovider.User { } func startHTTPFs() { + if runtime.GOOS != osWindows { + go func() { + if err := httpdtest.StartTestHTTPFsOverUnixSocket(httpFsSocketPath); err != nil { + logger.ErrorToConsole("could not start HTTPfs test server over UNIX socket: %v", err) + os.Exit(1) + } + }() + } go func() { if err := httpdtest.StartTestHTTPFs(httpFsPort); err != nil { logger.ErrorToConsole("could not start HTTPfs test server: %v", err) diff --git a/vfs/httpfs.go b/vfs/httpfs.go index f721c7d8..059190c8 100644 --- a/vfs/httpfs.go +++ b/vfs/httpfs.go @@ -9,6 +9,7 @@ import ( "io" "io/fs" "mime" + "net" "net/http" "net/url" "os" @@ -42,6 +43,10 @@ type HTTPFsConfig struct { APIKey *kms.Secret `json:"api_key,omitempty"` } +func (c *HTTPFsConfig) isUnixDomainSocket() bool { + return strings.HasPrefix(c.Endpoint, "http://unix") || strings.HasPrefix(c.Endpoint, "https://unix") +} + // HideConfidentialData hides confidential data func (c *HTTPFsConfig) HideConfidentialData() { if c.Password != nil { @@ -95,13 +100,19 @@ func (c *HTTPFsConfig) validate() error { return errors.New("httpfs: endpoint cannot be empty") } c.Endpoint = strings.TrimRight(c.Endpoint, "/") - _, err := url.Parse(c.Endpoint) + endpointURL, err := url.Parse(c.Endpoint) if err != nil { return fmt.Errorf("httpfs: invalid endpoint: %w", err) } if !util.IsStringPrefixInSlice(c.Endpoint, supportedEndpointSchema) { return errors.New("httpfs: invalid endpoint schema: http and https are supported") } + if endpointURL.Host == "unix" { + socketPath := endpointURL.Query().Get("socket_path") + if !filepath.IsAbs(socketPath) { + return fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath) + } + } if c.Password.IsEncrypted() && !c.Password.IsValid() { return errors.New("httpfs: invalid encrypted password") } @@ -179,14 +190,43 @@ func NewHTTPFs(connectionID, localTempDir, mountPath string, config HTTPFsConfig transport.MaxResponseHeaderBytes = 1 << 16 transport.WriteBufferSize = 1 << 16 transport.ReadBufferSize = 1 << 16 + if fs.config.isUnixDomainSocket() { + endpointURL, err := url.Parse(fs.config.Endpoint) + if err != nil { + return nil, err + } + if endpointURL.Host == "unix" { + socketPath := endpointURL.Query().Get("socket_path") + if !filepath.IsAbs(socketPath) { + return nil, fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath) + } + if endpointURL.Scheme == "https" { + transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + var tlsConfig *tls.Config + var d tls.Dialer + if config.SkipTLSVerify { + tlsConfig = getInsecureTLSConfig() + } + d.Config = tlsConfig + return d.DialContext(ctx, "unix", socketPath) + } + } else { + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath) + } + } + endpointURL.Path = path.Join(endpointURL.Path, endpointURL.Query().Get("api_prefix")) + endpointURL.RawQuery = "" + endpointURL.RawFragment = "" + fs.config.Endpoint = endpointURL.String() + } + } if config.SkipTLSVerify { if transport.TLSClientConfig != nil { transport.TLSClientConfig.InsecureSkipVerify = true } else { - transport.TLSClientConfig = &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, - InsecureSkipVerify: true, - } + transport.TLSClientConfig = getInsecureTLSConfig() } } fs.client = &http.Client{ @@ -646,6 +686,13 @@ func getErrorFromResponseCode(code int) error { } } +func getInsecureTLSConfig() *tls.Config { + return &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + InsecureSkipVerify: true, + } +} + type wrapReader struct { reader io.Reader }