Skip to content

Commit

Permalink
feat: allow extensions to be disabled
Browse files Browse the repository at this point in the history
  • Loading branch information
Sudhakar Reddy committed Oct 5, 2024
1 parent 71388dd commit 41ea0b1
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ You can build RIE into a base image. Download the RIE from GitHub to your local
```sh
#!/bin/sh
if [ -z "${AWS_LAMBDA_RUNTIME_API}" ]; then
exec /usr/local/bin/aws-lambda-rie /usr/bin/npx aws-lambda-ric
exec /usr/local/bin/aws-lambda-rie /var/lang/bin/npx aws-lambda-ric $1
else
exec /usr/bin/npx aws-lambda-ric
exec /var/lang/bin/npx aws-lambda-ric $1
fi
```

Expand Down
9 changes: 8 additions & 1 deletion cmd/aws-lambda-rie/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,18 @@ func main() {
log.WithError(err).Fatalf("The command line value for \"--runtime-interface-emulator-address\" is not a valid network address %q.", opts.RuntimeInterfaceEmulatorAddress)
}

enableExtensions := true
envDisableExtensionValue, envDisableExtensionSet := os.LookupEnv("AWS_LAMBDA_RIE_DISABLE_EXTENSIONS")
if envDisableExtensionSet && envDisableExtensionValue != "FALSE" {
enableExtensions = false
log.Info("Disabled extensions")
}

bootstrap, handler := getBootstrap(args, opts)
sandbox := rapidcore.
NewSandboxBuilder().
AddShutdownFunc(context.CancelFunc(func() { os.Exit(0) })).
SetExtensionsFlag(true).
SetExtensionsFlag(enableExtensions).
SetInitCachingFlag(opts.InitCachingEnabled)

if len(handler) > 0 {
Expand Down
36 changes: 35 additions & 1 deletion test/integration/local_lambda/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def test_env_var_with_equal_sign(self):

self.assertEqual(b'"4=4"', r.content)


def test_two_invokes(self):
image, rie, image_name = self.tagged_name("twoinvokes")

Expand Down Expand Up @@ -255,7 +254,42 @@ def test_custom_client_context(self):
content = json.loads(r.content)
self.assertEqual("bar", content["foo"])
self.assertEqual(123, content["baz"])

def test_disable_extension_with_empty_env_val(self):
image, rie, image_name = self.tagged_name("disable_extension_check_with_empty_value")
params = f"--name {image} -d --env AWS_LAMBDA_RIE_DISABLE_EXTENSIONS= -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_extension_is_enabled_handler"

with self.create_container(params, image):
r = self.invoke_function()

self.assertEqual(b'"false"', r.content)

def test_disable_extension_with_non_empty_env_val(self):
image, rie, image_name = self.tagged_name("disable_extension_check_with_non-empty_value")
params = f"--name {image} -d --env AWS_LAMBDA_RIE_DISABLE_EXTENSIONS=somevalue -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_extension_is_enabled_handler"

with self.create_container(params, image):
r = self.invoke_function()

self.assertEqual(b'"false"', r.content)

def test_enable_extension_with_env_var(self):
image, rie, image_name = self.tagged_name("enable_extension_check_with_env_var")
params = f"--name {image} -d --env AWS_LAMBDA_RIE_DISABLE_EXTENSIONS=FALSE -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_extension_is_enabled_handler"

with self.create_container(params, image):
r = self.invoke_function()

self.assertEqual(b'"true"', r.content)

def test_enable_extension_without_env_var(self):
image, rie, image_name = self.tagged_name("enable_extension_without_env_var")
params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_extension_is_enabled_handler"

with self.create_container(params, image):
r = self.invoke_function()

self.assertEqual(b'"true"', r.content)

if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion test/integration/testdata/Dockerfile-allinone
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ FROM public.ecr.aws/lambda/python:3.12-$IMAGE_ARCH

WORKDIR /var/task
COPY ./ ./

# Copy extension
ADD bash-extension /opt/extensions/
# This is to verify env vars are parsed correctly before executing the function
ENV MyEnv="4=4"
42 changes: 42 additions & 0 deletions test/integration/testdata/bash-extension
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash

# Name of the extension
EXTENSION_NAME="bash-extension"

# Log file path
LOG_FILE="/tmp/extension.log"


# Function to register the extension with the Lambda service
register_extension() {
curl -s -D /tmp/headers -X POST "http://${AWS_LAMBDA_RUNTIME_API}/2020-01-01/extension/register" \
-H "Content-Type: application/json" \
-H "Lambda-Extension-Name: $EXTENSION_NAME" \
-d '{"events": ["INVOKE"]}'
EXTENSION_ID=$(cat /tmp/headers | grep "Lambda-Extension-Identifier" | grep -oP '[a-f0-9\-]{36}')
echo "Extension Id: $EXTENSION_ID" >> $LOG_FILE
}

# Function to process events
process_events() {
# Main loop
while true; do
echo "Waiting for next event"
EVENT_DATA=$(curl -s -X GET \
-H "Lambda-Extension-Identifier: $EXTENSION_ID" \
"http://${AWS_LAMBDA_RUNTIME_API}/2020-01-01/extension/event/next")

# Check if the event is an invocation
if [[ $(echo "$EVENT_DATA" | jq -r '.eventType') == "INVOKE" ]]; then
echo "Invocation event received: $EVENT_DATA"
# Log the invocation event data
echo "$EVENT_DATA" >> "$LOG_FILE"
fi
done
}

# Register the extension
register_extension

# Process events
process_events
4 changes: 4 additions & 0 deletions test/integration/testdata/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def success_handler(event, context):
def check_env_var_handler(event, context):
return os.environ.get("MyEnv")

def check_extension_is_enabled_handler(event, context):
if os.path.isfile("/tmp/extension.log"):
return "true"
return "false"

def assert_env_var_is_overwritten(event, context):
print(os.environ.get("AWS_LAMBDA_FUNCTION_NAME"))
Expand Down

0 comments on commit 41ea0b1

Please sign in to comment.