middleware.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. package utils
  2. import (
  3. "context"
  4. "net/http"
  5. "time"
  6. "net"
  7. "strings"
  8. "fmt"
  9. "sync"
  10. "sync/atomic"
  11. "github.com/mxk/go-flowrate/flowrate"
  12. "github.com/oschwald/geoip2-golang"
  13. )
  14. // https://github.com/go-chi/chi/blob/master/middleware/timeout.go
  15. var PushShieldMetrics func(string)
  16. type safeInt struct {
  17. val int64
  18. }
  19. var BannedIPs = sync.Map{}
  20. // Close connection right away if banned (save resources)
  21. func IncrementIPAbuseCounter(ip string) {
  22. // Load or store a new *safeInt
  23. actual, _ := BannedIPs.LoadOrStore(ip, &safeInt{})
  24. counter := actual.(*safeInt)
  25. // Increment the counter using atomic for concurrent access
  26. atomic.AddInt64(&counter.val, 1)
  27. }
  28. func getIPAbuseCounter(ip string) int64 {
  29. // Load the *safeInt
  30. actual, ok := BannedIPs.Load(ip)
  31. if !ok {
  32. return 0
  33. }
  34. counter := actual.(*safeInt)
  35. // Load the value using atomic for concurrent access
  36. return atomic.LoadInt64(&counter.val)
  37. }
  38. func BlockBannedIPs(next http.Handler) http.Handler {
  39. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  40. ip, _, err := net.SplitHostPort(r.RemoteAddr)
  41. if err != nil {
  42. if hj, ok := w.(http.Hijacker); ok {
  43. conn, _, err := hj.Hijack()
  44. if err == nil {
  45. conn.Close()
  46. }
  47. }
  48. return
  49. }
  50. nbAbuse := getIPAbuseCounter(ip)
  51. if nbAbuse > 275 {
  52. Warn("IP " + ip + " has " + fmt.Sprintf("%d", nbAbuse) + " abuse(s) and will soon be banned.")
  53. }
  54. if nbAbuse > 300 {
  55. if hj, ok := w.(http.Hijacker); ok {
  56. conn, _, err := hj.Hijack()
  57. if err == nil {
  58. conn.Close()
  59. }
  60. }
  61. return
  62. }
  63. next.ServeHTTP(w, r)
  64. })
  65. }
  66. func CleanBannedIPs() {
  67. BannedIPs.Range(func(key, value interface{}) bool {
  68. BannedIPs.Delete(key)
  69. return true
  70. })
  71. }
  72. func MiddlewareTimeout(timeout time.Duration) func(next http.Handler) http.Handler {
  73. return func(next http.Handler) http.Handler {
  74. fn := func(w http.ResponseWriter, r *http.Request) {
  75. ctx, cancel := context.WithTimeout(r.Context(), timeout)
  76. defer func() {
  77. cancel()
  78. if ctx.Err() == context.DeadlineExceeded {
  79. Error("Request Timeout. Cancelling.", ctx.Err())
  80. HTTPError(w, "Gateway Timeout",
  81. http.StatusGatewayTimeout, "HTTP002")
  82. return
  83. }
  84. }()
  85. w.Header().Set("X-Timeout-Duration", timeout.String())
  86. r = r.WithContext(ctx)
  87. next.ServeHTTP(w, r)
  88. }
  89. return http.HandlerFunc(fn)
  90. }
  91. }
  92. type responseWriter struct {
  93. http.ResponseWriter
  94. *flowrate.Writer
  95. }
  96. func (w *responseWriter) Write(b []byte) (int, error) {
  97. return w.Writer.Write(b)
  98. }
  99. func BandwithLimiterMiddleware(max int64) func(next http.Handler) http.Handler {
  100. return func(next http.Handler) http.Handler {
  101. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  102. if(max > 0) {
  103. fw := flowrate.NewWriter(w, max)
  104. w = &responseWriter{w, fw}
  105. }
  106. next.ServeHTTP(w, r)
  107. })
  108. }
  109. }
  110. func SetSecurityHeaders(next http.Handler) http.Handler {
  111. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  112. if(IsHTTPS) {
  113. // TODO: Add preload if we have a valid certificate
  114. w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
  115. }
  116. w.Header().Set("X-Content-Type-Options", "nosniff")
  117. w.Header().Set("X-XSS-Protection", "1; mode=block")
  118. w.Header().Set("Content-Security-Policy", "frame-ancestors 'self'")
  119. w.Header().Set("X-Served-By-Cosmos", "1")
  120. next.ServeHTTP(w, r)
  121. })
  122. }
  123. func CORSHeader(origin string) func(next http.Handler) http.Handler {
  124. return func(next http.Handler) http.Handler {
  125. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  126. if origin != "" {
  127. w.Header().Set("Access-Control-Allow-Origin", origin)
  128. w.Header().Set("Access-Control-Allow-Credentials", "true")
  129. }
  130. next.ServeHTTP(w, r)
  131. })
  132. }
  133. }
  134. func AcceptHeader(accept string) func(next http.Handler) http.Handler {
  135. return func(next http.Handler) http.Handler {
  136. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  137. w.Header().Set("Content-Type", accept)
  138. next.ServeHTTP(w, r)
  139. })
  140. }
  141. }
  142. // GetIPLocation returns the ISO country code for a given IP address.
  143. func GetIPLocation(ip string) (string, error) {
  144. geoDB, err := geoip2.Open("GeoLite2-Country.mmdb")
  145. if err != nil {
  146. return "", err
  147. }
  148. defer geoDB.Close()
  149. parsedIP := net.ParseIP(ip)
  150. record, err := geoDB.Country(parsedIP)
  151. if err != nil {
  152. return "", err
  153. }
  154. return record.Country.IsoCode, nil
  155. }
  156. // BlockByCountryMiddleware returns a middleware function that blocks requests from specified countries.
  157. func BlockByCountryMiddleware(blockedCountries []string, CountryBlacklistIsWhitelist bool) func(http.Handler) http.Handler {
  158. return func(next http.Handler) http.Handler {
  159. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  160. ip, _, err := net.SplitHostPort(r.RemoteAddr)
  161. if err != nil {
  162. http.Error(w, "Invalid request", http.StatusBadRequest)
  163. return
  164. }
  165. countryCode, err := GetIPLocation(ip)
  166. if err == nil {
  167. if countryCode == "" {
  168. Debug("Country code is empty")
  169. } else {
  170. Debug("Country code: " + countryCode)
  171. }
  172. config := GetMainConfig()
  173. if CountryBlacklistIsWhitelist {
  174. if countryCode != "" {
  175. blocked := true
  176. for _, blockedCountry := range blockedCountries {
  177. if config.ServerCountry != countryCode && countryCode == blockedCountry {
  178. blocked = false
  179. }
  180. }
  181. if blocked {
  182. PushShieldMetrics("geo")
  183. IncrementIPAbuseCounter(ip)
  184. http.Error(w, "Access denied", http.StatusForbidden)
  185. return
  186. }
  187. } else {
  188. Warn("Missing geolocation information to block IPs")
  189. }
  190. } else {
  191. for _, blockedCountry := range blockedCountries {
  192. if config.ServerCountry != countryCode && countryCode == blockedCountry {
  193. http.Error(w, "Access denied", http.StatusForbidden)
  194. return
  195. }
  196. }
  197. }
  198. } else {
  199. Warn("Missing geolocation information to block IPs")
  200. }
  201. next.ServeHTTP(w, r)
  202. })
  203. }
  204. }
  205. // blockPostWithoutReferer blocks POST requests without a Referer header
  206. func BlockPostWithoutReferer(next http.Handler) http.Handler {
  207. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  208. if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" || r.Method == "DELETE" {
  209. referer := r.Header.Get("Referer")
  210. if referer == "" {
  211. PushShieldMetrics("referer")
  212. Error("Blocked POST request without Referer header", nil)
  213. http.Error(w, "Bad Request: Invalid request.", http.StatusBadRequest)
  214. ip, _, _ := net.SplitHostPort(r.RemoteAddr)
  215. if ip != "" {
  216. IncrementIPAbuseCounter(ip)
  217. }
  218. return
  219. }
  220. }
  221. // If it's not a POST request or the POST request has a Referer header, pass the request to the next handler
  222. next.ServeHTTP(w, r)
  223. })
  224. }
  225. func EnsureHostname(next http.Handler) http.Handler {
  226. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  227. Debug("Ensuring origin for requested resource from : " + r.Host)
  228. og := GetMainConfig().HTTPConfig.Hostname
  229. ni := GetMainConfig().NewInstall
  230. if ni || og == "0.0.0.0" {
  231. next.ServeHTTP(w, r)
  232. return
  233. }
  234. hostnames := GetAllHostnames(false, false)
  235. reqHostNoPort := strings.Split(r.Host, ":")[0]
  236. isOk := false
  237. for _, hostname := range hostnames {
  238. hostnameNoPort := strings.Split(hostname, ":")[0]
  239. if reqHostNoPort == hostnameNoPort {
  240. isOk = true
  241. }
  242. }
  243. if !isOk {
  244. PushShieldMetrics("hostname")
  245. Error("Invalid Hostname " + r.Host + " for request. Expecting one of " + fmt.Sprintf("%v", hostnames), nil)
  246. w.WriteHeader(http.StatusBadRequest)
  247. http.Error(w, "Bad Request: Invalid hostname. Use your domain instead of your IP to access your server. Check logs if more details are needed.", http.StatusBadRequest)
  248. ip, _, _ := net.SplitHostPort(r.RemoteAddr)
  249. if ip != "" {
  250. IncrementIPAbuseCounter(ip)
  251. }
  252. return
  253. }
  254. next.ServeHTTP(w, r)
  255. })
  256. }
  257. func IsValidHostname(hostname string) bool {
  258. og := GetMainConfig().HTTPConfig.Hostname
  259. ni := GetMainConfig().NewInstall
  260. if ni || og == "0.0.0.0" {
  261. return true
  262. }
  263. hostnames := GetAllHostnames(false, false)
  264. reqHostNoPort := strings.Split(hostname, ":")[0]
  265. reqHostNoPortNoSubdomain := ""
  266. if parts := strings.Split(reqHostNoPort, "."); len(parts) < 2 {
  267. reqHostNoPortNoSubdomain = reqHostNoPort
  268. } else {
  269. reqHostNoPortNoSubdomain = parts[len(parts)-2] + "." + parts[len(parts)-1]
  270. }
  271. for _, hostname := range hostnames {
  272. hostnameNoPort := strings.Split(hostname, ":")[0]
  273. hostnameNoPortNoSubdomain := ""
  274. if parts := strings.Split(hostnameNoPort, "."); len(parts) < 2 {
  275. hostnameNoPortNoSubdomain = hostnameNoPort
  276. } else {
  277. hostnameNoPortNoSubdomain = parts[len(parts)-2] + "." + parts[len(parts)-1]
  278. }
  279. if reqHostNoPortNoSubdomain == hostnameNoPortNoSubdomain {
  280. return true
  281. }
  282. }
  283. return false
  284. }
  285. func IPInRange(ipStr, cidrStr string) (bool, error) {
  286. _, cidrNet, err := net.ParseCIDR(cidrStr)
  287. if err != nil {
  288. return false, fmt.Errorf("parse CIDR range: %w", err)
  289. }
  290. ip := net.ParseIP(ipStr)
  291. if ip == nil {
  292. return false, fmt.Errorf("parse IP: invalid IP address")
  293. }
  294. return cidrNet.Contains(ip), nil
  295. }
  296. func Restrictions(RestrictToConstellation bool, WhitelistInboundIPs []string) func(next http.Handler) http.Handler {
  297. return func(next http.Handler) http.Handler {
  298. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  299. ip, _, err := net.SplitHostPort(r.RemoteAddr)
  300. if err != nil {
  301. http.Error(w, "Invalid request", http.StatusBadRequest)
  302. return
  303. }
  304. isUsingWhiteList := len(WhitelistInboundIPs) > 0
  305. isInWhitelist := false
  306. isInConstellation := strings.HasPrefix(ip, "192.168.201.") || strings.HasPrefix(ip, "192.168.202.")
  307. for _, ipRange := range WhitelistInboundIPs {
  308. Debug("Checking if " + ip + " is in " + ipRange)
  309. if strings.Contains(ipRange, "/") {
  310. if ok, _ := IPInRange(ip, ipRange); ok {
  311. isInWhitelist = true
  312. }
  313. } else {
  314. if ip == ipRange {
  315. isInWhitelist = true
  316. }
  317. }
  318. }
  319. if(RestrictToConstellation) {
  320. if(!isInConstellation) {
  321. if(!isUsingWhiteList) {
  322. PushShieldMetrics("ip-whitelists")
  323. IncrementIPAbuseCounter(ip)
  324. Error("Request from " + ip + " is blocked because of restrictions", nil)
  325. Debug("Blocked by RestrictToConstellation isInConstellation isUsingWhiteList")
  326. http.Error(w, "Access denied", http.StatusForbidden)
  327. return
  328. } else if (!isInWhitelist) {
  329. PushShieldMetrics("ip-whitelists")
  330. IncrementIPAbuseCounter(ip)
  331. Error("Request from " + ip + " is blocked because of restrictions", nil)
  332. Debug("Blocked by RestrictToConstellation isInConstellation isInWhitelist")
  333. http.Error(w, "Access denied", http.StatusForbidden)
  334. return
  335. }
  336. }
  337. } else if(isUsingWhiteList && !isInWhitelist) {
  338. PushShieldMetrics("ip-whitelists")
  339. IncrementIPAbuseCounter(ip)
  340. Error("Request from " + ip + " is blocked because of restrictions", nil)
  341. Debug("Blocked by RestrictToConstellation isInConstellation isUsingWhiteList isInWhitelist")
  342. http.Error(w, "Access denied", http.StatusForbidden)
  343. return
  344. }
  345. next.ServeHTTP(w, r)
  346. })
  347. }
  348. }