Browse Source

Improve localhost address validation (#9634)

* :art: Improv localhost address validation

* :bug: Compatible with browser extension
Yingyi / 颖逸 1 year ago
parent
commit
c90072c3cf
2 changed files with 81 additions and 44 deletions
  1. 28 44
      kernel/model/session.go
  2. 53 0
      kernel/util/net.go

+ 28 - 44
kernel/model/session.go

@@ -98,7 +98,7 @@ func LoginAuth(c *gin.Context) {
 
 		if err := session.Save(c); nil != err {
 			logging.LogErrorf("save session failed: " + err.Error())
-			c.Status(500)
+			c.Status(http.StatusInternalServerError)
 			return
 		}
 		return
@@ -109,7 +109,7 @@ func LoginAuth(c *gin.Context) {
 	workspaceSession.Captcha = gulu.Rand.String(7)
 	if err := session.Save(c); nil != err {
 		logging.LogErrorf("save session failed: " + err.Error())
-		c.Status(500)
+		c.Status(http.StatusInternalServerError)
 		return
 	}
 }
@@ -123,7 +123,7 @@ func GetCaptcha(c *gin.Context) {
 	})
 	if nil != err {
 		logging.LogErrorf("generates captcha failed: " + err.Error())
-		c.Status(500)
+		c.Status(http.StatusInternalServerError)
 		return
 	}
 
@@ -132,16 +132,16 @@ func GetCaptcha(c *gin.Context) {
 	workspaceSession.Captcha = img.Text
 	if err = session.Save(c); nil != err {
 		logging.LogErrorf("save session failed: " + err.Error())
-		c.Status(500)
+		c.Status(http.StatusInternalServerError)
 		return
 	}
 
 	if err = img.WriteImage(c.Writer); nil != err {
 		logging.LogErrorf("writes captcha image failed: " + err.Error())
-		c.Status(500)
+		c.Status(http.StatusInternalServerError)
 		return
 	}
-	c.Status(200)
+	c.Status(http.StatusOK)
 }
 
 func CheckReadonly(c *gin.Context) {
@@ -150,7 +150,7 @@ func CheckReadonly(c *gin.Context) {
 		result.Code = -1
 		result.Msg = Conf.Language(34)
 		result.Data = map[string]interface{}{"closeTimeout": 5000}
-		c.JSON(200, result)
+		c.JSON(http.StatusOK, result)
 		c.Abort()
 		return
 	}
@@ -158,38 +158,21 @@ func CheckReadonly(c *gin.Context) {
 
 func CheckAuth(c *gin.Context) {
 	//logging.LogInfof("check auth for [%s]", c.Request.RequestURI)
+	localhost := util.IsLocalHost(c.Request.RemoteAddr)
 
+	// 未设置访问授权码
 	if "" == Conf.AccessAuthCode {
-		if origin := c.GetHeader("Origin"); "" != origin {
-			// Authenticate requests with the Origin header other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9180
-			u, parseErr := url.Parse(origin)
-			if nil != parseErr {
-				logging.LogWarnf("parse origin [%s] failed: %s", origin, parseErr)
-				c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed: parse req header [Origin] failed"})
-				c.Abort()
-				return
-
-			}
-
-			if "chrome-extension" == strings.ToLower(u.Scheme) {
-				c.Next()
-				return
-			}
-
-			if !strings.HasPrefix(u.Host, util.LocalHost) && !strings.HasPrefix(u.Host, "[::1]") {
-				c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed: for security reasons, please set [Access authorization code] when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请设置 [访问授权码]"})
-				c.Abort()
-				return
-			}
-		}
-
-		if !strings.HasPrefix(c.Request.RemoteAddr, util.LocalHost) && !strings.HasPrefix(c.Request.RemoteAddr, "[::1]") {
-			// Authenticate requests of assets other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9388
-			if strings.HasPrefix(c.Request.RequestURI, "/assets/") {
-				c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed: for security reasons, please set [Access authorization code] when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请设置 [访问授权码]"})
-				c.Abort()
-				return
-			}
+		// Authenticate requests with the Origin header other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9180
+		host := c.GetHeader("Host")
+		origin := c.GetHeader("Origin")
+		forwardedHost := c.GetHeader("X-Forwarded-Host")
+		if !localhost ||
+			("" != host && !util.IsLocalHost(host)) ||
+			("" != origin && !util.IsLocalOrigin(origin) && !strings.HasPrefix(origin, "chrome-extension://")) ||
+			("" != forwardedHost && !util.IsLocalHost(forwardedHost)) {
+			c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed: for security reasons, please set [Access authorization code] when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请设置 [访问授权码]"})
+			c.Abort()
+			return
 		}
 
 		c.Next()
@@ -206,7 +189,7 @@ func CheckAuth(c *gin.Context) {
 	}
 
 	// 放过来自本机的某些请求
-	if strings.HasPrefix(c.Request.RemoteAddr, util.LocalHost) || strings.HasPrefix(c.Request.RemoteAddr, "[::1]") {
+	if localhost {
 		if strings.HasPrefix(c.Request.RequestURI, "/assets/") {
 			c.Next()
 			return
@@ -234,7 +217,7 @@ func CheckAuth(c *gin.Context) {
 				return
 			}
 
-			c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed"})
+			c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed"})
 			c.Abort()
 			return
 		}
@@ -247,7 +230,7 @@ func CheckAuth(c *gin.Context) {
 			return
 		}
 
-		c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed"})
+		c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed"})
 		c.Abort()
 		return
 	}
@@ -261,7 +244,7 @@ func CheckAuth(c *gin.Context) {
 		userAgentHeader := c.GetHeader("User-Agent")
 		if strings.HasPrefix(userAgentHeader, "SiYuan/") || strings.HasPrefix(userAgentHeader, "Mozilla/") {
 			if "GET" != c.Request.Method {
-				c.JSON(401, map[string]interface{}{"code": -1, "msg": Conf.Language(156)})
+				c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": Conf.Language(156)})
 				c.Abort()
 				return
 			}
@@ -271,12 +254,13 @@ func CheckAuth(c *gin.Context) {
 			queryParams.Set("to", c.Request.URL.String())
 			location.RawQuery = queryParams.Encode()
 			location.Path = "/check-auth"
-			c.Redirect(302, location.String())
+
+			c.Redirect(http.StatusFound, location.String())
 			c.Abort()
 			return
 		}
 
-		c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed"})
+		c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed"})
 		c.Abort()
 		return
 	}
@@ -316,7 +300,7 @@ func Timing(c *gin.Context) {
 func Recover(c *gin.Context) {
 	defer func() {
 		logging.Recover()
-		c.Status(500)
+		c.Status(http.StatusInternalServerError)
 	}()
 
 	c.Next()

+ 53 - 0
kernel/util/net.go

@@ -17,6 +17,7 @@
 package util
 
 import (
+	"net"
 	"net/http"
 	"net/url"
 	"strings"
@@ -31,6 +32,58 @@ import (
 	"github.com/siyuan-note/logging"
 )
 
+func ValidOptionalPort(port string) bool {
+	if port == "" {
+		return true
+	}
+	if port[0] != ':' {
+		return false
+	}
+	for _, b := range port[1:] {
+		if b < '0' || b > '9' {
+			return false
+		}
+	}
+	return true
+}
+
+func SplitHost(host string) (hostname, port string) {
+	hostname = host
+
+	colon := strings.LastIndexByte(hostname, ':')
+	if colon != -1 && ValidOptionalPort(hostname[colon:]) {
+		hostname, port = hostname[:colon], hostname[colon+1:]
+	}
+
+	if strings.HasPrefix(hostname, "[") && strings.HasSuffix(hostname, "]") {
+		hostname = hostname[1 : len(hostname)-1]
+	}
+
+	return
+}
+
+func IsLocalHostname(hostname string) bool {
+	if "localhost" == hostname {
+		return true
+	}
+	if ip := net.ParseIP(hostname); nil != ip {
+		return ip.IsLoopback()
+	}
+	return false
+}
+
+func IsLocalHost(host string) bool {
+	hostname, _ := SplitHost(host)
+	return IsLocalHostname(hostname)
+}
+
+func IsLocalOrigin(origin string) bool {
+	if url, err := url.Parse(origin); nil == err {
+		return IsLocalHostname(url.Hostname())
+	}
+	return false
+}
+
 func IsOnline(checkURL string, skipTlsVerify bool) bool {
 	_, err := url.Parse(checkURL)
 	if nil != err {