Sfoglia il codice sorgente

Merge pull request #45987 from thaJeztah/cleanup_iptables_the_sequel

libnetwork/iptables: some cleanups and refactoring: the sequel
Sebastiaan van Stijn 2 anni fa
parent
commit
6025938ee9

+ 1 - 1
libnetwork/drivers/bridge/setup_ip_tables.go

@@ -376,7 +376,7 @@ func setINC(version iptables.IPVersion, iface string, enable bool) error {
 const oldIsolationChain = "DOCKER-ISOLATION"
 
 func removeIPChains(version iptables.IPVersion) {
-	ipt := iptables.IPTable{Version: version}
+	ipt := iptables.GetIptable(version)
 
 	// Remove obsolete rules from default chains
 	ipt.ProgramRule(iptables.Filter, "FORWARD", iptables.Delete, []string{"-j", oldIsolationChain})

+ 1 - 1
libnetwork/firewall_linux_test.go

@@ -110,6 +110,6 @@ func resetIptables(t *testing.T) {
 
 		_, err := iptable.Raw("-F", fwdChainName)
 		assert.Check(t, err)
-		_ = iptable.RemoveExistingChain(usrChainName, "")
+		_ = iptable.RemoveExistingChain(usrChainName, iptables.Filter)
 	}
 }

+ 2 - 2
libnetwork/iptables/firewalld.go

@@ -65,8 +65,8 @@ var (
 	onReloaded       []*func() // callbacks when Firewalld has been reloaded
 )
 
-// FirewalldInit initializes firewalld management code.
-func FirewalldInit() error {
+// firewalldInit initializes firewalld management code.
+func firewalldInit() error {
 	var err error
 
 	if connection, err = newConnection(); err != nil {

+ 1 - 1
libnetwork/iptables/firewalld_test.go

@@ -13,7 +13,7 @@ func TestFirewalldInit(t *testing.T) {
 	if !checkRunning() {
 		t.Skip("firewalld is not running")
 	}
-	if err := FirewalldInit(); err != nil {
+	if err := firewalldInit(); err != nil {
 		t.Fatal(err)
 	}
 }

+ 88 - 65
libnetwork/iptables/iptables.go

@@ -21,15 +21,6 @@ import (
 // Action signifies the iptable action.
 type Action string
 
-// Policy is the default iptable policies
-type Policy string
-
-// Table refers to Nat, Filter or Mangle.
-type Table string
-
-// IPVersion refers to IP version, v4 or v6
-type IPVersion string
-
 const (
 	// Append appends the rule at the end of the chain.
 	Append Action = "-A"
@@ -37,19 +28,37 @@ const (
 	Delete Action = "-D"
 	// Insert inserts the rule at the top of the chain.
 	Insert Action = "-I"
+)
+
+// Policy is the default iptable policies
+type Policy string
+
+const (
+	// Drop is the default iptables DROP policy.
+	Drop Policy = "DROP"
+	// Accept is the default iptables ACCEPT policy.
+	Accept Policy = "ACCEPT"
+)
+
+// Table refers to Nat, Filter or Mangle.
+type Table string
+
+const (
 	// Nat table is used for nat translation rules.
 	Nat Table = "nat"
 	// Filter table is used for filter rules.
 	Filter Table = "filter"
 	// Mangle table is used for mangling the packet.
 	Mangle Table = "mangle"
-	// Drop is the default iptables DROP policy
-	Drop Policy = "DROP"
-	// Accept is the default iptables ACCEPT policy
-	Accept Policy = "ACCEPT"
-	// IPv4 is version 4
+)
+
+// IPVersion refers to IP version, v4 or v6
+type IPVersion string
+
+const (
+	// IPv4 is version 4.
 	IPv4 IPVersion = "IPV4"
-	// IPv6 is version 6
+	// IPv6 is version 6.
 	IPv6 IPVersion = "IPV6"
 )
 
@@ -57,15 +66,14 @@ var (
 	iptablesPath  string
 	ip6tablesPath string
 	supportsXlock = false
-	xLockWaitMsg  = "Another app is currently holding the xtables lock"
 	// used to lock iptables commands if xtables lock is not supported
 	bestEffortLock sync.Mutex
 	initOnce       sync.Once
 )
 
-// IPTable defines struct with IPVersion
+// IPTable defines struct with [IPVersion].
 type IPTable struct {
-	Version IPVersion
+	ipVersion IPVersion
 }
 
 // ChainInfo defines the iptables chain.
@@ -86,6 +94,19 @@ func (e ChainError) Error() string {
 	return fmt.Sprintf("error iptables %s: %s", e.Chain, string(e.Output))
 }
 
+// loopbackAddress returns the loopback address for the given IP version.
+func loopbackAddress(version IPVersion) string {
+	switch version {
+	case IPv4, "":
+		// IPv4 (default for backward-compatibility)
+		return "127.0.0.0/8"
+	case IPv6:
+		return "::1/128"
+	default:
+		panic("unknown IP version: " + version)
+	}
+}
+
 func detectIptables() {
 	path, err := exec.LookPath("iptables")
 	if err != nil {
@@ -117,7 +138,7 @@ func initFirewalld() {
 		log.G(context.TODO()).Info("skipping firewalld management for rootless mode")
 		return
 	}
-	if err := FirewalldInit(); err != nil {
+	if err := firewalldInit(); err != nil {
 		log.G(context.TODO()).WithError(err).Debugf("unable to initialize firewalld; using raw iptables instead")
 	}
 }
@@ -136,15 +157,28 @@ func initCheck() error {
 	return nil
 }
 
-// GetIptable returns an instance of IPTable with specified version
+// GetIptable returns an instance of IPTable with specified version ([IPv4]
+// or [IPv6]). It panics if an invalid [IPVersion] is provided.
 func GetIptable(version IPVersion) *IPTable {
-	return &IPTable{Version: version}
+	switch version {
+	case IPv4, IPv6:
+		// valid version
+	case "":
+		// default is IPv4 for backward-compatibility
+		version = IPv4
+	default:
+		panic("unknown IP version: " + version)
+	}
+	return &IPTable{ipVersion: version}
 }
 
 // NewChain adds a new chain to ip table.
 func (iptable IPTable) NewChain(name string, table Table, hairpinMode bool) (*ChainInfo, error) {
+	if name == "" {
+		return nil, fmt.Errorf("could not create chain: chain name is empty")
+	}
 	if table == "" {
-		table = Filter
+		return nil, fmt.Errorf("could not create chain %s: invalid table name: table name is empty", name)
 	}
 	// Add chain if it doesn't exist
 	if _, err := iptable.Raw("-t", string(table), "-n", "-L", name); err != nil {
@@ -158,18 +192,10 @@ func (iptable IPTable) NewChain(name string, table Table, hairpinMode bool) (*Ch
 		Name:        name,
 		Table:       table,
 		HairpinMode: hairpinMode,
-		IPVersion:   iptable.Version,
+		IPVersion:   iptable.ipVersion,
 	}, nil
 }
 
-// LoopbackByVersion returns loopback address by version
-func (iptable IPTable) LoopbackByVersion() string {
-	if iptable.Version == IPv6 {
-		return "::1/128"
-	}
-	return "127.0.0.0/8"
-}
-
 // ProgramChain is used to add rules to a chain
 func (iptable IPTable) ProgramChain(c *ChainInfo, bridgeName string, hairpinMode, enable bool) error {
 	if c.Name == "" {
@@ -211,7 +237,7 @@ func (iptable IPTable) ProgramChain(c *ChainInfo, bridgeName string, hairpinMode
 			"-j", c.Name,
 		}
 		if !hairpinMode {
-			output = append(output, "!", "--dst", iptable.LoopbackByVersion())
+			output = append(output, "!", "--dst", loopbackAddress(iptable.ipVersion))
 		}
 		if !iptable.Exists(Nat, "OUTPUT", output...) && enable {
 			if err := c.Output(Append, output...); err != nil {
@@ -272,13 +298,16 @@ func (iptable IPTable) ProgramChain(c *ChainInfo, bridgeName string, hairpinMode
 
 // RemoveExistingChain removes existing chain from the table.
 func (iptable IPTable) RemoveExistingChain(name string, table Table) error {
+	if name == "" {
+		return fmt.Errorf("could not remove chain: chain name is empty")
+	}
 	if table == "" {
-		table = Filter
+		return fmt.Errorf("could not remove chain %s: invalid table name: table name is empty", name)
 	}
 	c := &ChainInfo{
 		Name:      name,
 		Table:     table,
-		IPVersion: iptable.Version,
+		IPVersion: iptable.ipVersion,
 	}
 	return c.Remove()
 }
@@ -419,15 +448,15 @@ func (c *ChainInfo) Output(action Action, args ...string) error {
 
 // Remove removes the chain.
 func (c *ChainInfo) Remove() error {
-	iptable := GetIptable(c.IPVersion)
 	// Ignore errors - This could mean the chains were never set up
 	if c.Table == Nat {
 		_ = c.Prerouting(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "-j", c.Name)
-		_ = c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "!", "--dst", iptable.LoopbackByVersion(), "-j", c.Name)
+		_ = c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "!", "--dst", loopbackAddress(c.IPVersion), "-j", c.Name)
 		_ = c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "-j", c.Name) // Created in versions <= 0.1.6
 		_ = c.Prerouting(Delete)
 		_ = c.Output(Delete)
 	}
+	iptable := GetIptable(c.IPVersion)
 	_, _ = iptable.Raw("-t", string(c.Table), "-F", c.Name)
 	_, _ = iptable.Raw("-t", string(c.Table), "-X", c.Name)
 	return nil
@@ -465,14 +494,17 @@ func (iptable IPTable) exists(native bool, table Table, chain string, rule ...st
 	return err == nil
 }
 
-// Maximum duration that an iptables operation can take
-// before flagging a warning.
-const opWarnTime = 2 * time.Second
+const (
+	// opWarnTime is the maximum duration that an iptables operation can take before flagging a warning.
+	opWarnTime = 2 * time.Second
+
+	// xLockWaitMsg is the iptables warning about xtables lock that can be suppressed.
+	xLockWaitMsg = "Another app is currently holding the xtables lock"
+)
 
 func filterOutput(start time.Time, output []byte, args ...string) []byte {
-	// Flag operations that have taken a long time to complete
-	opTime := time.Since(start)
-	if opTime > opWarnTime {
+	if opTime := time.Since(start); opTime > opWarnTime {
+		// Flag operations that have taken a long time to complete
 		log.G(context.TODO()).Warnf("xtables contention detected while running [%s]: Waited for %.2f seconds and received %q", strings.Join(args, " "), float64(opTime)/float64(time.Second), string(output))
 	}
 	// ignore iptables' message about xtables lock:
@@ -489,7 +521,7 @@ func (iptable IPTable) Raw(args ...string) ([]byte, error) {
 	if firewalldRunning {
 		// select correct IP version for firewalld
 		ipv := Iptables
-		if iptable.Version == IPv6 {
+		if iptable.ipVersion == IPv6 {
 			ipv = IP6Tables
 		}
 
@@ -506,16 +538,9 @@ func (iptable IPTable) raw(args ...string) ([]byte, error) {
 	if err := initCheck(); err != nil {
 		return nil, err
 	}
-	if supportsXlock {
-		args = append([]string{"--wait"}, args...)
-	} else {
-		bestEffortLock.Lock()
-		defer bestEffortLock.Unlock()
-	}
-
 	path := iptablesPath
 	commandName := "iptables"
-	if iptable.Version == IPv6 {
+	if iptable.ipVersion == IPv6 {
 		if ip6tablesPath == "" {
 			return nil, fmt.Errorf("ip6tables is missing")
 		}
@@ -523,6 +548,13 @@ func (iptable IPTable) raw(args ...string) ([]byte, error) {
 		commandName = "ip6tables"
 	}
 
+	if supportsXlock {
+		args = append([]string{"--wait"}, args...)
+	} else {
+		bestEffortLock.Lock()
+		defer bestEffortLock.Unlock()
+	}
+
 	log.G(context.TODO()).Debugf("%s, %v", path, args)
 
 	startTime := time.Now()
@@ -554,10 +586,8 @@ func (iptable IPTable) RawCombinedOutputNative(args ...string) error {
 
 // ExistChain checks if a chain exists
 func (iptable IPTable) ExistChain(chain string, table Table) bool {
-	if _, err := iptable.Raw("-t", string(table), "-nL", chain); err == nil {
-		return true
-	}
-	return false
+	_, err := iptable.Raw("-t", string(table), "-nL", chain)
+	return err == nil
 }
 
 // SetDefaultPolicy sets the passed default policy for the table/chain
@@ -573,28 +603,21 @@ func (iptable IPTable) AddReturnRule(chain string) error {
 	if iptable.Exists(Filter, chain, "-j", "RETURN") {
 		return nil
 	}
-
-	err := iptable.RawCombinedOutput("-A", chain, "-j", "RETURN")
-	if err != nil {
+	if err := iptable.RawCombinedOutput("-A", chain, "-j", "RETURN"); err != nil {
 		return fmt.Errorf("unable to add return rule in %s chain: %v", chain, err)
 	}
-
 	return nil
 }
 
 // EnsureJumpRule ensures the jump rule is on top
 func (iptable IPTable) EnsureJumpRule(fromChain, toChain string) error {
 	if iptable.Exists(Filter, fromChain, "-j", toChain) {
-		err := iptable.RawCombinedOutput("-D", fromChain, "-j", toChain)
-		if err != nil {
+		if err := iptable.RawCombinedOutput("-D", fromChain, "-j", toChain); err != nil {
 			return fmt.Errorf("unable to remove jump to %s rule in %s chain: %v", toChain, fromChain, err)
 		}
 	}
-
-	err := iptable.RawCombinedOutput("-I", fromChain, "-j", toChain)
-	if err != nil {
+	if err := iptable.RawCombinedOutput("-I", fromChain, "-j", toChain); err != nil {
 		return fmt.Errorf("unable to insert jump to %s rule in %s chain: %v", toChain, fromChain, err)
 	}
-
 	return nil
 }