Przeglądaj źródła

State refactoring and add waiting functions

Docker-DCO-1.1-Signed-off-by: Alexandr Morozov <lk4d4math@gmail.com> (github: LK4D4)
Alexandr Morozov 11 lat temu
rodzic
commit
47065b9045
2 zmienionych plików z 194 dodań i 23 usunięć
  1. 92 23
      daemon/state.go
  2. 102 0
      daemon/state_test.go

+ 92 - 23
daemon/state.go

@@ -16,6 +16,13 @@ type State struct {
 	ExitCode   int
 	StartedAt  time.Time
 	FinishedAt time.Time
+	waitChan   chan struct{}
+}
+
+func NewState() *State {
+	return &State{
+		waitChan: make(chan struct{}),
+	}
 }
 
 // String returns a human-readable description of the state
@@ -35,56 +42,118 @@ func (s *State) String() string {
 	return fmt.Sprintf("Exited (%d) %s ago", s.ExitCode, units.HumanDuration(time.Now().UTC().Sub(s.FinishedAt)))
 }
 
+func wait(waitChan <-chan struct{}, timeout time.Duration) error {
+	if timeout < 0 {
+		<-waitChan
+		return nil
+	}
+	select {
+	case <-time.After(timeout):
+		return fmt.Errorf("Timed out: %v", timeout)
+	case <-waitChan:
+		return nil
+	}
+}
+
+// WaitRunning waits until state is running. If state already running it returns
+// immediatly. If you want wait forever you must supply negative timeout.
+// Returns pid, that was passed to SetRunning
+func (s *State) WaitRunning(timeout time.Duration) (int, error) {
+	s.RLock()
+	if s.IsRunning() {
+		pid := s.Pid
+		s.RUnlock()
+		return pid, nil
+	}
+	waitChan := s.waitChan
+	s.RUnlock()
+	if err := wait(waitChan, timeout); err != nil {
+		return -1, err
+	}
+	return s.GetPid(), nil
+}
+
+// WaitStop waits until state is stopped. If state already stopped it returns
+// immediatly. If you want wait forever you must supply negative timeout.
+// Returns exit code, that was passed to SetRunning
+func (s *State) WaitStop(timeout time.Duration) (int, error) {
+	s.RLock()
+	if !s.Running {
+		exitCode := s.ExitCode
+		s.RUnlock()
+		return exitCode, nil
+	}
+	waitChan := s.waitChan
+	s.RUnlock()
+	if err := wait(waitChan, timeout); err != nil {
+		return -1, err
+	}
+	return s.GetExitCode(), nil
+}
+
 func (s *State) IsRunning() bool {
 	s.RLock()
-	defer s.RUnlock()
+	res := s.Running
+	s.RUnlock()
+	return res
+}
 
-	return s.Running
+func (s *State) GetPid() int {
+	s.RLock()
+	res := s.Pid
+	s.RUnlock()
+	return res
 }
 
 func (s *State) GetExitCode() int {
 	s.RLock()
-	defer s.RUnlock()
-
-	return s.ExitCode
+	res := s.ExitCode
+	s.RUnlock()
+	return res
 }
 
 func (s *State) SetRunning(pid int) {
 	s.Lock()
-	defer s.Unlock()
-
-	s.Running = true
-	s.Paused = false
-	s.ExitCode = 0
-	s.Pid = pid
-	s.StartedAt = time.Now().UTC()
+	if !s.Running {
+		s.Running = true
+		s.Paused = false
+		s.ExitCode = 0
+		s.Pid = pid
+		s.StartedAt = time.Now().UTC()
+		close(s.waitChan) // fire waiters for start
+		s.waitChan = make(chan struct{})
+	}
+	s.Unlock()
 }
 
 func (s *State) SetStopped(exitCode int) {
 	s.Lock()
-	defer s.Unlock()
-
-	s.Running = false
-	s.Pid = 0
-	s.FinishedAt = time.Now().UTC()
-	s.ExitCode = exitCode
+	if s.Running {
+		s.Running = false
+		s.Pid = 0
+		s.FinishedAt = time.Now().UTC()
+		s.ExitCode = exitCode
+		close(s.waitChan) // fire waiters for stop
+		s.waitChan = make(chan struct{})
+	}
+	s.Unlock()
 }
 
 func (s *State) SetPaused() {
 	s.Lock()
-	defer s.Unlock()
 	s.Paused = true
+	s.Unlock()
 }
 
 func (s *State) SetUnpaused() {
 	s.Lock()
-	defer s.Unlock()
 	s.Paused = false
+	s.Unlock()
 }
 
 func (s *State) IsPaused() bool {
 	s.RLock()
-	defer s.RUnlock()
-
-	return s.Paused
+	res := s.Paused
+	s.RUnlock()
+	return res
 }

+ 102 - 0
daemon/state_test.go

@@ -0,0 +1,102 @@
+package daemon
+
+import (
+	"sync/atomic"
+	"testing"
+	"time"
+)
+
+func TestStateRunStop(t *testing.T) {
+	s := NewState()
+	for i := 1; i < 3; i++ { // full lifecycle two times
+		started := make(chan struct{})
+		var pid int64
+		go func() {
+			runPid, _ := s.WaitRunning(-1 * time.Second)
+			atomic.StoreInt64(&pid, int64(runPid))
+			close(started)
+		}()
+		s.SetRunning(i + 100)
+		if !s.IsRunning() {
+			t.Fatal("State not running")
+		}
+		if s.Pid != i+100 {
+			t.Fatalf("Pid %v, expected %v", s.Pid, i+100)
+		}
+		if s.ExitCode != 0 {
+			t.Fatalf("ExitCode %v, expected 0", s.ExitCode)
+		}
+		select {
+		case <-time.After(100 * time.Millisecond):
+			t.Fatal("Start callback doesn't fire in 100 milliseconds")
+		case <-started:
+			t.Log("Start callback fired")
+		}
+		runPid := int(atomic.LoadInt64(&pid))
+		if runPid != i+100 {
+			t.Fatalf("Pid %v, expected %v", runPid, i+100)
+		}
+		if pid, err := s.WaitRunning(-1 * time.Second); err != nil || pid != i+100 {
+			t.Fatal("WaitRunning returned pid: %v, err: %v, expected pid: %v, err: %v", pid, err, i+100, nil)
+		}
+
+		stopped := make(chan struct{})
+		var exit int64
+		go func() {
+			exitCode, _ := s.WaitStop(-1 * time.Second)
+			atomic.StoreInt64(&exit, int64(exitCode))
+			close(stopped)
+		}()
+		s.SetStopped(i)
+		if s.IsRunning() {
+			t.Fatal("State is running")
+		}
+		if s.ExitCode != i {
+			t.Fatalf("ExitCode %v, expected %v", s.ExitCode, i)
+		}
+		if s.Pid != 0 {
+			t.Fatalf("Pid %v, expected 0", s.Pid)
+		}
+		select {
+		case <-time.After(100 * time.Millisecond):
+			t.Fatal("Stop callback doesn't fire in 100 milliseconds")
+		case <-stopped:
+			t.Log("Stop callback fired")
+		}
+		exitCode := int(atomic.LoadInt64(&exit))
+		if exitCode != i {
+			t.Fatalf("ExitCode %v, expected %v", exitCode, i)
+		}
+		if exitCode, err := s.WaitStop(-1 * time.Second); err != nil || exitCode != i {
+			t.Fatal("WaitStop returned exitCode: %v, err: %v, expected exitCode: %v, err: %v", exitCode, err, i, nil)
+		}
+	}
+}
+
+func TestStateTimeoutWait(t *testing.T) {
+	s := NewState()
+	started := make(chan struct{})
+	go func() {
+		s.WaitRunning(100 * time.Millisecond)
+		close(started)
+	}()
+	select {
+	case <-time.After(200 * time.Millisecond):
+		t.Fatal("Start callback doesn't fire in 100 milliseconds")
+	case <-started:
+		t.Log("Start callback fired")
+	}
+	s.SetRunning(42)
+	stopped := make(chan struct{})
+	go func() {
+		s.WaitRunning(100 * time.Millisecond)
+		close(stopped)
+	}()
+	select {
+	case <-time.After(200 * time.Millisecond):
+		t.Fatal("Start callback doesn't fire in 100 milliseconds")
+	case <-stopped:
+		t.Log("Start callback fired")
+	}
+
+}