opts: use strings.Cut() and refactor parseDaemonHost()

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
Sebastiaan van Stijn 2022-11-01 11:13:45 +01:00
parent d3cd746067
commit 774cd9a26c
No known key found for this signature in database
GPG key ID: 76698F39D527CE8C
6 changed files with 53 additions and 60 deletions

View file

@ -31,19 +31,17 @@ func (p *PoolsOpt) Set(value string) error {
poolsDef := types.NetworkToSplit{}
for _, field := range fields {
parts := strings.SplitN(field, "=", 2)
if len(parts) != 2 {
// TODO(thaJeztah): this should not be case-insensitive.
key, val, ok := strings.Cut(strings.ToLower(field), "=")
if !ok {
return fmt.Errorf("invalid field '%s' must be a key=value pair", field)
}
key := strings.ToLower(parts[0])
value := strings.ToLower(parts[1])
switch key {
case "base":
poolsDef.Base = value
poolsDef.Base = val
case "size":
size, err := strconv.Atoi(value)
size, err := strconv.Atoi(val)
if err != nil {
return fmt.Errorf("invalid size value: %q (must be integer): %v", value, err)
}

View file

@ -16,15 +16,15 @@ import (
//
// The only validation here is to check if name is empty, per #25099
func ValidateEnv(val string) (string, error) {
arr := strings.SplitN(val, "=", 2)
if arr[0] == "" {
k, _, ok := strings.Cut(val, "=")
if k == "" {
return "", errors.New("invalid environment variable: " + val)
}
if len(arr) > 1 {
if ok {
return val, nil
}
if envVal, ok := os.LookupEnv(arr[0]); ok {
return arr[0] + "=" + envVal, nil
if envVal, ok := os.LookupEnv(k); ok {
return k + "=" + envVal, nil
}
return val, nil
}

View file

@ -60,8 +60,7 @@ func ParseHost(defaultToTLS, defaultToUnixXDG bool, val string) (string, error)
if err != nil {
return "", err
}
socket := filepath.Join(runtimeDir, "docker.sock")
host = "unix://" + socket
host = "unix://" + filepath.Join(runtimeDir, "docker.sock")
} else {
host = DefaultHost
}
@ -77,23 +76,32 @@ func ParseHost(defaultToTLS, defaultToUnixXDG bool, val string) (string, error)
// parseDaemonHost parses the specified address and returns an address that will be used as the host.
// Depending on the address specified, this may return one of the global Default* strings defined in hosts.go.
func parseDaemonHost(addr string) (string, error) {
addrParts := strings.SplitN(addr, "://", 2)
if len(addrParts) == 1 && addrParts[0] != "" {
addrParts = []string{"tcp", addrParts[0]}
func parseDaemonHost(address string) (string, error) {
proto, addr, ok := strings.Cut(address, "://")
if !ok && proto != "" {
addr = proto
proto = "tcp"
}
switch addrParts[0] {
switch proto {
case "tcp":
return ParseTCPAddr(addr, DefaultTCPHost)
return ParseTCPAddr(address, DefaultTCPHost)
case "unix":
return parseSimpleProtoAddr("unix", addrParts[1], DefaultUnixSocket)
a, err := parseSimpleProtoAddr(proto, addr, DefaultUnixSocket)
if err != nil {
return "", errors.Wrapf(err, "invalid bind address (%s)", address)
}
return a, nil
case "npipe":
return parseSimpleProtoAddr("npipe", addrParts[1], DefaultNamedPipe)
a, err := parseSimpleProtoAddr(proto, addr, DefaultNamedPipe)
if err != nil {
return "", errors.Wrapf(err, "invalid bind address (%s)", address)
}
return a, nil
case "fd":
return addr, nil
return address, nil
default:
return "", errors.Errorf("invalid bind address (%s): unsupported proto '%s'", addr, addrParts[0])
return "", errors.Errorf("invalid bind address (%s): unsupported proto '%s'", address, proto)
}
}
@ -102,9 +110,8 @@ func parseDaemonHost(addr string) (string, error) {
// socket address, either using the address parsed from addr, or the contents of
// defaultAddr if addr is a blank string.
func parseSimpleProtoAddr(proto, addr, defaultAddr string) (string, error) {
addr = strings.TrimPrefix(addr, proto+"://")
if strings.Contains(addr, "://") {
return "", errors.Errorf("invalid proto, expected %s: %s", proto, addr)
return "", errors.Errorf("invalid %s address: %s", proto, addr)
}
if addr == "" {
addr = defaultAddr
@ -172,14 +179,14 @@ func parseTCPAddr(address string, strict bool) (*url.URL, error) {
// ExtraHost is in the form of name:ip where the ip has to be a valid ip (IPv4 or IPv6).
func ValidateExtraHost(val string) (string, error) {
// allow for IPv6 addresses in extra hosts by only splitting on first ":"
arr := strings.SplitN(val, ":", 2)
if len(arr) != 2 || len(arr[0]) == 0 {
name, ip, ok := strings.Cut(val, ":")
if !ok || name == "" {
return "", errors.Errorf("bad format for add-host: %q", val)
}
// Skip IPaddr validation for special "host-gateway" string
if arr[1] != HostGatewayName {
if _, err := ValidateIPAddress(arr[1]); err != nil {
return "", errors.Errorf("invalid IP address in add-host: %q", arr[1])
if ip != HostGatewayName {
if _, err := ValidateIPAddress(ip); err != nil {
return "", errors.Errorf("invalid IP address in add-host: %q", ip)
}
}
return val, nil

View file

@ -85,6 +85,8 @@ func TestParseDockerDaemonHost(t *testing.T) {
"[0:0:0:0:0:0:0:1]:5555/path": "invalid bind address ([0:0:0:0:0:0:0:1]:5555/path): should not contain a path element",
"tcp://:5555/path": "invalid bind address (tcp://:5555/path): should not contain a path element",
"localhost:5555/path": "invalid bind address (localhost:5555/path): should not contain a path element",
"unix://tcp://127.0.0.1": "invalid bind address (unix://tcp://127.0.0.1): invalid unix address: tcp://127.0.0.1",
"unix://unix://tcp://127.0.0.1": "invalid bind address (unix://unix://tcp://127.0.0.1): invalid unix address: unix://tcp://127.0.0.1",
}
valids := map[string]string{
":": DefaultTCPHost,
@ -130,7 +132,7 @@ func TestParseDockerDaemonHost(t *testing.T) {
t.Errorf(`unexpected error: "%v"`, err)
}
if addr != expectedAddr {
t.Errorf(`expected "%s", got "%s""`, expectedAddr, addr)
t.Errorf(`expected "%s", got "%s"`, expectedAddr, addr)
}
})
}
@ -210,18 +212,6 @@ func TestParseTCP(t *testing.T) {
}
}
func TestParseInvalidUnixAddrInvalid(t *testing.T) {
if _, err := parseSimpleProtoAddr("unix", "tcp://127.0.0.1", "unix:///var/run/docker.sock"); err == nil || err.Error() != "invalid proto, expected unix: tcp://127.0.0.1" {
t.Fatalf("Expected an error, got %v", err)
}
if _, err := parseSimpleProtoAddr("unix", "unix://tcp://127.0.0.1", "/var/run/docker.sock"); err == nil || err.Error() != "invalid proto, expected unix: tcp://127.0.0.1" {
t.Fatalf("Expected an error, got %v", err)
}
if v, err := parseSimpleProtoAddr("unix", "", "/var/run/docker.sock"); err != nil || v != "unix:///var/run/docker.sock" {
t.Fatalf("Expected an %v, got %v", v, "unix:///var/run/docker.sock")
}
}
func TestValidateExtraHosts(t *testing.T) {
valid := []string{
`myhost:192.168.0.1`,

View file

@ -162,12 +162,8 @@ func (opts *MapOpts) Set(value string) error {
}
value = v
}
vals := strings.SplitN(value, "=", 2)
if len(vals) == 1 {
(opts.values)[vals[0]] = ""
} else {
(opts.values)[vals[0]] = vals[1]
}
k, v, _ := strings.Cut(value, "=")
(opts.values)[k] = v
return nil
}

View file

@ -29,27 +29,29 @@ func (o *RuntimeOpt) Name() string {
// Set validates and updates the list of Runtimes
func (o *RuntimeOpt) Set(val string) error {
parts := strings.SplitN(val, "=", 2)
if len(parts) != 2 {
k, v, ok := strings.Cut(val, "=")
if !ok {
return fmt.Errorf("invalid runtime argument: %s", val)
}
parts[0] = strings.TrimSpace(parts[0])
parts[1] = strings.TrimSpace(parts[1])
if parts[0] == "" || parts[1] == "" {
// TODO(thaJeztah): this should not accept spaces.
k = strings.TrimSpace(k)
v = strings.TrimSpace(v)
if k == "" || v == "" {
return fmt.Errorf("invalid runtime argument: %s", val)
}
parts[0] = strings.ToLower(parts[0])
if parts[0] == o.stockRuntimeName {
// TODO(thaJeztah): this should not be case-insensitive.
k = strings.ToLower(k)
if k == o.stockRuntimeName {
return fmt.Errorf("runtime name '%s' is reserved", o.stockRuntimeName)
}
if _, ok := (*o.values)[parts[0]]; ok {
return fmt.Errorf("runtime '%s' was already defined", parts[0])
if _, ok := (*o.values)[k]; ok {
return fmt.Errorf("runtime '%s' was already defined", k)
}
(*o.values)[parts[0]] = types.Runtime{Path: parts[1]}
(*o.values)[k] = types.Runtime{Path: v}
return nil
}