diff --git a/pkg/pidfile/pidfile.go b/pkg/pidfile/pidfile.go index 8017d9a784..cdf6f8f39e 100644 --- a/pkg/pidfile/pidfile.go +++ b/pkg/pidfile/pidfile.go @@ -12,19 +12,24 @@ import ( "github.com/docker/docker/pkg/process" ) -func checkPIDFileAlreadyExists(path string) error { +// Read reads the "PID file" at path, and returns the PID if it contains a +// valid PID of a running process, or 0 otherwise. It returns an error when +// failing to read the file, or if the file doesn't exist, but malformed content +// is ignored. Consumers should therefore check if the returned PID is a non-zero +// value before use. +func Read(path string) (pid int, err error) { pidByte, err := os.ReadFile(path) if err != nil { - if os.IsNotExist(err) { - return nil - } - return err + return 0, err } - pid, err := strconv.Atoi(string(bytes.TrimSpace(pidByte))) - if err == nil && process.Alive(pid) { - return fmt.Errorf("pid file found, ensure docker is not running or delete %s", path) + pid, err = strconv.Atoi(string(bytes.TrimSpace(pidByte))) + if err != nil { + return 0, nil } - return nil + if pid != 0 && process.Alive(pid) { + return pid, nil + } + return 0, nil } // Write writes a "PID file" at the specified path. It returns an error if the @@ -36,8 +41,12 @@ func Write(path string, pid int) error { // but 0 or negative PIDs are not acceptable. return fmt.Errorf("invalid PID (%d): only positive PIDs are allowed", pid) } - if err := checkPIDFileAlreadyExists(path); err != nil { + oldPID, err := Read(path) + if err != nil && !os.IsNotExist(err) { return err } + if oldPID != 0 { + return fmt.Errorf("process with PID %d is still running", oldPID) + } return os.WriteFile(path, []byte(strconv.Itoa(pid)), 0o644) } diff --git a/pkg/pidfile/pidfile_test.go b/pkg/pidfile/pidfile_test.go index 09996c6765..2179aa0d4b 100644 --- a/pkg/pidfile/pidfile_test.go +++ b/pkg/pidfile/pidfile_test.go @@ -1,8 +1,12 @@ package pidfile // import "github.com/docker/docker/pkg/pidfile" import ( + "errors" "os" + "os/exec" "path/filepath" + "runtime" + "strconv" "testing" ) @@ -21,6 +25,119 @@ func TestWrite(t *testing.T) { err = Write(path, os.Getpid()) if err == nil { - t.Fatal("Test file creation not blocked") + t.Error("Test file creation not blocked") + } + + pid, err := Read(path) + if err != nil { + t.Error(err) + } + if pid != os.Getpid() { + t.Errorf("expected pid %d, got %d", os.Getpid(), pid) } } + +func TestRead(t *testing.T) { + tmpDir := t.TempDir() + + t.Run("non-existing pidFile", func(t *testing.T) { + _, err := Read(filepath.Join(tmpDir, "nosuchfile")) + if !errors.Is(err, os.ErrNotExist) { + t.Errorf("expected an os.ErrNotExist, got: %+v", err) + } + }) + + // Verify that we ignore a malformed PID in the file. + t.Run("malformed pid", func(t *testing.T) { + // Not using Write here, to test Read in isolation. + pidFile := filepath.Join(tmpDir, "pidfile-malformed") + err := os.WriteFile(pidFile, []byte("something that's not an integer"), 0o644) + if err != nil { + t.Fatal(err) + } + pid, err := Read(pidFile) + if err != nil { + t.Error(err) + } + if pid != 0 { + t.Errorf("expected pid %d, got %d", 0, pid) + } + }) + + t.Run("zero pid", func(t *testing.T) { + // Not using Write here, to test Read in isolation. + pidFile := filepath.Join(tmpDir, "pidfile-zero") + err := os.WriteFile(pidFile, []byte(strconv.Itoa(0)), 0o644) + if err != nil { + t.Fatal(err) + } + pid, err := Read(pidFile) + if err != nil { + t.Error(err) + } + if pid != 0 { + t.Errorf("expected pid %d, got %d", 0, pid) + } + }) + + t.Run("negative pid", func(t *testing.T) { + // Not using Write here, to test Read in isolation. + pidFile := filepath.Join(tmpDir, "pidfile-negative") + err := os.WriteFile(pidFile, []byte(strconv.Itoa(-1)), 0o644) + if err != nil { + t.Fatal(err) + } + pid, err := Read(pidFile) + if err != nil { + t.Error(err) + } + if pid != 0 { + t.Errorf("expected pid %d, got %d", 0, pid) + } + }) + + t.Run("current process pid", func(t *testing.T) { + // Not using Write here, to test Read in isolation. + pidFile := filepath.Join(tmpDir, "pidfile") + err := os.WriteFile(pidFile, []byte(strconv.Itoa(os.Getpid())), 0o644) + if err != nil { + t.Fatal(err) + } + pid, err := Read(pidFile) + if err != nil { + t.Error(err) + } + if pid != os.Getpid() { + t.Errorf("expected pid %d, got %d", os.Getpid(), pid) + } + }) + + // Verify that we don't return a PID if the process exited. + t.Run("exited process", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("TODO: make this work on Windows") + } + + // Get a PID of an exited process. + cmd := exec.Command("echo", "hello world") + err := cmd.Run() + if err != nil { + t.Fatal(err) + } + exitedPID := cmd.ProcessState.Pid() + + // Not using Write here, to test Read in isolation. + pidFile := filepath.Join(tmpDir, "pidfile-exited") + err = os.WriteFile(pidFile, []byte(strconv.Itoa(exitedPID)), 0o644) + if err != nil { + t.Fatal(err) + } + pid, err := Read(pidFile) + if err != nil { + t.Error(err) + } + if pid != 0 { + t.Errorf("expected pid %d, got %d", 0, pid) + } + }) +}