Browse code

State refactoring and add waiting functions

Docker-DCO-1.1-Signed-off-by: Alexandr Morozov <lk4d4math@gmail.com> (github: LK4D4)

Alexandr Morozov authored on 2014/06/06 20:28:12
Showing 2 changed files
... ...
@@ -16,6 +16,13 @@ type State struct {
16 16
 	ExitCode   int
17 17
 	StartedAt  time.Time
18 18
 	FinishedAt time.Time
19
+	waitChan   chan struct{}
20
+}
21
+
22
+func NewState() *State {
23
+	return &State{
24
+		waitChan: make(chan struct{}),
25
+	}
19 26
 }
20 27
 
21 28
 // String returns a human-readable description of the state
... ...
@@ -35,56 +42,118 @@ func (s *State) String() string {
35 35
 	return fmt.Sprintf("Exited (%d) %s ago", s.ExitCode, units.HumanDuration(time.Now().UTC().Sub(s.FinishedAt)))
36 36
 }
37 37
 
38
+func wait(waitChan <-chan struct{}, timeout time.Duration) error {
39
+	if timeout < 0 {
40
+		<-waitChan
41
+		return nil
42
+	}
43
+	select {
44
+	case <-time.After(timeout):
45
+		return fmt.Errorf("Timed out: %v", timeout)
46
+	case <-waitChan:
47
+		return nil
48
+	}
49
+}
50
+
51
+// WaitRunning waits until state is running. If state already running it returns
52
+// immediatly. If you want wait forever you must supply negative timeout.
53
+// Returns pid, that was passed to SetRunning
54
+func (s *State) WaitRunning(timeout time.Duration) (int, error) {
55
+	s.RLock()
56
+	if s.IsRunning() {
57
+		pid := s.Pid
58
+		s.RUnlock()
59
+		return pid, nil
60
+	}
61
+	waitChan := s.waitChan
62
+	s.RUnlock()
63
+	if err := wait(waitChan, timeout); err != nil {
64
+		return -1, err
65
+	}
66
+	return s.GetPid(), nil
67
+}
68
+
69
+// WaitStop waits until state is stopped. If state already stopped it returns
70
+// immediatly. If you want wait forever you must supply negative timeout.
71
+// Returns exit code, that was passed to SetRunning
72
+func (s *State) WaitStop(timeout time.Duration) (int, error) {
73
+	s.RLock()
74
+	if !s.Running {
75
+		exitCode := s.ExitCode
76
+		s.RUnlock()
77
+		return exitCode, nil
78
+	}
79
+	waitChan := s.waitChan
80
+	s.RUnlock()
81
+	if err := wait(waitChan, timeout); err != nil {
82
+		return -1, err
83
+	}
84
+	return s.GetExitCode(), nil
85
+}
86
+
38 87
 func (s *State) IsRunning() bool {
39 88
 	s.RLock()
40
-	defer s.RUnlock()
89
+	res := s.Running
90
+	s.RUnlock()
91
+	return res
92
+}
41 93
 
42
-	return s.Running
94
+func (s *State) GetPid() int {
95
+	s.RLock()
96
+	res := s.Pid
97
+	s.RUnlock()
98
+	return res
43 99
 }
44 100
 
45 101
 func (s *State) GetExitCode() int {
46 102
 	s.RLock()
47
-	defer s.RUnlock()
48
-
49
-	return s.ExitCode
103
+	res := s.ExitCode
104
+	s.RUnlock()
105
+	return res
50 106
 }
51 107
 
52 108
 func (s *State) SetRunning(pid int) {
53 109
 	s.Lock()
54
-	defer s.Unlock()
55
-
56
-	s.Running = true
57
-	s.Paused = false
58
-	s.ExitCode = 0
59
-	s.Pid = pid
60
-	s.StartedAt = time.Now().UTC()
110
+	if !s.Running {
111
+		s.Running = true
112
+		s.Paused = false
113
+		s.ExitCode = 0
114
+		s.Pid = pid
115
+		s.StartedAt = time.Now().UTC()
116
+		close(s.waitChan) // fire waiters for start
117
+		s.waitChan = make(chan struct{})
118
+	}
119
+	s.Unlock()
61 120
 }
62 121
 
63 122
 func (s *State) SetStopped(exitCode int) {
64 123
 	s.Lock()
65
-	defer s.Unlock()
66
-
67
-	s.Running = false
68
-	s.Pid = 0
69
-	s.FinishedAt = time.Now().UTC()
70
-	s.ExitCode = exitCode
124
+	if s.Running {
125
+		s.Running = false
126
+		s.Pid = 0
127
+		s.FinishedAt = time.Now().UTC()
128
+		s.ExitCode = exitCode
129
+		close(s.waitChan) // fire waiters for stop
130
+		s.waitChan = make(chan struct{})
131
+	}
132
+	s.Unlock()
71 133
 }
72 134
 
73 135
 func (s *State) SetPaused() {
74 136
 	s.Lock()
75
-	defer s.Unlock()
76 137
 	s.Paused = true
138
+	s.Unlock()
77 139
 }
78 140
 
79 141
 func (s *State) SetUnpaused() {
80 142
 	s.Lock()
81
-	defer s.Unlock()
82 143
 	s.Paused = false
144
+	s.Unlock()
83 145
 }
84 146
 
85 147
 func (s *State) IsPaused() bool {
86 148
 	s.RLock()
87
-	defer s.RUnlock()
88
-
89
-	return s.Paused
149
+	res := s.Paused
150
+	s.RUnlock()
151
+	return res
90 152
 }
91 153
new file mode 100644
... ...
@@ -0,0 +1,102 @@
0
+package daemon
1
+
2
+import (
3
+	"sync/atomic"
4
+	"testing"
5
+	"time"
6
+)
7
+
8
+func TestStateRunStop(t *testing.T) {
9
+	s := NewState()
10
+	for i := 1; i < 3; i++ { // full lifecycle two times
11
+		started := make(chan struct{})
12
+		var pid int64
13
+		go func() {
14
+			runPid, _ := s.WaitRunning(-1 * time.Second)
15
+			atomic.StoreInt64(&pid, int64(runPid))
16
+			close(started)
17
+		}()
18
+		s.SetRunning(i + 100)
19
+		if !s.IsRunning() {
20
+			t.Fatal("State not running")
21
+		}
22
+		if s.Pid != i+100 {
23
+			t.Fatalf("Pid %v, expected %v", s.Pid, i+100)
24
+		}
25
+		if s.ExitCode != 0 {
26
+			t.Fatalf("ExitCode %v, expected 0", s.ExitCode)
27
+		}
28
+		select {
29
+		case <-time.After(100 * time.Millisecond):
30
+			t.Fatal("Start callback doesn't fire in 100 milliseconds")
31
+		case <-started:
32
+			t.Log("Start callback fired")
33
+		}
34
+		runPid := int(atomic.LoadInt64(&pid))
35
+		if runPid != i+100 {
36
+			t.Fatalf("Pid %v, expected %v", runPid, i+100)
37
+		}
38
+		if pid, err := s.WaitRunning(-1 * time.Second); err != nil || pid != i+100 {
39
+			t.Fatal("WaitRunning returned pid: %v, err: %v, expected pid: %v, err: %v", pid, err, i+100, nil)
40
+		}
41
+
42
+		stopped := make(chan struct{})
43
+		var exit int64
44
+		go func() {
45
+			exitCode, _ := s.WaitStop(-1 * time.Second)
46
+			atomic.StoreInt64(&exit, int64(exitCode))
47
+			close(stopped)
48
+		}()
49
+		s.SetStopped(i)
50
+		if s.IsRunning() {
51
+			t.Fatal("State is running")
52
+		}
53
+		if s.ExitCode != i {
54
+			t.Fatalf("ExitCode %v, expected %v", s.ExitCode, i)
55
+		}
56
+		if s.Pid != 0 {
57
+			t.Fatalf("Pid %v, expected 0", s.Pid)
58
+		}
59
+		select {
60
+		case <-time.After(100 * time.Millisecond):
61
+			t.Fatal("Stop callback doesn't fire in 100 milliseconds")
62
+		case <-stopped:
63
+			t.Log("Stop callback fired")
64
+		}
65
+		exitCode := int(atomic.LoadInt64(&exit))
66
+		if exitCode != i {
67
+			t.Fatalf("ExitCode %v, expected %v", exitCode, i)
68
+		}
69
+		if exitCode, err := s.WaitStop(-1 * time.Second); err != nil || exitCode != i {
70
+			t.Fatal("WaitStop returned exitCode: %v, err: %v, expected exitCode: %v, err: %v", exitCode, err, i, nil)
71
+		}
72
+	}
73
+}
74
+
75
+func TestStateTimeoutWait(t *testing.T) {
76
+	s := NewState()
77
+	started := make(chan struct{})
78
+	go func() {
79
+		s.WaitRunning(100 * time.Millisecond)
80
+		close(started)
81
+	}()
82
+	select {
83
+	case <-time.After(200 * time.Millisecond):
84
+		t.Fatal("Start callback doesn't fire in 100 milliseconds")
85
+	case <-started:
86
+		t.Log("Start callback fired")
87
+	}
88
+	s.SetRunning(42)
89
+	stopped := make(chan struct{})
90
+	go func() {
91
+		s.WaitRunning(100 * time.Millisecond)
92
+		close(stopped)
93
+	}()
94
+	select {
95
+	case <-time.After(200 * time.Millisecond):
96
+		t.Fatal("Start callback doesn't fire in 100 milliseconds")
97
+	case <-stopped:
98
+		t.Log("Start callback fired")
99
+	}
100
+
101
+}