Browse Source

libnetwork/iptables: implement passthrough as method

Move the Passthrough implementation to a method on firewalldConnection,
and add a check if firewalld is initialized and running.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
Sebastiaan van Stijn 1 year ago
parent
commit
9efb1aabeb

+ 9 - 3
libnetwork/iptables/firewalld.go

@@ -215,8 +215,14 @@ func (fwd *firewalldConnection) isRunning() bool {
 	return fwd.running.Load()
 }
 
-// Passthrough method simply passes args through to iptables/ip6tables
-func Passthrough(ipVersion IPVersion, args ...string) ([]byte, error) {
+// passthrough passes args through to iptables or ip6tables.
+//
+// It is a no-op if firewalld is not running or not initialized.
+func (fwd *firewalldConnection) passthrough(ipVersion IPVersion, args ...string) ([]byte, error) {
+	if !fwd.isRunning() {
+		return []byte(""), nil
+	}
+
 	// select correct IP version for firewalld
 	ipv := ipTables
 	if ipVersion == IPv6 {
@@ -225,7 +231,7 @@ func Passthrough(ipVersion IPVersion, args ...string) ([]byte, error) {
 
 	var output string
 	log.G(context.TODO()).Debugf("Firewalld passthrough: %s, %s", ipv, args)
-	if err := firewalld.sysObj.Call(dbusInterface+".direct.passthrough", 0, ipv, args).Store(&output); err != nil {
+	if err := fwd.sysObj.Call(dbusInterface+".direct.passthrough", 0, ipv, args).Store(&output); err != nil {
 		return nil, err
 	}
 	return []byte(output), nil

+ 14 - 3
libnetwork/iptables/firewalld_test.go

@@ -93,6 +93,13 @@ func TestReloaded(t *testing.T) {
 
 func TestPassthrough(t *testing.T) {
 	skipIfNoFirewalld(t)
+
+	fwd, err := newConnection()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer fwd.conn.Close()
+
 	rule1 := []string{
 		"-i", "lo",
 		"-p", "udp",
@@ -100,12 +107,12 @@ func TestPassthrough(t *testing.T) {
 		"-j", "ACCEPT",
 	}
 
-	_, err := Passthrough(IPv4, append([]string{"-A"}, rule1...)...)
+	_, err = fwd.passthrough(IPv4, append([]string{"-A"}, rule1...)...)
 	if err != nil {
-		t.Fatal(err)
+		t.Error(err)
 	}
 	if !GetIptable(IPv4).Exists(Filter, "INPUT", rule1...) {
-		t.Fatal("rule1 does not exist")
+		t.Error("rule1 does not exist")
 	}
 }
 
@@ -126,4 +133,8 @@ func TestFirewalldUninitialized(t *testing.T) {
 		t.Errorf("unexpected error when calling delInterface on an uninitialized firewalldConnection: %v", err)
 	}
 	fwd.registerReloadCallback(func() {})
+	_, err = fwd.passthrough(IPv4)
+	if err != nil {
+		t.Errorf("unexpected error when calling passthrough on an uninitialized firewalldConnection: %v", err)
+	}
 }

+ 1 - 1
libnetwork/iptables/iptables.go

@@ -520,7 +520,7 @@ func filterOutput(start time.Time, output []byte, args ...string) []byte {
 func (iptable IPTable) Raw(args ...string) ([]byte, error) {
 	if firewalld.isRunning() {
 		startTime := time.Now()
-		output, err := Passthrough(iptable.ipVersion, args...)
+		output, err := firewalld.passthrough(iptable.ipVersion, args...)
 		if err == nil || !strings.Contains(err.Error(), "was not provided by any .service files") {
 			return filterOutput(startTime, output, args...), err
 		}