Browse Source

Use a slice instead of a map of io.WriteClosers in broadcastwriter

Maps rely on the keys being comparable.
Using an interface type as the map key is dangerous,
because some interface types are not comparable.
I talked about this in my "Stupid Gopher Tricks" talk:
	https://talks.golang.org/2015/tricks.slide

In this case, if the user-provided writer is backed by a slice
(such as io.MultiWriter) then the code will panic at run time.

Signed-off-by: Andrew Gerrand <adg@golang.org>
Andrew Gerrand 10 years ago
parent
commit
31cbf76d0c
2 changed files with 37 additions and 17 deletions
  1. 19 17
      pkg/broadcastwriter/broadcastwriter.go
  2. 18 0
      pkg/broadcastwriter/broadcastwriter_test.go

+ 19 - 17
pkg/broadcastwriter/broadcastwriter.go

@@ -7,46 +7,48 @@ import (
 
 // BroadcastWriter accumulate multiple io.WriteCloser by stream.
 type BroadcastWriter struct {
-	sync.Mutex
-	writers map[io.WriteCloser]struct{}
+	mu      sync.Mutex
+	writers []io.WriteCloser
 }
 
 // AddWriter adds new io.WriteCloser.
 func (w *BroadcastWriter) AddWriter(writer io.WriteCloser) {
-	w.Lock()
-	w.writers[writer] = struct{}{}
-	w.Unlock()
+	w.mu.Lock()
+	w.writers = append(w.writers, writer)
+	w.mu.Unlock()
 }
 
 // Write writes bytes to all writers. Failed writers will be evicted during
 // this call.
 func (w *BroadcastWriter) Write(p []byte) (n int, err error) {
-	w.Lock()
-	for sw := range w.writers {
+	w.mu.Lock()
+	var evict []int
+	for i, sw := range w.writers {
 		if n, err := sw.Write(p); err != nil || n != len(p) {
 			// On error, evict the writer
-			delete(w.writers, sw)
+			evict = append(evict, i)
 		}
 	}
-	w.Unlock()
+	for n, i := range evict {
+		w.writers = append(w.writers[:i-n], w.writers[i-n+1:]...)
+	}
+	w.mu.Unlock()
 	return len(p), nil
 }
 
 // Clean closes and removes all writers. Last non-eol-terminated part of data
 // will be saved.
 func (w *BroadcastWriter) Clean() error {
-	w.Lock()
-	for w := range w.writers {
-		w.Close()
+	w.mu.Lock()
+	for _, sw := range w.writers {
+		sw.Close()
 	}
-	w.writers = make(map[io.WriteCloser]struct{})
-	w.Unlock()
+	w.writers = nil
+	w.mu.Unlock()
 	return nil
 }
 
 // New creates a new BroadcastWriter.
 func New() *BroadcastWriter {
-	return &BroadcastWriter{
-		writers: make(map[io.WriteCloser]struct{}),
-	}
+	return &BroadcastWriter{}
 }

+ 18 - 0
pkg/broadcastwriter/broadcastwriter_test.go

@@ -3,6 +3,7 @@ package broadcastwriter
 import (
 	"bytes"
 	"errors"
+	"strings"
 
 	"testing"
 )
@@ -82,6 +83,23 @@ func TestBroadcastWriter(t *testing.T) {
 		t.Errorf("Buffer contains %v", bufferC.String())
 	}
 
+	// Test4: Test eviction on multiple simultaneous failures
+	bufferB.failOnWrite = true
+	bufferC.failOnWrite = true
+	bufferD := &dummyWriter{}
+	writer.AddWriter(bufferD)
+	writer.Write([]byte("yo"))
+	writer.Write([]byte("ink"))
+	if strings.Contains(bufferB.String(), "yoink") {
+		t.Errorf("bufferB received write. contents: %q", bufferB)
+	}
+	if strings.Contains(bufferC.String(), "yoink") {
+		t.Errorf("bufferC received write. contents: %q", bufferC)
+	}
+	if g, w := bufferD.String(), "yoink"; g != w {
+		t.Errorf("bufferD = %q, want %q", g, w)
+	}
+
 	writer.Clean()
 }