فهرست منبع

httpfs: add support for UNIX domain sockets

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 3 سال پیش
والد
کامیت
6f4475ff72
7فایلهای تغییر یافته به همراه158 افزوده شده و 12 حذف شده
  1. 9 0
      docs/httpfs.md
  2. 1 1
      go.mod
  3. 2 2
      go.sum
  4. 10 0
      httpd/httpd_test.go
  5. 27 4
      httpdtest/httpfsimpl.go
  6. 57 0
      sftpd/httpfs_test.go
  7. 52 5
      vfs/httpfs.go

+ 9 - 0
docs/httpfs.md

@@ -16,3 +16,12 @@ Here is a mapping between HTTP response codes and protocol errors:
 - `501`, means not supported error
 - `200`, `201`, mean no error
 - any other response code means a generic error
+
+HTTPFs can also connect to UNIX domain sockets. To use UNIX domain sockets you need to set an endpoint with the following conventions:
+
+- the URL schema can be `http` or `https` as usual.
+- The URL host must be `unix`.
+- The socket path is mandatory and is set using the `socket_path` query parameter. The path must be query escaped.
+- The optional API prefix can be set using the `api_prefix` query parameter. The prefix must be query escaped.
+
+Here is an example endpoint for UNIX domain socket connections: `http://unix?socket_path=%2Ftmp%2Fsftpgofs.sock&api_prefix=%2Fapi%2Fv1`. In this case we are connecting using the `HTTP` protocol to the socket `/tmp/sftpgofs.sock` and we use the `/api/v1` prefix for API URLs.

+ 1 - 1
go.mod

@@ -44,7 +44,7 @@ require (
 	github.com/minio/sio v0.3.0
 	github.com/otiai10/copy v1.7.0
 	github.com/pires/go-proxyproto v0.6.2
-	github.com/pkg/sftp v1.13.5-0.20220303113417-dcfc1d5e4162
+	github.com/pkg/sftp v1.13.5
 	github.com/pquerna/otp v1.3.0
 	github.com/prometheus/client_golang v1.12.2
 	github.com/robfig/cron/v3 v3.0.1

+ 2 - 2
go.sum

@@ -648,8 +648,8 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
-github.com/pkg/sftp v1.13.5-0.20220303113417-dcfc1d5e4162 h1:uJSlAAzEUQq5tpfK+SWIIx/3UJ4EpjAYuMqZpKYrmw4=
-github.com/pkg/sftp v1.13.5-0.20220303113417-dcfc1d5e4162/go.mod h1:wHDZ0IZX6JcBYRK1TH9bcVq8G7TLpVHYIGJRFnmPfxg=
+github.com/pkg/sftp v1.13.5 h1:a3RLUqkyjYRtBTZJZ1VRrKbN3zhuPLlUc3sphVz81go=
+github.com/pkg/sftp v1.13.5/go.mod h1:wHDZ0IZX6JcBYRK1TH9bcVq8G7TLpVHYIGJRFnmPfxg=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=

+ 10 - 0
httpd/httpd_test.go

@@ -2435,6 +2435,16 @@ func TestAddUserInvalidFsConfig(t *testing.T) {
 	if assert.NoError(t, err) {
 		assert.Contains(t, string(resp), "invalid endpoint schema")
 	}
+	u.FsConfig.HTTPConfig.Endpoint = "http://unix?api_prefix=v1"
+	_, resp, err = httpdtest.AddUser(u, http.StatusBadRequest)
+	if assert.NoError(t, err) {
+		assert.Contains(t, string(resp), "invalid unix domain socket path")
+	}
+	u.FsConfig.HTTPConfig.Endpoint = "http://unix?socket_path=test.sock"
+	_, resp, err = httpdtest.AddUser(u, http.StatusBadRequest)
+	if assert.NoError(t, err) {
+		assert.Contains(t, string(resp), "invalid unix domain socket path")
+	}
 }
 
 func TestUserRedactedPassword(t *testing.T) {

+ 27 - 4
httpdtest/httpfsimpl.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"io"
 	"mime"
+	"net"
 	"net/http"
 	"net/url"
 	"os"
@@ -38,6 +39,7 @@ const (
 )
 
 // StartTestHTTPFs starts a test HTTP service that implements httpfs
+// and listens on the specified port
 func StartTestHTTPFs(port int) error {
 	fs := httpFsImpl{
 		port: port,
@@ -45,10 +47,20 @@ func StartTestHTTPFs(port int) error {
 	return fs.Run()
 }
 
+// StartTestHTTPFsOverUnixSocket starts a test HTTP service that implements httpfs
+// and listens on the specified UNIX domain socket path
+func StartTestHTTPFsOverUnixSocket(socketPath string) error {
+	fs := httpFsImpl{
+		unixSocketPath: socketPath,
+	}
+	return fs.Run()
+}
+
 type httpFsImpl struct {
-	router   *chi.Mux
-	basePath string
-	port     int
+	router         *chi.Mux
+	basePath       string
+	port           int
+	unixSocketPath string
 }
 
 type apiResponse struct {
@@ -473,7 +485,18 @@ func (fs *httpFsImpl) Run() error {
 		MaxHeaderBytes: 1 << 16, // 64KB
 	}
 
-	return httpServer.ListenAndServe()
+	if fs.unixSocketPath == "" {
+		return httpServer.ListenAndServe()
+	}
+	err := os.Remove(fs.unixSocketPath)
+	if err != nil && !os.IsNotExist(err) {
+		return err
+	}
+	listener, err := net.Listen("unix", fs.unixSocketPath)
+	if err != nil {
+		return err
+	}
+	return httpServer.Serve(listener)
 }
 
 func getStatFromInfo(info os.FileInfo) map[string]any {

+ 57 - 0
sftpd/httpfs_test.go

@@ -5,9 +5,11 @@ import (
 	"io/fs"
 	"math"
 	"net/http"
+	"net/url"
 	"os"
 	"path"
 	"path/filepath"
+	"runtime"
 	"testing"
 	"time"
 
@@ -26,6 +28,10 @@ const (
 	defaultHTTPFsUsername = "httpfs_user"
 )
 
+var (
+	httpFsSocketPath = filepath.Join(os.TempDir(), "httpfs.sock")
+)
+
 func TestBasicHTTPFsHandling(t *testing.T) {
 	usePubKey := true
 	u := getTestUserWithHTTPFs(usePubKey)
@@ -271,6 +277,49 @@ func TestHTTPFsWalk(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestHTTPFsOverUNIXSocket(t *testing.T) {
+	if runtime.GOOS == osWindows {
+		t.Skip("UNIX domain sockets are not supported on Windows")
+	}
+	assert.Eventually(t, func() bool {
+		_, err := os.Stat(httpFsSocketPath)
+		return err == nil
+	}, 1*time.Second, 50*time.Millisecond)
+	usePubKey := true
+	u := getTestUserWithHTTPFs(usePubKey)
+	u.FsConfig.HTTPConfig.Endpoint = fmt.Sprintf("http://unix?socket_path=%s&api_prefix=%s",
+		url.QueryEscape(httpFsSocketPath), url.QueryEscape("/api/v1"))
+	user, _, err := httpdtest.AddUser(u, http.StatusCreated)
+	assert.NoError(t, err)
+	conn, client, err := getSftpClient(user, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+
+		err = checkBasicSFTP(client)
+		assert.NoError(t, err)
+		testFilePath := filepath.Join(homeBasePath, testFileName)
+		testFileSize := int64(65535)
+		err = createTestFile(testFilePath, testFileSize)
+		assert.NoError(t, err)
+		err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
+		assert.NoError(t, err)
+		err = client.Remove(testFileName)
+		assert.NoError(t, err)
+		err = client.Mkdir(testFileName)
+		assert.NoError(t, err)
+		err = client.RemoveDirectory(testFileName)
+		assert.NoError(t, err)
+		err = os.Remove(testFilePath)
+		assert.NoError(t, err)
+	}
+
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+}
+
 func getTestUserWithHTTPFs(usePubKey bool) dataprovider.User {
 	u := getTestUser(usePubKey)
 	u.FsConfig.Provider = sdk.HTTPFilesystemProvider
@@ -284,6 +333,14 @@ func getTestUserWithHTTPFs(usePubKey bool) dataprovider.User {
 }
 
 func startHTTPFs() {
+	if runtime.GOOS != osWindows {
+		go func() {
+			if err := httpdtest.StartTestHTTPFsOverUnixSocket(httpFsSocketPath); err != nil {
+				logger.ErrorToConsole("could not start HTTPfs test server over UNIX socket: %v", err)
+				os.Exit(1)
+			}
+		}()
+	}
 	go func() {
 		if err := httpdtest.StartTestHTTPFs(httpFsPort); err != nil {
 			logger.ErrorToConsole("could not start HTTPfs test server: %v", err)

+ 52 - 5
vfs/httpfs.go

@@ -9,6 +9,7 @@ import (
 	"io"
 	"io/fs"
 	"mime"
+	"net"
 	"net/http"
 	"net/url"
 	"os"
@@ -42,6 +43,10 @@ type HTTPFsConfig struct {
 	APIKey   *kms.Secret `json:"api_key,omitempty"`
 }
 
+func (c *HTTPFsConfig) isUnixDomainSocket() bool {
+	return strings.HasPrefix(c.Endpoint, "http://unix") || strings.HasPrefix(c.Endpoint, "https://unix")
+}
+
 // HideConfidentialData hides confidential data
 func (c *HTTPFsConfig) HideConfidentialData() {
 	if c.Password != nil {
@@ -95,13 +100,19 @@ func (c *HTTPFsConfig) validate() error {
 		return errors.New("httpfs: endpoint cannot be empty")
 	}
 	c.Endpoint = strings.TrimRight(c.Endpoint, "/")
-	_, err := url.Parse(c.Endpoint)
+	endpointURL, err := url.Parse(c.Endpoint)
 	if err != nil {
 		return fmt.Errorf("httpfs: invalid endpoint: %w", err)
 	}
 	if !util.IsStringPrefixInSlice(c.Endpoint, supportedEndpointSchema) {
 		return errors.New("httpfs: invalid endpoint schema: http and https are supported")
 	}
+	if endpointURL.Host == "unix" {
+		socketPath := endpointURL.Query().Get("socket_path")
+		if !filepath.IsAbs(socketPath) {
+			return fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath)
+		}
+	}
 	if c.Password.IsEncrypted() && !c.Password.IsValid() {
 		return errors.New("httpfs: invalid encrypted password")
 	}
@@ -179,14 +190,43 @@ func NewHTTPFs(connectionID, localTempDir, mountPath string, config HTTPFsConfig
 	transport.MaxResponseHeaderBytes = 1 << 16
 	transport.WriteBufferSize = 1 << 16
 	transport.ReadBufferSize = 1 << 16
+	if fs.config.isUnixDomainSocket() {
+		endpointURL, err := url.Parse(fs.config.Endpoint)
+		if err != nil {
+			return nil, err
+		}
+		if endpointURL.Host == "unix" {
+			socketPath := endpointURL.Query().Get("socket_path")
+			if !filepath.IsAbs(socketPath) {
+				return nil, fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath)
+			}
+			if endpointURL.Scheme == "https" {
+				transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+					var tlsConfig *tls.Config
+					var d tls.Dialer
+					if config.SkipTLSVerify {
+						tlsConfig = getInsecureTLSConfig()
+					}
+					d.Config = tlsConfig
+					return d.DialContext(ctx, "unix", socketPath)
+				}
+			} else {
+				transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+					var d net.Dialer
+					return d.DialContext(ctx, "unix", socketPath)
+				}
+			}
+			endpointURL.Path = path.Join(endpointURL.Path, endpointURL.Query().Get("api_prefix"))
+			endpointURL.RawQuery = ""
+			endpointURL.RawFragment = ""
+			fs.config.Endpoint = endpointURL.String()
+		}
+	}
 	if config.SkipTLSVerify {
 		if transport.TLSClientConfig != nil {
 			transport.TLSClientConfig.InsecureSkipVerify = true
 		} else {
-			transport.TLSClientConfig = &tls.Config{
-				NextProtos:         []string{"h2", "http/1.1"},
-				InsecureSkipVerify: true,
-			}
+			transport.TLSClientConfig = getInsecureTLSConfig()
 		}
 	}
 	fs.client = &http.Client{
@@ -646,6 +686,13 @@ func getErrorFromResponseCode(code int) error {
 	}
 }
 
+func getInsecureTLSConfig() *tls.Config {
+	return &tls.Config{
+		NextProtos:         []string{"h2", "http/1.1"},
+		InsecureSkipVerify: true,
+	}
+}
+
 type wrapReader struct {
 	reader io.Reader
 }