浏览代码

Merge pull request #45986 from thaJeztah/libnetwork_cleanup_config

libnetwork/config: add Config.DriverConfig() and un-export DriverCfg
Bjorn Neergaard 2 年之前
父节点
当前提交
3f8ca3553d
共有 3 个文件被更改,包括 31 次插入36 次删除
  1. 8 4
      libnetwork/config/config.go
  2. 5 13
      libnetwork/controller.go
  3. 18 19
      libnetwork/firewall_linux_test.go

+ 8 - 4
libnetwork/config/config.go

@@ -25,7 +25,7 @@ type Config struct {
 	DefaultNetwork         string
 	DefaultDriver          string
 	Labels                 []string
-	DriverCfg              map[string]interface{}
+	driverCfg              map[string]map[string]any
 	ClusterProvider        cluster.Provider
 	NetworkControlPlaneMTU int
 	DefaultAddressPool     []*ipamutils.NetworkToSplit
@@ -37,7 +37,7 @@ type Config struct {
 // New creates a new Config and initializes it with the given Options.
 func New(opts ...Option) *Config {
 	cfg := &Config{
-		DriverCfg: make(map[string]interface{}),
+		driverCfg: make(map[string]map[string]any),
 	}
 
 	for _, opt := range opts {
@@ -53,6 +53,10 @@ func New(opts ...Option) *Config {
 	return cfg
 }
 
+func (c *Config) DriverConfig(name string) map[string]any {
+	return c.driverCfg[name]
+}
+
 // Option is an option setter function type used to pass various configurations
 // to the controller
 type Option func(c *Config)
@@ -81,9 +85,9 @@ func OptionDefaultAddressPoolConfig(addressPool []*ipamutils.NetworkToSplit) Opt
 }
 
 // OptionDriverConfig returns an option setter for driver configuration.
-func OptionDriverConfig(networkType string, config map[string]interface{}) Option {
+func OptionDriverConfig(networkType string, config map[string]any) Option {
 	return func(c *Config) {
-		c.DriverCfg[networkType] = config
+		c.driverCfg[networkType] = config
 	}
 }
 

+ 5 - 13
libnetwork/controller.go

@@ -335,11 +335,9 @@ func (c *Controller) makeDriverConfig(ntype string) map[string]interface{} {
 		cfg[key] = val
 	}
 
-	drvCfg, ok := c.cfg.DriverCfg[ntype]
-	if ok {
-		for k, v := range drvCfg.(map[string]interface{}) {
-			cfg[k] = v
-		}
+	// Merge in the existing config for this driver.
+	for k, v := range c.cfg.DriverConfig(ntype) {
+		cfg[k] = v
 	}
 
 	if c.cfg.Scope.IsValid() {
@@ -1146,10 +1144,7 @@ func (c *Controller) iptablesEnabled() bool {
 		return false
 	}
 	// parse map cfg["bridge"]["generic"]["EnableIPTable"]
-	cfgBridge, ok := c.cfg.DriverCfg["bridge"].(map[string]interface{})
-	if !ok {
-		return false
-	}
+	cfgBridge := c.cfg.DriverConfig("bridge")
 	cfgGeneric, ok := cfgBridge[netlabel.GenericData].(options.Generic)
 	if !ok {
 		return false
@@ -1170,10 +1165,7 @@ func (c *Controller) ip6tablesEnabled() bool {
 		return false
 	}
 	// parse map cfg["bridge"]["generic"]["EnableIP6Table"]
-	cfgBridge, ok := c.cfg.DriverCfg["bridge"].(map[string]interface{})
-	if !ok {
-		return false
-	}
+	cfgBridge := c.cfg.DriverConfig("bridge")
 	cfgGeneric, ok := cfgBridge[netlabel.GenericData].(options.Generic)
 	if !ok {
 		return false

+ 18 - 19
libnetwork/firewall_linux_test.go

@@ -5,11 +5,13 @@ import (
 	"strings"
 	"testing"
 
+	"github.com/docker/docker/libnetwork/config"
 	"github.com/docker/docker/libnetwork/iptables"
 	"github.com/docker/docker/libnetwork/netlabel"
 	"github.com/docker/docker/libnetwork/options"
 	"github.com/docker/docker/libnetwork/testutils"
 	"gotest.tools/v3/assert"
+	is "gotest.tools/v3/assert/cmp"
 )
 
 const (
@@ -52,46 +54,43 @@ func TestUserChain(t *testing.T) {
 			defer testutils.SetupTestOSContext(t)()
 			defer resetIptables(t)
 
-			c, err := New()
-			assert.NilError(t, err)
-			defer c.Stop()
-			c.cfg.DriverCfg["bridge"] = map[string]interface{}{
+			c, err := New(config.OptionDriverConfig("bridge", map[string]any{
 				netlabel.GenericData: options.Generic{
 					"EnableIPTables":  tc.iptables,
 					"EnableIP6Tables": tc.iptables,
 				},
-			}
+			}))
+			assert.NilError(t, err)
+			defer c.Stop()
 
 			// init. condition, FORWARD chain empty DOCKER-USER not exist
-			assert.DeepEqual(t, getRules(t, iptables.IPv4, fwdChainName), []string{"-P FORWARD ACCEPT"})
-			assert.DeepEqual(t, getRules(t, iptables.IPv6, fwdChainName), []string{"-P FORWARD ACCEPT"})
+			assert.Check(t, is.DeepEqual(getRules(t, iptable4, fwdChainName), []string{"-P FORWARD ACCEPT"}))
+			assert.Check(t, is.DeepEqual(getRules(t, iptable6, fwdChainName), []string{"-P FORWARD ACCEPT"}))
 
 			if tc.insert {
 				_, err = iptable4.Raw("-A", fwdChainName, "-j", "DROP")
-				assert.NilError(t, err)
+				assert.Check(t, err)
 				_, err = iptable6.Raw("-A", fwdChainName, "-j", "DROP")
-				assert.NilError(t, err)
+				assert.Check(t, err)
 			}
 			arrangeUserFilterRule()
 
-			assert.DeepEqual(t, getRules(t, iptables.IPv4, fwdChainName), tc.fwdChain)
-			assert.DeepEqual(t, getRules(t, iptables.IPv6, fwdChainName), tc.fwdChain)
+			assert.Check(t, is.DeepEqual(getRules(t, iptable4, fwdChainName), tc.fwdChain))
+			assert.Check(t, is.DeepEqual(getRules(t, iptable6, fwdChainName), tc.fwdChain))
 			if tc.userChain != nil {
-				assert.DeepEqual(t, getRules(t, iptables.IPv4, usrChainName), tc.userChain)
-				assert.DeepEqual(t, getRules(t, iptables.IPv6, usrChainName), tc.userChain)
+				assert.Check(t, is.DeepEqual(getRules(t, iptable4, usrChainName), tc.userChain))
+				assert.Check(t, is.DeepEqual(getRules(t, iptable6, usrChainName), tc.userChain))
 			} else {
-				_, err := iptable4.Raw("-S", usrChainName)
-				assert.Assert(t, err != nil, "ipv4 chain %v: created unexpectedly", usrChainName)
+				_, err = iptable4.Raw("-S", usrChainName)
+				assert.Check(t, is.ErrorContains(err, "No chain/target/match by that name"), "ipv4 chain %v: created unexpectedly", usrChainName)
 				_, err = iptable6.Raw("-S", usrChainName)
-				assert.Assert(t, err != nil, "ipv6 chain %v: created unexpectedly", usrChainName)
+				assert.Check(t, is.ErrorContains(err, "No chain/target/match by that name"), "ipv6 chain %v: created unexpectedly", usrChainName)
 			}
 		})
 	}
 }
 
-func getRules(t *testing.T, ipVer iptables.IPVersion, chain string) []string {
-	iptable := iptables.GetIptable(ipVer)
-
+func getRules(t *testing.T, iptable *iptables.IPTable, chain string) []string {
 	t.Helper()
 	output, err := iptable.Raw("-S", chain)
 	assert.NilError(t, err, "chain %s: failed to get rules", chain)