123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425 |
- package utils
- import (
- "context"
- "net/http"
- "time"
- "net"
- "strings"
- "fmt"
- "sync"
- "sync/atomic"
- "github.com/mxk/go-flowrate/flowrate"
- "github.com/oschwald/geoip2-golang"
- )
- // https://github.com/go-chi/chi/blob/master/middleware/timeout.go
- var PushShieldMetrics func(string)
- type safeInt struct {
- val int64
- }
- var BannedIPs = sync.Map{}
- // Close connection right away if banned (save resources)
- func IncrementIPAbuseCounter(ip string) {
- // Load or store a new *safeInt
- actual, _ := BannedIPs.LoadOrStore(ip, &safeInt{})
- counter := actual.(*safeInt)
- // Increment the counter using atomic for concurrent access
- atomic.AddInt64(&counter.val, 1)
- }
- func getIPAbuseCounter(ip string) int64 {
- // Load the *safeInt
- actual, ok := BannedIPs.Load(ip)
- if !ok {
- return 0
- }
- counter := actual.(*safeInt)
- // Load the value using atomic for concurrent access
- return atomic.LoadInt64(&counter.val)
- }
- func BlockBannedIPs(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ip, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
- if hj, ok := w.(http.Hijacker); ok {
- conn, _, err := hj.Hijack()
- if err == nil {
- conn.Close()
- }
- }
- return
- }
- nbAbuse := getIPAbuseCounter(ip)
- if nbAbuse > 275 {
- Warn("IP " + ip + " has " + fmt.Sprintf("%d", nbAbuse) + " abuse(s) and will soon be banned.")
- }
- if nbAbuse > 300 {
- if hj, ok := w.(http.Hijacker); ok {
- conn, _, err := hj.Hijack()
- if err == nil {
- conn.Close()
- }
- }
- return
- }
- next.ServeHTTP(w, r)
- })
- }
- func CleanBannedIPs() {
- BannedIPs.Range(func(key, value interface{}) bool {
- BannedIPs.Delete(key)
- return true
- })
- }
- func MiddlewareTimeout(timeout time.Duration) func(next http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- fn := func(w http.ResponseWriter, r *http.Request) {
- ctx, cancel := context.WithTimeout(r.Context(), timeout)
- defer func() {
- cancel()
- if ctx.Err() == context.DeadlineExceeded {
- Error("Request Timeout. Cancelling.", ctx.Err())
- HTTPError(w, "Gateway Timeout",
- http.StatusGatewayTimeout, "HTTP002")
- return
- }
- }()
- w.Header().Set("X-Timeout-Duration", timeout.String())
- r = r.WithContext(ctx)
- next.ServeHTTP(w, r)
- }
- return http.HandlerFunc(fn)
- }
- }
- type responseWriter struct {
- http.ResponseWriter
- *flowrate.Writer
- }
- func (w *responseWriter) Write(b []byte) (int, error) {
- return w.Writer.Write(b)
- }
- func BandwithLimiterMiddleware(max int64) func(next http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if(max > 0) {
- fw := flowrate.NewWriter(w, max)
- w = &responseWriter{w, fw}
- }
-
- next.ServeHTTP(w, r)
- })
- }
- }
- func SetSecurityHeaders(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if(IsHTTPS) {
- // TODO: Add preload if we have a valid certificate
- w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
- }
-
- w.Header().Set("X-Content-Type-Options", "nosniff")
- w.Header().Set("X-XSS-Protection", "1; mode=block")
- w.Header().Set("Content-Security-Policy", "frame-ancestors 'self'")
-
- w.Header().Set("X-Served-By-Cosmos", "1")
-
- next.ServeHTTP(w, r)
- })
- }
- func CORSHeader(origin string) func(next http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if origin != "" {
- w.Header().Set("Access-Control-Allow-Origin", origin)
- w.Header().Set("Access-Control-Allow-Credentials", "true")
- }
- next.ServeHTTP(w, r)
- })
- }
- }
- func AcceptHeader(accept string) func(next http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", accept)
- next.ServeHTTP(w, r)
- })
- }
- }
- // GetIPLocation returns the ISO country code for a given IP address.
- func GetIPLocation(ip string) (string, error) {
- geoDB, err := geoip2.Open("GeoLite2-Country.mmdb")
- if err != nil {
- return "", err
- }
- defer geoDB.Close()
- parsedIP := net.ParseIP(ip)
- record, err := geoDB.Country(parsedIP)
- if err != nil {
- return "", err
- }
- return record.Country.IsoCode, nil
- }
- // BlockByCountryMiddleware returns a middleware function that blocks requests from specified countries.
- func BlockByCountryMiddleware(blockedCountries []string, CountryBlacklistIsWhitelist bool) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ip, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
- http.Error(w, "Invalid request", http.StatusBadRequest)
- return
- }
- countryCode, err := GetIPLocation(ip)
- if err == nil {
- if countryCode == "" {
- Debug("Country code is empty")
- } else {
- Debug("Country code: " + countryCode)
- }
- config := GetMainConfig()
- if CountryBlacklistIsWhitelist {
- if countryCode != "" {
- blocked := true
- for _, blockedCountry := range blockedCountries {
- if config.ServerCountry != countryCode && countryCode == blockedCountry {
- blocked = false
- }
- }
- if blocked {
- PushShieldMetrics("geo")
- IncrementIPAbuseCounter(ip)
- http.Error(w, "Access denied", http.StatusForbidden)
- return
- }
- } else {
- Warn("Missing geolocation information to block IPs")
- }
- } else {
- for _, blockedCountry := range blockedCountries {
- if config.ServerCountry != countryCode && countryCode == blockedCountry {
- http.Error(w, "Access denied", http.StatusForbidden)
- return
- }
- }
- }
- } else {
- Warn("Missing geolocation information to block IPs")
- }
- next.ServeHTTP(w, r)
- })
- }
- }
- // blockPostWithoutReferer blocks POST requests without a Referer header
- func BlockPostWithoutReferer(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" || r.Method == "DELETE" {
- referer := r.Header.Get("Referer")
- if referer == "" {
- PushShieldMetrics("referer")
- Error("Blocked POST request without Referer header", nil)
- http.Error(w, "Bad Request: Invalid request.", http.StatusBadRequest)
- ip, _, _ := net.SplitHostPort(r.RemoteAddr)
- if ip != "" {
- IncrementIPAbuseCounter(ip)
- }
- return
- }
- }
- // If it's not a POST request or the POST request has a Referer header, pass the request to the next handler
- next.ServeHTTP(w, r)
- })
- }
- func EnsureHostname(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- Debug("Ensuring origin for requested resource from : " + r.Host)
- og := GetMainConfig().HTTPConfig.Hostname
- ni := GetMainConfig().NewInstall
- if ni || og == "0.0.0.0" {
- next.ServeHTTP(w, r)
- return
- }
- hostnames := GetAllHostnames(false, false)
- reqHostNoPort := strings.Split(r.Host, ":")[0]
- isOk := false
- for _, hostname := range hostnames {
- hostnameNoPort := strings.Split(hostname, ":")[0]
- if reqHostNoPort == hostnameNoPort {
- isOk = true
- }
- }
- if !isOk {
- PushShieldMetrics("hostname")
- Error("Invalid Hostname " + r.Host + " for request. Expecting one of " + fmt.Sprintf("%v", hostnames), nil)
- w.WriteHeader(http.StatusBadRequest)
- http.Error(w, "Bad Request: Invalid hostname. Use your domain instead of your IP to access your server. Check logs if more details are needed.", http.StatusBadRequest)
-
- ip, _, _ := net.SplitHostPort(r.RemoteAddr)
- if ip != "" {
- IncrementIPAbuseCounter(ip)
- }
- return
- }
- next.ServeHTTP(w, r)
- })
- }
- func IsValidHostname(hostname string) bool {
- og := GetMainConfig().HTTPConfig.Hostname
- ni := GetMainConfig().NewInstall
- if ni || og == "0.0.0.0" {
- return true
- }
- hostnames := GetAllHostnames(false, false)
- reqHostNoPort := strings.Split(hostname, ":")[0]
- reqHostNoPortNoSubdomain := ""
- if parts := strings.Split(reqHostNoPort, "."); len(parts) < 2 {
- reqHostNoPortNoSubdomain = reqHostNoPort
- } else {
- reqHostNoPortNoSubdomain = parts[len(parts)-2] + "." + parts[len(parts)-1]
- }
- for _, hostname := range hostnames {
- hostnameNoPort := strings.Split(hostname, ":")[0]
- hostnameNoPortNoSubdomain := ""
- if parts := strings.Split(hostnameNoPort, "."); len(parts) < 2 {
- hostnameNoPortNoSubdomain = hostnameNoPort
- } else {
- hostnameNoPortNoSubdomain = parts[len(parts)-2] + "." + parts[len(parts)-1]
- }
- if reqHostNoPortNoSubdomain == hostnameNoPortNoSubdomain {
- return true
- }
- }
- return false
- }
- func IPInRange(ipStr, cidrStr string) (bool, error) {
- _, cidrNet, err := net.ParseCIDR(cidrStr)
- if err != nil {
- return false, fmt.Errorf("parse CIDR range: %w", err)
- }
- ip := net.ParseIP(ipStr)
- if ip == nil {
- return false, fmt.Errorf("parse IP: invalid IP address")
- }
- return cidrNet.Contains(ip), nil
- }
- func Restrictions(RestrictToConstellation bool, WhitelistInboundIPs []string) func(next http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ip, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
- http.Error(w, "Invalid request", http.StatusBadRequest)
- return
- }
- isUsingWhiteList := len(WhitelistInboundIPs) > 0
- isInWhitelist := false
- isInConstellation := strings.HasPrefix(ip, "192.168.201.") || strings.HasPrefix(ip, "192.168.202.")
- for _, ipRange := range WhitelistInboundIPs {
- Debug("Checking if " + ip + " is in " + ipRange)
- if strings.Contains(ipRange, "/") {
- if ok, _ := IPInRange(ip, ipRange); ok {
- isInWhitelist = true
- }
- } else {
- if ip == ipRange {
- isInWhitelist = true
- }
- }
- }
- if(RestrictToConstellation) {
- if(!isInConstellation) {
- if(!isUsingWhiteList) {
- PushShieldMetrics("ip-whitelists")
- IncrementIPAbuseCounter(ip)
- Error("Request from " + ip + " is blocked because of restrictions", nil)
- Debug("Blocked by RestrictToConstellation isInConstellation isUsingWhiteList")
- http.Error(w, "Access denied", http.StatusForbidden)
- return
- } else if (!isInWhitelist) {
- PushShieldMetrics("ip-whitelists")
- IncrementIPAbuseCounter(ip)
- Error("Request from " + ip + " is blocked because of restrictions", nil)
- Debug("Blocked by RestrictToConstellation isInConstellation isInWhitelist")
- http.Error(w, "Access denied", http.StatusForbidden)
- return
- }
- }
- } else if(isUsingWhiteList && !isInWhitelist) {
- PushShieldMetrics("ip-whitelists")
- IncrementIPAbuseCounter(ip)
- Error("Request from " + ip + " is blocked because of restrictions", nil)
- Debug("Blocked by RestrictToConstellation isInConstellation isUsingWhiteList isInWhitelist")
- http.Error(w, "Access denied", http.StatusForbidden)
- return
- }
- next.ServeHTTP(w, r)
- })
- }
- }
|