diff --git a/libnetwork/controller.go b/libnetwork/controller.go index a01e4fa3d1..87315c1291 100644 --- a/libnetwork/controller.go +++ b/libnetwork/controller.go @@ -1247,3 +1247,23 @@ func (c *Controller) iptablesEnabled() bool { } return enabled } + +func (c *controller) ip6tablesEnabled() bool { + c.Lock() + defer c.Unlock() + + if c.cfg == nil { + return false + } + // parse map cfg["bridge"]["generic"]["EnableIP6Table"] + cfgBridge, ok := c.cfg.DriverCfg["bridge"].(map[string]interface{}) + if !ok { + return false + } + cfgGeneric, ok := cfgBridge[netlabel.GenericData].(options.Generic) + if !ok { + return false + } + enabled, _ := cfgGeneric["EnableIP6Tables"].(bool) + return enabled +} diff --git a/libnetwork/firewall_linux.go b/libnetwork/firewall_linux.go index 0379744c11..4eef752d19 100644 --- a/libnetwork/firewall_linux.go +++ b/libnetwork/firewall_linux.go @@ -21,24 +21,40 @@ func setupArrangeUserFilterRule(c *Controller) { // IPTableForwarding is disabled, because it contains rules configured by user that // are beyond docker engine's control. func arrangeUserFilterRule() { - if ctrl == nil || !ctrl.iptablesEnabled() { - return - } - // TODO IPv6 support - iptable := iptables.GetIptable(iptables.IPv4) - _, err := iptable.NewChain(userChain, iptables.Filter, false) - if err != nil { - logrus.Warnf("Failed to create %s chain: %v", userChain, err) + if ctrl == nil { return } - if err = iptable.AddReturnRule(userChain); err != nil { - logrus.Warnf("Failed to add the RETURN rule for %s: %v", userChain, err) - return + conds := []struct { + ipVer iptables.IPVersion + cond bool + }{ + {ipVer: iptables.IPv4, cond: ctrl.iptablesEnabled()}, + {ipVer: iptables.IPv6, cond: ctrl.ip6tablesEnabled()}, } - err = iptable.EnsureJumpRule("FORWARD", userChain) - if err != nil { - logrus.Warnf("Failed to ensure the jump rule for %s: %v", userChain, err) + for _, ipVerCond := range conds { + cond := ipVerCond.cond + if !cond { + continue + } + + ipVer := ipVerCond.ipVer + iptable := iptables.GetIptable(ipVer) + _, err := iptable.NewChain(userChain, iptables.Filter, false) + if err != nil { + logrus.WithError(err).Warnf("Failed to create %s %v chain", userChain, ipVer) + return + } + + if err = iptable.AddReturnRule(userChain); err != nil { + logrus.WithError(err).Warnf("Failed to add the RETURN rule for %s %v", userChain, ipVer) + return + } + + err = iptable.EnsureJumpRule("FORWARD", userChain) + if err != nil { + logrus.WithError(err).Warnf("Failed to ensure the jump rule for %s %v", userChain, ipVer) + } } } diff --git a/libnetwork/firewall_linux_test.go b/libnetwork/firewall_linux_test.go index 226bf4dfa5..091a91c6b7 100644 --- a/libnetwork/firewall_linux_test.go +++ b/libnetwork/firewall_linux_test.go @@ -18,7 +18,8 @@ const ( ) func TestUserChain(t *testing.T) { - iptable := iptables.GetIptable(iptables.IPv4) + iptable4 := iptables.GetIptable(iptables.IPv4) + iptable6 := iptables.GetIptable(iptables.IPv6) tests := []struct { iptables bool @@ -56,32 +57,40 @@ func TestUserChain(t *testing.T) { defer c.Stop() c.cfg.DriverCfg["bridge"] = map[string]interface{}{ netlabel.GenericData: options.Generic{ - "EnableIPTables": tc.iptables, + "EnableIPTables": tc.iptables, + "EnableIP6Tables": tc.iptables, }, } // init. condition, FORWARD chain empty DOCKER-USER not exist - assert.DeepEqual(t, getRules(t, fwdChainName), []string{"-P FORWARD ACCEPT"}) + assert.DeepEqual(t, getRules(t, iptables.IPv4, fwdChainName), []string{"-P FORWARD ACCEPT"}) + assert.DeepEqual(t, getRules(t, iptables.IPv6, fwdChainName), []string{"-P FORWARD ACCEPT"}) if tc.insert { - _, err = iptable.Raw("-A", fwdChainName, "-j", "DROP") + _, err = iptable4.Raw("-A", fwdChainName, "-j", "DROP") + assert.NilError(t, err) + _, err = iptable6.Raw("-A", fwdChainName, "-j", "DROP") assert.NilError(t, err) } arrangeUserFilterRule() - assert.DeepEqual(t, getRules(t, fwdChainName), tc.fwdChain) + assert.DeepEqual(t, getRules(t, iptables.IPv4, fwdChainName), tc.fwdChain) + assert.DeepEqual(t, getRules(t, iptables.IPv6, fwdChainName), tc.fwdChain) if tc.userChain != nil { - assert.DeepEqual(t, getRules(t, usrChainName), tc.userChain) + assert.DeepEqual(t, getRules(t, iptables.IPv4, usrChainName), tc.userChain) + assert.DeepEqual(t, getRules(t, iptables.IPv6, usrChainName), tc.userChain) } else { - _, err := iptable.Raw("-S", usrChainName) - assert.Assert(t, err != nil, "chain %v: created unexpectedly", usrChainName) + _, err := iptable4.Raw("-S", usrChainName) + assert.Assert(t, err != nil, "ipv4 chain %v: created unexpectedly", usrChainName) + _, err = iptable6.Raw("-S", usrChainName) + assert.Assert(t, err != nil, "ipv6 chain %v: created unexpectedly", usrChainName) } }) } } -func getRules(t *testing.T, chain string) []string { - iptable := iptables.GetIptable(iptables.IPv4) +func getRules(t *testing.T, ipVer iptables.IPVersion, chain string) []string { + iptable := iptables.GetIptable(ipVer) t.Helper() output, err := iptable.Raw("-S", chain) @@ -95,10 +104,13 @@ func getRules(t *testing.T, chain string) []string { } func resetIptables(t *testing.T) { - iptable := iptables.GetIptable(iptables.IPv4) - t.Helper() - _, err := iptable.Raw("-F", fwdChainName) - assert.NilError(t, err) - _ = iptable.RemoveExistingChain(usrChainName, "") + + for _, ipVer := range []iptables.IPVersion{iptables.IPv4, iptables.IPv6} { + iptable := iptables.GetIptable(ipVer) + + _, err := iptable.Raw("-F", fwdChainName) + assert.Check(t, err) + _ = iptable.RemoveExistingChain(usrChainName, "") + } }