From ba56ed44080dbd27872fb6bbebc9b9197307f163 Mon Sep 17 00:00:00 2001 From: Marco Cieno Date: Tue, 30 Apr 2024 00:13:20 +0200 Subject: [PATCH] feat: allow user-defined client context (#110) --- cmd/aws-lambda-rie/handlers.go | 9 ++++++ .../local_lambda/test_end_to_end.py | 30 ++++++++++++++++--- test/integration/testdata/main.py | 4 +++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/cmd/aws-lambda-rie/handlers.go b/cmd/aws-lambda-rie/handlers.go index 42032cf..2cca12d 100644 --- a/cmd/aws-lambda-rie/handlers.go +++ b/cmd/aws-lambda-rie/handlers.go @@ -5,6 +5,7 @@ package main import ( "bytes" + "encoding/base64" "fmt" "io/ioutil" "math" @@ -81,6 +82,13 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs i return } + rawClientContext, err := base64.StdEncoding.DecodeString(r.Header.Get("X-Amz-Client-Context")) + if err != nil { + log.Errorf("Failed to decode X-Amz-Client-Context: %s", err) + w.WriteHeader(500) + return + } + initDuration := "" inv := GetenvWithDefault("AWS_LAMBDA_FUNCTION_TIMEOUT", "300") timeoutDuration, _ := time.ParseDuration(inv + "s") @@ -114,6 +122,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs i TraceID: r.Header.Get("X-Amzn-Trace-Id"), LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), Payload: bytes.NewReader(bodyBytes), + ClientContext: string(rawClientContext), } fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion) diff --git a/test/integration/local_lambda/test_end_to_end.py b/test/integration/local_lambda/test_end_to_end.py index 7c5486f..8e34b77 100644 --- a/test/integration/local_lambda/test_end_to_end.py +++ b/test/integration/local_lambda/test_end_to_end.py @@ -4,6 +4,8 @@ from subprocess import Popen, PIPE from unittest import TestCase, main from pathlib import Path +import base64 +import json import time import os import requests @@ -62,12 +64,14 @@ def run_command(self, cmd): def sleep_1s(self): time.sleep(SLEEP_TIME) - - def invoke_function(self): + + def invoke_function(self, json={}, headers={}): return requests.post( - f"http://localhost:{self.PORT}/2015-03-31/functions/function/invocations", json={} + f"http://localhost:{self.PORT}/2015-03-31/functions/function/invocations", + json=json, + headers=headers, ) - + @contextmanager def create_container(self, param, image): try: @@ -234,6 +238,24 @@ def test_port_override(self): self.assertEqual(b'"My lambda ran succesfully"', r.content) + def test_custom_client_context(self): + image, rie, image_name = self.tagged_name("custom_client_context") + + params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.custom_client_context_handler" + + with self.create_container(params, image): + r = self.invoke_function(headers={ + "X-Amz-Client-Context": base64.b64encode(json.dumps({ + "custom": { + "foo": "bar", + "baz": 123, + } + }).encode('utf8')).decode('utf8'), + }) + content = json.loads(r.content) + self.assertEqual("bar", content["foo"]) + self.assertEqual(123, content["baz"]) + if __name__ == "__main__": main() diff --git a/test/integration/testdata/main.py b/test/integration/testdata/main.py index b6b527d..9757be8 100644 --- a/test/integration/testdata/main.py +++ b/test/integration/testdata/main.py @@ -41,3 +41,7 @@ def check_remaining_time_handler(event, context): # Wait 1s to see if the remaining time changes time.sleep(1) return context.get_remaining_time_in_millis() + + +def custom_client_context_handler(event, context): + return context.client_context.custom