Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor refresh of access and refresh tokens #21

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 18 additions & 39 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ var (

// accessTokenTimeout shows the access token expiry time.
// After the access token expires, one is required to obtain a new one
accessTokenTimeout = 60 * time.Minute

// refreshTokenTimeout shows the refresh token expiry time
refreshTokenTimeout = 24 * time.Hour
accessTokenTimeout = 59 * time.Minute
)

// AuthServerImpl defines the methods provided by
Expand All @@ -44,8 +41,7 @@ type client struct {
authServer AuthServerImpl
client *http.Client

refreshToken string
refreshTokenTicker *time.Ticker
refreshToken string

accessToken string
accessTokenTicker *time.Ticker
Expand Down Expand Up @@ -86,49 +82,30 @@ func mustNewClient(authServer AuthServerImpl) *client {
return client
}

// executed as a go routine to update the api tokens when they timeout
// executed as a go routine to update access and refresh token
func (s *client) background() {
for {
select {
case t := <-s.refreshTokenTicker.C:
logrus.Println("SIL Comms Refresh Token updated at: ", t)
err := s.login()
if err != nil {
s.authFailed = true
}
s.authFailed = false

case t := <-s.accessTokenTicker.C:
logrus.Println("SIL Comms Access Token updated at: ", t)
err := s.refreshAccessToken()
if err != nil {
s.authFailed = true
}
for t := range s.accessTokenTicker.C {
logrus.Println("SIL Comms Access Token updated at: ", t)
err := s.refreshAccessToken()
if err != nil {
s.authFailed = true
} else {
s.authFailed = false
}
}
}

// setAccessToken sets the access token and updates the ticker timer
func (s *client) setAccessToken(token string) {
s.accessToken = token
func (s *client) setRefreshAndAccessToken(token *TokenResponse) {
s.accessToken = token.Access
s.refreshToken = token.Refresh
if s.accessTokenTicker != nil {
s.accessTokenTicker.Reset(accessTokenTimeout)
} else {
s.accessTokenTicker = time.NewTicker(accessTokenTimeout)
}
}

// setRefreshToken sets the access token and updates the ticker timer
func (s *client) setRefreshToken(token string) {
s.refreshToken = token
if s.refreshTokenTicker != nil {
s.refreshTokenTicker.Reset(refreshTokenTimeout)
} else {
s.refreshTokenTicker = time.NewTicker(refreshTokenTimeout)
}
}

// login uses the provided credentials to login to the authserver backend
// It obtains the necessary tokens required to make authenticated requests
func (s *client) login() error {
Expand All @@ -149,12 +126,13 @@ func (s *client) login() error {
Refresh: resp.RefreshToken,
}

s.setRefreshToken(tokens.Refresh)
s.setAccessToken(tokens.Access)
s.setRefreshAndAccessToken(&tokens)

return nil
}

// refreshAccessToken makes a request to get
// new access and refresh tokens
func (s *client) refreshAccessToken() error {
ctx := context.Background()
resp, err := s.authServer.RefreshToken(ctx, s.refreshToken)
Expand All @@ -163,10 +141,11 @@ func (s *client) refreshAccessToken() error {
}

tokens := TokenResponse{
Access: resp.AccessToken,
Access: resp.AccessToken,
Refresh: resp.RefreshToken,
}

s.setAccessToken(tokens.Access)
s.setRefreshAndAccessToken(&tokens)

return nil
}
Expand Down
Loading