join_test.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package safepath
  2. import (
  3. "context"
  4. "os"
  5. "path/filepath"
  6. "runtime"
  7. "strings"
  8. "testing"
  9. "gotest.tools/v3/assert"
  10. is "gotest.tools/v3/assert/cmp"
  11. )
  12. func TestJoinEscapingSymlink(t *testing.T) {
  13. type testCase struct {
  14. name string
  15. target string
  16. }
  17. var cases []testCase
  18. if runtime.GOOS == "windows" {
  19. cases = []testCase{
  20. {name: "root", target: `C:\`},
  21. {name: "absolute file", target: `C:\Windows\System32\cmd.exe`},
  22. }
  23. } else {
  24. cases = []testCase{
  25. {name: "root", target: "/"},
  26. {name: "absolute file", target: "/etc/passwd"},
  27. }
  28. }
  29. cases = append(cases, testCase{name: "relative", target: "../../"})
  30. for _, tc := range cases {
  31. t.Run(tc.name, func(t *testing.T) {
  32. tempDir := t.TempDir()
  33. dir, err := filepath.EvalSymlinks(tempDir)
  34. assert.NilError(t, err, "filepath.EvalSymlinks failed for temporary directory %s", tempDir)
  35. err = os.Symlink(tc.target, filepath.Join(dir, "link"))
  36. assert.NilError(t, err, "failed to create symlink to %s", tc.target)
  37. safe, err := Join(context.Background(), dir, "link")
  38. if err == nil {
  39. safe.Close(context.Background())
  40. }
  41. assert.ErrorType(t, err, &ErrEscapesBase{})
  42. })
  43. }
  44. }
  45. func TestJoinGoodSymlink(t *testing.T) {
  46. tempDir := t.TempDir()
  47. dir, err := filepath.EvalSymlinks(tempDir)
  48. assert.NilError(t, err, "filepath.EvalSymlinks failed for temporary directory %s", tempDir)
  49. assert.Assert(t, os.WriteFile(filepath.Join(dir, "foo"), []byte("bar"), 0o744), "failed to create file 'foo'")
  50. assert.Assert(t, os.Mkdir(filepath.Join(dir, "subdir"), 0o744), "failed to create directory 'subdir'")
  51. assert.Assert(t, os.WriteFile(filepath.Join(dir, "subdir/hello.txt"), []byte("world"), 0o744), "failed to create file 'subdir/hello.txt'")
  52. assert.Assert(t, os.Symlink(filepath.Join(dir, "subdir"), filepath.Join(dir, "subdir_link_absolute")), "failed to create absolute symlink to directory 'subdir'")
  53. assert.Assert(t, os.Symlink("subdir", filepath.Join(dir, "subdir_link_relative")), "failed to create relative symlink to directory 'subdir'")
  54. assert.Assert(t, os.Symlink(filepath.Join(dir, "foo"), filepath.Join(dir, "foo_link_absolute")), "failed to create absolute symlink to file 'foo'")
  55. assert.Assert(t, os.Symlink("foo", filepath.Join(dir, "foo_link_relative")), "failed to create relative symlink to file 'foo'")
  56. for _, target := range []string{
  57. "foo", "subdir",
  58. "subdir_link_absolute", "foo_link_absolute",
  59. "subdir_link_relative", "foo_link_relative",
  60. } {
  61. t.Run(target, func(t *testing.T) {
  62. safe, err := Join(context.Background(), dir, target)
  63. assert.NilError(t, err)
  64. defer safe.Close(context.Background())
  65. if strings.HasPrefix(target, "subdir") {
  66. data, err := os.ReadFile(filepath.Join(safe.Path(), "hello.txt"))
  67. assert.NilError(t, err)
  68. assert.Assert(t, is.Equal(string(data), "world"))
  69. }
  70. })
  71. }
  72. }
  73. func TestJoinWithSymlinkReplace(t *testing.T) {
  74. tempDir := t.TempDir()
  75. dir, err := filepath.EvalSymlinks(tempDir)
  76. assert.NilError(t, err, "filepath.EvalSymlinks failed for temporary directory %s", tempDir)
  77. link := filepath.Join(dir, "link")
  78. target := filepath.Join(dir, "foo")
  79. err = os.WriteFile(target, []byte("bar"), 0o744)
  80. assert.NilError(t, err, "failed to create test file")
  81. err = os.Symlink(target, link)
  82. assert.Check(t, err, "failed to create symlink to foo")
  83. safe, err := Join(context.Background(), dir, "link")
  84. assert.NilError(t, err)
  85. defer safe.Close(context.Background())
  86. // Delete the link target.
  87. err = os.Remove(target)
  88. if runtime.GOOS == "windows" {
  89. // On Windows it shouldn't be possible.
  90. assert.Assert(t, is.ErrorType(err, &os.PathError{}), "link shouldn't be deletable before cleanup")
  91. } else {
  92. // On Linux we can delete it just fine.
  93. assert.NilError(t, err, "failed to remove symlink")
  94. // Replace target with a symlink to /etc/paswd
  95. err = os.Symlink("/etc/passwd", target)
  96. assert.NilError(t, err, "failed to create symlink")
  97. }
  98. // The returned safe path should still point to the old file.
  99. data, err := os.ReadFile(safe.Path())
  100. assert.NilError(t, err, "failed to read file")
  101. assert.Check(t, is.Equal(string(data), "bar"))
  102. }
  103. func TestJoinCloseInvalidates(t *testing.T) {
  104. tempDir := t.TempDir()
  105. dir, err := filepath.EvalSymlinks(tempDir)
  106. assert.NilError(t, err)
  107. foo := filepath.Join(dir, "foo")
  108. err = os.WriteFile(foo, []byte("bar"), 0o744)
  109. assert.NilError(t, err, "failed to create test file")
  110. safe, err := Join(context.Background(), dir, "foo")
  111. assert.NilError(t, err)
  112. assert.Check(t, safe.IsValid())
  113. assert.NilError(t, safe.Close(context.Background()))
  114. assert.Check(t, !safe.IsValid())
  115. }