Browse Source

parse IP proxy header also if listening on UNIX domain socket

Fixes #867

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 3 years ago
parent
commit
118744a860
7 changed files with 48 additions and 5 deletions
  1. 1 1
      docs/repo.md
  2. 5 0
      httpd/httpd.go
  3. 12 0
      httpd/internal_test.go
  4. 7 2
      httpd/server.go
  5. 12 0
      webdavd/internal_test.go
  6. 6 2
      webdavd/server.go
  7. 5 0
      webdavd/webdavd.go

+ 1 - 1
docs/repo.md

@@ -1,6 +1,6 @@
 # SFTPGo repositories
 
-These repository are available through Oregon State University's free mirror service. Special thanks to Lance Albertson, Director of the Oregon State University Open Source Lab, who helped me with the initial setup.
+These repositories are available through Oregon State University's free mirror service. Special thanks to Lance Albertson, Director of the Oregon State University Open Source Lab, who helped me with the initial setup.
 
 ## APT repo
 

+ 5 - 0
httpd/httpd.go

@@ -476,6 +476,11 @@ func (b *Binding) checkBranding() {
 }
 
 func (b *Binding) parseAllowedProxy() error {
+	if filepath.IsAbs(b.Address) && len(b.ProxyAllowed) > 0 {
+		// unix domain socket
+		b.allowHeadersFrom = []func(net.IP) bool{func(ip net.IP) bool { return true }}
+		return nil
+	}
 	allowedFuncs, err := util.ParseAllowedIPAndRanges(b.ProxyAllowed)
 	if err != nil {
 		return err

+ 12 - 0
httpd/internal_test.go

@@ -1532,6 +1532,18 @@ func TestJWTTokenCleanup(t *testing.T) {
 	stopCleanupTicker()
 }
 
+func TestAllowedProxyUnixDomainSocket(t *testing.T) {
+	b := Binding{
+		Address:      filepath.Join(os.TempDir(), "sock"),
+		ProxyAllowed: []string{"127.0.0.1", "127.0.1.1"},
+	}
+	err := b.parseAllowedProxy()
+	assert.NoError(t, err)
+	if assert.Len(t, b.allowHeadersFrom, 1) {
+		assert.True(t, b.allowHeadersFrom[0](nil))
+	}
+}
+
 func TestProxyHeaders(t *testing.T) {
 	username := "adminTest"
 	password := "testPwd"

+ 7 - 2
httpd/server.go

@@ -9,6 +9,7 @@ import (
 	"log"
 	"net"
 	"net/http"
+	"path/filepath"
 	"strings"
 	"time"
 
@@ -972,9 +973,13 @@ func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request {
 func (s *httpdServer) checkConnection(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
-		ip := net.ParseIP(ipAddr)
+		var ip net.IP
+		isUnixSocket := filepath.IsAbs(s.binding.Address)
+		if !isUnixSocket {
+			ip = net.ParseIP(ipAddr)
+		}
 		areHeadersAllowed := false
-		if ip != nil {
+		if isUnixSocket || ip != nil {
 			for _, allow := range s.binding.allowHeadersFrom {
 				if allow(ip) {
 					parsedIP := util.GetRealIP(r, s.binding.ClientIPProxyHeader, s.binding.ClientIPHeaderDepth)

+ 12 - 0
webdavd/internal_test.go

@@ -412,6 +412,18 @@ func TestUserInvalidParams(t *testing.T) {
 	writeLog(req, http.StatusOK, nil)
 }
 
+func TestAllowedProxyUnixDomainSocket(t *testing.T) {
+	b := Binding{
+		Address:      filepath.Join(os.TempDir(), "sock"),
+		ProxyAllowed: []string{"127.0.0.1", "127.0.1.1"},
+	}
+	err := b.parseAllowedProxy()
+	assert.NoError(t, err)
+	if assert.Len(t, b.allowHeadersFrom, 1) {
+		assert.True(t, b.allowHeadersFrom[0](nil))
+	}
+}
+
 func TestRemoteAddress(t *testing.T) {
 	remoteAddr1 := "100.100.100.100"
 	remoteAddr2 := "172.172.172.172"

+ 6 - 2
webdavd/server.go

@@ -331,8 +331,12 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo
 
 func (s *webDavServer) checkRemoteAddress(r *http.Request) string {
 	ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
-	ip := net.ParseIP(ipAddr)
-	if ip != nil {
+	var ip net.IP
+	isUnixSocket := filepath.IsAbs(s.binding.Address)
+	if !isUnixSocket {
+		ip = net.ParseIP(ipAddr)
+	}
+	if isUnixSocket || ip != nil {
 		for _, allow := range s.binding.allowHeadersFrom {
 			if allow(ip) {
 				parsedIP := util.GetRealIP(r, s.binding.ClientIPProxyHeader, s.binding.ClientIPHeaderDepth)

+ 5 - 0
webdavd/webdavd.go

@@ -109,6 +109,11 @@ type Binding struct {
 }
 
 func (b *Binding) parseAllowedProxy() error {
+	if filepath.IsAbs(b.Address) && len(b.ProxyAllowed) > 0 {
+		// unix domain socket
+		b.allowHeadersFrom = []func(net.IP) bool{func(ip net.IP) bool { return true }}
+		return nil
+	}
 	allowedFuncs, err := util.ParseAllowedIPAndRanges(b.ProxyAllowed)
 	if err != nil {
 		return err