package etchosts import ( "bytes" "fmt" "os" "path/filepath" "testing" "golang.org/x/sync/errgroup" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" ) func TestBuildDefault(t *testing.T) { file, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } defer os.Remove(file.Name()) // check that /etc/hosts has consistent ordering for i := 0; i <= 5; i++ { err = Build(file.Name(), nil) if err != nil { t.Fatal(err) } content, err := os.ReadFile(file.Name()) if err != nil { t.Fatal(err) } expected := "127.0.0.1\tlocalhost\n::1\tlocalhost ip6-localhost ip6-loopback\nfe00::0\tip6-localnet\nff00::0\tip6-mcastprefix\nff02::1\tip6-allnodes\nff02::2\tip6-allrouters\n" if expected != string(content) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } } } func TestBuildNoIPv6(t *testing.T) { d := t.TempDir() filename := filepath.Join(d, "hosts") err := BuildNoIPv6(filename, []Record{ { Hosts: "another.example", IP: "fdbb:c59c:d015::3", }, { Hosts: "another.example", IP: "10.11.12.13", }, }) assert.NilError(t, err) content, err := os.ReadFile(filename) assert.NilError(t, err) assert.Check(t, is.DeepEqual(string(content), "127.0.0.1\tlocalhost\n10.11.12.13\tanother.example\n")) } func TestUpdate(t *testing.T) { file, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } defer os.Remove(file.Name()) if err := Build(file.Name(), []Record{ { "testhostname.testdomainname testhostname", "10.11.12.13", }, }); err != nil { t.Fatal(err) } content, err := os.ReadFile(file.Name()) if err != nil { t.Fatal(err) } if expected := "10.11.12.13\ttesthostname.testdomainname testhostname\n"; !bytes.Contains(content, []byte(expected)) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } if err := Update(file.Name(), "1.1.1.1", "testhostname"); err != nil { t.Fatal(err) } content, err = os.ReadFile(file.Name()) if err != nil { t.Fatal(err) } if expected := "1.1.1.1\ttesthostname.testdomainname testhostname\n"; !bytes.Contains(content, []byte(expected)) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } } // This regression test ensures that when a host is given a new IP // via the Update function that other hosts which start with the // same name as the targeted host are not erroneously updated as well. // In the test example, if updating a host called "prefix", unrelated // hosts named "prefixAndMore" or "prefix2" or anything else starting // with "prefix" should not be changed. For more information see // GitHub issue #603. func TestUpdateIgnoresPrefixedHostname(t *testing.T) { file, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } defer os.Remove(file.Name()) if err := Build(file.Name(), []Record{ { Hosts: "prefix", IP: "2.2.2.2", }, { Hosts: "prefixAndMore", IP: "3.3.3.3", }, { Hosts: "unaffectedHost", IP: "4.4.4.4", }, }); err != nil { t.Fatal(err) } content, err := os.ReadFile(file.Name()) if err != nil { t.Fatal(err) } if expected := "2.2.2.2\tprefix\n3.3.3.3\tprefixAndMore\n4.4.4.4\tunaffectedHost\n"; !bytes.Contains(content, []byte(expected)) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } if err := Update(file.Name(), "5.5.5.5", "prefix"); err != nil { t.Fatal(err) } content, err = os.ReadFile(file.Name()) if err != nil { t.Fatal(err) } if expected := "5.5.5.5\tprefix\n3.3.3.3\tprefixAndMore\n4.4.4.4\tunaffectedHost\n"; !bytes.Contains(content, []byte(expected)) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } } // This regression test covers the host prefix issue for the // Delete function. In the test example, if deleting a host called // "prefix", an unrelated host called "prefixAndMore" should not // be deleted. For more information see GitHub issue #603. func TestDeleteIgnoresPrefixedHostname(t *testing.T) { file, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } defer os.Remove(file.Name()) err = Build(file.Name(), nil) if err != nil { t.Fatal(err) } if err := Add(file.Name(), []Record{ { Hosts: "prefix", IP: "1.1.1.1", }, { Hosts: "prefixAndMore", IP: "2.2.2.2", }, }); err != nil { t.Fatal(err) } if err := Delete(file.Name(), []Record{ { Hosts: "prefix", IP: "1.1.1.1", }, }); err != nil { t.Fatal(err) } content, err := os.ReadFile(file.Name()) if err != nil { t.Fatal(err) } if expected := "2.2.2.2\tprefixAndMore\n"; !bytes.Contains(content, []byte(expected)) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } if expected := "1.1.1.1\tprefix\n"; bytes.Contains(content, []byte(expected)) { t.Fatalf("Did not expect to find '%s' got '%s'", expected, content) } } func TestAddEmpty(t *testing.T) { file, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } defer os.Remove(file.Name()) err = Build(file.Name(), nil) if err != nil { t.Fatal(err) } if err := Add(file.Name(), []Record{}); err != nil { t.Fatal(err) } } func TestAdd(t *testing.T) { file, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } defer os.Remove(file.Name()) err = Build(file.Name(), nil) if err != nil { t.Fatal(err) } if err := Add(file.Name(), []Record{ { Hosts: "testhostname", IP: "2.2.2.2", }, }); err != nil { t.Fatal(err) } content, err := os.ReadFile(file.Name()) if err != nil { t.Fatal(err) } if expected := "2.2.2.2\ttesthostname\n"; !bytes.Contains(content, []byte(expected)) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } } func TestDeleteEmpty(t *testing.T) { file, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } defer os.Remove(file.Name()) err = Build(file.Name(), nil) if err != nil { t.Fatal(err) } if err := Delete(file.Name(), []Record{}); err != nil { t.Fatal(err) } } func TestDeleteNewline(t *testing.T) { file, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } defer os.Remove(file.Name()) b := []byte("\n") if _, err := file.Write(b); err != nil { t.Fatal(err) } rec := []Record{ { Hosts: "prefix", IP: "2.2.2.2", }, } if err := Delete(file.Name(), rec); err != nil { t.Fatal(err) } } func TestDelete(t *testing.T) { file, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } defer os.Remove(file.Name()) err = Build(file.Name(), nil) if err != nil { t.Fatal(err) } if err := Add(file.Name(), []Record{ { Hosts: "testhostname1", IP: "1.1.1.1", }, { Hosts: "testhostname2", IP: "2.2.2.2", }, { Hosts: "testhostname3", IP: "3.3.3.3", }, }); err != nil { t.Fatal(err) } if err := Delete(file.Name(), []Record{ { Hosts: "testhostname1", IP: "1.1.1.1", }, { Hosts: "testhostname3", IP: "3.3.3.3", }, }); err != nil { t.Fatal(err) } content, err := os.ReadFile(file.Name()) if err != nil { t.Fatal(err) } if expected := "2.2.2.2\ttesthostname2\n"; !bytes.Contains(content, []byte(expected)) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } if expected := "1.1.1.1\ttesthostname1\n"; bytes.Contains(content, []byte(expected)) { t.Fatalf("Did not expect to find '%s' got '%s'", expected, content) } } func TestConcurrentWrites(t *testing.T) { file, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) } defer os.Remove(file.Name()) err = Build(file.Name(), nil) if err != nil { t.Fatal(err) } if err := Add(file.Name(), []Record{ { Hosts: "inithostname", IP: "172.17.0.1", }, }); err != nil { t.Fatal(err) } group := new(errgroup.Group) for i := 0; i < 10; i++ { i := i group.Go(func() error { rec := []Record{ { IP: fmt.Sprintf("%d.%d.%d.%d", i, i, i, i), Hosts: fmt.Sprintf("testhostname%d", i), }, } for j := 0; j < 25; j++ { if err := Add(file.Name(), rec); err != nil { return err } if err := Delete(file.Name(), rec); err != nil { return err } } return nil }) } if err := group.Wait(); err != nil { t.Fatal(err) } content, err := os.ReadFile(file.Name()) if err != nil { t.Fatal(err) } if expected := "172.17.0.1\tinithostname\n"; !bytes.Contains(content, []byte(expected)) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } } func benchDelete(b *testing.B) { b.StopTimer() file, err := os.CreateTemp("", "") if err != nil { b.Fatal(err) } defer func() { b.StopTimer() file.Close() os.Remove(file.Name()) b.StartTimer() }() err = Build(file.Name(), nil) if err != nil { b.Fatal(err) } var records []Record var toDelete []Record for i := 0; i < 255; i++ { record := Record{ Hosts: fmt.Sprintf("testhostname%d", i), IP: fmt.Sprintf("%d.%d.%d.%d", i, i, i, i), } records = append(records, record) if i%2 == 0 { toDelete = append(records, record) } } if err := Add(file.Name(), records); err != nil { b.Fatal(err) } b.StartTimer() if err := Delete(file.Name(), toDelete); err != nil { b.Fatal(err) } } func BenchmarkDelete(b *testing.B) { for i := 0; i < b.N; i++ { benchDelete(b) } }