Explorar o código

Merge pull request #17185 from cpuguy83/use_finer_locking_for_volume_store

Fix potential races in the volume store
Alexander Morozov %!s(int64=9) %!d(string=hai) anos
pai
achega
cc207aa136
Modificáronse 4 ficheiros con 324 adicións e 34 borrados
  1. 65 0
      pkg/locker/README.md
  2. 111 0
      pkg/locker/locker.go
  3. 90 0
      pkg/locker/locker_test.go
  4. 58 34
      volume/store/store.go

+ 65 - 0
pkg/locker/README.md

@@ -0,0 +1,65 @@
+Locker
+=====
+
+locker provides a mechanism for creating finer-grained locking to help
+free up more global locks to handle other tasks.
+
+The implementation looks close to a sync.Mutex, however the user must provide a
+reference to use to refer to the underlying lock when locking and unlocking,
+and unlock may generate an error.
+
+If a lock with a given name does not exist when `Lock` is called, one is
+created.
+Lock references are automatically cleaned up on `Unlock` if nothing else is
+waiting for the lock.
+
+
+## Usage
+
+```go
+package important
+
+import (
+	"sync"
+	"time"
+
+	"github.com/docker/docker/pkg/locker"
+)
+
+type important struct {
+	locks *locker.Locker
+	data  map[string]interface{}
+	mu    sync.Mutex
+}
+
+func (i *important) Get(name string) interface{} {
+	i.locks.Lock(name)
+	defer i.locks.Unlock(name)
+	return data[name]
+}
+
+func (i *important) Create(name string, data interface{}) {
+	i.locks.Lock(name)
+	defer i.locks.Unlock(name)
+
+	i.createImporatant(data)
+
+	s.mu.Lock()
+	i.data[name] = data
+	s.mu.Unlock()
+}
+
+func (i *important) createImportant(data interface{}) {
+	time.Sleep(10 * time.Second)
+}
+```
+
+For functions dealing with a given name, always lock at the beginning of the
+function (or before doing anything with the underlying state), this ensures any
+other function that is dealing with the same name will block.
+
+When needing to modify the underlying data, use the global lock to ensure nothing
+else is modfying it at the same time.
+Since name lock is already in place, no reads will occur while the modification
+is being performed.
+

+ 111 - 0
pkg/locker/locker.go

@@ -0,0 +1,111 @@
+/*
+Package locker provides a mechanism for creating finer-grained locking to help
+free up more global locks to handle other tasks.
+
+The implementation looks close to a sync.Mutex, however the user must provide a
+reference to use to refer to the underlying lock when locking and unlocking,
+and unlock may generate an error.
+
+If a lock with a given name does not exist when `Lock` is called, one is
+created.
+Lock references are automatically cleaned up on `Unlock` if nothing else is
+waiting for the lock.
+*/
+package locker
+
+import (
+	"errors"
+	"sync"
+	"sync/atomic"
+)
+
+// ErrNoSuchLock is returned when the requested lock does not exist
+var ErrNoSuchLock = errors.New("no such lock")
+
+// Locker provides a locking mechanism based on the passed in reference name
+type Locker struct {
+	mu    sync.Mutex
+	locks map[string]*lockCtr
+}
+
+// lockCtr is used by Locker to represent a lock with a given name.
+type lockCtr struct {
+	mu sync.Mutex
+	// waiters is the number of waiters waiting to acquire the lock
+	waiters uint32
+}
+
+// inc increments the number of waiters waiting for the lock
+func (l *lockCtr) inc() {
+	atomic.AddUint32(&l.waiters, 1)
+}
+
+// dec decrements the number of waiters wating on the lock
+func (l *lockCtr) dec() {
+	atomic.AddUint32(&l.waiters, ^uint32(l.waiters-1))
+}
+
+// count gets the current number of waiters
+func (l *lockCtr) count() uint32 {
+	return atomic.LoadUint32(&l.waiters)
+}
+
+// Lock locks the mutex
+func (l *lockCtr) Lock() {
+	l.mu.Lock()
+}
+
+// Unlock unlocks the mutex
+func (l *lockCtr) Unlock() {
+	l.mu.Unlock()
+}
+
+// New creates a new Locker
+func New() *Locker {
+	return &Locker{
+		locks: make(map[string]*lockCtr),
+	}
+}
+
+// Lock locks a mutex with the given name. If it doesn't exist, one is created
+func (l *Locker) Lock(name string) {
+	l.mu.Lock()
+	if l.locks == nil {
+		l.locks = make(map[string]*lockCtr)
+	}
+
+	nameLock, exists := l.locks[name]
+	if !exists {
+		nameLock = &lockCtr{}
+		l.locks[name] = nameLock
+	}
+
+	// increment the nameLock waiters while inside the main mutex
+	// this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently
+	nameLock.inc()
+	l.mu.Unlock()
+
+	// Lock the nameLock outside the main mutex so we don't block other operations
+	// once locked then we can decrement the number of waiters for this lock
+	nameLock.Lock()
+	nameLock.dec()
+}
+
+// Unlock unlocks the mutex with the given name
+// If the given lock is not being waited on by any other callers, it is deleted
+func (l *Locker) Unlock(name string) error {
+	l.mu.Lock()
+	nameLock, exists := l.locks[name]
+	if !exists {
+		l.mu.Unlock()
+		return ErrNoSuchLock
+	}
+
+	if nameLock.count() == 0 {
+		delete(l.locks, name)
+	}
+	nameLock.Unlock()
+
+	l.mu.Unlock()
+	return nil
+}

+ 90 - 0
pkg/locker/locker_test.go

@@ -0,0 +1,90 @@
+package locker
+
+import (
+	"runtime"
+	"testing"
+)
+
+func TestLockCounter(t *testing.T) {
+	l := &lockCtr{}
+	l.inc()
+
+	if l.waiters != 1 {
+		t.Fatal("counter inc failed")
+	}
+
+	l.dec()
+	if l.waiters != 0 {
+		t.Fatal("counter dec failed")
+	}
+}
+
+func TestLockerLock(t *testing.T) {
+	l := New()
+	l.Lock("test")
+	ctr := l.locks["test"]
+
+	if ctr.count() != 0 {
+		t.Fatalf("expected waiters to be 0, got :%d", ctr.waiters)
+	}
+
+	chDone := make(chan struct{})
+	go func() {
+		l.Lock("test")
+		close(chDone)
+	}()
+
+	runtime.Gosched()
+
+	select {
+	case <-chDone:
+		t.Fatal("lock should not have returned while it was still held")
+	default:
+	}
+
+	if ctr.count() != 1 {
+		t.Fatalf("expected waiters to be 1, got: %d", ctr.count())
+	}
+
+	if err := l.Unlock("test"); err != nil {
+		t.Fatal(err)
+	}
+	runtime.Gosched()
+
+	select {
+	case <-chDone:
+	default:
+		// one more time just to be sure
+		runtime.Gosched()
+		select {
+		case <-chDone:
+		default:
+			t.Fatalf("lock should have completed")
+		}
+	}
+
+	if ctr.count() != 0 {
+		t.Fatalf("expected waiters to be 0, got: %d", ctr.count())
+	}
+}
+
+func TestLockerUnlock(t *testing.T) {
+	l := New()
+
+	l.Lock("test")
+	l.Unlock("test")
+
+	chDone := make(chan struct{})
+	go func() {
+		l.Lock("test")
+		close(chDone)
+	}()
+
+	runtime.Gosched()
+
+	select {
+	case <-chDone:
+	default:
+		t.Fatalf("lock should not be blocked")
+	}
+}

+ 58 - 34
volume/store/store.go

@@ -5,6 +5,7 @@ import (
 	"sync"
 	"sync"
 
 
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
+	"github.com/docker/docker/pkg/locker"
 	"github.com/docker/docker/volume"
 	"github.com/docker/docker/volume"
 	"github.com/docker/docker/volume/drivers"
 	"github.com/docker/docker/volume/drivers"
 )
 )
@@ -22,14 +23,35 @@ var (
 // reference counting of volumes in the system.
 // reference counting of volumes in the system.
 func New() *VolumeStore {
 func New() *VolumeStore {
 	return &VolumeStore{
 	return &VolumeStore{
-		vols: make(map[string]*volumeCounter),
+		vols:  make(map[string]*volumeCounter),
+		locks: &locker.Locker{},
 	}
 	}
 }
 }
 
 
+func (s *VolumeStore) get(name string) (*volumeCounter, bool) {
+	s.globalLock.Lock()
+	vc, exists := s.vols[name]
+	s.globalLock.Unlock()
+	return vc, exists
+}
+
+func (s *VolumeStore) set(name string, vc *volumeCounter) {
+	s.globalLock.Lock()
+	s.vols[name] = vc
+	s.globalLock.Unlock()
+}
+
+func (s *VolumeStore) remove(name string) {
+	s.globalLock.Lock()
+	delete(s.vols, name)
+	s.globalLock.Unlock()
+}
+
 // VolumeStore is a struct that stores the list of volumes available and keeps track of their usage counts
 // VolumeStore is a struct that stores the list of volumes available and keeps track of their usage counts
 type VolumeStore struct {
 type VolumeStore struct {
-	vols map[string]*volumeCounter
-	mu   sync.Mutex
+	vols       map[string]*volumeCounter
+	locks      *locker.Locker
+	globalLock sync.Mutex
 }
 }
 
 
 // volumeCounter keeps track of references to a volume
 // volumeCounter keeps track of references to a volume
@@ -47,14 +69,14 @@ func (s *VolumeStore) AddAll(vols []volume.Volume) {
 
 
 // Create tries to find an existing volume with the given name or create a new one from the passed in driver
 // Create tries to find an existing volume with the given name or create a new one from the passed in driver
 func (s *VolumeStore) Create(name, driverName string, opts map[string]string) (volume.Volume, error) {
 func (s *VolumeStore) Create(name, driverName string, opts map[string]string) (volume.Volume, error) {
-	s.mu.Lock()
 	name = normaliseVolumeName(name)
 	name = normaliseVolumeName(name)
-	if vc, exists := s.vols[name]; exists {
+	s.locks.Lock(name)
+	defer s.locks.Unlock(name)
+
+	if vc, exists := s.get(name); exists {
 		v := vc.Volume
 		v := vc.Volume
-		s.mu.Unlock()
 		return v, nil
 		return v, nil
 	}
 	}
-	s.mu.Unlock()
 	logrus.Debugf("Registering new volume reference: driver %s, name %s", driverName, name)
 	logrus.Debugf("Registering new volume reference: driver %s, name %s", driverName, name)
 
 
 	vd, err := volumedrivers.GetDriver(driverName)
 	vd, err := volumedrivers.GetDriver(driverName)
@@ -76,19 +98,17 @@ func (s *VolumeStore) Create(name, driverName string, opts map[string]string) (v
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	s.mu.Lock()
-	s.vols[normaliseVolumeName(v.Name())] = &volumeCounter{v, 0}
-	s.mu.Unlock()
-
+	s.set(name, &volumeCounter{v, 0})
 	return v, nil
 	return v, nil
 }
 }
 
 
 // Get looks if a volume with the given name exists and returns it if so
 // Get looks if a volume with the given name exists and returns it if so
 func (s *VolumeStore) Get(name string) (volume.Volume, error) {
 func (s *VolumeStore) Get(name string) (volume.Volume, error) {
 	name = normaliseVolumeName(name)
 	name = normaliseVolumeName(name)
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	vc, exists := s.vols[name]
+	s.locks.Lock(name)
+	defer s.locks.Unlock(name)
+
+	vc, exists := s.get(name)
 	if !exists {
 	if !exists {
 		return nil, ErrNoSuchVolume
 		return nil, ErrNoSuchVolume
 	}
 	}
@@ -97,11 +117,12 @@ func (s *VolumeStore) Get(name string) (volume.Volume, error) {
 
 
 // Remove removes the requested volume. A volume is not removed if the usage count is > 0
 // Remove removes the requested volume. A volume is not removed if the usage count is > 0
 func (s *VolumeStore) Remove(v volume.Volume) error {
 func (s *VolumeStore) Remove(v volume.Volume) error {
-	s.mu.Lock()
-	defer s.mu.Unlock()
 	name := normaliseVolumeName(v.Name())
 	name := normaliseVolumeName(v.Name())
+	s.locks.Lock(name)
+	defer s.locks.Unlock(name)
+
 	logrus.Debugf("Removing volume reference: driver %s, name %s", v.DriverName(), name)
 	logrus.Debugf("Removing volume reference: driver %s, name %s", v.DriverName(), name)
-	vc, exists := s.vols[name]
+	vc, exists := s.get(name)
 	if !exists {
 	if !exists {
 		return ErrNoSuchVolume
 		return ErrNoSuchVolume
 	}
 	}
@@ -117,20 +138,21 @@ func (s *VolumeStore) Remove(v volume.Volume) error {
 	if err := vd.Remove(vc.Volume); err != nil {
 	if err := vd.Remove(vc.Volume); err != nil {
 		return err
 		return err
 	}
 	}
-	delete(s.vols, name)
+
+	s.remove(name)
 	return nil
 	return nil
 }
 }
 
 
 // Increment increments the usage count of the passed in volume by 1
 // Increment increments the usage count of the passed in volume by 1
 func (s *VolumeStore) Increment(v volume.Volume) {
 func (s *VolumeStore) Increment(v volume.Volume) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
 	name := normaliseVolumeName(v.Name())
 	name := normaliseVolumeName(v.Name())
-	logrus.Debugf("Incrementing volume reference: driver %s, name %s", v.DriverName(), name)
+	s.locks.Lock(name)
+	defer s.locks.Unlock(name)
 
 
-	vc, exists := s.vols[name]
+	logrus.Debugf("Incrementing volume reference: driver %s, name %s", v.DriverName(), v.Name())
+	vc, exists := s.get(name)
 	if !exists {
 	if !exists {
-		s.vols[name] = &volumeCounter{v, 1}
+		s.set(name, &volumeCounter{v, 1})
 		return
 		return
 	}
 	}
 	vc.count++
 	vc.count++
@@ -138,12 +160,12 @@ func (s *VolumeStore) Increment(v volume.Volume) {
 
 
 // Decrement decrements the usage count of the passed in volume by 1
 // Decrement decrements the usage count of the passed in volume by 1
 func (s *VolumeStore) Decrement(v volume.Volume) {
 func (s *VolumeStore) Decrement(v volume.Volume) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
 	name := normaliseVolumeName(v.Name())
 	name := normaliseVolumeName(v.Name())
-	logrus.Debugf("Decrementing volume reference: driver %s, name %s", v.DriverName(), name)
+	s.locks.Lock(name)
+	defer s.locks.Unlock(name)
+	logrus.Debugf("Decrementing volume reference: driver %s, name %s", v.DriverName(), v.Name())
 
 
-	vc, exists := s.vols[name]
+	vc, exists := s.get(name)
 	if !exists {
 	if !exists {
 		return
 		return
 	}
 	}
@@ -155,9 +177,11 @@ func (s *VolumeStore) Decrement(v volume.Volume) {
 
 
 // Count returns the usage count of the passed in volume
 // Count returns the usage count of the passed in volume
 func (s *VolumeStore) Count(v volume.Volume) uint {
 func (s *VolumeStore) Count(v volume.Volume) uint {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	vc, exists := s.vols[normaliseVolumeName(v.Name())]
+	name := normaliseVolumeName(v.Name())
+	s.locks.Lock(name)
+	defer s.locks.Unlock(name)
+
+	vc, exists := s.get(name)
 	if !exists {
 	if !exists {
 		return 0
 		return 0
 	}
 	}
@@ -166,8 +190,8 @@ func (s *VolumeStore) Count(v volume.Volume) uint {
 
 
 // List returns all the available volumes
 // List returns all the available volumes
 func (s *VolumeStore) List() []volume.Volume {
 func (s *VolumeStore) List() []volume.Volume {
-	s.mu.Lock()
-	defer s.mu.Unlock()
+	s.globalLock.Lock()
+	defer s.globalLock.Unlock()
 	var ls []volume.Volume
 	var ls []volume.Volume
 	for _, vc := range s.vols {
 	for _, vc := range s.vols {
 		ls = append(ls, vc.Volume)
 		ls = append(ls, vc.Volume)
@@ -192,8 +216,8 @@ func byDriver(name string) filterFunc {
 
 
 // filter returns the available volumes filtered by a filterFunc function
 // filter returns the available volumes filtered by a filterFunc function
 func (s *VolumeStore) filter(f filterFunc) []volume.Volume {
 func (s *VolumeStore) filter(f filterFunc) []volume.Volume {
-	s.mu.Lock()
-	defer s.mu.Unlock()
+	s.globalLock.Lock()
+	defer s.globalLock.Unlock()
 	var ls []volume.Volume
 	var ls []volume.Volume
 	for _, vc := range s.vols {
 	for _, vc := range s.vols {
 		if f(vc.Volume) {
 		if f(vc.Volume) {