Skip to content

Commit

Permalink
Setup sticky session for Kerberos and NTML HTTP Authentication
Browse files Browse the repository at this point in the history
When server responds with `WWW-Authenticate: Negotiate`, save VCAP_ID
cookie on response to client so that subsequent request with
`Authorization: Negotiate ...` will be directed to the same application
instance.

See [RFC-4559](https://www.ietf.org/rfc/rfc4559.txt)

Signed-off-by: Josh Russett <[email protected]>
  • Loading branch information
mariash committed Jan 31, 2024
1 parent 1b1e6ac commit 3de570d
Show file tree
Hide file tree
Showing 12 changed files with 294 additions and 151 deletions.
15 changes: 11 additions & 4 deletions handlers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,24 @@ func EndpointIteratorForRequest(request *http.Request, loadBalanceMethod string,
if err != nil {
return nil, fmt.Errorf("could not find reqInfo in context")
}
return reqInfo.RoutePool.Endpoints(loadBalanceMethod, getStickySession(request, stickySessionCookieNames), azPreference, az), nil
stickyEndpointID, mustBeSticky := GetStickySession(request, stickySessionCookieNames)
return reqInfo.RoutePool.Endpoints(loadBalanceMethod, stickyEndpointID, mustBeSticky, azPreference, az), nil
}

func getStickySession(request *http.Request, stickySessionCookieNames config.StringSet) string {
func GetStickySession(request *http.Request, stickySessionCookieNames config.StringSet) (string, bool) {
containsAuthNegotiateHeader := strings.HasPrefix(strings.ToLower(request.Header.Get("Authorization")), "negotiate")
if containsAuthNegotiateHeader {
if sticky, err := request.Cookie(VcapCookieId); err == nil {
return sticky.Value, true
}
}
// Try choosing a backend using sticky session
for stickyCookieName, _ := range stickySessionCookieNames {
if _, err := request.Cookie(stickyCookieName); err == nil {
if sticky, err := request.Cookie(VcapCookieId); err == nil {
return sticky.Value
return sticky.Value, false
}
}
}
return ""
return "", false
}
21 changes: 4 additions & 17 deletions proxy/round_tripper/proxy_round_tripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response
return nil, errors.New("ProxyResponseWriter not set on context")
}

stickyEndpointID := getStickySession(request, rt.config.StickySessionCookieNames)
stickyEndpointID, mustBeSticky := handlers.GetStickySession(request, rt.config.StickySessionCookieNames)
numberOfEndpoints := reqInfo.RoutePool.NumEndpoints()
iter := reqInfo.RoutePool.Endpoints(rt.config.LoadBalance, stickyEndpointID, rt.config.LoadBalanceAZPreference, rt.config.Zone)
iter := reqInfo.RoutePool.Endpoints(rt.config.LoadBalance, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone)

// The selectEndpointErr needs to be tracked separately. If we get an error
// while selecting an endpoint we might just have run out of routes. In
Expand Down Expand Up @@ -389,9 +389,8 @@ func setupStickySession(

requestContainsStickySessionCookies := originalEndpointId != ""
requestNotSentToRequestedApp := originalEndpointId != endpoint.PrivateInstanceId
containsAuthNegotiateHeader := strings.ToLower(response.Header.Get("Authorization")) == "negotiate"

shouldSetVCAPID := containsAuthNegotiateHeader || (requestContainsStickySessionCookies && requestNotSentToRequestedApp)
containsAuthNegotiateHeader := strings.HasPrefix(strings.ToLower(response.Header.Get("WWW-Authenticate")), "negotiate")
shouldSetVCAPID := (containsAuthNegotiateHeader || requestContainsStickySessionCookies) && requestNotSentToRequestedApp

secure := false
maxAge := 0
Expand Down Expand Up @@ -443,18 +442,6 @@ func setupStickySession(
}
}

func getStickySession(request *http.Request, stickySessionCookieNames config.StringSet) string {
// Try choosing a backend using sticky session
for stickyCookieName, _ := range stickySessionCookieNames {
if _, err := request.Cookie(stickyCookieName); err == nil {
if sticky, err := request.Cookie(VcapCookieId); err == nil {
return sticky.Value
}
}
}
return ""
}

func requestSentToRouteService(request *http.Request) bool {
sigHeader := request.Header.Get(routeservice.HeaderKeySignature)
rsUrl := request.Header.Get(routeservice.HeaderKeyForwardedURL)
Expand Down
30 changes: 26 additions & 4 deletions proxy/round_tripper/proxy_round_tripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ var _ = Describe("ProxyRoundTripper", func() {
res, err := proxyRoundTripper.RoundTrip(req)
Expect(err).NotTo(HaveOccurred())

iter := routePool.Endpoints("", "", AZPreference, AZ)
iter := routePool.Endpoints("", "", false, AZPreference, AZ)
ep1 := iter.Next(0)
ep2 := iter.Next(1)
Expect(ep1.PrivateInstanceId).To(Equal(ep2.PrivateInstanceId))
Expand Down Expand Up @@ -427,7 +427,7 @@ var _ = Describe("ProxyRoundTripper", func() {
_, err := proxyRoundTripper.RoundTrip(req)
Expect(err).To(MatchError(ContainSubstring("tls: handshake failure")))

iter := routePool.Endpoints("", "", AZPreference, AZ)
iter := routePool.Endpoints("", "", false, AZPreference, AZ)
ep1 := iter.Next(0)
ep2 := iter.Next(1)
Expect(ep1).To(Equal(ep2))
Expand Down Expand Up @@ -1035,7 +1035,7 @@ var _ = Describe("ProxyRoundTripper", func() {
}

setAuthorizationNegotiateHeader := func(resp *http.Response) (response *http.Response) {
resp.Header.Add("Authorization", "Negotiate")
resp.Header.Add("WWW-Authenticate", "Negotiate SOME-TOKEN")
return resp
}

Expand Down Expand Up @@ -1161,7 +1161,7 @@ var _ = Describe("ProxyRoundTripper", func() {
})
})

Context("when there is an 'Authorization: Negotiate' header set on the response", func() {
Context("when there is an 'WWW-Authenticate: Negotiate ...' header set on the response", func() {
BeforeEach(func() {
transport.RoundTripStub = func(req *http.Request) (*http.Response, error) {
resp := &http.Response{StatusCode: http.StatusTeapot, Header: make(map[string][]string)}
Expand Down Expand Up @@ -1486,7 +1486,29 @@ var _ = Describe("ProxyRoundTripper", func() {
})
})

Context("when VCAP_ID cookie and 'Authorization: Negotiate ...' header are on the request", func() {
BeforeEach(func() {
req.AddCookie(&http.Cookie{
Name: round_tripper.VcapCookieId,
Value: "id-2",
})
req.Header.Add("Authorization", "Negotiate SOME-TOKEN")
transport.RoundTripStub = func(req *http.Request) (*http.Response, error) {
Expect(req.URL.Host).To(Equal("1.1.1.1:9092"))
resp := &http.Response{StatusCode: http.StatusTeapot, Header: make(map[string][]string)}
return resp, nil
}
})

It("will select the previous backend and VCAP_ID is set on the response", func() {
Consistently(func() error {
_, err := proxyRoundTripper.RoundTrip(req)
return err
}).ShouldNot(HaveOccurred())
})
})
})

Context("when endpoint timeout is not 0", func() {
var reqCh chan *http.Request
BeforeEach(func() {
Expand Down
2 changes: 1 addition & 1 deletion proxy/session_affinity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

const StickyCookieKey = "JSESSIONID"

var _ = Describe("Session Affinity", func() {
var _ = Describe("Session Affinity with JSESSIONID", func() {
var done chan bool
var jSessionIdCookie *http.Cookie

Expand Down
34 changes: 17 additions & 17 deletions registry/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ var _ = Describe("RouteRegistry", func() {
Expect(r.NumEndpoints()).To(Equal(1))

p := r.Lookup("foo.com")
Expect(p.Endpoints("", "", azPreference, az).Next(0).ModificationTag).To(Equal(modTag))
Expect(p.Endpoints("", "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag))
})
})

Expand All @@ -396,7 +396,7 @@ var _ = Describe("RouteRegistry", func() {
Expect(r.NumEndpoints()).To(Equal(1))

p := r.Lookup("foo.com")
Expect(p.Endpoints("", "", azPreference, az).Next(0).ModificationTag).To(Equal(modTag))
Expect(p.Endpoints("", "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag))
})

Context("updating an existing route with an older modification tag", func() {
Expand All @@ -416,7 +416,7 @@ var _ = Describe("RouteRegistry", func() {
Expect(r.NumEndpoints()).To(Equal(1))

p := r.Lookup("foo.com")
ep := p.Endpoints("", "", azPreference, az).Next(0)
ep := p.Endpoints("", "", false, azPreference, az).Next(0)
Expect(ep.ModificationTag).To(Equal(modTag))
Expect(ep).To(Equal(endpoint2))
})
Expand All @@ -435,7 +435,7 @@ var _ = Describe("RouteRegistry", func() {
Expect(r.NumEndpoints()).To(Equal(1))

p := r.Lookup("foo.com")
Expect(p.Endpoints("", "", azPreference, az).Next(0).ModificationTag).To(Equal(modTag))
Expect(p.Endpoints("", "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag))
})
})
})
Expand Down Expand Up @@ -703,7 +703,7 @@ var _ = Describe("RouteRegistry", func() {
Expect(r.NumUris()).To(Equal(1))

p1 := r.Lookup("foo/bar")
iter := p1.Endpoints("", "", azPreference, az)
iter := p1.Endpoints("", "", false, azPreference, az)
Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234"))

p2 := r.Lookup("foo")
Expand Down Expand Up @@ -799,7 +799,7 @@ var _ = Describe("RouteRegistry", func() {
p2 := r.Lookup("FOO")
Expect(p1).To(Equal(p2))

iter := p1.Endpoints("", "", azPreference, az)
iter := p1.Endpoints("", "", false, azPreference, az)
Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234"))
})

Expand All @@ -818,7 +818,7 @@ var _ = Describe("RouteRegistry", func() {

p := r.Lookup("bar")
Expect(p).ToNot(BeNil())
e := p.Endpoints("", "", azPreference, az).Next(0)
e := p.Endpoints("", "", false, azPreference, az).Next(0)
Expect(e).ToNot(BeNil())
Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:123[4|5]"))

Expand All @@ -833,13 +833,13 @@ var _ = Describe("RouteRegistry", func() {

p := r.Lookup("foo.wild.card")
Expect(p).ToNot(BeNil())
e := p.Endpoints("", "", azPreference, az).Next(0)
e := p.Endpoints("", "", false, azPreference, az).Next(0)
Expect(e).ToNot(BeNil())
Expect(e.CanonicalAddr()).To(Equal("192.168.1.2:1234"))

p = r.Lookup("foo.space.wild.card")
Expect(p).ToNot(BeNil())
e = p.Endpoints("", "", azPreference, az).Next(0)
e = p.Endpoints("", "", false, azPreference, az).Next(0)
Expect(e).ToNot(BeNil())
Expect(e.CanonicalAddr()).To(Equal("192.168.1.2:1234"))
})
Expand All @@ -853,7 +853,7 @@ var _ = Describe("RouteRegistry", func() {

p := r.Lookup("not.wild.card")
Expect(p).ToNot(BeNil())
e := p.Endpoints("", "", azPreference, az).Next(0)
e := p.Endpoints("", "", false, azPreference, az).Next(0)
Expect(e).ToNot(BeNil())
Expect(e.CanonicalAddr()).To(Equal("192.168.1.1:1234"))
})
Expand Down Expand Up @@ -885,7 +885,7 @@ var _ = Describe("RouteRegistry", func() {
p := r.Lookup("dora.app.com/env?foo=bar")

Expect(p).ToNot(BeNil())
iter := p.Endpoints("", "", azPreference, az)
iter := p.Endpoints("", "", false, azPreference, az)
Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234"))
})

Expand All @@ -894,7 +894,7 @@ var _ = Describe("RouteRegistry", func() {
p := r.Lookup("dora.app.com/env/abc?foo=bar&baz=bing")

Expect(p).ToNot(BeNil())
iter := p.Endpoints("", "", azPreference, az)
iter := p.Endpoints("", "", false, azPreference, az)
Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234"))
})
})
Expand All @@ -914,7 +914,7 @@ var _ = Describe("RouteRegistry", func() {
p1 := r.Lookup("foo/extra/paths")
Expect(p1).ToNot(BeNil())

iter := p1.Endpoints("", "", azPreference, az)
iter := p1.Endpoints("", "", false, azPreference, az)
Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234"))
})

Expand All @@ -926,7 +926,7 @@ var _ = Describe("RouteRegistry", func() {
p1 := r.Lookup("foo?fields=foo,bar")
Expect(p1).ToNot(BeNil())

iter := p1.Endpoints("", "", azPreference, az)
iter := p1.Endpoints("", "", false, azPreference, az)
Expect(iter.Next(0).CanonicalAddr()).To(Equal("192.168.1.1:1234"))
})

Expand Down Expand Up @@ -962,7 +962,7 @@ var _ = Describe("RouteRegistry", func() {
Expect(r.NumEndpoints()).To(Equal(2))

p := r.LookupWithInstance("bar.com/foo", appId, appIndex)
e := p.Endpoints("", "", azPreference, az).Next(0)
e := p.Endpoints("", "", false, azPreference, az).Next(0)

Expect(e).ToNot(BeNil())
Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:1234"))
Expand All @@ -976,7 +976,7 @@ var _ = Describe("RouteRegistry", func() {
Expect(r.NumEndpoints()).To(Equal(2))

p := r.LookupWithInstance("bar.com/foo", appId, appIndex)
e := p.Endpoints("", "", azPreference, az).Next(0)
e := p.Endpoints("", "", false, azPreference, az).Next(0)

Expect(e).ToNot(BeNil())
Expect(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:1234"))
Expand Down Expand Up @@ -1169,7 +1169,7 @@ var _ = Describe("RouteRegistry", func() {

p := r.Lookup("foo")
Expect(p).ToNot(BeNil())
Expect(p.Endpoints("", "", azPreference, az).Next(0)).To(Equal(endpoint))
Expect(p.Endpoints("", "", false, azPreference, az).Next(0)).To(Equal(endpoint))

p = r.Lookup("bar")
Expect(p).To(BeNil())
Expand Down
8 changes: 4 additions & 4 deletions route/endpoint_iterator_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ func setupEndpointIterator(total int, azDistribution int, strategy string) route
var lb route.EndpointIterator
switch strategy {
case "round-robin":
lb = route.NewRoundRobin(pool, "", false, localAZ)
lb = route.NewRoundRobin(pool, "", false, false, localAZ)
case "round-robin-locally-optimistic":
lb = route.NewRoundRobin(pool, "", true, localAZ)
lb = route.NewRoundRobin(pool, "", false, true, localAZ)
case "least-connection":
lb = route.NewLeastConnection(pool, "", false, localAZ)
lb = route.NewLeastConnection(pool, "", false, false, localAZ)
case "least-connection-locally-optimistic":
lb = route.NewLeastConnection(pool, "", true, localAZ)
lb = route.NewLeastConnection(pool, "", false, true, localAZ)
default:
panic("invalid load balancing strategy")
}
Expand Down
14 changes: 11 additions & 3 deletions route/leastconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ import (
type LeastConnection struct {
pool *EndpointPool
initialEndpoint string
mustBeSticky bool
lastEndpoint *Endpoint
randomize *rand.Rand
locallyOptimistic bool
localAvailabilityZone string
}

func NewLeastConnection(p *EndpointPool, initial string, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator {
func NewLeastConnection(p *EndpointPool, initial string, mustBeSticky bool, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator {
return &LeastConnection{
pool: p,
initialEndpoint: initial,
mustBeSticky: mustBeSticky,
randomize: rand.New(rand.NewSource(time.Now().UnixNano())),
locallyOptimistic: locallyOptimistic,
localAvailabilityZone: localAvailabilityZone,
Expand All @@ -28,11 +30,17 @@ func (r *LeastConnection) Next(attempt int) *Endpoint {
var e *endpointElem
if r.initialEndpoint != "" {
e = r.pool.findById(r.initialEndpoint)
r.initialEndpoint = ""

if e != nil && e.isOverloaded() {
e = nil
}

if e == nil && r.mustBeSticky {
return nil
}

if !r.mustBeSticky {
r.initialEndpoint = ""
}
}

if e != nil {
Expand Down
Loading

0 comments on commit 3de570d

Please sign in to comment.