httpfs: add support for UNIX domain sockets

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2022-06-12 18:29:49 +02:00
parent 0b9a96ec6b
commit 6f4475ff72
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
7 changed files with 158 additions and 12 deletions

View file

@ -16,3 +16,12 @@ Here is a mapping between HTTP response codes and protocol errors:
- `501`, means not supported error
- `200`, `201`, mean no error
- any other response code means a generic error
HTTPFs can also connect to UNIX domain sockets. To use UNIX domain sockets you need to set an endpoint with the following conventions:
- the URL schema can be `http` or `https` as usual.
- The URL host must be `unix`.
- The socket path is mandatory and is set using the `socket_path` query parameter. The path must be query escaped.
- The optional API prefix can be set using the `api_prefix` query parameter. The prefix must be query escaped.
Here is an example endpoint for UNIX domain socket connections: `http://unix?socket_path=%2Ftmp%2Fsftpgofs.sock&api_prefix=%2Fapi%2Fv1`. In this case we are connecting using the `HTTP` protocol to the socket `/tmp/sftpgofs.sock` and we use the `/api/v1` prefix for API URLs.

2
go.mod
View file

@ -44,7 +44,7 @@ require (
github.com/minio/sio v0.3.0
github.com/otiai10/copy v1.7.0
github.com/pires/go-proxyproto v0.6.2
github.com/pkg/sftp v1.13.5-0.20220303113417-dcfc1d5e4162
github.com/pkg/sftp v1.13.5
github.com/pquerna/otp v1.3.0
github.com/prometheus/client_golang v1.12.2
github.com/robfig/cron/v3 v3.0.1

4
go.sum
View file

@ -648,8 +648,8 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
github.com/pkg/sftp v1.13.5-0.20220303113417-dcfc1d5e4162 h1:uJSlAAzEUQq5tpfK+SWIIx/3UJ4EpjAYuMqZpKYrmw4=
github.com/pkg/sftp v1.13.5-0.20220303113417-dcfc1d5e4162/go.mod h1:wHDZ0IZX6JcBYRK1TH9bcVq8G7TLpVHYIGJRFnmPfxg=
github.com/pkg/sftp v1.13.5 h1:a3RLUqkyjYRtBTZJZ1VRrKbN3zhuPLlUc3sphVz81go=
github.com/pkg/sftp v1.13.5/go.mod h1:wHDZ0IZX6JcBYRK1TH9bcVq8G7TLpVHYIGJRFnmPfxg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=

View file

@ -2435,6 +2435,16 @@ func TestAddUserInvalidFsConfig(t *testing.T) {
if assert.NoError(t, err) {
assert.Contains(t, string(resp), "invalid endpoint schema")
}
u.FsConfig.HTTPConfig.Endpoint = "http://unix?api_prefix=v1"
_, resp, err = httpdtest.AddUser(u, http.StatusBadRequest)
if assert.NoError(t, err) {
assert.Contains(t, string(resp), "invalid unix domain socket path")
}
u.FsConfig.HTTPConfig.Endpoint = "http://unix?socket_path=test.sock"
_, resp, err = httpdtest.AddUser(u, http.StatusBadRequest)
if assert.NoError(t, err) {
assert.Contains(t, string(resp), "invalid unix domain socket path")
}
}
func TestUserRedactedPassword(t *testing.T) {

View file

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"mime"
"net"
"net/http"
"net/url"
"os"
@ -38,6 +39,7 @@ const (
)
// StartTestHTTPFs starts a test HTTP service that implements httpfs
// and listens on the specified port
func StartTestHTTPFs(port int) error {
fs := httpFsImpl{
port: port,
@ -45,10 +47,20 @@ func StartTestHTTPFs(port int) error {
return fs.Run()
}
// StartTestHTTPFsOverUnixSocket starts a test HTTP service that implements httpfs
// and listens on the specified UNIX domain socket path
func StartTestHTTPFsOverUnixSocket(socketPath string) error {
fs := httpFsImpl{
unixSocketPath: socketPath,
}
return fs.Run()
}
type httpFsImpl struct {
router *chi.Mux
basePath string
port int
router *chi.Mux
basePath string
port int
unixSocketPath string
}
type apiResponse struct {
@ -473,7 +485,18 @@ func (fs *httpFsImpl) Run() error {
MaxHeaderBytes: 1 << 16, // 64KB
}
return httpServer.ListenAndServe()
if fs.unixSocketPath == "" {
return httpServer.ListenAndServe()
}
err := os.Remove(fs.unixSocketPath)
if err != nil && !os.IsNotExist(err) {
return err
}
listener, err := net.Listen("unix", fs.unixSocketPath)
if err != nil {
return err
}
return httpServer.Serve(listener)
}
func getStatFromInfo(info os.FileInfo) map[string]any {

View file

@ -5,9 +5,11 @@ import (
"io/fs"
"math"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"runtime"
"testing"
"time"
@ -26,6 +28,10 @@ const (
defaultHTTPFsUsername = "httpfs_user"
)
var (
httpFsSocketPath = filepath.Join(os.TempDir(), "httpfs.sock")
)
func TestBasicHTTPFsHandling(t *testing.T) {
usePubKey := true
u := getTestUserWithHTTPFs(usePubKey)
@ -271,6 +277,49 @@ func TestHTTPFsWalk(t *testing.T) {
assert.NoError(t, err)
}
func TestHTTPFsOverUNIXSocket(t *testing.T) {
if runtime.GOOS == osWindows {
t.Skip("UNIX domain sockets are not supported on Windows")
}
assert.Eventually(t, func() bool {
_, err := os.Stat(httpFsSocketPath)
return err == nil
}, 1*time.Second, 50*time.Millisecond)
usePubKey := true
u := getTestUserWithHTTPFs(usePubKey)
u.FsConfig.HTTPConfig.Endpoint = fmt.Sprintf("http://unix?socket_path=%s&api_prefix=%s",
url.QueryEscape(httpFsSocketPath), url.QueryEscape("/api/v1"))
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = checkBasicSFTP(client)
assert.NoError(t, err)
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
err = createTestFile(testFilePath, testFileSize)
assert.NoError(t, err)
err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
assert.NoError(t, err)
err = client.Remove(testFileName)
assert.NoError(t, err)
err = client.Mkdir(testFileName)
assert.NoError(t, err)
err = client.RemoveDirectory(testFileName)
assert.NoError(t, err)
err = os.Remove(testFilePath)
assert.NoError(t, err)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
}
func getTestUserWithHTTPFs(usePubKey bool) dataprovider.User {
u := getTestUser(usePubKey)
u.FsConfig.Provider = sdk.HTTPFilesystemProvider
@ -284,6 +333,14 @@ func getTestUserWithHTTPFs(usePubKey bool) dataprovider.User {
}
func startHTTPFs() {
if runtime.GOOS != osWindows {
go func() {
if err := httpdtest.StartTestHTTPFsOverUnixSocket(httpFsSocketPath); err != nil {
logger.ErrorToConsole("could not start HTTPfs test server over UNIX socket: %v", err)
os.Exit(1)
}
}()
}
go func() {
if err := httpdtest.StartTestHTTPFs(httpFsPort); err != nil {
logger.ErrorToConsole("could not start HTTPfs test server: %v", err)

View file

@ -9,6 +9,7 @@ import (
"io"
"io/fs"
"mime"
"net"
"net/http"
"net/url"
"os"
@ -42,6 +43,10 @@ type HTTPFsConfig struct {
APIKey *kms.Secret `json:"api_key,omitempty"`
}
func (c *HTTPFsConfig) isUnixDomainSocket() bool {
return strings.HasPrefix(c.Endpoint, "http://unix") || strings.HasPrefix(c.Endpoint, "https://unix")
}
// HideConfidentialData hides confidential data
func (c *HTTPFsConfig) HideConfidentialData() {
if c.Password != nil {
@ -95,13 +100,19 @@ func (c *HTTPFsConfig) validate() error {
return errors.New("httpfs: endpoint cannot be empty")
}
c.Endpoint = strings.TrimRight(c.Endpoint, "/")
_, err := url.Parse(c.Endpoint)
endpointURL, err := url.Parse(c.Endpoint)
if err != nil {
return fmt.Errorf("httpfs: invalid endpoint: %w", err)
}
if !util.IsStringPrefixInSlice(c.Endpoint, supportedEndpointSchema) {
return errors.New("httpfs: invalid endpoint schema: http and https are supported")
}
if endpointURL.Host == "unix" {
socketPath := endpointURL.Query().Get("socket_path")
if !filepath.IsAbs(socketPath) {
return fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath)
}
}
if c.Password.IsEncrypted() && !c.Password.IsValid() {
return errors.New("httpfs: invalid encrypted password")
}
@ -179,14 +190,43 @@ func NewHTTPFs(connectionID, localTempDir, mountPath string, config HTTPFsConfig
transport.MaxResponseHeaderBytes = 1 << 16
transport.WriteBufferSize = 1 << 16
transport.ReadBufferSize = 1 << 16
if fs.config.isUnixDomainSocket() {
endpointURL, err := url.Parse(fs.config.Endpoint)
if err != nil {
return nil, err
}
if endpointURL.Host == "unix" {
socketPath := endpointURL.Query().Get("socket_path")
if !filepath.IsAbs(socketPath) {
return nil, fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath)
}
if endpointURL.Scheme == "https" {
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
var tlsConfig *tls.Config
var d tls.Dialer
if config.SkipTLSVerify {
tlsConfig = getInsecureTLSConfig()
}
d.Config = tlsConfig
return d.DialContext(ctx, "unix", socketPath)
}
} else {
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", socketPath)
}
}
endpointURL.Path = path.Join(endpointURL.Path, endpointURL.Query().Get("api_prefix"))
endpointURL.RawQuery = ""
endpointURL.RawFragment = ""
fs.config.Endpoint = endpointURL.String()
}
}
if config.SkipTLSVerify {
if transport.TLSClientConfig != nil {
transport.TLSClientConfig.InsecureSkipVerify = true
} else {
transport.TLSClientConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
}
transport.TLSClientConfig = getInsecureTLSConfig()
}
}
fs.client = &http.Client{
@ -646,6 +686,13 @@ func getErrorFromResponseCode(code int) error {
}
}
func getInsecureTLSConfig() *tls.Config {
return &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
}
}
type wrapReader struct {
reader io.Reader
}