diff --git a/libnetwork/controller.go b/libnetwork/controller.go index f4b5dd781d..724c6717cf 100644 --- a/libnetwork/controller.go +++ b/libnetwork/controller.go @@ -71,10 +71,10 @@ type NetworkController interface { WalkNetworks(walker NetworkWalker) // NetworkByName returns the Network which has the passed name, if it exists otherwise nil is returned - NetworkByName(name string) Network + NetworkByName(name string) (Network, error) // NetworkByID returns the Network which has the passed id, if it exists otherwise nil is returned - NetworkByID(id string) Network + NetworkByID(id string) (Network, error) } // NetworkWalker is a client provided function which will be used to walk the Networks. @@ -129,7 +129,7 @@ func (c *controller) RegisterDriver(networkType string, driver driverapi.Driver) // are network specific and modeled in a generic way. func (c *controller) NewNetwork(networkType, name string, options ...NetworkOption) (Network, error) { if name == "" { - return nil, ErrInvalidNetworkName + return nil, ErrInvalidName } // Check if a driver for the specified network type is available c.Lock() @@ -192,31 +192,35 @@ func (c *controller) WalkNetworks(walker NetworkWalker) { } } -func (c *controller) NetworkByName(name string) Network { +func (c *controller) NetworkByName(name string) (Network, error) { + if name == "" { + return nil, ErrInvalidName + } var n Network - if name != "" { - s := func(current Network) bool { - if current.Name() == name { - n = current - return true - } - return false + s := func(current Network) bool { + if current.Name() == name { + n = current + return true } - - c.WalkNetworks(s) + return false } - return n + c.WalkNetworks(s) + + return n, nil } -func (c *controller) NetworkByID(id string) Network { +func (c *controller) NetworkByID(id string) (Network, error) { + if id == "" { + return nil, ErrInvalidID + } c.Lock() defer c.Unlock() if n, ok := c.networks[types.UUID(id)]; ok { - return n + return n, nil } - return nil + return nil, nil } func (c *controller) sandboxAdd(key string, create bool) (sandbox.Sandbox, error) { diff --git a/libnetwork/error.go b/libnetwork/error.go index 491c1bff39..b3b99fafa3 100644 --- a/libnetwork/error.go +++ b/libnetwork/error.go @@ -14,16 +14,16 @@ var ( ErrInvalidNetworkDriver = errors.New("invalid driver bound to network") // ErrInvalidJoin is returned if a join is attempted on an endpoint // which already has a container joined. - ErrInvalidJoin = errors.New("A container has already joined the endpoint") + ErrInvalidJoin = errors.New("a container has already joined the endpoint") // ErrNoContainer is returned when the endpoint has no container // attached to it. ErrNoContainer = errors.New("no container attached to the endpoint") - // ErrInvalidEndpointName is returned if an invalid endpoint name - // is passed when creating an endpoint - ErrInvalidEndpointName = errors.New("invalid endpoint name") - // ErrInvalidNetworkName is returned if an invalid network name - // is passed when creating a network - ErrInvalidNetworkName = errors.New("invalid network name") + // ErrInvalidID is returned when a query-by-id method is being invoked + // with an empty id parameter + ErrInvalidID = errors.New("invalid ID") + // ErrInvalidName is returned when a query-by-name or resource create method is + // invoked with an empty name parameter + ErrInvalidName = errors.New("invalid Name") ) // NetworkTypeError type is returned when the network type string is not diff --git a/libnetwork/libnetwork_test.go b/libnetwork/libnetwork_test.go index f5edbd8bb6..330184a2c0 100644 --- a/libnetwork/libnetwork_test.go +++ b/libnetwork/libnetwork_test.go @@ -281,8 +281,8 @@ func TestNetworkName(t *testing.T) { if err == nil { t.Fatal("Expected to fail. But instead succeeded") } - if err != libnetwork.ErrInvalidNetworkName { - t.Fatal("Expected to fail with ErrInvalidNetworkName error") + if err != libnetwork.ErrInvalidName { + t.Fatal("Expected to fail with ErrInvalidName error") } networkName := "testnetwork" @@ -404,8 +404,8 @@ func TestUnknownEndpoint(t *testing.T) { if err == nil { t.Fatal("Expected to fail. But instead succeeded") } - if err != libnetwork.ErrInvalidEndpointName { - t.Fatal("Expected to fail with ErrInvalidEndpointName error") + if err != libnetwork.ErrInvalidName { + t.Fatal("Expected to fail with ErrInvalidName error") } ep, err := network.CreateEndpoint("testep") @@ -526,30 +526,46 @@ func TestControllerQuery(t *testing.T) { t.Fatal(err) } - g := controller.NetworkByName("") - if g != nil { + _, err = controller.NetworkByName("") + if err == nil { t.Fatalf("NetworkByName() succeeded with invalid target name") } - - g = controller.NetworkByID("") - if g != nil { - t.Fatalf("NetworkByID() succeeded with invalid target id: %v", g) + if err != libnetwork.ErrInvalidName { + t.Fatalf("NetworkByName() failed with unexpected error: %v", err) } - g = controller.NetworkByID("network1") - if g != nil { - t.Fatalf("NetworkByID() succeeded with invalid target name") + _, err = controller.NetworkByID("") + if err == nil { + t.Fatalf("NetworkByID() succeeded with invalid target id") + } + if err != libnetwork.ErrInvalidID { + t.Fatalf("NetworkByID() failed with unexpected error: %v", err) } - g = controller.NetworkByName("network1") + g, err := controller.NetworkByID("network1") + if err != nil { + t.Fatalf("Unexpected failure for NetworkByID(): %v", err) + } + if g != nil { + t.Fatalf("NetworkByID() succeeded with unknown target id") + } + + g, err = controller.NetworkByName("network1") + if err != nil { + t.Fatalf("Unexpected failure for NetworkByName(): %v", err) + } if g == nil { t.Fatalf("NetworkByName() did not find the network") } + if g != net1 { t.Fatalf("NetworkByName() returned the wrong network") } - g = controller.NetworkByID(net1.ID()) + g, err = controller.NetworkByID(net1.ID()) + if err != nil { + t.Fatalf("Unexpected failure for NetworkByID(): %v", err) + } if net1 != g { t.Fatalf("NetworkByID() returned unexpected element: %v", g) } @@ -578,31 +594,42 @@ func TestNetworkQuery(t *testing.T) { t.Fatal(err) } - e := net1.EndpointByName("ep11") + e, err := net1.EndpointByName("ep11") + if err != nil { + t.Fatal(err) + } if ep11 != e { t.Fatalf("EndpointByName() returned %v instead of %v", e, ep11) } - e = net1.EndpointByName("") + e, err = net1.EndpointByName("") + if err == nil { + t.Fatalf("EndpointByName() succeeded with invalid target name") + } + if err != libnetwork.ErrInvalidName { + t.Fatalf("EndpointByName() failed with unexpected error: %v", err) + } + + e, err = net1.EndpointByName("IamNotAnEndpoint") + if err != nil { + t.Fatal(err) + } if e != nil { t.Fatalf("EndpointByName(): expected nil, got %v", e) } - e = net1.EndpointByName("IamNotAnEndpoint") - if e != nil { - t.Fatalf("EndpointByName(): expected nil, got %v", e) - } - - e = net1.EndpointByID(ep12.ID()) + e, err = net1.EndpointByID(ep12.ID()) if ep12 != e { t.Fatalf("EndpointByID() returned %v instead of %v", e, ep12) } - e = net1.EndpointByID("") - if e != nil { - t.Fatalf("EndpointByID(): expected nil, got %v", e) + e, err = net1.EndpointByID("") + if err == nil { + t.Fatalf("EndpointByID() succeeded with invalid target id") + } + if err != libnetwork.ErrInvalidID { + t.Fatalf("EndpointByID() failed with unexpected error: %v", err) } - } const containerID = "valid_container" @@ -1100,12 +1127,18 @@ func runParallelTests(t *testing.T, thrNumber int) { } defer netns.Set(origns) - net := ctrlr.NetworkByName("network1") + net, err := ctrlr.NetworkByName("network1") + if err != nil { + t.Fatal(err) + } if net == nil { t.Fatal("Could not find network1") } - ep := net.EndpointByName("ep1") + ep, err := net.EndpointByName("ep1") + if err != nil { + t.Fatal(err) + } if ep == nil { t.Fatal("Could not find ep1") } diff --git a/libnetwork/network.go b/libnetwork/network.go index df89b8a3ae..0a4425e3ee 100644 --- a/libnetwork/network.go +++ b/libnetwork/network.go @@ -37,10 +37,10 @@ type Network interface { WalkEndpoints(walker EndpointWalker) // EndpointByName returns the Endpoint which has the passed name, if it exists otherwise nil is returned - EndpointByName(name string) Endpoint + EndpointByName(name string) (Endpoint, error) // EndpointByID returns the Endpoint which has the passed id, if it exists otherwise nil is returned - EndpointByID(id string) Endpoint + EndpointByID(id string) (Endpoint, error) } // EndpointWalker is a client provided function which will be used to walk the Endpoints. @@ -133,7 +133,7 @@ func (n *network) Delete() error { func (n *network) CreateEndpoint(name string, options ...EndpointOption) (Endpoint, error) { if name == "" { - return nil, ErrInvalidEndpointName + return nil, ErrInvalidName } ep := &endpoint{name: name, generic: make(map[string]interface{})} ep.id = types.UUID(stringid.GenerateRandomID()) @@ -172,29 +172,33 @@ func (n *network) WalkEndpoints(walker EndpointWalker) { } } -func (n *network) EndpointByName(name string) Endpoint { +func (n *network) EndpointByName(name string) (Endpoint, error) { + if name == "" { + return nil, ErrInvalidName + } var e Endpoint - if name != "" { - s := func(current Endpoint) bool { - if current.Name() == name { - e = current - return true - } - return false + s := func(current Endpoint) bool { + if current.Name() == name { + e = current + return true } - - n.WalkEndpoints(s) + return false } - return e + n.WalkEndpoints(s) + + return e, nil } -func (n *network) EndpointByID(id string) Endpoint { +func (n *network) EndpointByID(id string) (Endpoint, error) { + if id == "" { + return nil, ErrInvalidID + } n.Lock() defer n.Unlock() if e, ok := n.endpoints[types.UUID(id)]; ok { - return e + return e, nil } - return nil + return nil, nil }