Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support reloading preferred_ranges #1043

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}

if n.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, addr)
})
Expand Down
11 changes: 8 additions & 3 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}

// Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges)
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)

cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Expand Down Expand Up @@ -122,7 +124,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}

// Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges)
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)

cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Expand Down Expand Up @@ -209,7 +213,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
hostMap := NewHostMap(l, vpncidr, preferredRanges)
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)

// Generate keys for CA and peer's cert.
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
Expand Down
4 changes: 2 additions & 2 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
return nil
}

ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges())
return &ch
}

Expand All @@ -157,7 +157,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control
}

hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
return &ch
}

Expand Down
4 changes: 3 additions & 1 deletion control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
hm := newHostMap(l, &net.IPNet{})
hm.preferredRanges.Store(&[]*net.IPNet{})

remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{
Expand Down
2 changes: 1 addition & 1 deletion handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)

f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes")

// Swap the packet store to benefit the original intended recipient
Expand Down
10 changes: 5 additions & 5 deletions handshake_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hostinfo := hh.hostinfo
// If we are out of time, clean up
if hh.counter >= hm.config.retries {
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)).
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
WithField("initiatorIndex", hh.hostinfo.localIndexId).
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Expand Down Expand Up @@ -211,7 +211,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
}

remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)

// We only care about a lighthouse trigger if we have new remotes to send to.
Expand All @@ -235,7 +235,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger

// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udp.Addr
hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
Expand Down Expand Up @@ -362,7 +362,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
hm.mainHostMap.RUnlock()
// Do not attempt promotion if you are a lighthouse
if !hm.lightHouse.amLighthouse {
h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f)
h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f)
}
return h, true
}
Expand Down Expand Up @@ -600,7 +600,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
}

func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
return c.mainHostMap.preferredRanges
return c.mainHostMap.GetPreferredRanges()
}

func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
Expand Down
4 changes: 3 additions & 1 deletion handshake_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap(l, vpncidr, preferredRanges)
mainHM := newHostMap(l, vpncidr)
mainHM.preferredRanges.Store(&preferredRanges)

lh := newTestLighthouse()

cs := &CertState{
Expand Down
71 changes: 52 additions & 19 deletions hostmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
Expand Down Expand Up @@ -57,9 +58,8 @@ type HostMap struct {
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
RemoteIndexes map[uint32]*HostInfo
Hosts map[iputil.VpnIp]*HostInfo
preferredRanges []*net.IPNet
preferredRanges atomic.Pointer[[]*net.IPNet]
vpnCIDR *net.IPNet
metricsEnabled bool
l *logrus.Logger
}

Expand Down Expand Up @@ -254,21 +254,53 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}

func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[iputil.VpnIp]*HostInfo{}
i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{}
relays := map[uint32]*HostInfo{}
m := HostMap{
Indexes: i,
Relays: relays,
RemoteIndexes: r,
Hosts: h,
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
l: l,
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
hm := newHostMap(l, vpnCIDR)

hm.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) {
hm.reload(c, false)
})

l.WithField("network", hm.vpnCIDR.String()).
WithField("preferredRanges", hm.GetPreferredRanges()).
Info("Main HostMap created")

return hm
}

func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
return &HostMap{
Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{},
Hosts: map[iputil.VpnIp]*HostInfo{},
vpnCIDR: vpnCIDR,
l: l,
}
}

func (hm *HostMap) reload(c *config.C, initial bool) {
if initial || c.HasChanged("preferred_ranges") {
var preferredRanges []*net.IPNet
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})

for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)

if err != nil {
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
continue
}

preferredRanges = append(preferredRanges, preferredRange)
}

oldRanges := hm.preferredRanges.Swap(&preferredRanges)
if !initial {
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
}
}
return &m
}

// EmitStats reports host, index, and relay counts to the stats collection system
Expand Down Expand Up @@ -457,7 +489,7 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostI
hm.RUnlock()
// Do not attempt promotion if you are a lighthouse
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
h.TryPromoteBest(hm.preferredRanges, promoteIfce)
h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce)
}
return h

Expand Down Expand Up @@ -504,7 +536,8 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
}

func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
return hm.preferredRanges
//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
return *hm.preferredRanges.Load()
}

func (hm *HostMap) ForEachVpnIp(f controlEach) {
Expand Down Expand Up @@ -596,7 +629,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
// NOTE: We do this loop here instead of calling `isPreferred` in
// remote_list.go so that we only have to loop over preferredRanges once.
newIsPreferred := false
for _, l := range hm.preferredRanges {
for _, l := range hm.GetPreferredRanges() {
// return early if we are already on a preferred remote
if l.Contains(currentRemote.IP) {
return false
Expand Down
37 changes: 33 additions & 4 deletions hostmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ import (
"net"
"testing"

"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)

func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
hm := newHostMap(
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
[]*net.IPNet{},
)

f := &Interface{}
Expand Down Expand Up @@ -91,13 +91,12 @@ func TestHostMap_MakePrimary(t *testing.T) {

func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
hm := newHostMap(
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
[]*net.IPNet{},
)

f := &Interface{}
Expand Down Expand Up @@ -205,3 +204,33 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
prim = hm.QueryVpnIp(1)
assert.Nil(t, prim)
}

func TestHostMap_reload(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)

hm := NewHostMapFromConfig(
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
c,
)

toS := func(ipn []*net.IPNet) []string {
var s []string
for _, n := range ipn {
s = append(s, n.String())
}
return s
}

assert.Empty(t, hm.GetPreferredRanges())

c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))

c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
}
47 changes: 1 addition & 46 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,52 +173,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}

// Set up my internal host map
var preferredRanges []*net.IPNet
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
// First, check if 'preferred_ranges' is set and fallback to 'local_range'
if len(rawPreferredRanges) > 0 {
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
}
preferredRanges = append(preferredRanges, preferredRange)
}
}

// local_range was superseded by preferred_ranges. If it is still present,
// merge the local_range setting into preferred_ranges. We will probably
// deprecate local_range and remove in the future.
rawLocalRange := c.GetString("local_range", "")
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
}

// Check if the entry for local_range was already specified in
// preferred_ranges. Don't put it into the slice twice if so.
var found bool
for _, r := range preferredRanges {
if r.String() == localRange.String() {
found = true
break
}
}
if !found {
preferredRanges = append(preferredRanges, localRange)
}
}

hostMap := NewHostMap(l, tunCidr, preferredRanges)
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)

l.
WithField("network", hostMap.vpnCIDR.String()).
WithField("preferredRanges", hostMap.preferredRanges).
Info("Main HostMap created")

hostMap := NewHostMapFromConfig(l, tunCidr, c)
punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
enc.SetIndent("", " ")
}

return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
}

func sshReload(c *config.C, w sshd.StringWriter) error {
Expand Down