regstate.go 6.3 KB


  1. //go:build windows
  2. package regstate
  3. import (
  4. "encoding/json"
  5. "fmt"
  6. "net/url"
  7. "os"
  8. "path/filepath"
  9. "reflect"
  10. "syscall"
  11. "golang.org/x/sys/windows"
  12. "golang.org/x/sys/windows/registry"
  13. )
  14. //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go regstate.go
  15. //sys regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) = advapi32.RegCreateKeyExW
  16. const (
  17. _REG_OPTION_VOLATILE = 1
  18. _REG_OPENED_EXISTING_KEY = 2
  19. )
  20. type Key struct {
  21. registry.Key
  22. Name string
  23. }
  24. var localMachine = &Key{registry.LOCAL_MACHINE, "HKEY_LOCAL_MACHINE"}
  25. var localUser = &Key{registry.CURRENT_USER, "HKEY_CURRENT_USER"}
  26. var rootPath = `SOFTWARE\Microsoft\runhcs`
  27. type NotFoundError struct {
  28. ID string
  29. }
  30. func (err *NotFoundError) Error() string {
  31. return fmt.Sprintf("ID '%s' was not found", err.ID)
  32. }
  33. func IsNotFoundError(err error) bool {
  34. _, ok := err.(*NotFoundError)
  35. return ok
  36. }
  37. type NoStateError struct {
  38. ID string
  39. Key string
  40. }
  41. func (err *NoStateError) Error() string {
  42. return fmt.Sprintf("state '%s' is not present for ID '%s'", err.Key, err.ID)
  43. }
  44. func createVolatileKey(k *Key, path string, access uint32) (newk *Key, openedExisting bool, err error) {
  45. var (
  46. h syscall.Handle
  47. d uint32
  48. )
  49. fullpath := filepath.Join(k.Name, path)
  50. pathPtr, _ := windows.UTF16PtrFromString(path)
  51. err = regCreateKeyEx(syscall.Handle(k.Key), pathPtr, 0, nil, _REG_OPTION_VOLATILE, access, nil, &h, &d)
  52. if err != nil {
  53. return nil, false, &os.PathError{Op: "RegCreateKeyEx", Path: fullpath, Err: err}
  54. }
  55. return &Key{registry.Key(h), fullpath}, d == _REG_OPENED_EXISTING_KEY, nil
  56. }
  57. func hive(perUser bool) *Key {
  58. r := localMachine
  59. if perUser {
  60. r = localUser
  61. }
  62. return r
  63. }
  64. func Open(root string, perUser bool) (*Key, error) {
  65. k, _, err := createVolatileKey(hive(perUser), rootPath, registry.ALL_ACCESS)
  66. if err != nil {
  67. return nil, err
  68. }
  69. defer k.Close()
  70. k2, _, err := createVolatileKey(k, url.PathEscape(root), registry.ALL_ACCESS)
  71. if err != nil {
  72. return nil, err
  73. }
  74. return k2, nil
  75. }
  76. func RemoveAll(root string, perUser bool) error {
  77. k, err := hive(perUser).open(rootPath)
  78. if err != nil {
  79. return err
  80. }
  81. defer k.Close()
  82. r, err := k.open(url.PathEscape(root))
  83. if err != nil {
  84. return err
  85. }
  86. defer r.Close()
  87. ids, err := r.Enumerate()
  88. if err != nil {
  89. return err
  90. }
  91. for _, id := range ids {
  92. err = r.Remove(id)
  93. if err != nil {
  94. return err
  95. }
  96. }
  97. r.Close()
  98. return k.Remove(root)
  99. }
  100. func (k *Key) Close() error {
  101. err := k.Key.Close()
  102. k.Key = 0
  103. return err
  104. }
  105. func (k *Key) Enumerate() ([]string, error) {
  106. escapedIDs, err := k.ReadSubKeyNames(0)
  107. if err != nil {
  108. return nil, err
  109. }
  110. var ids []string
  111. for _, e := range escapedIDs {
  112. id, err := url.PathUnescape(e)
  113. if err == nil {
  114. ids = append(ids, id)
  115. }
  116. }
  117. return ids, nil
  118. }
  119. func (k *Key) open(name string) (*Key, error) {
  120. fullpath := filepath.Join(k.Name, name)
  121. nk, err := registry.OpenKey(k.Key, name, registry.ALL_ACCESS)
  122. if err != nil {
  123. return nil, &os.PathError{Op: "RegOpenKey", Path: fullpath, Err: err}
  124. }
  125. return &Key{nk, fullpath}, nil
  126. }
  127. func (k *Key) openid(id string) (*Key, error) {
  128. escaped := url.PathEscape(id)
  129. fullpath := filepath.Join(k.Name, escaped)
  130. nk, err := k.open(escaped)
  131. if perr, ok := err.(*os.PathError); ok && perr.Err == syscall.ERROR_FILE_NOT_FOUND {
  132. return nil, &NotFoundError{id}
  133. }
  134. if err != nil {
  135. return nil, &os.PathError{Op: "RegOpenKey", Path: fullpath, Err: err}
  136. }
  137. return nk, nil
  138. }
  139. func (k *Key) Remove(id string) error {
  140. escaped := url.PathEscape(id)
  141. err := registry.DeleteKey(k.Key, escaped)
  142. if err != nil {
  143. if err == syscall.ERROR_FILE_NOT_FOUND {
  144. return &NotFoundError{id}
  145. }
  146. return &os.PathError{Op: "RegDeleteKey", Path: filepath.Join(k.Name, escaped), Err: err}
  147. }
  148. return nil
  149. }
  150. func (k *Key) set(id string, create bool, key string, state interface{}) error {
  151. var sk *Key
  152. var err error
  153. if create {
  154. var existing bool
  155. eid := url.PathEscape(id)
  156. sk, existing, err = createVolatileKey(k, eid, registry.ALL_ACCESS)
  157. if err != nil {
  158. return err
  159. }
  160. defer sk.Close()
  161. if existing {
  162. sk.Close()
  163. return fmt.Errorf("container %s already exists", id)
  164. }
  165. } else {
  166. sk, err = k.openid(id)
  167. if err != nil {
  168. return err
  169. }
  170. defer sk.Close()
  171. }
  172. switch reflect.TypeOf(state).Kind() {
  173. case reflect.Bool:
  174. v := uint32(0)
  175. if state.(bool) {
  176. v = 1
  177. }
  178. err = sk.SetDWordValue(key, v)
  179. case reflect.Int:
  180. err = sk.SetQWordValue(key, uint64(state.(int)))
  181. case reflect.String:
  182. err = sk.SetStringValue(key, state.(string))
  183. default:
  184. var js []byte
  185. js, err = json.Marshal(state)
  186. if err != nil {
  187. return err
  188. }
  189. err = sk.SetBinaryValue(key, js)
  190. }
  191. if err != nil {
  192. if err == syscall.ERROR_FILE_NOT_FOUND {
  193. return &NoStateError{id, key}
  194. }
  195. return &os.PathError{Op: "RegSetValueEx", Path: sk.Name + ":" + key, Err: err}
  196. }
  197. return nil
  198. }
  199. func (k *Key) Create(id, key string, state interface{}) error {
  200. return k.set(id, true, key, state)
  201. }
  202. func (k *Key) Set(id, key string, state interface{}) error {
  203. return k.set(id, false, key, state)
  204. }
  205. func (k *Key) Clear(id, key string) error {
  206. sk, err := k.openid(id)
  207. if err != nil {
  208. return err
  209. }
  210. defer sk.Close()
  211. err = sk.DeleteValue(key)
  212. if err != nil {
  213. if err == syscall.ERROR_FILE_NOT_FOUND {
  214. return &NoStateError{id, key}
  215. }
  216. return &os.PathError{Op: "RegDeleteValue", Path: sk.Name + ":" + key, Err: err}
  217. }
  218. return nil
  219. }
  220. func (k *Key) Get(id, key string, state interface{}) error {
  221. sk, err := k.openid(id)
  222. if err != nil {
  223. return err
  224. }
  225. defer sk.Close()
  226. var js []byte
  227. switch reflect.TypeOf(state).Elem().Kind() {
  228. case reflect.Bool:
  229. var v uint64
  230. v, _, err = sk.GetIntegerValue(key)
  231. if err == nil {
  232. *state.(*bool) = v != 0
  233. }
  234. case reflect.Int:
  235. var v uint64
  236. v, _, err = sk.GetIntegerValue(key)
  237. if err == nil {
  238. *state.(*int) = int(v)
  239. }
  240. case reflect.String:
  241. var v string
  242. v, _, err = sk.GetStringValue(key)
  243. if err == nil {
  244. *state.(*string) = string(v)
  245. }
  246. default:
  247. js, _, err = sk.GetBinaryValue(key)
  248. }
  249. if err != nil {
  250. if err == syscall.ERROR_FILE_NOT_FOUND {
  251. return &NoStateError{id, key}
  252. }
  253. return &os.PathError{Op: "RegQueryValueEx", Path: sk.Name + ":" + key, Err: err}
  254. }
  255. if js != nil {
  256. err = json.Unmarshal(js, state)
  257. }
  258. return err
  259. }