Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Session retains credentials when disk full #5056

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import software.aws.toolkits.core.utils.tryOrNull
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.telemetry.AuthTelemetry
import software.aws.toolkits.telemetry.Result
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.nio.file.Path
Expand All @@ -46,6 +49,7 @@ import java.time.Instant
import java.time.ZoneOffset
import java.time.format.DateTimeFormatter.ISO_INSTANT
import java.util.TimeZone
import java.util.concurrent.ConcurrentHashMap

/**
* Caches the [AccessToken] to disk to allow it to be re-used with other tools such as the CLI.
Expand Down Expand Up @@ -98,24 +102,30 @@ class DiskCache(

override fun invalidateClientRegistration(ssoRegion: String) {
LOG.debug { "invalidateClientRegistration for $ssoRegion" }
InMemoryCache.remove(clientRegistrationCache(ssoRegion).toString())
clientRegistrationCache(ssoRegion).tryDeleteIfExists()
}

override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration? {
LOG.debug { "loadClientRegistration for $cacheKey" }
val inputStream = clientRegistrationCache(cacheKey).tryInputStreamIfExists()
if (inputStream == null) {
val stage = LoadCredentialStage.ACCESS_FILE
LOG.warn { "Failed to load Client Registration: cache file does not exist" }
AuthTelemetry.modifyConnection(
action = "Load cache file",
source = "loadClientRegistration",
result = Result.Failed,
reason = "Failed to load Client Registration",
reasonDesc = "Load Step:$stage failed. Unable to load file"
)
return null
}
?: //try to load from in memory cache
return InMemoryCache.get(clientRegistrationCache(cacheKey).toString())?.let { data ->
ByteArrayInputStream(data).use { memoryStream ->
loadClientRegistration(memoryStream)
}
} ?: run {
val stage = LoadCredentialStage.ACCESS_FILE
LOG.warn { "Failed to load Client Registration: cache file does not exist" }
AuthTelemetry.modifyConnection(
action = "Load cache file",
source = "loadClientRegistration",
result = Result.Failed,
reason = "Failed to load Client Registration",
reasonDesc = "Load Step:$stage failed. Unable to load file"
)
null
}
return loadClientRegistration(inputStream)
}

Expand All @@ -130,6 +140,7 @@ class DiskCache(
override fun invalidateClientRegistration(cacheKey: ClientRegistrationCacheKey) {
LOG.debug { "invalidateClientRegistration for $cacheKey" }
try {
InMemoryCache.remove(clientRegistrationCache(cacheKey).toString())
clientRegistrationCache(cacheKey).tryDeleteIfExists()
} catch (e: Exception) {
AuthTelemetry.modifyConnection(
Expand All @@ -146,6 +157,7 @@ class DiskCache(
override fun invalidateAccessToken(ssoUrl: String) {
LOG.debug { "invalidateAccessToken for $ssoUrl" }
try {
InMemoryCache.remove(accessTokenCache(ssoUrl).toString())
accessTokenCache(ssoUrl).tryDeleteIfExists()
} catch (e: Exception) {
AuthTelemetry.modifyConnection(
Expand All @@ -162,11 +174,18 @@ class DiskCache(
override fun loadAccessToken(cacheKey: AccessTokenCacheKey): AccessToken? {
LOG.debug { "loadAccessToken for $cacheKey" }
val cacheFile = accessTokenCache(cacheKey)
val inputStream = cacheFile.tryInputStreamIfExists() ?: return null

val token = loadAccessToken(inputStream)

return token
// If file exists, returns InputStream, if not returns null
return cacheFile.tryInputStreamIfExists()
//try to load and parse access token, returns AccessToken or null if expired
?.let { loadAccessToken(it) }
// If file doesn't exist or loadAccessToken failed, try in-memory cache
samgst-amazon marked this conversation as resolved.
Show resolved Hide resolved
?: InMemoryCache.get(cacheFile.toString())?.let { data ->
// If in-memory cache has data, create stream and try to load token
ByteArrayInputStream(data).use { memoryStream ->
loadAccessToken(memoryStream)
}
// If both file system and in-memory cache attempts fail, returns null
}
}

override fun saveAccessToken(cacheKey: AccessTokenCacheKey, accessToken: AccessToken) {
Expand All @@ -180,6 +199,7 @@ class DiskCache(
override fun invalidateAccessToken(cacheKey: AccessTokenCacheKey) {
LOG.debug { "invalidateAccessToken for $cacheKey" }
try {
InMemoryCache.remove(accessTokenCache(cacheKey).toString())
accessTokenCache(cacheKey).tryDeleteIfExists()
} catch (e: Exception) {
AuthTelemetry.modifyConnection(
Expand Down Expand Up @@ -278,6 +298,14 @@ class DiskCache(
outputStream().use(consumer)
}
} catch (e: Exception) {
when {
e is IOException -> {
if (e.message?.contains("No space left on device") == true) {
LOG.warn { "Disk space full. Storing credentials in memory for this session" }
storeInMemory(path, consumer)
}
}
}
AuthTelemetry.modifyConnection(
action = "Write file",
source = "writeKey",
Expand All @@ -294,6 +322,27 @@ class DiskCache(

private fun AccessToken.isDefinitelyExpired(): Boolean = refreshToken == null && !expiresAt.isNotExpired()

private fun storeInMemory(path: Path, consumer: (OutputStream) -> Unit) {
val byteArrayOutputStream = ByteArrayOutputStream()
consumer(byteArrayOutputStream)
val data = byteArrayOutputStream.toByteArray()
InMemoryCache.put(path.toString(), data)
}

private object InMemoryCache {
samgst-amazon marked this conversation as resolved.
Show resolved Hide resolved
private val cache = ConcurrentHashMap<String, ByteArray>()

fun put(key: String, value: ByteArray) {
cache[key] = value
}

fun get(key: String): ByteArray? = cache[key]

fun remove(key: String) {
cache.remove(key)
}
}

private class CliCompatibleInstantDeserializer : StdDeserializer<Instant>(Instant::class.java) {
override fun deserialize(parser: JsonParser, context: DeserializationContext): Instant {
val dateString = parser.valueAsString
Expand Down
Loading