Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Session retains credentials when disk full #5056

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.fasterxml.jackson.databind.util.StdDateFormat
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import org.jetbrains.annotations.VisibleForTesting
import software.aws.toolkits.core.utils.createParentDirectories
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.deleteIfExists
Expand All @@ -35,6 +36,9 @@ import software.aws.toolkits.core.utils.tryOrNull
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.telemetry.AuthTelemetry
import software.aws.toolkits.telemetry.Result
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.nio.file.Path
Expand All @@ -46,6 +50,7 @@ import java.time.Instant
import java.time.ZoneOffset
import java.time.format.DateTimeFormatter.ISO_INSTANT
import java.util.TimeZone
import java.util.concurrent.ConcurrentHashMap

/**
* Caches the [AccessToken] to disk to allow it to be re-used with other tools such as the CLI.
Expand Down Expand Up @@ -98,12 +103,21 @@ class DiskCache(

override fun invalidateClientRegistration(ssoRegion: String) {
LOG.debug { "invalidateClientRegistration for $ssoRegion" }
InMemoryCache.remove(clientRegistrationCache(ssoRegion).toString())
clientRegistrationCache(ssoRegion).tryDeleteIfExists()
}

override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration? {
LOG.debug { "loadClientRegistration for $cacheKey" }
val inputStream = clientRegistrationCache(cacheKey).tryInputStreamIfExists()
val cacheFile = clientRegistrationCache(cacheKey)
// try InMemoryCacheFirst in case of stale registration on full disk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the disk be the source of truth?
consider the case where multiple different JetBrains IDE instances are running

InMemoryCache.get(cacheFile.toString())?.let { data ->
ByteArrayInputStream(data).use { memoryStream ->
return loadClientRegistration(memoryStream)
}
}

val inputStream = cacheFile.tryInputStreamIfExists()
if (inputStream == null) {
val stage = LoadCredentialStage.ACCESS_FILE
LOG.warn { "Failed to load Client Registration: cache file does not exist" }
Expand All @@ -130,6 +144,7 @@ class DiskCache(
override fun invalidateClientRegistration(cacheKey: ClientRegistrationCacheKey) {
LOG.debug { "invalidateClientRegistration for $cacheKey" }
try {
InMemoryCache.remove(clientRegistrationCache(cacheKey).toString())
clientRegistrationCache(cacheKey).tryDeleteIfExists()
} catch (e: Exception) {
AuthTelemetry.modifyConnection(
Expand All @@ -146,6 +161,7 @@ class DiskCache(
override fun invalidateAccessToken(ssoUrl: String) {
LOG.debug { "invalidateAccessToken for $ssoUrl" }
try {
InMemoryCache.remove(accessTokenCache(ssoUrl).toString())
accessTokenCache(ssoUrl).tryDeleteIfExists()
} catch (e: Exception) {
AuthTelemetry.modifyConnection(
Expand All @@ -159,9 +175,17 @@ class DiskCache(
}
}


override fun loadAccessToken(cacheKey: AccessTokenCacheKey): AccessToken? {
LOG.debug { "loadAccessToken for $cacheKey" }
val cacheFile = accessTokenCache(cacheKey)
// try InMemoryCacheFirst in case of stale token on full disk
InMemoryCache.get(cacheFile.toString())?.let { data ->
ByteArrayInputStream(data).use { memoryStream ->
return loadAccessToken(memoryStream)
}
}

val inputStream = cacheFile.tryInputStreamIfExists() ?: return null

val token = loadAccessToken(inputStream)
Expand All @@ -180,6 +204,7 @@ class DiskCache(
override fun invalidateAccessToken(cacheKey: AccessTokenCacheKey) {
LOG.debug { "invalidateAccessToken for $cacheKey" }
try {
InMemoryCache.remove(accessTokenCache(cacheKey).toString())
accessTokenCache(cacheKey).tryDeleteIfExists()
} catch (e: Exception) {
AuthTelemetry.modifyConnection(
Expand Down Expand Up @@ -278,6 +303,14 @@ class DiskCache(
outputStream().use(consumer)
}
} catch (e: Exception) {
when {
e is IOException -> {
if (e.message?.contains("No space left on device") == true) {
LOG.warn { "Disk space full. Storing credentials in memory for this session" }
storeInMemory(path, consumer)
}
}
}
AuthTelemetry.modifyConnection(
action = "Write file",
source = "writeKey",
Expand All @@ -294,6 +327,28 @@ class DiskCache(

private fun AccessToken.isDefinitelyExpired(): Boolean = refreshToken == null && !expiresAt.isNotExpired()

private fun storeInMemory(path: Path, consumer: (OutputStream) -> Unit) {
val byteArrayOutputStream = ByteArrayOutputStream()
consumer(byteArrayOutputStream)
val data = byteArrayOutputStream.toByteArray()
InMemoryCache.put(path.toString(), data)
}

@VisibleForTesting
internal object InMemoryCache {
private val cache = ConcurrentHashMap<String, ByteArray>()

fun put(key: String, value: ByteArray) {
cache[key] = value
}

fun get(key: String): ByteArray? = cache[key]

fun remove(key: String) {
cache.remove(key)
}
}

private class CliCompatibleInstantDeserializer : StdDeserializer<Instant>(Instant::class.java) {
override fun deserialize(parser: JsonParser, context: DeserializationContext): Instant {
val dateString = parser.valueAsString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@ import com.intellij.openapi.util.SystemInfo
import com.intellij.openapi.util.io.NioFiles
import com.intellij.testFramework.ApplicationExtension
import org.assertj.core.api.Assertions.assertThat
import org.junit.Assert.assertEquals
import org.junit.Assert.assertNotNull
import org.junit.Assert.assertNull
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.DisabledOnOs
import org.junit.jupiter.api.condition.OS
import org.junit.jupiter.api.extension.ExtendWith
import org.junit.jupiter.api.io.TempDir
import org.mockito.Mockito.mockStatic
import org.mockito.kotlin.any
import org.mockito.kotlin.whenever
import software.aws.toolkits.core.utils.readText
import software.aws.toolkits.core.utils.test.assertPosixPermissions
import software.aws.toolkits.core.utils.writeText
import java.io.IOException
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
Expand Down Expand Up @@ -49,6 +56,14 @@ class DiskCacheTest {
sut = DiskCache(cacheLocation, clock)
}

fun setupMockOutputStreamThrowingIOException() {
mockStatic(Files::class.java).use { mockedFiles ->
whenever(Files.newOutputStream(any<Path>())).thenThrow(
IOException("No space left on device")
)
}
}

@Test
fun nonExistentClientRegistrationReturnsNull() {
assertThat(
Expand Down Expand Up @@ -713,4 +728,207 @@ class DiskCacheTest {
.usingRecursiveComparison()
.isEqualTo(sut.loadAccessToken(key2))
}

@Test
fun `saveAccessToken falls back to InMemoryCache when disk is full`() {
setupMockOutputStreamThrowingIOException()
// Mock the writeKey method to simulate a disk full scenario
val key = PKCEAccessTokenCacheKey(ssoUrl, ssoRegion, scopes)
val testToken = PKCEAuthorizationGrantToken(
ssoUrl,
ssoRegion,
"test_access_token",
"test_refresh_token",
Instant.now().plusSeconds(3600),
Instant.now()
)

sut.saveAccessToken(key, testToken)

val loadedToken = sut.loadAccessToken(key)
assertNotNull(loadedToken)
assertEquals(testToken, loadedToken)
}

@Test
fun `saveClientRegistration falls back to InMemoryCache when disk is full`() {
setupMockOutputStreamThrowingIOException()
val key = PKCEClientRegistrationCacheKey(
issuerUrl = ssoUrl,
scopes = scopes,
region = ssoRegion,
clientType = "public",
grantTypes = listOf("authorization_code", "refresh_token"),
redirectUris = listOf("http://127.0.0.1/oauth/callback")
)
val testRegistration = PKCEClientRegistration(
"test_client_id",
"test_client_secret",
Instant.now().plusSeconds(3600),
scopes,
ssoUrl,
ssoRegion,
"public",
listOf("authorization_code", "refresh_token"),
listOf("http://127.0.0.1/oauth/callback")
)

sut.saveClientRegistration(key, testRegistration)

val loadedRegistration = sut.loadClientRegistration(key)
assertNotNull(loadedRegistration)
assertEquals(testRegistration, loadedRegistration)
}

@Test
fun `invalidateAccessToken removes token from InMemoryCache when disk is full`() {
setupMockOutputStreamThrowingIOException()
val key = PKCEAccessTokenCacheKey(ssoUrl, ssoRegion, scopes)
val testToken = PKCEAuthorizationGrantToken(
ssoUrl,
ssoRegion,
"test_access_token",
"test_refresh_token",
Instant.now().plusSeconds(3600),
Instant.now()
)

sut.saveAccessToken(key, testToken)
sut.invalidateAccessToken(key)

val loadedToken = sut.loadAccessToken(key)
assertNull(loadedToken)
}

@Test
fun `invalidateClientRegistration removes registration from InMemoryCache when disk is full`() {
setupMockOutputStreamThrowingIOException()
val key = PKCEClientRegistrationCacheKey(
issuerUrl = ssoUrl,
scopes = scopes,
region = ssoRegion,
clientType = "public",
grantTypes = listOf("authorization_code", "refresh_token"),
redirectUris = listOf("http://127.0.0.1/oauth/callback")
)
val testRegistration = PKCEClientRegistration(
"test_client_id",
"test_client_secret",
Instant.now().plusSeconds(3600),
scopes,
ssoUrl,
ssoRegion,
"public",
listOf("authorization_code", "refresh_token"),
listOf("http://127.0.0.1/oauth/callback")
)

sut.saveClientRegistration(key, testRegistration)
sut.invalidateClientRegistration(key)

val loadedRegistration = sut.loadClientRegistration(key)
assertNull(loadedRegistration)
}

@Test
fun `test client registration update with disk error falls back to memory cache`() {
// Create a cache key and initial registration
val key = PKCEClientRegistrationCacheKey(
issuerUrl = ssoUrl,
scopes = scopes,
region = ssoRegion,
clientType = "public",
grantTypes = listOf("authorization_code", "refresh_token"),
redirectUris = listOf("http://127.0.0.1/oauth/callback")
)
val initialRegistration = PKCEClientRegistration(
"stale_id",
"stale_client_secret",
Instant.now().plusSeconds(3600),
scopes,
ssoUrl,
ssoRegion,
"public",
listOf("authorization_code", "refresh_token"),
listOf("http://127.0.0.1/oauth/callback")
)

// Save initial registration (should save to disk)
sut.saveClientRegistration(key, initialRegistration)

// Verify initial save
val loadedInitial = sut.loadClientRegistration(key)
assertNotNull(loadedInitial)
assertEquals(initialRegistration, loadedInitial)

// Setup mock to throw IOException for future disk writes
setupMockOutputStreamThrowingIOException()

// Create updated registration
val updatedRegistration = PKCEClientRegistration(
"fresh_ID",
"fresh_client_secret",
Instant.now().plusSeconds(3600),
scopes,
ssoUrl,
ssoRegion,
"public",
listOf("authorization_code", "refresh_token"),
listOf("http://127.0.0.1/oauth/callback")
)

// Try to save updated registration (should fall back to in-memory cache)
sut.saveClientRegistration(key, updatedRegistration)

// Load registration again
val loadedUpdated = sut.loadClientRegistration(key)

// Verify that we get the updated registration, not the initial one
assertNotNull(loadedUpdated)
assertEquals(updatedRegistration, loadedUpdated)
}

@Test
fun `test access token update with disk error falls back to memory cache`() {
val key = PKCEAccessTokenCacheKey(ssoUrl, ssoRegion, scopes)
val initialToken = PKCEAuthorizationGrantToken(
ssoUrl,
ssoRegion,
"stale_access_token",
"stale_refresh_token",
Instant.now().plusSeconds(3600),
Instant.now()
)

// Save initial token (should save to disk)
sut.saveAccessToken(key, initialToken)

// Verify initial save
val loadedInitial = sut.loadAccessToken(key)
assertNotNull(loadedInitial)
assertEquals(initialToken, loadedInitial)

// Setup mock to throw IOException for future disk writes
setupMockOutputStreamThrowingIOException()

// Create updated token
val updatedToken = PKCEAuthorizationGrantToken(
ssoUrl,
ssoRegion,
"fresh_access_token",
"fresh_refresh_token",
Instant.now().plusSeconds(3600),
Instant.now()
)

// Try to save updated token (should fall back to in-memory cache)
sut.saveAccessToken(key, updatedToken)

// Load token again
val loadedUpdated = sut.loadAccessToken(key)

// Verify that we get the updated token, not the initial one
assertNotNull(loadedUpdated)
assertEquals(updatedToken, loadedUpdated)
}
}
Loading