From 3de570d8f296c1d938ed5901e3f931cb7af140ec Mon Sep 17 00:00:00 2001 From: Maria Shaldybin Date: Wed, 31 Jan 2024 18:57:04 +0000 Subject: [PATCH] Setup sticky session for Kerberos and NTML HTTP Authentication When server responds with `WWW-Authenticate: Negotiate`, save VCAP_ID cookie on response to client so that subsequent request with `Authorization: Negotiate ...` will be directed to the same application instance. See [RFC-4559](https://www.ietf.org/rfc/rfc4559.txt) Signed-off-by: Josh Russett --- handlers/helpers.go | 15 +- proxy/round_tripper/proxy_round_tripper.go | 21 +- .../round_tripper/proxy_round_tripper_test.go | 30 ++- proxy/session_affinity_test.go | 2 +- registry/registry_test.go | 34 +-- route/endpoint_iterator_benchmark_test.go | 8 +- route/leastconnection.go | 14 +- route/leastconnection_test.go | 79 ++++--- route/pool.go | 6 +- route/pool_test.go | 6 +- route/roundrobin.go | 14 +- route/roundrobin_test.go | 216 ++++++++++++------ 12 files changed, 294 insertions(+), 151 deletions(-) diff --git a/handlers/helpers.go b/handlers/helpers.go index beb4cc2f..c6b6cff6 100644 --- a/handlers/helpers.go +++ b/handlers/helpers.go @@ -67,17 +67,24 @@ func EndpointIteratorForRequest(request *http.Request, loadBalanceMethod string, if err != nil { return nil, fmt.Errorf("could not find reqInfo in context") } - return reqInfo.RoutePool.Endpoints(loadBalanceMethod, getStickySession(request, stickySessionCookieNames), azPreference, az), nil + stickyEndpointID, mustBeSticky := GetStickySession(request, stickySessionCookieNames) + return reqInfo.RoutePool.Endpoints(loadBalanceMethod, stickyEndpointID, mustBeSticky, azPreference, az), nil } -func getStickySession(request *http.Request, stickySessionCookieNames config.StringSet) string { +func GetStickySession(request *http.Request, stickySessionCookieNames config.StringSet) (string, bool) { + containsAuthNegotiateHeader := strings.HasPrefix(strings.ToLower(request.Header.Get("Authorization")), "negotiate") + if containsAuthNegotiateHeader { + if sticky, err := request.Cookie(VcapCookieId); err == nil { + return sticky.Value, true + } + } // Try choosing a backend using sticky session for stickyCookieName, _ := range stickySessionCookieNames { if _, err := request.Cookie(stickyCookieName); err == nil { if sticky, err := request.Cookie(VcapCookieId); err == nil { - return sticky.Value + return sticky.Value, false } } } - return "" + return "", false } diff --git a/proxy/round_tripper/proxy_round_tripper.go b/proxy/round_tripper/proxy_round_tripper.go index 82e8c8bc..5edca609 100644 --- a/proxy/round_tripper/proxy_round_tripper.go +++ b/proxy/round_tripper/proxy_round_tripper.go @@ -135,9 +135,9 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response return nil, errors.New("ProxyResponseWriter not set on context") } - stickyEndpointID := getStickySession(request, rt.config.StickySessionCookieNames) + stickyEndpointID, mustBeSticky := handlers.GetStickySession(request, rt.config.StickySessionCookieNames) numberOfEndpoints := reqInfo.RoutePool.NumEndpoints() - iter := reqInfo.RoutePool.Endpoints(rt.config.LoadBalance, stickyEndpointID, rt.config.LoadBalanceAZPreference, rt.config.Zone) + iter := reqInfo.RoutePool.Endpoints(rt.config.LoadBalance, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) // The selectEndpointErr needs to be tracked separately. If we get an error // while selecting an endpoint we might just have run out of routes. In @@ -389,9 +389,8 @@ func setupStickySession( requestContainsStickySessionCookies := originalEndpointId != "" requestNotSentToRequestedApp := originalEndpointId != endpoint.PrivateInstanceId - containsAuthNegotiateHeader := strings.ToLower(response.Header.Get("Authorization")) == "negotiate" - - shouldSetVCAPID := containsAuthNegotiateHeader || (requestContainsStickySessionCookies && requestNotSentToRequestedApp) + containsAuthNegotiateHeader := strings.HasPrefix(strings.ToLower(response.Header.Get("WWW-Authenticate")), "negotiate") + shouldSetVCAPID := (containsAuthNegotiateHeader || requestContainsStickySessionCookies) && requestNotSentToRequestedApp secure := false maxAge := 0 @@ -443,18 +442,6 @@ func setupStickySession( } } -func getStickySession(request *http.Request, stickySessionCookieNames config.StringSet) string { - // Try choosing a backend using sticky session - for stickyCookieName, _ := range stickySessionCookieNames { - if _, err := request.Cookie(stickyCookieName); err == nil { - if sticky, err := request.Cookie(VcapCookieId); err == nil { - return sticky.Value - } - } - } - return "" -} - func requestSentToRouteService(request *http.Request) bool { sigHeader := request.Header.Get(routeservice.HeaderKeySignature) rsUrl := request.Header.Get(routeservice.HeaderKeyForwardedURL) diff --git a/proxy/round_tripper/proxy_round_tripper_test.go b/proxy/round_tripper/proxy_round_tripper_test.go index 2e7f0a9e..5f73982e 100644 --- a/proxy/round_tripper/proxy_round_tripper_test.go +++ b/proxy/round_tripper/proxy_round_tripper_test.go @@ -269,7 +269,7 @@ var _ = Describe("ProxyRoundTripper", func() { res, err := proxyRoundTripper.RoundTrip(req) Expect(err).NotTo(HaveOccurred()) - iter := routePool.Endpoints("", "", AZPreference, AZ) + iter := routePool.Endpoints("", "", false, AZPreference, AZ) ep1 := iter.Next(0) ep2 := iter.Next(1) Expect(ep1.PrivateInstanceId).To(Equal(ep2.PrivateInstanceId)) @@ -427,7 +427,7 @@ var _ = Describe("ProxyRoundTripper", func() { _, err := proxyRoundTripper.RoundTrip(req) Expect(err).To(MatchError(ContainSubstring("tls: handshake failure"))) - iter := routePool.Endpoints("", "", AZPreference, AZ) + iter := routePool.Endpoints("", "", false, AZPreference, AZ) ep1 := iter.Next(0) ep2 := iter.Next(1) Expect(ep1).To(Equal(ep2)) @@ -1035,7 +1035,7 @@ var _ = Describe("ProxyRoundTripper", func() { } setAuthorizationNegotiateHeader := func(resp *http.Response) (response *http.Response) { - resp.Header.Add("Authorization", "Negotiate") + resp.Header.Add("WWW-Authenticate", "Negotiate SOME-TOKEN") return resp } @@ -1161,7 +1161,7 @@ var _ = Describe("ProxyRoundTripper", func() { }) }) - Context("when there is an 'Authorization: Negotiate' header set on the response", func() { + Context("when there is an 'WWW-Authenticate: Negotiate ...' header set on the response", func() { BeforeEach(func() { transport.RoundTripStub = func(req *http.Request) (*http.Response, error) { resp := &http.Response{StatusCode: http.StatusTeapot, Header: make(map[string][]string)} @@ -1486,7 +1486,29 @@ var _ = Describe("ProxyRoundTripper", func() { }) }) + Context("when VCAP_ID cookie and 'Authorization: Negotiate ...' header are on the request", func() { + BeforeEach(func() { + req.AddCookie(&http.Cookie{ + Name: round_tripper.VcapCookieId, + Value: "id-2", + }) + req.Header.Add("Authorization", "Negotiate SOME-TOKEN") + transport.RoundTripStub = func(req *http.Request) (*http.Response, error) { + Expect(req.URL.Host).To(Equal("1.1.1.1:9092")) + resp := &http.Response{StatusCode: http.StatusTeapot, Header: make(map[string][]string)} + return resp, nil + } + }) + + It("will select the previous backend and VCAP_ID is set on the response", func() { + Consistently(func() error { + _, err := proxyRoundTripper.RoundTrip(req) + return err + }).ShouldNot(HaveOccurred()) + }) + }) }) + Context("when endpoint timeout is not 0", func() { var reqCh chan *http.Request BeforeEach(func() { diff --git a/proxy/session_affinity_test.go b/proxy/session_affinity_test.go index e2570a86..c2083a43 100644 --- a/proxy/session_affinity_test.go +++ b/proxy/session_affinity_test.go @@ -13,7 +13,7 @@ import ( const StickyCookieKey = "JSESSIONID" -var _ = Describe("Session Affinity", func() { +var _ = Describe("Session Affinity with JSESSIONID", func() { var done chan bool var jSessionIdCookie *http.Cookie diff --git a/registry/registry_test.go b/registry/registry_test.go index a2757ad2..481f191b 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -374,7 +374,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints("", "", azPreference, az).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints("", "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag)) }) }) @@ -396,7 +396,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints("", "", azPreference, az).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints("", "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag)) }) Context("updating an existing route with an older modification tag", func() { @@ -416,7 +416,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - ep := p.Endpoints("", "", azPreference, az).Next(0) + ep := p.Endpoints("", "", false, azPreference, az).Next(0) Expect(ep.ModificationTag).To(Equal(modTag)) Expect(ep).To(Equal(endpoint2)) }) @@ -435,7 +435,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(1)) p := r.Lookup("foo.com") - Expect(p.Endpoints("", "", azPreference, az).Next(0).ModificationTag).To(Equal(modTag)) + Expect(p.Endpoints("", "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag)) }) }) }) @@ -703,7 +703,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumUris()).To(Equal(1)) p1 := r.Lookup("foo/bar") - iter := p1.Endpoints("", "", azPreference, az) + iter := p1.Endpoints("", "", false, azPreference, az) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) p2 := r.Lookup("foo") @@ -799,7 +799,7 @@ var _ = Describe("RouteRegistry", func() { p2 := r.Lookup("FOO") Expect(p1).To(Equal(p2)) - iter := p1.Endpoints("", "", azPreference, az) + iter := p1.Endpoints("", "", false, azPreference, az) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -818,7 +818,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("bar") Expect(p).ToNot(BeNil()) - e := p.Endpoints("", "", azPreference, az).Next(0) + e := p.Endpoints("", "", false, azPreference, az).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:123[4|5]")) @@ -833,13 +833,13 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("foo.wild.card") Expect(p).ToNot(BeNil()) - e := p.Endpoints("", "", azPreference, az).Next(0) + e := p.Endpoints("", "", false, azPreference, az).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.2:1234")) p = r.Lookup("foo.space.wild.card") Expect(p).ToNot(BeNil()) - e = p.Endpoints("", "", azPreference, az).Next(0) + e = p.Endpoints("", "", false, azPreference, az).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.2:1234")) }) @@ -853,7 +853,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("not.wild.card") Expect(p).ToNot(BeNil()) - e := p.Endpoints("", "", azPreference, az).Next(0) + e := p.Endpoints("", "", false, azPreference, az).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -885,7 +885,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("dora.app.com/env?foo=bar") Expect(p).ToNot(BeNil()) - iter := p.Endpoints("", "", azPreference, az) + iter := p.Endpoints("", "", false, azPreference, az) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -894,7 +894,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("dora.app.com/env/abc?foo=bar&baz=bing") Expect(p).ToNot(BeNil()) - iter := p.Endpoints("", "", azPreference, az) + iter := p.Endpoints("", "", false, azPreference, az) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) }) @@ -914,7 +914,7 @@ var _ = Describe("RouteRegistry", func() { p1 := r.Lookup("foo/extra/paths") Expect(p1).ToNot(BeNil()) - iter := p1.Endpoints("", "", azPreference, az) + iter := p1.Endpoints("", "", false, azPreference, az) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -926,7 +926,7 @@ var _ = Describe("RouteRegistry", func() { p1 := r.Lookup("foo?fields=foo,bar") Expect(p1).ToNot(BeNil()) - iter := p1.Endpoints("", "", azPreference, az) + iter := p1.Endpoints("", "", false, azPreference, az) Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) @@ -962,7 +962,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(2)) p := r.LookupWithInstance("bar.com/foo", appId, appIndex) - e := p.Endpoints("", "", azPreference, az).Next(0) + e := p.Endpoints("", "", false, azPreference, az).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:1234")) @@ -976,7 +976,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(2)) p := r.LookupWithInstance("bar.com/foo", appId, appIndex) - e := p.Endpoints("", "", azPreference, az).Next(0) + e := p.Endpoints("", "", false, azPreference, az).Next(0) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:1234")) @@ -1169,7 +1169,7 @@ var _ = Describe("RouteRegistry", func() { p := r.Lookup("foo") Expect(p).ToNot(BeNil()) - Expect(p.Endpoints("", "", azPreference, az).Next(0)).To(Equal(endpoint)) + Expect(p.Endpoints("", "", false, azPreference, az).Next(0)).To(Equal(endpoint)) p = r.Lookup("bar") Expect(p).To(BeNil()) diff --git a/route/endpoint_iterator_benchmark_test.go b/route/endpoint_iterator_benchmark_test.go index 5b16350a..0b092eb6 100644 --- a/route/endpoint_iterator_benchmark_test.go +++ b/route/endpoint_iterator_benchmark_test.go @@ -69,13 +69,13 @@ func setupEndpointIterator(total int, azDistribution int, strategy string) route var lb route.EndpointIterator switch strategy { case "round-robin": - lb = route.NewRoundRobin(pool, "", false, localAZ) + lb = route.NewRoundRobin(pool, "", false, false, localAZ) case "round-robin-locally-optimistic": - lb = route.NewRoundRobin(pool, "", true, localAZ) + lb = route.NewRoundRobin(pool, "", false, true, localAZ) case "least-connection": - lb = route.NewLeastConnection(pool, "", false, localAZ) + lb = route.NewLeastConnection(pool, "", false, false, localAZ) case "least-connection-locally-optimistic": - lb = route.NewLeastConnection(pool, "", true, localAZ) + lb = route.NewLeastConnection(pool, "", false, true, localAZ) default: panic("invalid load balancing strategy") } diff --git a/route/leastconnection.go b/route/leastconnection.go index 14dd9293..fcf7fdeb 100644 --- a/route/leastconnection.go +++ b/route/leastconnection.go @@ -8,16 +8,18 @@ import ( type LeastConnection struct { pool *EndpointPool initialEndpoint string + mustBeSticky bool lastEndpoint *Endpoint randomize *rand.Rand locallyOptimistic bool localAvailabilityZone string } -func NewLeastConnection(p *EndpointPool, initial string, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator { +func NewLeastConnection(p *EndpointPool, initial string, mustBeSticky bool, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator { return &LeastConnection{ pool: p, initialEndpoint: initial, + mustBeSticky: mustBeSticky, randomize: rand.New(rand.NewSource(time.Now().UnixNano())), locallyOptimistic: locallyOptimistic, localAvailabilityZone: localAvailabilityZone, @@ -28,11 +30,17 @@ func (r *LeastConnection) Next(attempt int) *Endpoint { var e *endpointElem if r.initialEndpoint != "" { e = r.pool.findById(r.initialEndpoint) - r.initialEndpoint = "" - if e != nil && e.isOverloaded() { e = nil } + + if e == nil && r.mustBeSticky { + return nil + } + + if !r.mustBeSticky { + r.initialEndpoint = "" + } } if e != nil { diff --git a/route/leastconnection_test.go b/route/leastconnection_test.go index 67028825..81dfe791 100644 --- a/route/leastconnection_test.go +++ b/route/leastconnection_test.go @@ -25,10 +25,9 @@ var _ = Describe("LeastConnection", func() { }) Describe("Next", func() { - Context("when pool is empty", func() { It("does not select an endpoint", func() { - iter := route.NewLeastConnection(pool, "", false, "meow-az") + iter := route.NewLeastConnection(pool, "", false, false, "meow-az") Expect(iter.Next(0)).To(BeNil()) }) }) @@ -57,7 +56,7 @@ var _ = Describe("LeastConnection", func() { Context("when all endpoints have no statistics", func() { It("selects a random endpoint", func() { - iter := route.NewLeastConnection(pool, "", false, "meow-az") + iter := route.NewLeastConnection(pool, "", false, false, "meow-az") n := iter.Next(0) Expect(n).NotTo(BeNil()) }) @@ -74,7 +73,7 @@ var _ = Describe("LeastConnection", func() { for i := 0; i < 100; i++ { wg.Add(1) go func(attempt int) { - iter := route.NewLeastConnection(pool, "", false, "meow-az") + iter := route.NewLeastConnection(pool, "", false, false, "meow-az") n1 := iter.Next(attempt) Expect(n1).NotTo(BeNil()) @@ -90,10 +89,9 @@ var _ = Describe("LeastConnection", func() { }) Context("when endpoints have varying number of connections", func() { - It("selects endpoint with least connection", func() { setConnectionCount(endpoints, []int{0, 1, 1, 1, 1}) - iter := route.NewLeastConnection(pool, "", false, "meow-az") + iter := route.NewLeastConnection(pool, "", false, false, "meow-az") Expect(iter.Next(0)).To(Equal(endpoints[0])) setConnectionCount(endpoints, []int{1, 0, 1, 1, 1}) @@ -122,7 +120,7 @@ var _ = Describe("LeastConnection", func() { }) It("selects random endpoint from all with least connection", func() { - iter := route.NewLeastConnection(pool, "", false, "meow-az") + iter := route.NewLeastConnection(pool, "", false, false, "meow-az") setConnectionCount(endpoints, []int{1, 0, 0, 0, 0}) okRandoms := []string{ @@ -161,18 +159,20 @@ var _ = Describe("LeastConnection", func() { pool.Put(epOne) // epTwo is always overloaded epTwo = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.2.2", Port: 2222, PrivateInstanceId: "private-label-2"}) - epTwo.Stats.NumberConnections.Increment() - epTwo.Stats.NumberConnections.Increment() pool.Put(epTwo) }) Context("when there is no initial endpoint", func() { Context("when all endpoints are overloaded", func() { - It("returns nil", func() { + BeforeEach(func() { epOne.Stats.NumberConnections.Increment() epOne.Stats.NumberConnections.Increment() - iter := route.NewLeastConnection(pool, "", false, "meow-az") + epTwo.Stats.NumberConnections.Increment() + epTwo.Stats.NumberConnections.Increment() + }) + It("returns nil", func() { + iter := route.NewLeastConnection(pool, "", false, false, "meow-az") Consistently(func() *route.Endpoint { return iter.Next(0) }).Should(BeNil()) @@ -180,11 +180,15 @@ var _ = Describe("LeastConnection", func() { }) Context("when there is only one endpoint", func() { + BeforeEach(func() { + Expect(pool.Remove(epOne)).To(BeTrue()) + epTwo.Stats.NumberConnections.Increment() + epTwo.Stats.NumberConnections.Increment() + }) + Context("when that endpoint is overload", func() { It("returns no endpoint", func() { - Expect(pool.Remove(epOne)).To(BeTrue()) - iter := route.NewLeastConnection(pool, "", false, "meow-az") - + iter := route.NewLeastConnection(pool, "", false, false, "meow-az") Consistently(func() *route.Endpoint { return iter.Next(0) }).Should(BeNil()) @@ -195,23 +199,44 @@ var _ = Describe("LeastConnection", func() { Context("when there is an initial endpoint", func() { var iter route.EndpointIterator - BeforeEach(func() { - iter = route.NewLeastConnection(pool, "private-label-2", false, "meow-az") - }) Context("when the initial endpoint is overloaded", func() { - Context("when there is an unencumbered endpoint", func() { - It("returns the unencumbered endpoint", func() { - Expect(iter.Next(0)).To(Equal(epOne)) - Expect(iter.Next(1)).To(Equal(epOne)) + BeforeEach(func() { + epOne.Stats.NumberConnections.Increment() + epOne.Stats.NumberConnections.Increment() + }) + + Context("when the endpoint is not required to be sticky", func() { + BeforeEach(func() { + iter = route.NewLeastConnection(pool, "private-label-1", false, false, "meow-az") + }) + + Context("when there is an unencumbered endpoint", func() { + It("returns the unencumbered endpoint", func() { + Expect(iter.Next(0)).To(Equal(epTwo)) + Expect(iter.Next(1)).To(Equal(epTwo)) + }) + }) + + Context("when there isn't an unencumbered endpoint", func() { + BeforeEach(func() { + epTwo.Stats.NumberConnections.Increment() + epTwo.Stats.NumberConnections.Increment() + }) + + It("returns nil", func() { + Consistently(func() *route.Endpoint { + return iter.Next(0) + }).Should(BeNil()) + }) }) }) - Context("when there isn't an unencumbered endpoint", func() { + Context("when the endpoint must be be sticky", func() { BeforeEach(func() { - epOne.Stats.NumberConnections.Increment() - epOne.Stats.NumberConnections.Increment() + iter = route.NewLeastConnection(pool, "private-label-1", true, false, "meow-az") }) + It("returns nil", func() { Consistently(func() *route.Endpoint { return iter.Next(0) @@ -249,7 +274,7 @@ var _ = Describe("LeastConnection", func() { }) JustBeforeEach(func() { - iter = route.NewLeastConnection(pool, "", true, localAZ) + iter = route.NewLeastConnection(pool, "", false, true, localAZ) }) Context("on the first attempt", func() { @@ -424,7 +449,7 @@ var _ = Describe("LeastConnection", func() { Expect(endpointFoo.Stats.NumberConnections.Count()).To(Equal(int64(0))) pool.Put(endpointFoo) - iter := route.NewLeastConnection(pool, "foo", false, "meow-az") + iter := route.NewLeastConnection(pool, "foo", false, false, "meow-az") iter.PreRequest(endpointFoo) Expect(endpointFoo.Stats.NumberConnections.Count()).To(Equal(int64(1))) }) @@ -439,7 +464,7 @@ var _ = Describe("LeastConnection", func() { } Expect(endpointFoo.Stats.NumberConnections.Count()).To(Equal(int64(1))) pool.Put(endpointFoo) - iter := route.NewLeastConnection(pool, "foo", false, "meow-az") + iter := route.NewLeastConnection(pool, "foo", false, false, "meow-az") iter.PostRequest(endpointFoo) Expect(endpointFoo.Stats.NumberConnections.Count()).To(Equal(int64(0))) }) diff --git a/route/pool.go b/route/pool.go index 1e0d0ea9..88bc4ace 100644 --- a/route/pool.go +++ b/route/pool.go @@ -368,12 +368,12 @@ func (p *EndpointPool) removeEndpoint(e *endpointElem) { p.Update() } -func (p *EndpointPool) Endpoints(defaultLoadBalance, initial, azPreference, az string) EndpointIterator { +func (p *EndpointPool) Endpoints(defaultLoadBalance string, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { switch defaultLoadBalance { case config.LOAD_BALANCE_LC: - return NewLeastConnection(p, initial, azPreference == config.AZ_PREF_LOCAL, az) + return NewLeastConnection(p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) default: - return NewRoundRobin(p, initial, azPreference == config.AZ_PREF_LOCAL, az) + return NewRoundRobin(p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) } } diff --git a/route/pool_test.go b/route/pool_test.go index 321dd5be..8b34fbe4 100644 --- a/route/pool_test.go +++ b/route/pool_test.go @@ -181,7 +181,7 @@ var _ = Describe("EndpointPool", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: modTag2}) Expect(pool.Put(endpoint)).To(Equal(route.UPDATED)) - Expect(pool.Endpoints("", "", azPreference, az).Next(0).ModificationTag).To(Equal(modTag2)) + Expect(pool.Endpoints("", "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag2)) }) Context("when modification_tag is older", func() { @@ -196,7 +196,7 @@ var _ = Describe("EndpointPool", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: olderModTag}) Expect(pool.Put(endpoint)).To(Equal(route.UNMODIFIED)) - Expect(pool.Endpoints("", "", azPreference, az).Next(0).ModificationTag).To(Equal(modTag2)) + Expect(pool.Endpoints("", "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag2)) }) }) }) @@ -302,7 +302,7 @@ var _ = Describe("EndpointPool", func() { azPreference := "none" connectionResetError := &net.OpError{Op: "read", Err: errors.New("read: connection reset by peer")} pool.EndpointFailed(failedEndpoint, connectionResetError) - i := pool.Endpoints("", "", azPreference, az) + i := pool.Endpoints("", "", false, azPreference, az) epOne := i.Next(0) epTwo := i.Next(1) Expect(epOne).To(Equal(epTwo)) diff --git a/route/roundrobin.go b/route/roundrobin.go index 6cf3b0e9..78094fb3 100644 --- a/route/roundrobin.go +++ b/route/roundrobin.go @@ -8,15 +8,17 @@ type RoundRobin struct { pool *EndpointPool initialEndpoint string + mustBeSticky bool lastEndpoint *Endpoint locallyOptimistic bool localAvailabilityZone string } -func NewRoundRobin(p *EndpointPool, initial string, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator { +func NewRoundRobin(p *EndpointPool, initial string, mustBeSticky bool, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator { return &RoundRobin{ pool: p, initialEndpoint: initial, + mustBeSticky: mustBeSticky, locallyOptimistic: locallyOptimistic, localAvailabilityZone: localAvailabilityZone, } @@ -26,11 +28,17 @@ func (r *RoundRobin) Next(attempt int) *Endpoint { var e *endpointElem if r.initialEndpoint != "" { e = r.pool.findById(r.initialEndpoint) - r.initialEndpoint = "" - if e != nil && e.isOverloaded() { e = nil } + + if e == nil && r.mustBeSticky { + return nil + } + + if !r.mustBeSticky { + r.initialEndpoint = "" + } } if e != nil { diff --git a/route/roundrobin_test.go b/route/roundrobin_test.go index b9110ce6..89fbf418 100644 --- a/route/roundrobin_test.go +++ b/route/roundrobin_test.go @@ -40,7 +40,7 @@ var _ = Describe("RoundRobin", func() { counts := make([]int, len(endpoints)) - iter := route.NewRoundRobin(pool, "", false, "meow-az") + iter := route.NewRoundRobin(pool, "", false, false, "meow-az") loops := 50 for i := 0; i < len(endpoints)*loops; i += 1 { @@ -66,7 +66,7 @@ var _ = Describe("RoundRobin", func() { DescribeTable("it returns nil when no endpoints exist", func(nextIdx int) { pool.NextIdx = nextIdx - iter := route.NewRoundRobin(pool, "", false, "meow-az") + iter := route.NewRoundRobin(pool, "", false, false, "meow-az") e := iter.Next(0) Expect(e).To(BeNil()) }, @@ -84,7 +84,7 @@ var _ = Describe("RoundRobin", func() { pool.Put(route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 1237})) for i := 0; i < 10; i++ { - iter := route.NewRoundRobin(pool, b.PrivateInstanceId, false, "meow-az") + iter := route.NewRoundRobin(pool, b.PrivateInstanceId, false, false, "meow-az") e := iter.Next(i) Expect(e).ToNot(BeNil()) Expect(e.PrivateInstanceId).To(Equal(b.PrivateInstanceId)) @@ -107,7 +107,7 @@ var _ = Describe("RoundRobin", func() { pool.Put(route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 1237})) for i := 0; i < 10; i++ { - iter := route.NewRoundRobin(pool, b.CanonicalAddr(), false, "meow-az") + iter := route.NewRoundRobin(pool, b.CanonicalAddr(), false, false, "meow-az") e := iter.Next(i) Expect(e).ToNot(BeNil()) Expect(e.CanonicalAddr()).To(Equal(b.CanonicalAddr())) @@ -129,12 +129,12 @@ var _ = Describe("RoundRobin", func() { pool.Put(endpointFoo) pool.Put(endpointBar) - iter := route.NewRoundRobin(pool, endpointFoo.PrivateInstanceId, false, "meow-az") + iter := route.NewRoundRobin(pool, endpointFoo.PrivateInstanceId, false, false, "meow-az") foundEndpoint := iter.Next(0) Expect(foundEndpoint).ToNot(BeNil()) Expect(foundEndpoint).To(Equal(endpointFoo)) - iter = route.NewRoundRobin(pool, endpointBar.PrivateInstanceId, false, "meow-az") + iter = route.NewRoundRobin(pool, endpointBar.PrivateInstanceId, false, false, "meow-az") foundEndpoint = iter.Next(1) Expect(foundEndpoint).ToNot(BeNil()) Expect(foundEndpoint).To(Equal(endpointBar)) @@ -144,20 +144,38 @@ var _ = Describe("RoundRobin", func() { Entry("When the next index is 1", 1), ) - DescribeTable("it returns the next available endpoint when the initial is not found", - func(nextIdx int) { - pool.NextIdx = nextIdx - endpointFoo := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 1234, PrivateInstanceId: "foo"}) - pool.Put(endpointFoo) + Context("when endpoint is not required to be sticky", func() { + DescribeTable("it returns the next available endpoint when the initial is not found", + func(nextIdx int) { + pool.NextIdx = nextIdx + endpointFoo := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 1234, PrivateInstanceId: "foo"}) + pool.Put(endpointFoo) - iter := route.NewRoundRobin(pool, "bogus", false, "meow-az") - e := iter.Next(0) - Expect(e).ToNot(BeNil()) - Expect(e).To(Equal(endpointFoo)) - }, - Entry("When the next index is -1", -1), - Entry("When the next index is 0", 0), - ) + iter := route.NewRoundRobin(pool, "bogus", false, false, "meow-az") + e := iter.Next(0) + Expect(e).ToNot(BeNil()) + Expect(e).To(Equal(endpointFoo)) + }, + Entry("When the next index is -1", -1), + Entry("When the next index is 0", 0), + ) + }) + + Context("when endpoint must be sticky", func() { + DescribeTable("it returns nil when the initial is not found", + func(nextIdx int) { + pool.NextIdx = nextIdx + endpointFoo := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 1234, PrivateInstanceId: "foo"}) + pool.Put(endpointFoo) + + iter := route.NewRoundRobin(pool, "bogus", true, false, "meow-az") + e := iter.Next(0) + Expect(e).To(BeNil()) + }, + Entry("When the next index is -1", -1), + Entry("When the next index is 0", 0), + ) + }) DescribeTable("it finds the correct endpoint when private ids change", func(nextIdx int) { @@ -165,7 +183,7 @@ var _ = Describe("RoundRobin", func() { endpointFoo := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 1234, PrivateInstanceId: "foo"}) pool.Put(endpointFoo) - iter := route.NewRoundRobin(pool, endpointFoo.PrivateInstanceId, false, "meow-az") + iter := route.NewRoundRobin(pool, endpointFoo.PrivateInstanceId, false, false, "meow-az") foundEndpoint := iter.Next(0) Expect(foundEndpoint).ToNot(BeNil()) Expect(foundEndpoint).To(Equal(endpointFoo)) @@ -173,11 +191,11 @@ var _ = Describe("RoundRobin", func() { endpointBar := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 1234, PrivateInstanceId: "bar"}) pool.Put(endpointBar) - iter = route.NewRoundRobin(pool, "foo", false, "meow-az") + iter = route.NewRoundRobin(pool, "foo", false, false, "meow-az") foundEndpoint = iter.Next(0) Expect(foundEndpoint).ToNot(Equal(endpointFoo)) - iter = route.NewRoundRobin(pool, "bar", false, "meow-az") + iter = route.NewRoundRobin(pool, "bar", false, false, "meow-az") foundEndpoint = iter.Next(0) Expect(foundEndpoint).To(Equal(endpointBar)) }, @@ -197,7 +215,7 @@ var _ = Describe("RoundRobin", func() { iterateLoop := func(pool *route.EndpointPool) { defer GinkgoRecover() for j := 0; j < numReaders; j++ { - iter := route.NewRoundRobin(pool, "", false, "meow-az") + iter := route.NewRoundRobin(pool, "", false, false, "meow-az") Expect(iter.Next(j)).NotTo(BeNil()) } wg.Done() @@ -235,7 +253,7 @@ var _ = Describe("RoundRobin", func() { MaxConnsPerBackend: 2, }) - epOne = route.NewEndpoint(&route.EndpointOpts{Host: "5.5.5.5", Port: 5555, PrivateInstanceId: "private-label-1"}) + epOne = route.NewEndpoint(&route.EndpointOpts{Host: "5.5.5.5", Port: 5555, PrivateInstanceId: "private-label-1", UseTLS: true}) pool.Put(epOne) epTwo = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.2.2", Port: 2222, PrivateInstanceId: "private-label-2"}) pool.Put(epTwo) @@ -247,7 +265,7 @@ var _ = Describe("RoundRobin", func() { pool.NextIdx = nextIdx epTwo.Stats.NumberConnections.Increment() epTwo.Stats.NumberConnections.Increment() - iter := route.NewRoundRobin(pool, "", false, "meow-az") + iter := route.NewRoundRobin(pool, "", false, false, "meow-az") foundEndpoint := iter.Next(0) Expect(foundEndpoint).To(Equal(epOne)) @@ -268,7 +286,7 @@ var _ = Describe("RoundRobin", func() { epOne.Stats.NumberConnections.Increment() epTwo.Stats.NumberConnections.Increment() epTwo.Stats.NumberConnections.Increment() - iter := route.NewRoundRobin(pool, "", false, "meow-az") + iter := route.NewRoundRobin(pool, "", false, false, "meow-az") Consistently(func() *route.Endpoint { return iter.Next(0) @@ -287,7 +305,7 @@ var _ = Describe("RoundRobin", func() { epThree := route.NewEndpoint(&route.EndpointOpts{Host: "3.3.3.3", Port: 2222, PrivateInstanceId: "private-label-2"}) pool.Put(epThree) - iter := route.NewRoundRobin(pool, "", false, "meow-az") + iter := route.NewRoundRobin(pool, "", false, false, "meow-az") Expect(iter.Next(0)).To(Equal(epOne)) iter.EndpointFailed(&net.OpError{Op: "dial"}) @@ -313,46 +331,114 @@ var _ = Describe("RoundRobin", func() { Context("when there is an initial endpoint", func() { var iter route.EndpointIterator - BeforeEach(func() { - iter = route.NewRoundRobin(pool, "private-label-1", false, "meow-az") - }) - Context("when the initial endpoint is overloaded", func() { + Context("when the endpoint is not required to be sticky", func() { BeforeEach(func() { - epOne.Stats.NumberConnections.Increment() - epOne.Stats.NumberConnections.Increment() + iter = route.NewRoundRobin(pool, "private-label-1", false, false, "meow-az") }) - Context("when there is an unencumbered endpoint", func() { - DescribeTable("it returns the unencumbered endpoint", - func(nextIdx int) { - pool.NextIdx = nextIdx - Expect(iter.Next(0)).To(Equal(epTwo)) - Expect(iter.Next(1)).To(Equal(epTwo)) - }, - Entry("When the next index is -1", -1), - Entry("When the next index is 0", 0), - Entry("When the next index is 1", 1), - ) + Context("when the initial endpoint is overloaded", func() { + BeforeEach(func() { + epOne.Stats.NumberConnections.Increment() + epOne.Stats.NumberConnections.Increment() + }) + + Context("when there is an unencumbered endpoint", func() { + DescribeTable("it returns the unencumbered endpoint", + func(nextIdx int) { + pool.NextIdx = nextIdx + Expect(iter.Next(0)).To(Equal(epTwo)) + Expect(iter.Next(1)).To(Equal(epTwo)) + }, + Entry("When the next index is -1", -1), + Entry("When the next index is 0", 0), + Entry("When the next index is 1", 1), + ) + }) + + Context("when there isn't an unencumbered endpoint", func() { + BeforeEach(func() { + epTwo.Stats.NumberConnections.Increment() + epTwo.Stats.NumberConnections.Increment() + }) + + DescribeTable("it returns nil", + func(nextIdx int) { + pool.NextIdx = nextIdx + Consistently(func() *route.Endpoint { + return iter.Next(0) + }).Should(BeNil()) + }, + Entry("When the next index is -1", -1), + Entry("When the next index is 0", 0), + Entry("When the next index is 1", 1), + ) + }) + }) + }) + + Context("when the endpoint must to be sticky", func() { + BeforeEach(func() { + iter = route.NewRoundRobin(pool, "private-label-1", true, false, "meow-az") }) - Context("when there isn't an unencumbered endpoint", func() { + Context("when the initial endpoint is overloaded", func() { BeforeEach(func() { - epTwo.Stats.NumberConnections.Increment() - epTwo.Stats.NumberConnections.Increment() + epOne.Stats.NumberConnections.Increment() + epOne.Stats.NumberConnections.Increment() }) - DescribeTable("it returns nil", - func(nextIdx int) { - pool.NextIdx = nextIdx - Consistently(func() *route.Endpoint { - return iter.Next(0) - }).Should(BeNil()) - }, - Entry("When the next index is -1", -1), - Entry("When the next index is 0", 0), - Entry("When the next index is 1", 1), - ) + Context("when there is an unencumbered endpoint", func() { + DescribeTable("it returns nil", + func(nextIdx int) { + pool.NextIdx = nextIdx + Consistently(func() *route.Endpoint { + return iter.Next(0) + }).Should(BeNil()) + }, + Entry("When the next index is -1", -1), + Entry("When the next index is 0", 0), + Entry("When the next index is 1", 1), + ) + }) + + Context("when there isn't an unencumbered endpoint", func() { + BeforeEach(func() { + epTwo.Stats.NumberConnections.Increment() + epTwo.Stats.NumberConnections.Increment() + }) + + DescribeTable("it returns nil", + func(nextIdx int) { + pool.NextIdx = nextIdx + Consistently(func() *route.Endpoint { + return iter.Next(0) + }).Should(BeNil()) + }, + Entry("When the next index is -1", -1), + Entry("When the next index is 0", 0), + Entry("When the next index is 1", 1), + ) + }) + }) + + Context("when initial endpoint becomes overloaded", func() { + It("doesn't mark endpoint as failed", func() { + Expect(pool.NumEndpoints()).To(Equal(2)) + Expect(iter.Next(0)).To(Equal(epOne)) + + epOne.Stats.NumberConnections.Increment() + epOne.Stats.NumberConnections.Increment() + + Expect(iter.Next(0)).To(BeNil()) + + Expect(pool.NumEndpoints()).To(Equal(2)) + + epOne.Stats.NumberConnections.Decrement() + epOne.Stats.NumberConnections.Decrement() + + Expect(iter.Next(0)).To(Equal(epOne)) + }) }) }) }) @@ -384,7 +470,7 @@ var _ = Describe("RoundRobin", func() { }) JustBeforeEach(func() { - iter = route.NewRoundRobin(pool, "", true, localAZ) + iter = route.NewRoundRobin(pool, "", false, true, localAZ) }) Context("on the first attempt", func() { @@ -631,7 +717,7 @@ var _ = Describe("RoundRobin", func() { counts := make([]int, len(endpoints)) - iter := route.NewRoundRobin(pool, "", true, localAZ) + iter := route.NewRoundRobin(pool, "", false, true, localAZ) loops := 50 for i := 0; i < len(endpoints)*loops; i += 1 { @@ -671,7 +757,7 @@ var _ = Describe("RoundRobin", func() { pool.Put(e1) pool.Put(e2) - iter := route.NewRoundRobin(pool, "", false, "meow-az") + iter := route.NewRoundRobin(pool, "", false, false, "meow-az") n := iter.Next(0) Expect(n).ToNot(BeNil()) @@ -697,7 +783,7 @@ var _ = Describe("RoundRobin", func() { pool.Put(e1) pool.Put(e2) - iter := route.NewRoundRobin(pool, "", false, "meow-az") + iter := route.NewRoundRobin(pool, "", false, false, "meow-az") n1 := iter.Next(0) iter.EndpointFailed(&net.OpError{Op: "dial"}) n2 := iter.Next(1) @@ -729,7 +815,7 @@ var _ = Describe("RoundRobin", func() { pool.Put(e1) pool.Put(e2) - iter := route.NewRoundRobin(pool, "", false, "meow-az") + iter := route.NewRoundRobin(pool, "", false, false, "meow-az") n1 := iter.Next(0) n2 := iter.Next(1) Expect(n1).ToNot(Equal(n2)) @@ -757,7 +843,7 @@ var _ = Describe("RoundRobin", func() { endpointFoo := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 1234, PrivateInstanceId: "foo"}) Expect(endpointFoo.Stats.NumberConnections.Count()).To(Equal(int64(0))) pool.Put(endpointFoo) - iter := route.NewRoundRobin(pool, "foo", false, "meow-az") + iter := route.NewRoundRobin(pool, "foo", false, false, "meow-az") iter.PreRequest(endpointFoo) Expect(endpointFoo.Stats.NumberConnections.Count()).To(Equal(int64(1))) }) @@ -771,7 +857,7 @@ var _ = Describe("RoundRobin", func() { } Expect(endpointFoo.Stats.NumberConnections.Count()).To(Equal(int64(1))) pool.Put(endpointFoo) - iter := route.NewRoundRobin(pool, "foo", false, "meow-az") + iter := route.NewRoundRobin(pool, "foo", false, false, "meow-az") iter.PostRequest(endpointFoo) Expect(endpointFoo.Stats.NumberConnections.Count()).To(Equal(int64(0))) })