浏览代码

opts: simplify ValidateEnv to use os.LookupEnv

os.LookupEnv() was not available yet at the time that this was
implemented (9ab73260f8e4662e7321b257c636928892f023cf), but now
provides the functionality we need, so replacing our custom handling.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
Sebastiaan van Stijn 4 年之前
父节点
当前提交
c255404a25
共有 2 个文件被更改,包括 31 次插入54 次删除
  1. 7 25
      opts/env.go
  2. 24 29
      opts/env_test.go

+ 7 - 25
opts/env.go

@@ -1,48 +1,30 @@
 package opts // import "github.com/docker/docker/opts"
 package opts // import "github.com/docker/docker/opts"
 
 
 import (
 import (
-	"fmt"
 	"os"
 	"os"
-	"runtime"
 	"strings"
 	"strings"
 
 
 	"github.com/pkg/errors"
 	"github.com/pkg/errors"
 )
 )
 
 
 // ValidateEnv validates an environment variable and returns it.
 // ValidateEnv validates an environment variable and returns it.
-// If no value is specified, it returns the current value using os.Getenv.
+// If no value is specified, it obtains its value from the current environment
 //
 //
 // As on ParseEnvFile and related to #16585, environment variable names
 // As on ParseEnvFile and related to #16585, environment variable names
-// are not validate what so ever, it's up to application inside docker
+// are not validate whatsoever, it's up to application inside docker
 // to validate them or not.
 // to validate them or not.
 //
 //
 // The only validation here is to check if name is empty, per #25099
 // The only validation here is to check if name is empty, per #25099
 func ValidateEnv(val string) (string, error) {
 func ValidateEnv(val string) (string, error) {
-	arr := strings.Split(val, "=")
+	arr := strings.SplitN(val, "=", 2)
 	if arr[0] == "" {
 	if arr[0] == "" {
-		return "", errors.Errorf("invalid environment variable: %s", val)
+		return "", errors.New("invalid environment variable: " + val)
 	}
 	}
 	if len(arr) > 1 {
 	if len(arr) > 1 {
 		return val, nil
 		return val, nil
 	}
 	}
-	if !doesEnvExist(val) {
-		return val, nil
-	}
-	return fmt.Sprintf("%s=%s", val, os.Getenv(val)), nil
-}
-
-func doesEnvExist(name string) bool {
-	for _, entry := range os.Environ() {
-		parts := strings.SplitN(entry, "=", 2)
-		if runtime.GOOS == "windows" {
-			// Environment variable are case-insensitive on Windows. PaTh, path and PATH are equivalent.
-			if strings.EqualFold(parts[0], name) {
-				return true
-			}
-		}
-		if parts[0] == name {
-			return true
-		}
+	if envVal, ok := os.LookupEnv(arr[0]); ok {
+		return arr[0] + "=" + envVal, nil
 	}
 	}
-	return false
+	return val, nil
 }
 }

+ 24 - 29
opts/env_test.go

@@ -5,14 +5,17 @@ import (
 	"os"
 	"os"
 	"runtime"
 	"runtime"
 	"testing"
 	"testing"
+
+	"gotest.tools/v3/assert"
 )
 )
 
 
 func TestValidateEnv(t *testing.T) {
 func TestValidateEnv(t *testing.T) {
-	testcase := []struct {
+	type testCase struct {
 		value    string
 		value    string
 		expected string
 		expected string
 		err      error
 		err      error
-	}{
+	}
+	tests := []testCase{
 		{
 		{
 			value:    "a",
 			value:    "a",
 			expected: "a",
 			expected: "a",
@@ -51,7 +54,11 @@ func TestValidateEnv(t *testing.T) {
 		},
 		},
 		{
 		{
 			value: "=a",
 			value: "=a",
-			err:   fmt.Errorf(fmt.Sprintf("invalid environment variable: %s", "=a")),
+			err:   fmt.Errorf("invalid environment variable: =a"),
+		},
+		{
+			value:    "PATH=",
+			expected: "PATH=",
 		},
 		},
 		{
 		{
 			value:    "PATH=something",
 			value:    "PATH=something",
@@ -83,42 +90,30 @@ func TestValidateEnv(t *testing.T) {
 		},
 		},
 		{
 		{
 			value: "=",
 			value: "=",
-			err:   fmt.Errorf(fmt.Sprintf("invalid environment variable: %s", "=")),
+			err:   fmt.Errorf("invalid environment variable: ="),
 		},
 		},
 	}
 	}
 
 
-	// Environment variables are case in-sensitive on Windows
 	if runtime.GOOS == "windows" {
 	if runtime.GOOS == "windows" {
-		tmp := struct {
-			value    string
-			expected string
-			err      error
-		}{
+		// Environment variables are case in-sensitive on Windows
+		tests = append(tests, testCase{
 			value:    "PaTh",
 			value:    "PaTh",
 			expected: fmt.Sprintf("PaTh=%v", os.Getenv("PATH")),
 			expected: fmt.Sprintf("PaTh=%v", os.Getenv("PATH")),
 			err:      nil,
 			err:      nil,
-		}
-		testcase = append(testcase, tmp)
+		})
 	}
 	}
 
 
-	for _, r := range testcase {
-		actual, err := ValidateEnv(r.value)
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.value, func(t *testing.T) {
+			actual, err := ValidateEnv(tc.value)
 
 
-		if err != nil {
-			if r.err == nil {
-				t.Fatalf("Expected err is nil, got err[%v]", err)
+			if tc.err == nil {
+				assert.NilError(t, err)
+			} else {
+				assert.Error(t, err, tc.err.Error())
 			}
 			}
-			if err.Error() != r.err.Error() {
-				t.Fatalf("Expected err[%v], got err[%v]", r.err, err)
-			}
-		}
-
-		if err == nil && r.err != nil {
-			t.Fatalf("Expected err[%v], but err is nil", r.err)
-		}
-
-		if actual != r.expected {
-			t.Fatalf("Expected [%v], got [%v]", r.expected, actual)
-		}
+			assert.Equal(t, actual, tc.expected)
+		})
 	}
 	}
 }
 }