Skip to content

Commit

Permalink
Shed the layers of indirection on udp listeners, get the full hostinf…
Browse files Browse the repository at this point in the history
…o to the lighthouse request handler
  • Loading branch information
nbrownus committed Sep 21, 2024
1 parent 28cd257 commit e681be2
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 159 deletions.
12 changes: 9 additions & 3 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,22 @@ func (f *Interface) listenOut(i int) {
runtime.LockOSThread()

var li udp.Conn
// TODO clean this up with a coherent interface for each outside connection
if i > 0 {
li = f.writers[i]
} else {
li = f.outside
}

ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)

li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
}

func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
Expand Down
78 changes: 36 additions & 42 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -915,24 +915,18 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
return lhh.meta
}

func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
return func(rAddr netip.AddrPort, vpnAddrs []netip.Addr, p []byte) {
lhh.HandleRequest(rAddr, vpnAddrs, p, f)
}
}

func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnAddrs []netip.Addr, p []byte, w EncWriter) {
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, reqHostinfo *HostInfo, p []byte, w EncWriter) {
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", rAddr).
lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet")
//TODO: send recv_error?
return
}

if n.Details == nil {
lhh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", rAddr).
lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update")
//TODO: send recv_error?
return
Expand All @@ -942,24 +936,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnAddrs []net

switch n.Type {
case NebulaMeta_HostQuery:
lhh.handleHostQuery(n, vpnAddrs, rAddr, w)
lhh.handleHostQuery(n, reqHostinfo, rAddr, w)

case NebulaMeta_HostQueryReply:
lhh.handleHostQueryReply(n, vpnAddrs)
lhh.handleHostQueryReply(n, reqHostinfo)

case NebulaMeta_HostUpdateNotification:
lhh.handleHostUpdateNotification(n, vpnAddrs, w)
lhh.handleHostUpdateNotification(n, reqHostinfo, w)

case NebulaMeta_HostMovedNotification:
case NebulaMeta_HostPunchNotification:
lhh.handleHostPunchNotification(n, vpnAddrs, w)
lhh.handleHostPunchNotification(n, reqHostinfo, w)

case NebulaMeta_HostUpdateNotificationAck:
// noop
}
}

func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, reqHostinfo *HostInfo, addr netip.AddrPort, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
Expand Down Expand Up @@ -1007,15 +1001,15 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnAddrs []netip.Ad
}

if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).Error("Failed to marshal lighthouse host query reply")
lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host query reply")
return
}

lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
w.SendMessageToVpnIp(header.LightHouse, 0, vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
w.SendMessageToVpnIp(header.LightHouse, 0, reqHostinfo.vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])

// This signals the other side to punch some zero byte udp packets
found, ln, err = lhh.lh.queryAndPrepMessage(vpnAddrs[0], func(c *cache) (int, error) {
found, ln, err = lhh.lh.queryAndPrepMessage(reqHostinfo.vpnAddrs[0], func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification
//TODO: unsure which version to use. If we had access to the hostmap we could see if there is already a tunnel
Expand All @@ -1027,15 +1021,15 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnAddrs []netip.Ad
}

if useVersion == cert.Version1 {
if !vpnAddrs[0].Is4() {
if !reqHostinfo.vpnAddrs[0].Is4() {
return 0, fmt.Errorf("invalid vpn ip for v1 handleHostQuery")
}
b := vpnAddrs[0].As4()
b := reqHostinfo.vpnAddrs[0].As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
lhh.coalesceAnswers(useVersion, c, n)

} else if useVersion == cert.Version2 {
n.Details.VpnAddr = netAddrToProtoAddr(vpnAddrs[0])
n.Details.VpnAddr = netAddrToProtoAddr(reqHostinfo.vpnAddrs[0])
lhh.coalesceAnswers(useVersion, c, n)

} else {
Expand All @@ -1050,7 +1044,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnAddrs []netip.Ad
}

if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).Error("Failed to marshal lighthouse host was queried for")
lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host was queried for")
return
}

Expand Down Expand Up @@ -1100,9 +1094,9 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
}
}

func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnAddrs []netip.Addr) {
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, reqHostinfo *HostInfo) {
//TODO: this is kind of dumb
if !lhh.lh.IsLighthouseIP(vpnAddrs[0]) {
if !lhh.lh.IsLighthouseIP(reqHostinfo.vpnAddrs[0]) {
return
}

Expand All @@ -1121,8 +1115,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnAddrs []net
am.Lock()
lhh.lh.Unlock()

am.unlockedSetV4(vpnAddrs[0], certVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnAddrs[0], certVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetV4(reqHostinfo.vpnAddrs[0], certVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(reqHostinfo.vpnAddrs[0], certVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)

var relays []netip.Addr
if len(n.Details.OldRelayVpnAddrs) > 0 {
Expand All @@ -1139,7 +1133,7 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnAddrs []net
}
}

am.unlockedSetRelay(vpnAddrs[0], certVpnIp, relays)
am.unlockedSetRelay(reqHostinfo.vpnAddrs[0], certVpnIp, relays)
am.Unlock()

// Non-blocking attempt to trigger, skip if it would block
Expand All @@ -1149,10 +1143,10 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnAddrs []net
}
}

func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAddrs []netip.Addr, w EncWriter) {
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, reqHostinfo *HostInfo, w EncWriter) {
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnAddrs)
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", reqHostinfo.vpnAddrs)
}
return
}
Expand All @@ -1173,20 +1167,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAdd
//todo hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
//todo why do we care about the vpnip in the packet? We know where it came from, right?

if detailsVpnIp != vpnAddrs[0] {
if detailsVpnIp != reqHostinfo.vpnAddrs[0] {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update")
lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update")
}
return
}

lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList(vpnAddrs[0])
am := lhh.lh.unlockedGetRemoteList(reqHostinfo.vpnAddrs[0])
am.Lock()
lhh.lh.Unlock()

am.unlockedSetV4(vpnAddrs[0], detailsVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnAddrs[0], detailsVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetV4(reqHostinfo.vpnAddrs[0], detailsVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(reqHostinfo.vpnAddrs[0], detailsVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)

var relays []netip.Addr
if len(n.Details.OldRelayVpnAddrs) > 0 {
Expand All @@ -1203,40 +1197,40 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAdd
}
}

am.unlockedSetRelay(vpnAddrs[0], detailsVpnIp, relays)
am.unlockedSetRelay(reqHostinfo.vpnAddrs[0], detailsVpnIp, relays)
am.Unlock()

n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck

if useVersion == cert.Version1 {
if !vpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", vpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
if !reqHostinfo.vpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
return
}
vpnIpB := vpnAddrs[0].As4()
vpnIpB := reqHostinfo.vpnAddrs[0].As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnIpB[:])

} else if useVersion == cert.Version2 {
n.Details.VpnAddr = netAddrToProtoAddr(vpnAddrs[0])
n.Details.VpnAddr = netAddrToProtoAddr(reqHostinfo.vpnAddrs[0])

} else {
panic("unsupported version")
}

ln, err := n.MarshalTo(lhh.pb)
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).Error("Failed to marshal lighthouse host update ack")
lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host update ack")
return
}

lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
w.SendMessageToVpnIp(header.LightHouse, 0, vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
w.SendMessageToVpnIp(header.LightHouse, 0, reqHostinfo.vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
}

func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnAddrs []netip.Addr, w EncWriter) {
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, reqHostinfo *HostInfo, w EncWriter) {
//TODO: this is kinda stupid
if !lhh.lh.IsLighthouseIP(vpnAddrs[0]) {
if !lhh.lh.IsLighthouseIP(reqHostinfo.vpnAddrs[0]) {
return
}

Expand Down
9 changes: 5 additions & 4 deletions lighthouse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {

mw := &mockEncWriter{}

hi := &HostInfo{vpnAddrs: []netip.Addr{vpnIp2}}
b.Run("notfound", func(b *testing.B) {
lhh := lh.NewRequestHandler()
req := &NebulaMeta{
Expand All @@ -147,7 +148,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
p, err := req.Marshal()
assert.NoError(b, err)
for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, []netip.Addr{vpnIp2}, p, mw)
lhh.HandleRequest(rAddr, hi, p, mw)
}
})
b.Run("found", func(b *testing.B) {
Expand All @@ -163,7 +164,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
assert.NoError(b, err)

for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, []netip.Addr{vpnIp2}, p, mw)
lhh.HandleRequest(rAddr, hi, p, mw)
}
})
}
Expand Down Expand Up @@ -324,7 +325,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
w := &testEncWriter{
metaFilter: &filter,
}
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
lhh.HandleRequest(fromAddr, &HostInfo{vpnAddrs: []netip.Addr{myVpnIp}}, b, w)
return w.lastReply
}

Expand All @@ -349,7 +350,7 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
}

w := &testEncWriter{}
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
lhh.HandleRequest(fromAddr, &HostInfo{vpnAddrs: []netip.Addr{vpnIp}}, b, w)
}

//TODO: this is a RemoteList test
Expand Down
22 changes: 2 additions & 20 deletions outside.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,14 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4"
)

const (
minFwPacketLen = 4
)

// TODO: IPV6-WORK this can likely be removed now
func readOutsidePackets(f *Interface) udp.EncReader {
return func(
addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh udp.LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
) {
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
}
}

func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
Expand Down Expand Up @@ -163,7 +145,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}

lhf(ip, hostinfo.vpnAddrs, d)
lhf.HandleRequest(ip, hostinfo, d, f)

// Fallthrough to the bottom to record incoming traffic

Expand Down
15 changes: 3 additions & 12 deletions udp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,19 @@ import (
"net/netip"

"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
)

const MTU = 9001

type EncReader func(
addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
payload []byte,
)

type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Close() error
Expand All @@ -39,7 +30,7 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
Expand Down
10 changes: 0 additions & 10 deletions udp/temp.go

This file was deleted.

Loading

0 comments on commit e681be2

Please sign in to comment.