Selaa lähdekoodia

opts: use strings.Cut() and refactor parseDaemonHost()

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
Sebastiaan van Stijn 2 vuotta sitten
vanhempi
commit
774cd9a26c
6 muutettua tiedostoa jossa 53 lisäystä ja 60 poistoa
  1. 5 7
      opts/address_pools.go
  2. 5 5
      opts/env.go
  3. 26 19
      opts/hosts.go
  4. 3 13
      opts/hosts_test.go
  5. 2 6
      opts/opts.go
  6. 12 10
      opts/runtime.go

+ 5 - 7
opts/address_pools.go

@@ -31,19 +31,17 @@ func (p *PoolsOpt) Set(value string) error {
 	poolsDef := types.NetworkToSplit{}
 
 	for _, field := range fields {
-		parts := strings.SplitN(field, "=", 2)
-		if len(parts) != 2 {
+		// TODO(thaJeztah): this should not be case-insensitive.
+		key, val, ok := strings.Cut(strings.ToLower(field), "=")
+		if !ok {
 			return fmt.Errorf("invalid field '%s' must be a key=value pair", field)
 		}
 
-		key := strings.ToLower(parts[0])
-		value := strings.ToLower(parts[1])
-
 		switch key {
 		case "base":
-			poolsDef.Base = value
+			poolsDef.Base = val
 		case "size":
-			size, err := strconv.Atoi(value)
+			size, err := strconv.Atoi(val)
 			if err != nil {
 				return fmt.Errorf("invalid size value: %q (must be integer): %v", value, err)
 			}

+ 5 - 5
opts/env.go

@@ -16,15 +16,15 @@ import (
 //
 // The only validation here is to check if name is empty, per #25099
 func ValidateEnv(val string) (string, error) {
-	arr := strings.SplitN(val, "=", 2)
-	if arr[0] == "" {
+	k, _, ok := strings.Cut(val, "=")
+	if k == "" {
 		return "", errors.New("invalid environment variable: " + val)
 	}
-	if len(arr) > 1 {
+	if ok {
 		return val, nil
 	}
-	if envVal, ok := os.LookupEnv(arr[0]); ok {
-		return arr[0] + "=" + envVal, nil
+	if envVal, ok := os.LookupEnv(k); ok {
+		return k + "=" + envVal, nil
 	}
 	return val, nil
 }

+ 26 - 19
opts/hosts.go

@@ -60,8 +60,7 @@ func ParseHost(defaultToTLS, defaultToUnixXDG bool, val string) (string, error)
 			if err != nil {
 				return "", err
 			}
-			socket := filepath.Join(runtimeDir, "docker.sock")
-			host = "unix://" + socket
+			host = "unix://" + filepath.Join(runtimeDir, "docker.sock")
 		} else {
 			host = DefaultHost
 		}
@@ -77,23 +76,32 @@ func ParseHost(defaultToTLS, defaultToUnixXDG bool, val string) (string, error)
 
 // parseDaemonHost parses the specified address and returns an address that will be used as the host.
 // Depending on the address specified, this may return one of the global Default* strings defined in hosts.go.
-func parseDaemonHost(addr string) (string, error) {
-	addrParts := strings.SplitN(addr, "://", 2)
-	if len(addrParts) == 1 && addrParts[0] != "" {
-		addrParts = []string{"tcp", addrParts[0]}
+func parseDaemonHost(address string) (string, error) {
+	proto, addr, ok := strings.Cut(address, "://")
+	if !ok && proto != "" {
+		addr = proto
+		proto = "tcp"
 	}
 
-	switch addrParts[0] {
+	switch proto {
 	case "tcp":
-		return ParseTCPAddr(addr, DefaultTCPHost)
+		return ParseTCPAddr(address, DefaultTCPHost)
 	case "unix":
-		return parseSimpleProtoAddr("unix", addrParts[1], DefaultUnixSocket)
+		a, err := parseSimpleProtoAddr(proto, addr, DefaultUnixSocket)
+		if err != nil {
+			return "", errors.Wrapf(err, "invalid bind address (%s)", address)
+		}
+		return a, nil
 	case "npipe":
-		return parseSimpleProtoAddr("npipe", addrParts[1], DefaultNamedPipe)
+		a, err := parseSimpleProtoAddr(proto, addr, DefaultNamedPipe)
+		if err != nil {
+			return "", errors.Wrapf(err, "invalid bind address (%s)", address)
+		}
+		return a, nil
 	case "fd":
-		return addr, nil
+		return address, nil
 	default:
-		return "", errors.Errorf("invalid bind address (%s): unsupported proto '%s'", addr, addrParts[0])
+		return "", errors.Errorf("invalid bind address (%s): unsupported proto '%s'", address, proto)
 	}
 }
 
@@ -102,9 +110,8 @@ func parseDaemonHost(addr string) (string, error) {
 // socket address, either using the address parsed from addr, or the contents of
 // defaultAddr if addr is a blank string.
 func parseSimpleProtoAddr(proto, addr, defaultAddr string) (string, error) {
-	addr = strings.TrimPrefix(addr, proto+"://")
 	if strings.Contains(addr, "://") {
-		return "", errors.Errorf("invalid proto, expected %s: %s", proto, addr)
+		return "", errors.Errorf("invalid %s address: %s", proto, addr)
 	}
 	if addr == "" {
 		addr = defaultAddr
@@ -172,14 +179,14 @@ func parseTCPAddr(address string, strict bool) (*url.URL, error) {
 // ExtraHost is in the form of name:ip where the ip has to be a valid ip (IPv4 or IPv6).
 func ValidateExtraHost(val string) (string, error) {
 	// allow for IPv6 addresses in extra hosts by only splitting on first ":"
-	arr := strings.SplitN(val, ":", 2)
-	if len(arr) != 2 || len(arr[0]) == 0 {
+	name, ip, ok := strings.Cut(val, ":")
+	if !ok || name == "" {
 		return "", errors.Errorf("bad format for add-host: %q", val)
 	}
 	// Skip IPaddr validation for special "host-gateway" string
-	if arr[1] != HostGatewayName {
-		if _, err := ValidateIPAddress(arr[1]); err != nil {
-			return "", errors.Errorf("invalid IP address in add-host: %q", arr[1])
+	if ip != HostGatewayName {
+		if _, err := ValidateIPAddress(ip); err != nil {
+			return "", errors.Errorf("invalid IP address in add-host: %q", ip)
 		}
 	}
 	return val, nil

+ 3 - 13
opts/hosts_test.go

@@ -85,6 +85,8 @@ func TestParseDockerDaemonHost(t *testing.T) {
 		"[0:0:0:0:0:0:0:1]:5555/path":   "invalid bind address ([0:0:0:0:0:0:0:1]:5555/path): should not contain a path element",
 		"tcp://:5555/path":              "invalid bind address (tcp://:5555/path): should not contain a path element",
 		"localhost:5555/path":           "invalid bind address (localhost:5555/path): should not contain a path element",
+		"unix://tcp://127.0.0.1":        "invalid bind address (unix://tcp://127.0.0.1): invalid unix address: tcp://127.0.0.1",
+		"unix://unix://tcp://127.0.0.1": "invalid bind address (unix://unix://tcp://127.0.0.1): invalid unix address: unix://tcp://127.0.0.1",
 	}
 	valids := map[string]string{
 		":":                       DefaultTCPHost,
@@ -130,7 +132,7 @@ func TestParseDockerDaemonHost(t *testing.T) {
 				t.Errorf(`unexpected error: "%v"`, err)
 			}
 			if addr != expectedAddr {
-				t.Errorf(`expected "%s", got "%s""`, expectedAddr, addr)
+				t.Errorf(`expected "%s", got "%s"`, expectedAddr, addr)
 			}
 		})
 	}
@@ -210,18 +212,6 @@ func TestParseTCP(t *testing.T) {
 	}
 }
 
-func TestParseInvalidUnixAddrInvalid(t *testing.T) {
-	if _, err := parseSimpleProtoAddr("unix", "tcp://127.0.0.1", "unix:///var/run/docker.sock"); err == nil || err.Error() != "invalid proto, expected unix: tcp://127.0.0.1" {
-		t.Fatalf("Expected an error, got %v", err)
-	}
-	if _, err := parseSimpleProtoAddr("unix", "unix://tcp://127.0.0.1", "/var/run/docker.sock"); err == nil || err.Error() != "invalid proto, expected unix: tcp://127.0.0.1" {
-		t.Fatalf("Expected an error, got %v", err)
-	}
-	if v, err := parseSimpleProtoAddr("unix", "", "/var/run/docker.sock"); err != nil || v != "unix:///var/run/docker.sock" {
-		t.Fatalf("Expected an %v, got %v", v, "unix:///var/run/docker.sock")
-	}
-}
-
 func TestValidateExtraHosts(t *testing.T) {
 	valid := []string{
 		`myhost:192.168.0.1`,

+ 2 - 6
opts/opts.go

@@ -162,12 +162,8 @@ func (opts *MapOpts) Set(value string) error {
 		}
 		value = v
 	}
-	vals := strings.SplitN(value, "=", 2)
-	if len(vals) == 1 {
-		(opts.values)[vals[0]] = ""
-	} else {
-		(opts.values)[vals[0]] = vals[1]
-	}
+	k, v, _ := strings.Cut(value, "=")
+	(opts.values)[k] = v
 	return nil
 }
 

+ 12 - 10
opts/runtime.go

@@ -29,27 +29,29 @@ func (o *RuntimeOpt) Name() string {
 
 // Set validates and updates the list of Runtimes
 func (o *RuntimeOpt) Set(val string) error {
-	parts := strings.SplitN(val, "=", 2)
-	if len(parts) != 2 {
+	k, v, ok := strings.Cut(val, "=")
+	if !ok {
 		return fmt.Errorf("invalid runtime argument: %s", val)
 	}
 
-	parts[0] = strings.TrimSpace(parts[0])
-	parts[1] = strings.TrimSpace(parts[1])
-	if parts[0] == "" || parts[1] == "" {
+	// TODO(thaJeztah): this should not accept spaces.
+	k = strings.TrimSpace(k)
+	v = strings.TrimSpace(v)
+	if k == "" || v == "" {
 		return fmt.Errorf("invalid runtime argument: %s", val)
 	}
 
-	parts[0] = strings.ToLower(parts[0])
-	if parts[0] == o.stockRuntimeName {
+	// TODO(thaJeztah): this should not be case-insensitive.
+	k = strings.ToLower(k)
+	if k == o.stockRuntimeName {
 		return fmt.Errorf("runtime name '%s' is reserved", o.stockRuntimeName)
 	}
 
-	if _, ok := (*o.values)[parts[0]]; ok {
-		return fmt.Errorf("runtime '%s' was already defined", parts[0])
+	if _, ok := (*o.values)[k]; ok {
+		return fmt.Errorf("runtime '%s' was already defined", k)
 	}
 
-	(*o.values)[parts[0]] = types.Runtime{Path: parts[1]}
+	(*o.values)[k] = types.Runtime{Path: v}
 
 	return nil
 }