provider.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. package processcreds
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "os"
  9. "os/exec"
  10. "runtime"
  11. "time"
  12. "github.com/aws/aws-sdk-go-v2/aws"
  13. "github.com/aws/aws-sdk-go-v2/internal/sdkio"
  14. )
  15. const (
  16. // ProviderName is the name this credentials provider will label any
  17. // returned credentials Value with.
  18. ProviderName = `ProcessProvider`
  19. // DefaultTimeout default limit on time a process can run.
  20. DefaultTimeout = time.Duration(1) * time.Minute
  21. )
  22. // ProviderError is an error indicating failure initializing or executing the
  23. // process credentials provider
  24. type ProviderError struct {
  25. Err error
  26. }
  27. // Error returns the error message.
  28. func (e *ProviderError) Error() string {
  29. return fmt.Sprintf("process provider error: %v", e.Err)
  30. }
  31. // Unwrap returns the underlying error the provider error wraps.
  32. func (e *ProviderError) Unwrap() error {
  33. return e.Err
  34. }
  35. // Provider satisfies the credentials.Provider interface, and is a
  36. // client to retrieve credentials from a process.
  37. type Provider struct {
  38. // Provides a constructor for exec.Cmd that are invoked by the provider for
  39. // retrieving credentials. Use this to provide custom creation of exec.Cmd
  40. // with things like environment variables, or other configuration.
  41. //
  42. // The provider defaults to the DefaultNewCommand function.
  43. commandBuilder NewCommandBuilder
  44. options Options
  45. }
  46. // Options is the configuration options for configuring the Provider.
  47. type Options struct {
  48. // Timeout limits the time a process can run.
  49. Timeout time.Duration
  50. }
  51. // NewCommandBuilder provides the interface for specifying how command will be
  52. // created that the Provider will use to retrieve credentials with.
  53. type NewCommandBuilder interface {
  54. NewCommand(context.Context) (*exec.Cmd, error)
  55. }
  56. // NewCommandBuilderFunc provides a wrapper type around a function pointer to
  57. // satisfy the NewCommandBuilder interface.
  58. type NewCommandBuilderFunc func(context.Context) (*exec.Cmd, error)
  59. // NewCommand calls the underlying function pointer the builder was initialized with.
  60. func (fn NewCommandBuilderFunc) NewCommand(ctx context.Context) (*exec.Cmd, error) {
  61. return fn(ctx)
  62. }
  63. // DefaultNewCommandBuilder provides the default NewCommandBuilder
  64. // implementation used by the provider. It takes a command and arguments to
  65. // invoke. The command will also be initialized with the current process
  66. // environment variables, stderr, and stdin pipes.
  67. type DefaultNewCommandBuilder struct {
  68. Args []string
  69. }
  70. // NewCommand returns an initialized exec.Cmd with the builder's initialized
  71. // Args. The command is also initialized current process environment variables,
  72. // stderr, and stdin pipes.
  73. func (b DefaultNewCommandBuilder) NewCommand(ctx context.Context) (*exec.Cmd, error) {
  74. var cmdArgs []string
  75. if runtime.GOOS == "windows" {
  76. cmdArgs = []string{"cmd.exe", "/C"}
  77. } else {
  78. cmdArgs = []string{"sh", "-c"}
  79. }
  80. if len(b.Args) == 0 {
  81. return nil, &ProviderError{
  82. Err: fmt.Errorf("failed to prepare command: command must not be empty"),
  83. }
  84. }
  85. cmdArgs = append(cmdArgs, b.Args...)
  86. cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
  87. cmd.Env = os.Environ()
  88. cmd.Stderr = os.Stderr // display stderr on console for MFA
  89. cmd.Stdin = os.Stdin // enable stdin for MFA
  90. return cmd, nil
  91. }
  92. // NewProvider returns a pointer to a new Credentials object wrapping the
  93. // Provider.
  94. //
  95. // The provider defaults to the DefaultNewCommandBuilder for creating command
  96. // the Provider will use to retrieve credentials with.
  97. func NewProvider(command string, options ...func(*Options)) *Provider {
  98. var args []string
  99. // Ensure that the command arguments are not set if the provided command is
  100. // empty. This will error out when the command is executed since no
  101. // arguments are specified.
  102. if len(command) > 0 {
  103. args = []string{command}
  104. }
  105. commanBuilder := DefaultNewCommandBuilder{
  106. Args: args,
  107. }
  108. return NewProviderCommand(commanBuilder, options...)
  109. }
  110. // NewProviderCommand returns a pointer to a new Credentials object with the
  111. // specified command, and default timeout duration. Use this to provide custom
  112. // creation of exec.Cmd for options like environment variables, or other
  113. // configuration.
  114. func NewProviderCommand(builder NewCommandBuilder, options ...func(*Options)) *Provider {
  115. p := &Provider{
  116. commandBuilder: builder,
  117. options: Options{
  118. Timeout: DefaultTimeout,
  119. },
  120. }
  121. for _, option := range options {
  122. option(&p.options)
  123. }
  124. return p
  125. }
  126. // A CredentialProcessResponse is the AWS credentials format that must be
  127. // returned when executing an external credential_process.
  128. type CredentialProcessResponse struct {
  129. // As of this writing, the Version key must be set to 1. This might
  130. // increment over time as the structure evolves.
  131. Version int
  132. // The access key ID that identifies the temporary security credentials.
  133. AccessKeyID string `json:"AccessKeyId"`
  134. // The secret access key that can be used to sign requests.
  135. SecretAccessKey string
  136. // The token that users must pass to the service API to use the temporary credentials.
  137. SessionToken string
  138. // The date on which the current credentials expire.
  139. Expiration *time.Time
  140. }
  141. // Retrieve executes the credential process command and returns the
  142. // credentials, or error if the command fails.
  143. func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
  144. out, err := p.executeCredentialProcess(ctx)
  145. if err != nil {
  146. return aws.Credentials{Source: ProviderName}, err
  147. }
  148. // Serialize and validate response
  149. resp := &CredentialProcessResponse{}
  150. if err = json.Unmarshal(out, resp); err != nil {
  151. return aws.Credentials{Source: ProviderName}, &ProviderError{
  152. Err: fmt.Errorf("parse failed of process output: %s, error: %w", out, err),
  153. }
  154. }
  155. if resp.Version != 1 {
  156. return aws.Credentials{Source: ProviderName}, &ProviderError{
  157. Err: fmt.Errorf("wrong version in process output (not 1)"),
  158. }
  159. }
  160. if len(resp.AccessKeyID) == 0 {
  161. return aws.Credentials{Source: ProviderName}, &ProviderError{
  162. Err: fmt.Errorf("missing AccessKeyId in process output"),
  163. }
  164. }
  165. if len(resp.SecretAccessKey) == 0 {
  166. return aws.Credentials{Source: ProviderName}, &ProviderError{
  167. Err: fmt.Errorf("missing SecretAccessKey in process output"),
  168. }
  169. }
  170. creds := aws.Credentials{
  171. Source: ProviderName,
  172. AccessKeyID: resp.AccessKeyID,
  173. SecretAccessKey: resp.SecretAccessKey,
  174. SessionToken: resp.SessionToken,
  175. }
  176. // Handle expiration
  177. if resp.Expiration != nil {
  178. creds.CanExpire = true
  179. creds.Expires = *resp.Expiration
  180. }
  181. return creds, nil
  182. }
  183. // executeCredentialProcess starts the credential process on the OS and
  184. // returns the results or an error.
  185. func (p *Provider) executeCredentialProcess(ctx context.Context) ([]byte, error) {
  186. if p.options.Timeout >= 0 {
  187. var cancelFunc func()
  188. ctx, cancelFunc = context.WithTimeout(ctx, p.options.Timeout)
  189. defer cancelFunc()
  190. }
  191. cmd, err := p.commandBuilder.NewCommand(ctx)
  192. if err != nil {
  193. return nil, err
  194. }
  195. // get creds json on process's stdout
  196. output := bytes.NewBuffer(make([]byte, 0, int(8*sdkio.KibiByte)))
  197. if cmd.Stdout != nil {
  198. cmd.Stdout = io.MultiWriter(cmd.Stdout, output)
  199. } else {
  200. cmd.Stdout = output
  201. }
  202. execCh := make(chan error, 1)
  203. go executeCommand(cmd, execCh)
  204. select {
  205. case execError := <-execCh:
  206. if execError == nil {
  207. break
  208. }
  209. select {
  210. case <-ctx.Done():
  211. return output.Bytes(), &ProviderError{
  212. Err: fmt.Errorf("credential process timed out: %w", execError),
  213. }
  214. default:
  215. return output.Bytes(), &ProviderError{
  216. Err: fmt.Errorf("error in credential_process: %w", execError),
  217. }
  218. }
  219. }
  220. out := output.Bytes()
  221. if runtime.GOOS == "windows" {
  222. // windows adds slashes to quotes
  223. out = bytes.ReplaceAll(out, []byte(`\"`), []byte(`"`))
  224. }
  225. return out, nil
  226. }
  227. func executeCommand(cmd *exec.Cmd, exec chan error) {
  228. // Start the command
  229. err := cmd.Start()
  230. if err == nil {
  231. err = cmd.Wait()
  232. }
  233. exec <- err
  234. }