Parcourir la source

Merge pull request #591 from WeiZhang555/iptables-clean

Cleanup iptables after bridge network is removed
aboch il y a 9 ans
Parent
commit
5c562e2c33

+ 24 - 22
libnetwork/drivers/bridge/bridge.go

@@ -41,6 +41,9 @@ const (
 	DefaultGatewayV6AuxKey = "DefaultGatewayIPv6"
 	DefaultGatewayV6AuxKey = "DefaultGatewayIPv6"
 )
 )
 
 
+type iptableCleanFunc func() error
+type iptablesCleanFuncs []iptableCleanFunc
+
 // configuration info for the "bridge" driver.
 // configuration info for the "bridge" driver.
 type configuration struct {
 type configuration struct {
 	EnableIPForwarding  bool
 	EnableIPForwarding  bool
@@ -92,12 +95,13 @@ type bridgeEndpoint struct {
 }
 }
 
 
 type bridgeNetwork struct {
 type bridgeNetwork struct {
-	id         string
-	bridge     *bridgeInterface // The bridge's L3 interface
-	config     *networkConfiguration
-	endpoints  map[string]*bridgeEndpoint // key: endpoint id
-	portMapper *portmapper.PortMapper
-	driver     *driver // The network's driver
+	id            string
+	bridge        *bridgeInterface // The bridge's L3 interface
+	config        *networkConfiguration
+	endpoints     map[string]*bridgeEndpoint // key: endpoint id
+	portMapper    *portmapper.PortMapper
+	driver        *driver // The network's driver
+	iptCleanFuncs iptablesCleanFuncs
 	sync.Mutex
 	sync.Mutex
 }
 }
 
 
@@ -236,6 +240,10 @@ func parseErr(label, value, errString string) error {
 	return types.BadRequestErrorf("failed to parse %s value: %v (%s)", label, value, errString)
 	return types.BadRequestErrorf("failed to parse %s value: %v (%s)", label, value, errString)
 }
 }
 
 
+func (n *bridgeNetwork) registerIptCleanFunc(clean iptableCleanFunc) {
+	n.iptCleanFuncs = append(n.iptCleanFuncs, clean)
+}
+
 func (n *bridgeNetwork) getDriverChains() (*iptables.ChainInfo, *iptables.ChainInfo, error) {
 func (n *bridgeNetwork) getDriverChains() (*iptables.ChainInfo, *iptables.ChainInfo, error) {
 	n.Lock()
 	n.Lock()
 	defer n.Unlock()
 	defer n.Unlock()
@@ -604,6 +612,10 @@ func (d *driver) createNetwork(config *networkConfiguration) error {
 			}
 			}
 			return err
 			return err
 		}
 		}
+		network.registerIptCleanFunc(func() error {
+			nwList := d.getNetworks()
+			return network.isolateNetwork(nwList, false)
+		})
 		return nil
 		return nil
 	}
 	}
 
 
@@ -722,22 +734,6 @@ func (d *driver) DeleteNetwork(nid string) error {
 		return err
 		return err
 	}
 	}
 
 
-	// In case of failures after this point, restore the network isolation rules
-	nwList := d.getNetworks()
-	defer func() {
-		if err != nil {
-			if err := n.isolateNetwork(nwList, true); err != nil {
-				logrus.Warnf("Failed on restoring the inter-network iptables rules on cleanup: %v", err)
-			}
-		}
-	}()
-
-	// Remove inter-network communication rules.
-	err = n.isolateNetwork(nwList, false)
-	if err != nil {
-		return err
-	}
-
 	// We only delete the bridge when it's not the default bridge. This is keep the backward compatible behavior.
 	// We only delete the bridge when it's not the default bridge. This is keep the backward compatible behavior.
 	if !config.DefaultBridge {
 	if !config.DefaultBridge {
 		if err := netlink.LinkDel(n.bridge.Link); err != nil {
 		if err := netlink.LinkDel(n.bridge.Link); err != nil {
@@ -745,6 +741,12 @@ func (d *driver) DeleteNetwork(nid string) error {
 		}
 		}
 	}
 	}
 
 
+	// clean all relevant iptables rules
+	for _, cleanFunc := range n.iptCleanFuncs {
+		if errClean := cleanFunc(); errClean != nil {
+			logrus.Warnf("Failed to clean iptables rules for bridge network: %v", errClean)
+		}
+	}
 	return d.storeDelete(config)
 	return d.storeDelete(config)
 }
 }
 
 

+ 8 - 2
libnetwork/drivers/bridge/setup_ip_tables.go

@@ -68,21 +68,27 @@ func (n *bridgeNetwork) setupIPTables(config *networkConfiguration, i *bridgeInt
 	if err = setupIPTablesInternal(config.BridgeName, maskedAddrv4, config.EnableICC, config.EnableIPMasquerade, hairpinMode, true); err != nil {
 	if err = setupIPTablesInternal(config.BridgeName, maskedAddrv4, config.EnableICC, config.EnableIPMasquerade, hairpinMode, true); err != nil {
 		return fmt.Errorf("Failed to Setup IP tables: %s", err.Error())
 		return fmt.Errorf("Failed to Setup IP tables: %s", err.Error())
 	}
 	}
+	n.registerIptCleanFunc(func() error {
+		return setupIPTablesInternal(config.BridgeName, maskedAddrv4, config.EnableICC, config.EnableIPMasquerade, hairpinMode, false)
+	})
 
 
 	natChain, filterChain, err := n.getDriverChains()
 	natChain, filterChain, err := n.getDriverChains()
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("Failed to setup IP tables, cannot acquire chain info %s", err.Error())
 		return fmt.Errorf("Failed to setup IP tables, cannot acquire chain info %s", err.Error())
 	}
 	}
 
 
-	err = iptables.ProgramChain(natChain, config.BridgeName, hairpinMode)
+	err = iptables.ProgramChain(natChain, config.BridgeName, hairpinMode, true)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("Failed to program NAT chain: %s", err.Error())
 		return fmt.Errorf("Failed to program NAT chain: %s", err.Error())
 	}
 	}
 
 
-	err = iptables.ProgramChain(filterChain, config.BridgeName, hairpinMode)
+	err = iptables.ProgramChain(filterChain, config.BridgeName, hairpinMode, true)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("Failed to program FILTER chain: %s", err.Error())
 		return fmt.Errorf("Failed to program FILTER chain: %s", err.Error())
 	}
 	}
+	n.registerIptCleanFunc(func() error {
+		return iptables.ProgramChain(filterChain, config.BridgeName, hairpinMode, false)
+	})
 
 
 	n.portMapper.SetIptablesChain(filterChain, n.getNetworkBridgeName())
 	n.portMapper.SetIptablesChain(filterChain, n.getNetworkBridgeName())
 
 

+ 1 - 1
libnetwork/iptables/firewalld_test.go

@@ -22,7 +22,7 @@ func TestReloaded(t *testing.T) {
 	fwdChain, err = NewChain("FWD", Filter, false)
 	fwdChain, err = NewChain("FWD", Filter, false)
 	bridgeName := "lo"
 	bridgeName := "lo"
 
 
-	err = ProgramChain(fwdChain, bridgeName, false)
+	err = ProgramChain(fwdChain, bridgeName, false, true)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}

+ 20 - 4
libnetwork/iptables/iptables.go

@@ -95,7 +95,7 @@ func NewChain(name string, table Table, hairpinMode bool) (*ChainInfo, error) {
 }
 }
 
 
 // ProgramChain is used to add rules to a chain
 // ProgramChain is used to add rules to a chain
-func ProgramChain(c *ChainInfo, bridgeName string, hairpinMode bool) error {
+func ProgramChain(c *ChainInfo, bridgeName string, hairpinMode, enable bool) error {
 	if c.Name == "" {
 	if c.Name == "" {
 		return fmt.Errorf("Could not program chain, missing chain name.")
 		return fmt.Errorf("Could not program chain, missing chain name.")
 	}
 	}
@@ -106,10 +106,14 @@ func ProgramChain(c *ChainInfo, bridgeName string, hairpinMode bool) error {
 			"-m", "addrtype",
 			"-m", "addrtype",
 			"--dst-type", "LOCAL",
 			"--dst-type", "LOCAL",
 			"-j", c.Name}
 			"-j", c.Name}
-		if !Exists(Nat, "PREROUTING", preroute...) {
+		if !Exists(Nat, "PREROUTING", preroute...) && enable {
 			if err := c.Prerouting(Append, preroute...); err != nil {
 			if err := c.Prerouting(Append, preroute...); err != nil {
 				return fmt.Errorf("Failed to inject docker in PREROUTING chain: %s", err)
 				return fmt.Errorf("Failed to inject docker in PREROUTING chain: %s", err)
 			}
 			}
+		} else if Exists(Nat, "PREROUTING", preroute...) && !enable {
+			if err := c.Prerouting(Delete, preroute...); err != nil {
+				return fmt.Errorf("Failed to remove docker in PREROUTING chain: %s", err)
+			}
 		}
 		}
 		output := []string{
 		output := []string{
 			"-m", "addrtype",
 			"-m", "addrtype",
@@ -118,10 +122,14 @@ func ProgramChain(c *ChainInfo, bridgeName string, hairpinMode bool) error {
 		if !hairpinMode {
 		if !hairpinMode {
 			output = append(output, "!", "--dst", "127.0.0.0/8")
 			output = append(output, "!", "--dst", "127.0.0.0/8")
 		}
 		}
-		if !Exists(Nat, "OUTPUT", output...) {
+		if !Exists(Nat, "OUTPUT", output...) && enable {
 			if err := c.Output(Append, output...); err != nil {
 			if err := c.Output(Append, output...); err != nil {
 				return fmt.Errorf("Failed to inject docker in OUTPUT chain: %s", err)
 				return fmt.Errorf("Failed to inject docker in OUTPUT chain: %s", err)
 			}
 			}
+		} else if Exists(Nat, "OUTPUT", output...) && !enable {
+			if err := c.Output(Delete, output...); err != nil {
+				return fmt.Errorf("Failed to inject docker in OUTPUT chain: %s", err)
+			}
 		}
 		}
 	case Filter:
 	case Filter:
 		if bridgeName == "" {
 		if bridgeName == "" {
@@ -131,13 +139,21 @@ func ProgramChain(c *ChainInfo, bridgeName string, hairpinMode bool) error {
 		link := []string{
 		link := []string{
 			"-o", bridgeName,
 			"-o", bridgeName,
 			"-j", c.Name}
 			"-j", c.Name}
-		if !Exists(Filter, "FORWARD", link...) {
+		if !Exists(Filter, "FORWARD", link...) && enable {
 			insert := append([]string{string(Insert), "FORWARD"}, link...)
 			insert := append([]string{string(Insert), "FORWARD"}, link...)
 			if output, err := Raw(insert...); err != nil {
 			if output, err := Raw(insert...); err != nil {
 				return err
 				return err
 			} else if len(output) != 0 {
 			} else if len(output) != 0 {
 				return fmt.Errorf("Could not create linking rule to %s/%s: %s", c.Table, c.Name, output)
 				return fmt.Errorf("Could not create linking rule to %s/%s: %s", c.Table, c.Name, output)
 			}
 			}
+		} else if Exists(Filter, "FORWARD", link...) && !enable {
+			del := append([]string{string(Delete), "FORWARD"}, link...)
+			if output, err := Raw(del...); err != nil {
+				return err
+			} else if len(output) != 0 {
+				return fmt.Errorf("Could not delete linking rule from %s/%s: %s", c.Table, c.Name, output)
+			}
+
 		}
 		}
 	}
 	}
 	return nil
 	return nil

+ 2 - 2
libnetwork/iptables/iptables_test.go

@@ -22,13 +22,13 @@ func TestNewChain(t *testing.T) {
 
 
 	bridgeName = "lo"
 	bridgeName = "lo"
 	natChain, err = NewChain(chainName, Nat, false)
 	natChain, err = NewChain(chainName, Nat, false)
-	err = ProgramChain(natChain, bridgeName, false)
+	err = ProgramChain(natChain, bridgeName, false, true)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
 	filterChain, err = NewChain(chainName, Filter, false)
 	filterChain, err = NewChain(chainName, Filter, false)
-	err = ProgramChain(filterChain, bridgeName, false)
+	err = ProgramChain(filterChain, bridgeName, false, true)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}