Skip to content

Commit

Permalink
add metrics interceptor for okhttp
Browse files Browse the repository at this point in the history
  • Loading branch information
aajtodd committed Jul 6, 2023
1 parent 2e97f3a commit 276429e
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ extra["moduleName"] = "aws.smithy.kotlin.runtime.http.engine.okhttp"

val coroutinesVersion: String by project
val okHttpVersion: String by project
val otelVersion: String by project

kotlin {
sourceSets {
Expand All @@ -31,6 +32,9 @@ kotlin {
jvmTest {
dependencies {
implementation(project(":runtime:testing"))
// use otel testing capabilities
implementation(project(":runtime:observability:telemetry-provider-otel"))
implementation("io.opentelemetry:opentelemetry-sdk-testing:$otelVersion")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.smithy.kotlin.runtime.http.engine.okhttp

import aws.smithy.kotlin.runtime.telemetry.metrics.MonotonicCounter
import aws.smithy.kotlin.runtime.util.Attributes
import aws.smithy.kotlin.runtime.util.attributesOf
import okhttp3.*
import okio.*

/**
* Instrument the HTTP throughput metrics (e.g. bytes rcvd/sent)
*/
internal object MetricsInterceptor : Interceptor {
override fun intercept(chain: Interceptor.Chain): Response {
val originalRequest = chain.request()
val metrics = originalRequest.tag<SdkRequestTag>()?.metrics ?: return chain.proceed(originalRequest)

val attrs = attributesOf { "server.address" to "${originalRequest.url.host}:${originalRequest.url.port}" }
val request = if (originalRequest.body != null) {
originalRequest.newBuilder()
.method(originalRequest.method, originalRequest.body?.instrument(metrics.bytesSent, attrs))
.build()
} else {
originalRequest
}

val originalResponse = chain.proceed(request)
val response = if (originalResponse.body.contentLength() != 0L) {
originalResponse.newBuilder()
.body(originalResponse.body.instrument(metrics.bytesReceived, attrs))
.build()
} else {
originalResponse
}

return response
}
}

internal class InstrumentedSink(
private val delegate: BufferedSink,
private val counter: MonotonicCounter,
private val attributes: Attributes,
) : Sink by delegate {
override fun write(source: Buffer, byteCount: Long) {
delegate.write(source, byteCount)
counter.add(byteCount, attributes)
}
override fun close() {
delegate.emit()
delegate.close()
}
}

internal class InstrumentedRequestBody(
private val delegate: RequestBody,
private val counter: MonotonicCounter,
private val attributes: Attributes,
) : RequestBody() {
override fun contentType(): MediaType? = delegate.contentType()
override fun isOneShot(): Boolean = delegate.isOneShot()
override fun isDuplex(): Boolean = delegate.isDuplex()
override fun contentLength(): Long = delegate.contentLength()
override fun writeTo(sink: BufferedSink) {
val metricsSink = InstrumentedSink(sink, counter, attributes).buffer()
delegate.writeTo(metricsSink)
metricsSink.close()
}
}

internal fun RequestBody.instrument(counter: MonotonicCounter, attributes: Attributes): RequestBody =
InstrumentedRequestBody(this, counter, attributes)

internal class InstrumentedSource(
private val delegate: Source,
private val counter: MonotonicCounter,
private val attributes: Attributes,
) : Source by delegate {
override fun timeout(): Timeout = delegate.timeout()
override fun read(sink: Buffer, byteCount: Long): Long {
val rc = delegate.read(sink, byteCount)
if (rc > 0L) {
counter.add(rc, attributes)
}
return rc
}
override fun close() {
delegate.close()
}
}

internal class InstrumentedResponseBody(
private val delegate: ResponseBody,
private val counter: MonotonicCounter,
private val attributes: Attributes,
) : ResponseBody() {
override fun contentType(): MediaType? = delegate.contentType()
override fun contentLength(): Long = delegate.contentLength()
override fun source(): BufferedSource =
InstrumentedSource(delegate.source(), counter, attributes).buffer()
}

internal fun ResponseBody.instrument(counter: MonotonicCounter, attributes: Attributes): ResponseBody =
InstrumentedResponseBody(this, counter, attributes)
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class OkHttpEngine(
override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall {
val callContext = callContext()

val engineRequest = request.toOkHttpRequest(context, callContext)
val engineRequest = request.toOkHttpRequest(context, callContext, metrics)
val engineCall = client.newCall(engineRequest)
val engineResponse = mapOkHttpExceptions { engineCall.executeAsync() }

Expand Down Expand Up @@ -132,6 +132,8 @@ private fun OkHttpEngineConfig.buildClient(metrics: HttpClientMetrics): OkHttpCl
proxyAuthenticator(OkHttpProxyAuthenticator(config.proxySelector))

dns(OkHttpDns(config.hostResolver))

addInterceptor(MetricsInterceptor)
}.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@ package aws.smithy.kotlin.runtime.http.engine.okhttp

import aws.smithy.kotlin.runtime.http.*
import aws.smithy.kotlin.runtime.http.engine.ProxyConfig
import aws.smithy.kotlin.runtime.http.engine.internal.HttpClientMetrics
import aws.smithy.kotlin.runtime.http.request.HttpRequest
import aws.smithy.kotlin.runtime.http.response.HttpResponse
import aws.smithy.kotlin.runtime.io.SdkSource
import aws.smithy.kotlin.runtime.io.internal.toSdk
import aws.smithy.kotlin.runtime.net.*
import aws.smithy.kotlin.runtime.operation.ExecutionContext
import kotlinx.coroutines.*
import okhttp3.*
import okhttp3.Authenticator
import okhttp3.Credentials
import okhttp3.Dns
import okhttp3.RequestBody.Companion.toRequestBody
import okhttp3.Route
import okhttp3.internal.http.HttpMethod
import java.io.IOException
import java.net.*
Expand All @@ -31,17 +30,18 @@ import okhttp3.Response as OkHttpResponse
/**
* SDK specific "tag" attached to an [okhttp3.Request] instance
*/
internal data class SdkRequestTag(val execContext: ExecutionContext, val callContext: CoroutineContext)
internal data class SdkRequestTag(val execContext: ExecutionContext, val callContext: CoroutineContext, val metrics: HttpClientMetrics)

/**
* Convert SDK [HttpRequest] to an [okhttp3.Request] instance
*/
internal fun HttpRequest.toOkHttpRequest(
execContext: ExecutionContext,
callContext: CoroutineContext,
metrics: HttpClientMetrics,
): OkHttpRequest {
val builder = OkHttpRequest.Builder()
builder.tag(SdkRequestTag::class, SdkRequestTag(execContext, callContext))
builder.tag(SdkRequestTag::class, SdkRequestTag(execContext, callContext, metrics))

builder.url(url.toString())

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.smithy.kotlin.runtime.http.engine.okhttp

import aws.smithy.kotlin.runtime.ExperimentalApi
import aws.smithy.kotlin.runtime.http.engine.internal.HttpClientMetrics
import aws.smithy.kotlin.runtime.operation.ExecutionContext
import aws.smithy.kotlin.runtime.telemetry.otel.OpenTelemetryProvider
import aws.smithy.kotlin.runtime.util.emptyAttributes
import io.opentelemetry.api.common.AttributeKey
import io.opentelemetry.sdk.metrics.data.MetricData
import io.opentelemetry.sdk.testing.junit5.OpenTelemetryExtension
import okhttp3.*
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.RequestBody.Companion.toRequestBody
import okhttp3.ResponseBody.Companion.toResponseBody
import okio.blackholeSink
import okio.buffer
import org.junit.jupiter.api.extension.RegisterExtension
import kotlin.coroutines.EmptyCoroutineContext
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.fail

@OptIn(ExperimentalApi::class)
class MetricsInterceptorTest {
companion object {
@JvmField
@RegisterExtension
val otelTesting = OpenTelemetryExtension.create()
}

private val provider = OpenTelemetryProvider(otelTesting.openTelemetry)
private val meter = provider.meterProvider.getOrCreateMeter("test")

@Test
fun testInstrumentedSource() {
val source = okio.Buffer()
val data = "a".repeat(15 * 1024)
source.writeUtf8(data)

val sink = okio.Buffer()
val counter = meter.createMonotonicCounter("TestCounter", "By")
val instrumented = InstrumentedSource(source, counter, emptyAttributes())
do {
val rc = instrumented.read(sink, 399)
} while (rc >= 0L)

assertEquals(data.length.toLong(), sink.size)

val counted = otelTesting.metrics.first().longCounterSum()
assertEquals(data.length.toLong(), counted)
}

@Test
fun testInstrumentedSink() {
val source = okio.Buffer()
val data = "b".repeat(13 * 1024)
source.writeUtf8(data)

val sink = okio.Buffer()
val counter = meter.createMonotonicCounter("TestCounter", "By")
val instrumented = InstrumentedSink(sink, counter, emptyAttributes())

val buffered = instrumented.buffer()
buffered.writeAll(source)
buffered.close()

assertEquals(data.length.toLong(), sink.size)

val counted = otelTesting.metrics.first().longCounterSum()
assertEquals(data.length.toLong(), counted)
}

@Test
fun testMetricsInterceptor() {
val reqData = "a".repeat(15 * 1024)
val reqBody = reqData.toRequestBody()
val metrics = HttpClientMetrics("test", provider)
val tag = SdkRequestTag(ExecutionContext(), EmptyCoroutineContext, metrics)
val request = Request.Builder()
.url("https://localhost:1/")
.method("PUT", reqBody)
.tag<SdkRequestTag>(tag)
.build()

val respData = "b".repeat(13 * 1024)
val respBody = respData.toResponseBody("text/plain; charset=utf-8".toMediaType())
val mockResp = Response.Builder()
.request(request)
.protocol(Protocol.HTTP_1_1)
.code(200)
.message("Intercepted")
.body(respBody)
.build()

val client = OkHttpClient.Builder()
.addInterceptor(MetricsInterceptor)
.addInterceptor { chain ->
// consume the body and short circuit with a mock response
chain.request().body?.writeTo(blackholeSink().buffer())
mockResp
}
.build()

val resp = client.newCall(request).execute()
val actualRespData = resp.body.source().readByteArray().decodeToString()
assertEquals(respData, actualRespData)

val actualBytesSent = otelTesting.metrics
.find { it.name == "smithy.client.http.bytes_sent" } ?: fail("expected bytes_sent")

val actualBytesReceived = otelTesting.metrics
.find { it.name == "smithy.client.http.bytes_received" } ?: fail("expected bytes_received")

assertEquals(reqData.length.toLong(), actualBytesSent.longCounterSum())
assertEquals(respData.length.toLong(), actualBytesReceived.longCounterSum())

val bytesSentAttr = actualBytesSent.longSumData.points.first().attributes.get(AttributeKey.stringKey("server.address"))
val bytesRecvAttr = actualBytesSent.longSumData.points.first().attributes.get(AttributeKey.stringKey("server.address"))
val expectedAttr = "localhost:1"
assertEquals(expectedAttr, bytesRecvAttr)
assertEquals(expectedAttr, bytesSentAttr)
}

private fun MetricData.longCounterSum(): Long = longSumData.points.sumOf { it.value }
}
Loading

0 comments on commit 276429e

Please sign in to comment.