hvsock.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. //go:build windows
  2. // +build windows
  3. package winio
  4. import (
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "os"
  11. "syscall"
  12. "time"
  13. "unsafe"
  14. "golang.org/x/sys/windows"
  15. "github.com/Microsoft/go-winio/internal/socket"
  16. "github.com/Microsoft/go-winio/pkg/guid"
  17. )
  18. const afHVSock = 34 // AF_HYPERV
  19. // Well known Service and VM IDs
  20. // https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards
  21. // HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions.
  22. func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000
  23. return guid.GUID{}
  24. }
  25. // HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions.
  26. func HvsockGUIDBroadcast() guid.GUID { // ffffffff-ffff-ffff-ffff-ffffffffffff
  27. return guid.GUID{
  28. Data1: 0xffffffff,
  29. Data2: 0xffff,
  30. Data3: 0xffff,
  31. Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
  32. }
  33. }
  34. // HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
  35. func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838
  36. return guid.GUID{
  37. Data1: 0xe0e16197,
  38. Data2: 0xdd56,
  39. Data3: 0x4a10,
  40. Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
  41. }
  42. }
  43. // HvsockGUIDSiloHost is the address of a silo's host partition:
  44. // - The silo host of a hosted silo is the utility VM.
  45. // - The silo host of a silo on a physical host is the physical host.
  46. func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568
  47. return guid.GUID{
  48. Data1: 0x36bd0c5c,
  49. Data2: 0x7276,
  50. Data3: 0x4223,
  51. Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
  52. }
  53. }
  54. // HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions.
  55. func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
  56. return guid.GUID{
  57. Data1: 0x90db8b89,
  58. Data2: 0xd35,
  59. Data3: 0x4f79,
  60. Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
  61. }
  62. }
  63. // HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition.
  64. // Listening on this VmId accepts connection from:
  65. // - Inside silos: silo host partition.
  66. // - Inside hosted silo: host of the VM.
  67. // - Inside VM: VM host.
  68. // - Physical host: Not supported.
  69. func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878
  70. return guid.GUID{
  71. Data1: 0xa42e7cda,
  72. Data2: 0xd03f,
  73. Data3: 0x480c,
  74. Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
  75. }
  76. }
  77. // hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol.
  78. func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3
  79. return guid.GUID{
  80. Data2: 0xfacb,
  81. Data3: 0x11e6,
  82. Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
  83. }
  84. }
  85. // An HvsockAddr is an address for a AF_HYPERV socket.
  86. type HvsockAddr struct {
  87. VMID guid.GUID
  88. ServiceID guid.GUID
  89. }
  90. type rawHvsockAddr struct {
  91. Family uint16
  92. _ uint16
  93. VMID guid.GUID
  94. ServiceID guid.GUID
  95. }
  96. var _ socket.RawSockaddr = &rawHvsockAddr{}
  97. // Network returns the address's network name, "hvsock".
  98. func (*HvsockAddr) Network() string {
  99. return "hvsock"
  100. }
  101. func (addr *HvsockAddr) String() string {
  102. return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
  103. }
  104. // VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
  105. func VsockServiceID(port uint32) guid.GUID {
  106. g := hvsockVsockServiceTemplate() // make a copy
  107. g.Data1 = port
  108. return g
  109. }
  110. func (addr *HvsockAddr) raw() rawHvsockAddr {
  111. return rawHvsockAddr{
  112. Family: afHVSock,
  113. VMID: addr.VMID,
  114. ServiceID: addr.ServiceID,
  115. }
  116. }
  117. func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
  118. addr.VMID = raw.VMID
  119. addr.ServiceID = raw.ServiceID
  120. }
  121. // Sockaddr returns a pointer to and the size of this struct.
  122. //
  123. // Implements the [socket.RawSockaddr] interface, and allows use in
  124. // [socket.Bind] and [socket.ConnectEx].
  125. func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) {
  126. return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
  127. }
  128. // Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`.
  129. func (r *rawHvsockAddr) FromBytes(b []byte) error {
  130. n := int(unsafe.Sizeof(rawHvsockAddr{}))
  131. if len(b) < n {
  132. return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
  133. }
  134. copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
  135. if r.Family != afHVSock {
  136. return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily)
  137. }
  138. return nil
  139. }
  140. // HvsockListener is a socket listener for the AF_HYPERV address family.
  141. type HvsockListener struct {
  142. sock *win32File
  143. addr HvsockAddr
  144. }
  145. var _ net.Listener = &HvsockListener{}
  146. // HvsockConn is a connected socket of the AF_HYPERV address family.
  147. type HvsockConn struct {
  148. sock *win32File
  149. local, remote HvsockAddr
  150. }
  151. var _ net.Conn = &HvsockConn{}
  152. func newHVSocket() (*win32File, error) {
  153. fd, err := syscall.Socket(afHVSock, syscall.SOCK_STREAM, 1)
  154. if err != nil {
  155. return nil, os.NewSyscallError("socket", err)
  156. }
  157. f, err := makeWin32File(fd)
  158. if err != nil {
  159. syscall.Close(fd)
  160. return nil, err
  161. }
  162. f.socket = true
  163. return f, nil
  164. }
  165. // ListenHvsock listens for connections on the specified hvsock address.
  166. func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
  167. l := &HvsockListener{addr: *addr}
  168. sock, err := newHVSocket()
  169. if err != nil {
  170. return nil, l.opErr("listen", err)
  171. }
  172. sa := addr.raw()
  173. err = socket.Bind(windows.Handle(sock.handle), &sa)
  174. if err != nil {
  175. return nil, l.opErr("listen", os.NewSyscallError("socket", err))
  176. }
  177. err = syscall.Listen(sock.handle, 16)
  178. if err != nil {
  179. return nil, l.opErr("listen", os.NewSyscallError("listen", err))
  180. }
  181. return &HvsockListener{sock: sock, addr: *addr}, nil
  182. }
  183. func (l *HvsockListener) opErr(op string, err error) error {
  184. return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
  185. }
  186. // Addr returns the listener's network address.
  187. func (l *HvsockListener) Addr() net.Addr {
  188. return &l.addr
  189. }
  190. // Accept waits for the next connection and returns it.
  191. func (l *HvsockListener) Accept() (_ net.Conn, err error) {
  192. sock, err := newHVSocket()
  193. if err != nil {
  194. return nil, l.opErr("accept", err)
  195. }
  196. defer func() {
  197. if sock != nil {
  198. sock.Close()
  199. }
  200. }()
  201. c, err := l.sock.prepareIO()
  202. if err != nil {
  203. return nil, l.opErr("accept", err)
  204. }
  205. defer l.sock.wg.Done()
  206. // AcceptEx, per documentation, requires an extra 16 bytes per address.
  207. //
  208. // https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex
  209. const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
  210. var addrbuf [addrlen * 2]byte
  211. var bytes uint32
  212. err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /* rxdatalen */, addrlen, addrlen, &bytes, &c.o)
  213. if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
  214. return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
  215. }
  216. conn := &HvsockConn{
  217. sock: sock,
  218. }
  219. // The local address returned in the AcceptEx buffer is the same as the Listener socket's
  220. // address. However, the service GUID reported by GetSockName is different from the Listeners
  221. // socket, and is sometimes the same as the local address of the socket that dialed the
  222. // address, with the service GUID.Data1 incremented, but othertimes is different.
  223. // todo: does the local address matter? is the listener's address or the actual address appropriate?
  224. conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
  225. conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
  226. // initialize the accepted socket and update its properties with those of the listening socket
  227. if err = windows.Setsockopt(windows.Handle(sock.handle),
  228. windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
  229. (*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
  230. return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
  231. }
  232. sock = nil
  233. return conn, nil
  234. }
  235. // Close closes the listener, causing any pending Accept calls to fail.
  236. func (l *HvsockListener) Close() error {
  237. return l.sock.Close()
  238. }
  239. // HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]).
  240. type HvsockDialer struct {
  241. // Deadline is the time the Dial operation must connect before erroring.
  242. Deadline time.Time
  243. // Retries is the number of additional connects to try if the connection times out, is refused,
  244. // or the host is unreachable
  245. Retries uint
  246. // RetryWait is the time to wait after a connection error to retry
  247. RetryWait time.Duration
  248. rt *time.Timer // redial wait timer
  249. }
  250. // Dial the Hyper-V socket at addr.
  251. //
  252. // See [HvsockDialer.Dial] for more information.
  253. func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
  254. return (&HvsockDialer{}).Dial(ctx, addr)
  255. }
  256. // Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful.
  257. // Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between
  258. // retries.
  259. //
  260. // Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx.
  261. func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
  262. op := "dial"
  263. // create the conn early to use opErr()
  264. conn = &HvsockConn{
  265. remote: *addr,
  266. }
  267. if !d.Deadline.IsZero() {
  268. var cancel context.CancelFunc
  269. ctx, cancel = context.WithDeadline(ctx, d.Deadline)
  270. defer cancel()
  271. }
  272. // preemptive timeout/cancellation check
  273. if err = ctx.Err(); err != nil {
  274. return nil, conn.opErr(op, err)
  275. }
  276. sock, err := newHVSocket()
  277. if err != nil {
  278. return nil, conn.opErr(op, err)
  279. }
  280. defer func() {
  281. if sock != nil {
  282. sock.Close()
  283. }
  284. }()
  285. sa := addr.raw()
  286. err = socket.Bind(windows.Handle(sock.handle), &sa)
  287. if err != nil {
  288. return nil, conn.opErr(op, os.NewSyscallError("bind", err))
  289. }
  290. c, err := sock.prepareIO()
  291. if err != nil {
  292. return nil, conn.opErr(op, err)
  293. }
  294. defer sock.wg.Done()
  295. var bytes uint32
  296. for i := uint(0); i <= d.Retries; i++ {
  297. err = socket.ConnectEx(
  298. windows.Handle(sock.handle),
  299. &sa,
  300. nil, // sendBuf
  301. 0, // sendDataLen
  302. &bytes,
  303. (*windows.Overlapped)(unsafe.Pointer(&c.o)))
  304. _, err = sock.asyncIO(c, nil, bytes, err)
  305. if i < d.Retries && canRedial(err) {
  306. if err = d.redialWait(ctx); err == nil {
  307. continue
  308. }
  309. }
  310. break
  311. }
  312. if err != nil {
  313. return nil, conn.opErr(op, os.NewSyscallError("connectex", err))
  314. }
  315. // update the connection properties, so shutdown can be used
  316. if err = windows.Setsockopt(
  317. windows.Handle(sock.handle),
  318. windows.SOL_SOCKET,
  319. windows.SO_UPDATE_CONNECT_CONTEXT,
  320. nil, // optvalue
  321. 0, // optlen
  322. ); err != nil {
  323. return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err))
  324. }
  325. // get the local name
  326. var sal rawHvsockAddr
  327. err = socket.GetSockName(windows.Handle(sock.handle), &sal)
  328. if err != nil {
  329. return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
  330. }
  331. conn.local.fromRaw(&sal)
  332. // one last check for timeout, since asyncIO doesn't check the context
  333. if err = ctx.Err(); err != nil {
  334. return nil, conn.opErr(op, err)
  335. }
  336. conn.sock = sock
  337. sock = nil
  338. return conn, nil
  339. }
  340. // redialWait waits before attempting to redial, resetting the timer as appropriate.
  341. func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
  342. if d.RetryWait == 0 {
  343. return nil
  344. }
  345. if d.rt == nil {
  346. d.rt = time.NewTimer(d.RetryWait)
  347. } else {
  348. // should already be stopped and drained
  349. d.rt.Reset(d.RetryWait)
  350. }
  351. select {
  352. case <-ctx.Done():
  353. case <-d.rt.C:
  354. return nil
  355. }
  356. // stop and drain the timer
  357. if !d.rt.Stop() {
  358. <-d.rt.C
  359. }
  360. return ctx.Err()
  361. }
  362. // assumes error is a plain, unwrapped syscall.Errno provided by direct syscall.
  363. func canRedial(err error) bool {
  364. //nolint:errorlint // guaranteed to be an Errno
  365. switch err {
  366. case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT,
  367. windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL:
  368. return true
  369. default:
  370. return false
  371. }
  372. }
  373. func (conn *HvsockConn) opErr(op string, err error) error {
  374. // translate from "file closed" to "socket closed"
  375. if errors.Is(err, ErrFileClosed) {
  376. err = socket.ErrSocketClosed
  377. }
  378. return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
  379. }
  380. func (conn *HvsockConn) Read(b []byte) (int, error) {
  381. c, err := conn.sock.prepareIO()
  382. if err != nil {
  383. return 0, conn.opErr("read", err)
  384. }
  385. defer conn.sock.wg.Done()
  386. buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
  387. var flags, bytes uint32
  388. err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
  389. n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
  390. if err != nil {
  391. var eno windows.Errno
  392. if errors.As(err, &eno) {
  393. err = os.NewSyscallError("wsarecv", eno)
  394. }
  395. return 0, conn.opErr("read", err)
  396. } else if n == 0 {
  397. err = io.EOF
  398. }
  399. return n, err
  400. }
  401. func (conn *HvsockConn) Write(b []byte) (int, error) {
  402. t := 0
  403. for len(b) != 0 {
  404. n, err := conn.write(b)
  405. if err != nil {
  406. return t + n, err
  407. }
  408. t += n
  409. b = b[n:]
  410. }
  411. return t, nil
  412. }
  413. func (conn *HvsockConn) write(b []byte) (int, error) {
  414. c, err := conn.sock.prepareIO()
  415. if err != nil {
  416. return 0, conn.opErr("write", err)
  417. }
  418. defer conn.sock.wg.Done()
  419. buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
  420. var bytes uint32
  421. err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
  422. n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
  423. if err != nil {
  424. var eno windows.Errno
  425. if errors.As(err, &eno) {
  426. err = os.NewSyscallError("wsasend", eno)
  427. }
  428. return 0, conn.opErr("write", err)
  429. }
  430. return n, err
  431. }
  432. // Close closes the socket connection, failing any pending read or write calls.
  433. func (conn *HvsockConn) Close() error {
  434. return conn.sock.Close()
  435. }
  436. func (conn *HvsockConn) IsClosed() bool {
  437. return conn.sock.IsClosed()
  438. }
  439. // shutdown disables sending or receiving on a socket.
  440. func (conn *HvsockConn) shutdown(how int) error {
  441. if conn.IsClosed() {
  442. return socket.ErrSocketClosed
  443. }
  444. err := syscall.Shutdown(conn.sock.handle, how)
  445. if err != nil {
  446. // If the connection was closed, shutdowns fail with "not connected"
  447. if errors.Is(err, windows.WSAENOTCONN) ||
  448. errors.Is(err, windows.WSAESHUTDOWN) {
  449. err = socket.ErrSocketClosed
  450. }
  451. return os.NewSyscallError("shutdown", err)
  452. }
  453. return nil
  454. }
  455. // CloseRead shuts down the read end of the socket, preventing future read operations.
  456. func (conn *HvsockConn) CloseRead() error {
  457. err := conn.shutdown(syscall.SHUT_RD)
  458. if err != nil {
  459. return conn.opErr("closeread", err)
  460. }
  461. return nil
  462. }
  463. // CloseWrite shuts down the write end of the socket, preventing future write operations and
  464. // notifying the other endpoint that no more data will be written.
  465. func (conn *HvsockConn) CloseWrite() error {
  466. err := conn.shutdown(syscall.SHUT_WR)
  467. if err != nil {
  468. return conn.opErr("closewrite", err)
  469. }
  470. return nil
  471. }
  472. // LocalAddr returns the local address of the connection.
  473. func (conn *HvsockConn) LocalAddr() net.Addr {
  474. return &conn.local
  475. }
  476. // RemoteAddr returns the remote address of the connection.
  477. func (conn *HvsockConn) RemoteAddr() net.Addr {
  478. return &conn.remote
  479. }
  480. // SetDeadline implements the net.Conn SetDeadline method.
  481. func (conn *HvsockConn) SetDeadline(t time.Time) error {
  482. // todo: implement `SetDeadline` for `win32File`
  483. if err := conn.SetReadDeadline(t); err != nil {
  484. return fmt.Errorf("set read deadline: %w", err)
  485. }
  486. if err := conn.SetWriteDeadline(t); err != nil {
  487. return fmt.Errorf("set write deadline: %w", err)
  488. }
  489. return nil
  490. }
  491. // SetReadDeadline implements the net.Conn SetReadDeadline method.
  492. func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
  493. return conn.sock.SetReadDeadline(t)
  494. }
  495. // SetWriteDeadline implements the net.Conn SetWriteDeadline method.
  496. func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
  497. return conn.sock.SetWriteDeadline(t)
  498. }