socket.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. //go:build windows
  2. package socket
  3. import (
  4. "errors"
  5. "fmt"
  6. "net"
  7. "sync"
  8. "syscall"
  9. "unsafe"
  10. "github.com/Microsoft/go-winio/pkg/guid"
  11. "golang.org/x/sys/windows"
  12. )
  13. //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
  14. //sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
  15. //sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
  16. //sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
  17. const socketError = uintptr(^uint32(0))
  18. var (
  19. // todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
  20. ErrBufferSize = errors.New("buffer size")
  21. ErrAddrFamily = errors.New("address family")
  22. ErrInvalidPointer = errors.New("invalid pointer")
  23. ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
  24. )
  25. // todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
  26. // GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
  27. // If rsa is not large enough, the [windows.WSAEFAULT] is returned.
  28. func GetSockName(s windows.Handle, rsa RawSockaddr) error {
  29. ptr, l, err := rsa.Sockaddr()
  30. if err != nil {
  31. return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
  32. }
  33. // although getsockname returns WSAEFAULT if the buffer is too small, it does not set
  34. // &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
  35. return getsockname(s, ptr, &l)
  36. }
  37. // GetPeerName returns the remote address the socket is connected to.
  38. //
  39. // See [GetSockName] for more information.
  40. func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
  41. ptr, l, err := rsa.Sockaddr()
  42. if err != nil {
  43. return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
  44. }
  45. return getpeername(s, ptr, &l)
  46. }
  47. func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
  48. ptr, l, err := rsa.Sockaddr()
  49. if err != nil {
  50. return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
  51. }
  52. return bind(s, ptr, l)
  53. }
  54. // "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
  55. // their sockaddr interface, so they cannot be used with HvsockAddr
  56. // Replicate functionality here from
  57. // https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
  58. // The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
  59. // runtime via a WSAIoctl call:
  60. // https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
  61. type runtimeFunc struct {
  62. id guid.GUID
  63. once sync.Once
  64. addr uintptr
  65. err error
  66. }
  67. func (f *runtimeFunc) Load() error {
  68. f.once.Do(func() {
  69. var s windows.Handle
  70. s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
  71. if f.err != nil {
  72. return
  73. }
  74. defer windows.CloseHandle(s) //nolint:errcheck
  75. var n uint32
  76. f.err = windows.WSAIoctl(s,
  77. windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
  78. (*byte)(unsafe.Pointer(&f.id)),
  79. uint32(unsafe.Sizeof(f.id)),
  80. (*byte)(unsafe.Pointer(&f.addr)),
  81. uint32(unsafe.Sizeof(f.addr)),
  82. &n,
  83. nil, // overlapped
  84. 0, // completionRoutine
  85. )
  86. })
  87. return f.err
  88. }
  89. var (
  90. // todo: add `AcceptEx` and `GetAcceptExSockaddrs`
  91. WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
  92. Data1: 0x25a207b9,
  93. Data2: 0xddf3,
  94. Data3: 0x4660,
  95. Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
  96. }
  97. connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
  98. )
  99. func ConnectEx(
  100. fd windows.Handle,
  101. rsa RawSockaddr,
  102. sendBuf *byte,
  103. sendDataLen uint32,
  104. bytesSent *uint32,
  105. overlapped *windows.Overlapped,
  106. ) error {
  107. if err := connectExFunc.Load(); err != nil {
  108. return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
  109. }
  110. ptr, n, err := rsa.Sockaddr()
  111. if err != nil {
  112. return err
  113. }
  114. return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
  115. }
  116. // BOOL LpfnConnectex(
  117. // [in] SOCKET s,
  118. // [in] const sockaddr *name,
  119. // [in] int namelen,
  120. // [in, optional] PVOID lpSendBuffer,
  121. // [in] DWORD dwSendDataLength,
  122. // [out] LPDWORD lpdwBytesSent,
  123. // [in] LPOVERLAPPED lpOverlapped
  124. // )
  125. func connectEx(
  126. s windows.Handle,
  127. name unsafe.Pointer,
  128. namelen int32,
  129. sendBuf *byte,
  130. sendDataLen uint32,
  131. bytesSent *uint32,
  132. overlapped *windows.Overlapped,
  133. ) (err error) {
  134. // todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN
  135. r1, _, e1 := syscall.Syscall9(connectExFunc.addr,
  136. 7,
  137. uintptr(s),
  138. uintptr(name),
  139. uintptr(namelen),
  140. uintptr(unsafe.Pointer(sendBuf)),
  141. uintptr(sendDataLen),
  142. uintptr(unsafe.Pointer(bytesSent)),
  143. uintptr(unsafe.Pointer(overlapped)),
  144. 0,
  145. 0)
  146. if r1 == 0 {
  147. if e1 != 0 {
  148. err = error(e1)
  149. } else {
  150. err = syscall.EINVAL
  151. }
  152. }
  153. return err
  154. }