123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- //go:build linux
- package iptables
- import (
- "net"
- "os/exec"
- "strconv"
- "strings"
- "testing"
- "golang.org/x/sync/errgroup"
- )
- const (
- chainName = "DOCKEREST"
- bridgeName = "lo"
- )
- func createNewChain(t *testing.T) (*IPTable, *ChainInfo, *ChainInfo) {
- t.Helper()
- iptable := GetIptable(IPv4)
- natChain, err := iptable.NewChain(chainName, Nat, false)
- if err != nil {
- t.Fatal(err)
- }
- err = iptable.ProgramChain(natChain, bridgeName, false, true)
- if err != nil {
- t.Fatal(err)
- }
- filterChain, err := iptable.NewChain(chainName, Filter, false)
- if err != nil {
- t.Fatal(err)
- }
- err = iptable.ProgramChain(filterChain, bridgeName, false, true)
- if err != nil {
- t.Fatal(err)
- }
- return iptable, natChain, filterChain
- }
- func TestNewChain(t *testing.T) {
- createNewChain(t)
- }
- func TestForward(t *testing.T) {
- iptable, natChain, filterChain := createNewChain(t)
- ip := net.ParseIP("192.168.1.1")
- port := 1234
- dstAddr := "172.17.0.1"
- dstPort := 4321
- proto := "tcp"
- err := natChain.Forward(Insert, ip, port, proto, dstAddr, dstPort, bridgeName)
- if err != nil {
- t.Fatal(err)
- }
- dnatRule := []string{
- "-d", ip.String(),
- "-p", proto,
- "--dport", strconv.Itoa(port),
- "-j", "DNAT",
- "--to-destination", dstAddr + ":" + strconv.Itoa(dstPort),
- "!", "-i", bridgeName,
- }
- if !iptable.Exists(natChain.Table, natChain.Name, dnatRule...) {
- t.Fatal("DNAT rule does not exist")
- }
- filterRule := []string{
- "!", "-i", bridgeName,
- "-o", bridgeName,
- "-d", dstAddr,
- "-p", proto,
- "--dport", strconv.Itoa(dstPort),
- "-j", "ACCEPT",
- }
- if !iptable.Exists(filterChain.Table, filterChain.Name, filterRule...) {
- t.Fatal("filter rule does not exist")
- }
- masqRule := []string{
- "-d", dstAddr,
- "-s", dstAddr,
- "-p", proto,
- "--dport", strconv.Itoa(dstPort),
- "-j", "MASQUERADE",
- }
- if !iptable.Exists(natChain.Table, "POSTROUTING", masqRule...) {
- t.Fatal("MASQUERADE rule does not exist")
- }
- }
- func TestLink(t *testing.T) {
- iptable, _, filterChain := createNewChain(t)
- ip1 := net.ParseIP("192.168.1.1")
- ip2 := net.ParseIP("192.168.1.2")
- port := 1234
- proto := "tcp"
- err := filterChain.Link(Append, ip1, ip2, port, proto, bridgeName)
- if err != nil {
- t.Fatal(err)
- }
- rule1 := []string{
- "-i", bridgeName,
- "-o", bridgeName,
- "-p", proto,
- "-s", ip1.String(),
- "-d", ip2.String(),
- "--dport", strconv.Itoa(port),
- "-j", "ACCEPT",
- }
- if !iptable.Exists(filterChain.Table, filterChain.Name, rule1...) {
- t.Fatal("rule1 does not exist")
- }
- rule2 := []string{
- "-i", bridgeName,
- "-o", bridgeName,
- "-p", proto,
- "-s", ip2.String(),
- "-d", ip1.String(),
- "--sport", strconv.Itoa(port),
- "-j", "ACCEPT",
- }
- if !iptable.Exists(filterChain.Table, filterChain.Name, rule2...) {
- t.Fatal("rule2 does not exist")
- }
- }
- func TestPrerouting(t *testing.T) {
- iptable, natChain, _ := createNewChain(t)
- args := []string{"-i", "lo", "-d", "192.168.1.1"}
- err := natChain.Prerouting(Insert, args...)
- if err != nil {
- t.Fatal(err)
- }
- if !iptable.Exists(natChain.Table, "PREROUTING", args...) {
- t.Fatal("rule does not exist")
- }
- delRule := append([]string{"-D", "PREROUTING", "-t", string(Nat)}, args...)
- if _, err = iptable.Raw(delRule...); err != nil {
- t.Fatal(err)
- }
- }
- func TestOutput(t *testing.T) {
- iptable, natChain, _ := createNewChain(t)
- args := []string{"-o", "lo", "-d", "192.168.1.1"}
- err := natChain.Output(Insert, args...)
- if err != nil {
- t.Fatal(err)
- }
- if !iptable.Exists(natChain.Table, "OUTPUT", args...) {
- t.Fatal("rule does not exist")
- }
- delRule := append([]string{
- "-D", "OUTPUT", "-t",
- string(natChain.Table),
- }, args...)
- if _, err = iptable.Raw(delRule...); err != nil {
- t.Fatal(err)
- }
- }
- func TestConcurrencyWithWait(t *testing.T) {
- RunConcurrencyTest(t, true)
- }
- func TestConcurrencyNoWait(t *testing.T) {
- RunConcurrencyTest(t, false)
- }
- // Runs 10 concurrent rule additions. This will fail if iptables
- // is actually invoked simultaneously without --wait.
- // Note that if iptables does not support the xtable lock on this
- // system, then allowXlock has no effect -- it will always be off.
- func RunConcurrencyTest(t *testing.T, allowXlock bool) {
- _, natChain, _ := createNewChain(t)
- if !allowXlock && supportsXlock {
- supportsXlock = false
- defer func() { supportsXlock = true }()
- }
- ip := net.ParseIP("192.168.1.1")
- port := 1234
- dstAddr := "172.17.0.1"
- dstPort := 4321
- proto := "tcp"
- group := new(errgroup.Group)
- for i := 0; i < 10; i++ {
- group.Go(func() error {
- return natChain.Forward(Append, ip, port, proto, dstAddr, dstPort, "lo")
- })
- }
- if err := group.Wait(); err != nil {
- t.Fatal(err)
- }
- }
- func TestCleanup(t *testing.T) {
- iptable, _, filterChain := createNewChain(t)
- var rules []byte
- // Cleanup filter/FORWARD first otherwise output of iptables-save is dirty
- link := []string{
- "-t", string(filterChain.Table),
- string(Delete), "FORWARD",
- "-o", bridgeName,
- "-j", filterChain.Name,
- }
- if _, err := iptable.Raw(link...); err != nil {
- t.Fatal(err)
- }
- filterChain.Remove()
- err := iptable.RemoveExistingChain(chainName, Nat)
- if err != nil {
- t.Fatal(err)
- }
- rules, err = exec.Command("iptables-save").Output()
- if err != nil {
- t.Fatal(err)
- }
- if strings.Contains(string(rules), chainName) {
- t.Fatalf("Removing chain failed. %s found in iptables-save", chainName)
- }
- }
- func TestExistsRaw(t *testing.T) {
- const testChain1 = "ABCD"
- const testChain2 = "EFGH"
- iptable := GetIptable(IPv4)
- _, err := iptable.NewChain(testChain1, Filter, false)
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- iptable.RemoveExistingChain(testChain1, Filter)
- }()
- _, err = iptable.NewChain(testChain2, Filter, false)
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- iptable.RemoveExistingChain(testChain2, Filter)
- }()
- // Test detection over full and truncated rule string
- input := []struct{ rule []string }{
- {[]string{"-s", "172.8.9.9/32", "-j", "ACCEPT"}},
- {[]string{"-d", "172.8.9.0/24", "-j", "DROP"}},
- {[]string{"-s", "172.0.3.0/24", "-d", "172.17.0.0/24", "-p", "tcp", "-m", "tcp", "--dport", "80", "-j", testChain2}},
- {[]string{"-j", "RETURN"}},
- }
- for i, r := range input {
- ruleAdd := append([]string{"-t", string(Filter), "-A", testChain1}, r.rule...)
- err = iptable.RawCombinedOutput(ruleAdd...)
- if err != nil {
- t.Fatalf("i=%d, err: %v", i, err)
- }
- if !iptable.exists(true, Filter, testChain1, r.rule...) {
- t.Fatalf("Failed to detect rule. i=%d", i)
- }
- // Truncate the rule
- trg := r.rule[len(r.rule)-1]
- trg = trg[:len(trg)-2]
- r.rule[len(r.rule)-1] = trg
- if iptable.exists(true, Filter, testChain1, r.rule...) {
- t.Fatalf("Invalid detection. i=%d", i)
- }
- }
- }
|