|
@@ -331,24 +331,63 @@ func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
|
|
// moves back into it when done. If newNs is close, the socket will be opened
|
|
// moves back into it when done. If newNs is close, the socket will be opened
|
|
// in the current network namespace.
|
|
// in the current network namespace.
|
|
func GetNetlinkSocketAt(newNs, curNs netns.NsHandle, protocol int) (*NetlinkSocket, error) {
|
|
func GetNetlinkSocketAt(newNs, curNs netns.NsHandle, protocol int) (*NetlinkSocket, error) {
|
|
- var err error
|
|
|
|
|
|
+ c, err := executeInNetns(newNs, curNs)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, err
|
|
|
|
+ }
|
|
|
|
+ defer c()
|
|
|
|
+ return getNetlinkSocket(protocol)
|
|
|
|
+}
|
|
|
|
|
|
|
|
+// executeInNetns sets execution of the code following this call to the
|
|
|
|
+// network namespace newNs, then moves the thread back to curNs if open,
|
|
|
|
+// otherwise to the current netns at the time the function was invoked
|
|
|
|
+// In case of success, the caller is expected to execute the returned function
|
|
|
|
+// at the end of the code that needs to be executed in the network namespace.
|
|
|
|
+// Example:
|
|
|
|
+// func jobAt(...) error {
|
|
|
|
+// d, err := executeInNetns(...)
|
|
|
|
+// if err != nil { return err}
|
|
|
|
+// defer d()
|
|
|
|
+// < code which needs to be executed in specific netns>
|
|
|
|
+// }
|
|
|
|
+// TODO: his function probably belongs to netns pkg.
|
|
|
|
+func executeInNetns(newNs, curNs netns.NsHandle) (func(), error) {
|
|
|
|
+ var (
|
|
|
|
+ err error
|
|
|
|
+ moveBack func(netns.NsHandle) error
|
|
|
|
+ closeNs func() error
|
|
|
|
+ unlockThd func()
|
|
|
|
+ )
|
|
|
|
+ restore := func() {
|
|
|
|
+ // order matters
|
|
|
|
+ if moveBack != nil {
|
|
|
|
+ moveBack(curNs)
|
|
|
|
+ }
|
|
|
|
+ if closeNs != nil {
|
|
|
|
+ closeNs()
|
|
|
|
+ }
|
|
|
|
+ if unlockThd != nil {
|
|
|
|
+ unlockThd()
|
|
|
|
+ }
|
|
|
|
+ }
|
|
if newNs.IsOpen() {
|
|
if newNs.IsOpen() {
|
|
runtime.LockOSThread()
|
|
runtime.LockOSThread()
|
|
- defer runtime.UnlockOSThread()
|
|
|
|
|
|
+ unlockThd = runtime.UnlockOSThread
|
|
if !curNs.IsOpen() {
|
|
if !curNs.IsOpen() {
|
|
if curNs, err = netns.Get(); err != nil {
|
|
if curNs, err = netns.Get(); err != nil {
|
|
|
|
+ restore()
|
|
return nil, fmt.Errorf("could not get current namespace while creating netlink socket: %v", err)
|
|
return nil, fmt.Errorf("could not get current namespace while creating netlink socket: %v", err)
|
|
}
|
|
}
|
|
- defer curNs.Close()
|
|
|
|
|
|
+ closeNs = curNs.Close
|
|
}
|
|
}
|
|
if err := netns.Set(newNs); err != nil {
|
|
if err := netns.Set(newNs); err != nil {
|
|
|
|
+ restore()
|
|
return nil, fmt.Errorf("failed to set into network namespace %d while creating netlink socket: %v", newNs, err)
|
|
return nil, fmt.Errorf("failed to set into network namespace %d while creating netlink socket: %v", newNs, err)
|
|
}
|
|
}
|
|
- defer netns.Set(curNs)
|
|
|
|
|
|
+ moveBack = netns.Set
|
|
}
|
|
}
|
|
-
|
|
|
|
- return getNetlinkSocket(protocol)
|
|
|
|
|
|
+ return restore, nil
|
|
}
|
|
}
|
|
|
|
|
|
// Create a netlink socket with a given protocol (e.g. NETLINK_ROUTE)
|
|
// Create a netlink socket with a given protocol (e.g. NETLINK_ROUTE)
|
|
@@ -377,6 +416,18 @@ func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
|
|
return s, nil
|
|
return s, nil
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// SubscribeAt works like Subscribe plus let's the caller choose the network
|
|
|
|
+// namespace in which the socket would be opened (newNs). Then control goes back
|
|
|
|
+// to curNs if open, otherwise to the netns at the time this function was called.
|
|
|
|
+func SubscribeAt(newNs, curNs netns.NsHandle, protocol int, groups ...uint) (*NetlinkSocket, error) {
|
|
|
|
+ c, err := executeInNetns(newNs, curNs)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, err
|
|
|
|
+ }
|
|
|
|
+ defer c()
|
|
|
|
+ return Subscribe(protocol, groups...)
|
|
|
|
+}
|
|
|
|
+
|
|
func (s *NetlinkSocket) Close() {
|
|
func (s *NetlinkSocket) Close() {
|
|
syscall.Close(s.fd)
|
|
syscall.Close(s.fd)
|
|
s.fd = -1
|
|
s.fd = -1
|