From 2e652c4fcef3acce182740502bdea41ce786cde3 Mon Sep 17 00:00:00 2001 From: Renato Valenzuela Date: Wed, 10 Nov 2021 20:05:09 +0000 Subject: [PATCH] Pulling upstream changes 2021-11 --- cmd/aws-lambda-rie/main.go | 11 +- lambda/agents/agent.go | 7 +- lambda/agents/agent_test.go | 62 ++++++- lambda/appctx/appctxutil.go | 88 +++++++-- lambda/appctx/appctxutil_test.go | 77 ++++++++ lambda/core/credentials.go | 119 +++++++++++++ lambda/core/credentials_test.go | 98 ++++++++++ lambda/core/directinvoke/directinvoke.go | 6 + lambda/core/runtime_state_names.go | 1 + lambda/core/states.go | 167 ++++++------------ lambda/core/states_test.go | 77 +++++++- lambda/interop/model.go | 29 ++- lambda/logging/internal_log.go | 33 ++++ lambda/logging/internal_log_test.go | 88 +++++++++ lambda/metering/time.go | 7 +- lambda/rapi/handler/agentregister_test.go | 15 +- lambda/rapi/handler/credentials.go | 40 +++++ lambda/rapi/handler/credentials_test.go | 91 ++++++++++ lambda/rapi/handler/initerror.go | 13 +- lambda/rapi/handler/invocationerror.go | 14 +- lambda/rapi/handler/invocationerror_test.go | 1 + lambda/rapi/handler/invocationresponse.go | 15 +- .../rapi/handler/invocationresponse_test.go | 4 +- lambda/rapi/handler/runtimelogs.go | 24 +-- lambda/rapi/handler/runtimelogs_test.go | 52 +++--- lambda/rapi/router.go | 18 +- lambda/rapi/server.go | 11 +- lambda/rapid/sandbox.go | 45 +++-- lambda/rapid/start.go | 164 ++++++++++++++--- lambda/rapid/start_test.go | 60 ++++++- lambda/rapidcore/bootstrap.go | 5 + lambda/rapidcore/bootstrap_test.go | 7 + lambda/rapidcore/env/environment.go | 13 ++ lambda/rapidcore/env/environment_test.go | 20 +++ lambda/rapidcore/sandbox.go | 65 ++++--- lambda/rapidcore/server.go | 28 +-- lambda/rapidcore/server_test.go | 11 +- lambda/rapidcore/standalone/executeHandler.go | 8 + lambda/rapidcore/standalone/invokeHandler.go | 16 ++ lambda/rapidcore/standalone/util.go | 6 +- lambda/runtimecmd/runtime_command.go | 6 +- lambda/runtimecmd/runtime_command_test.go | 6 +- lambda/telemetry/events_api.go | 14 ++ lambda/telemetry/logs_api.go | 19 -- lambda/telemetry/logs_egress_api.go | 26 +++ lambda/telemetry/logs_subscription_api.go | 37 ++++ lambda/testdata/agents/bash_echo.sh | 4 - lambda/testdata/agents/bash_stderr.sh | 5 + lambda/testdata/agents/bash_stdout.sh | 5 + .../testdata/agents/bash_stdout_and_stderr.sh | 8 + lambda/testdata/flowtesting.go | 47 +++-- 51 files changed, 1429 insertions(+), 364 deletions(-) create mode 100644 lambda/core/credentials.go create mode 100644 lambda/core/credentials_test.go create mode 100644 lambda/rapi/handler/credentials.go create mode 100644 lambda/rapi/handler/credentials_test.go create mode 100644 lambda/telemetry/events_api.go delete mode 100644 lambda/telemetry/logs_api.go create mode 100644 lambda/telemetry/logs_egress_api.go create mode 100644 lambda/telemetry/logs_subscription_api.go delete mode 100755 lambda/testdata/agents/bash_echo.sh create mode 100755 lambda/testdata/agents/bash_stderr.sh create mode 100755 lambda/testdata/agents/bash_stdout.sh create mode 100755 lambda/testdata/agents/bash_stdout_and_stderr.sh diff --git a/cmd/aws-lambda-rie/main.go b/cmd/aws-lambda-rie/main.go index a151ae7..3a87e46 100644 --- a/cmd/aws-lambda-rie/main.go +++ b/cmd/aws-lambda-rie/main.go @@ -21,7 +21,8 @@ const ( ) type options struct { - LogLevel string `long:"log-level" default:"info" description:"log level"` + LogLevel string `long:"log-level" default:"info" description:"log level"` + InitCachingEnabled bool `long:"enable-init-caching" description:"Enable support for Init Caching"` } func main() { @@ -32,7 +33,11 @@ func main() { rapidcore.SetLogLevel(opts.LogLevel) bootstrap, handler := getBootstrap(args, opts) - sandbox := rapidcore.NewSandboxBuilder(bootstrap).AddShutdownFunc(context.CancelFunc(func() { os.Exit(0) })).SetExtensionsFlag(true) + sandbox := rapidcore. + NewSandboxBuilder(bootstrap). + AddShutdownFunc(context.CancelFunc(func() { os.Exit(0) })). + SetExtensionsFlag(true). + SetInitCachingFlag(opts.InitCachingEnabled) if len(handler) > 0 { sandbox.SetHandler(handler) @@ -72,7 +77,7 @@ func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) { fmt.Sprintf("%s/bootstrap", currentWorkingDir), } - if !isBootstrapFileExist(bootstrapLookupCmd[0]) { + if !isBootstrapFileExist(bootstrapLookupCmd[0]) { var bootstrapCmdCandidates = []string{ optBootstrap, runtimeBootstrap, diff --git a/lambda/agents/agent.go b/lambda/agents/agent.go index 0e0ec19..16625c2 100644 --- a/lambda/agents/agent.go +++ b/lambda/agents/agent.go @@ -25,13 +25,12 @@ type ExternalAgentProcess struct { } // NewExternalAgentProcess returns a new external agent process -func NewExternalAgentProcess(path string, env []string, logWriter io.Writer) ExternalAgentProcess { +func NewExternalAgentProcess(path string, env []string, stdoutWriter io.Writer, stderrWriter io.Writer) ExternalAgentProcess { command := exec.Command(path) command.Env = env - w := NewNewlineSplitWriter(logWriter) - command.Stdout = w - command.Stderr = w + command.Stdout = NewNewlineSplitWriter(stdoutWriter) + command.Stderr = NewNewlineSplitWriter(stderrWriter) command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} return ExternalAgentProcess{ diff --git a/lambda/agents/agent_test.go b/lambda/agents/agent_test.go index efcc567..d314a76 100644 --- a/lambda/agents/agent_test.go +++ b/lambda/agents/agent_test.go @@ -204,7 +204,7 @@ func TestFindAgentMixed(t *testing.T) { // Test our ability to start agents func TestAgentStart(t *testing.T) { assert := assert.New(t) - agent := NewExternalAgentProcess("../testdata/agents/bash_true.sh", []string{}, &mockWriter{}) + agent := NewExternalAgentProcess("../testdata/agents/bash_true.sh", []string{}, &mockWriter{}, &mockWriter{}) assert.Nil(agent.Start()) assert.Nil(agent.Wait()) } @@ -212,22 +212,68 @@ func TestAgentStart(t *testing.T) { // Test that execution of invalid agents is correctly reported func TestInvalidAgentStart(t *testing.T) { assert := assert.New(t) - agent := NewExternalAgentProcess("/bin/none", []string{}, &mockWriter{}) + agent := NewExternalAgentProcess("/bin/none", []string{}, &mockWriter{}, &mockWriter{}) assert.True(os.IsNotExist(agent.Start())) } -// Test that execution of invalid agents is correctly reported -func TestAgentTelemetry(t *testing.T) { +func TestAgentStdoutWriter(t *testing.T) { + // Given + assert := assert.New(t) + + stdout := &mockWriter{} + stderr := &mockWriter{} + expectedStdout := "stdout line 1\nstdout line 2\nstdout line 3\n" + expectedStderr := "" + + agent := NewExternalAgentProcess("../testdata/agents/bash_stdout.sh", []string{}, stdout, stderr) + + // When + assert.NoError(agent.Start()) + assert.NoError(agent.Wait()) + + // Then + assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) + assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) +} + +func TestAgentStderrWriter(t *testing.T) { + // Given assert := assert.New(t) - buffer := &mockWriter{} - agent := NewExternalAgentProcess("../testdata/agents/bash_echo.sh", []string{}, buffer) + stdout := &mockWriter{} + stderr := &mockWriter{} + expectedStdout := "" + expectedStderr := "stderr line 1\nstderr line 2\nstderr line 3\n" + + agent := NewExternalAgentProcess("../testdata/agents/bash_stderr.sh", []string{}, stdout, stderr) + + // When + assert.NoError(agent.Start()) + assert.NoError(agent.Wait()) + + // Then + assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) + assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) +} + +func TestAgentStdoutAndStderrSeperateWriters(t *testing.T) { + // Given + assert := assert.New(t) + + stdout := &mockWriter{} + stderr := &mockWriter{} + expectedStdout := "stdout line 1\nstdout line 2\nstdout line 3\n" + expectedStderr := "stderr line 1\nstderr line 2\nstderr line 3\n" + + agent := NewExternalAgentProcess("../testdata/agents/bash_stdout_and_stderr.sh", []string{}, stdout, stderr) + // When assert.NoError(agent.Start()) assert.NoError(agent.Wait()) - message := "hello world\n|barbaz\n|hello world\n|barbaz2" - assert.Equal(message, string(bytes.Join(buffer.bytesReceived, []byte("|")))) + // Then + assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) + assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) } type mockWriter struct { diff --git a/lambda/appctx/appctxutil.go b/lambda/appctx/appctxutil.go index a5f7266..a3e652f 100644 --- a/lambda/appctx/appctxutil.go +++ b/lambda/appctx/appctxutil.go @@ -5,11 +5,10 @@ package appctx import ( "context" - "net/http" - "strings" - "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" + "net/http" + "strings" log "github.com/sirupsen/logrus" ) @@ -24,6 +23,9 @@ type ReqCtxKey int // context object into request context. const ReqCtxApplicationContextKey ReqCtxKey = iota +// MaxRuntimeReleaseLength Max length for user agent string. +const MaxRuntimeReleaseLength = 128 + // FromRequest retrieves application context from the request context. func FromRequest(request *http.Request) ApplicationContext { return request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) @@ -39,24 +41,78 @@ func GetRuntimeRelease(appCtx ApplicationContext) string { return appCtx.GetOrDefault(AppCtxRuntimeReleaseKey, "").(string) } -// UpdateAppCtxWithRuntimeRelease extracts runtime release info from user agent header and put it into appCtx. +// GetUserAgentFromRequest Returns the first token -seperated by a space- +// from request header 'User-Agent'. +func GetUserAgentFromRequest(request *http.Request) string { + runtimeRelease := "" + userAgent := request.Header.Get("User-Agent") + // Split around spaces and use only the first token. + if fields := strings.Fields(userAgent); len(fields) > 0 && len(fields[0]) > 0 { + runtimeRelease = fields[0] + } + return runtimeRelease +} + +// CreateRuntimeReleaseFromRequest Gets runtime features from request header +// 'Lambda-Runtime-Features', and append it to the given runtime release. +func CreateRuntimeReleaseFromRequest(request *http.Request, runtimeRelease string) string { + lambdaRuntimeFeaturesHeader := request.Header.Get("Lambda-Runtime-Features") + + // "(", ")" are not valid token characters, and potentially could invalidate runtime_release + lambdaRuntimeFeaturesHeader = strings.ReplaceAll(lambdaRuntimeFeaturesHeader, "(", "") + lambdaRuntimeFeaturesHeader = strings.ReplaceAll(lambdaRuntimeFeaturesHeader, ")", "") + + numberOfAppendedFeatures := 0 + // Available length is a maximum length available for runtime features (including delimiters). From maximal runtime + // release length we subtract what we already have plus 3 additional bytes for a space and a pair of brackets for + // list of runtime features that is added later. + runtimeReleaseLength := len(runtimeRelease) + if runtimeReleaseLength == 0 { + runtimeReleaseLength = len("Unknown") + } + availableLength := MaxRuntimeReleaseLength - runtimeReleaseLength - 3 + var lambdaRuntimeFeatures []string + + for _, feature := range strings.Fields(lambdaRuntimeFeaturesHeader) { + featureLength := len(feature) + // If featureLength <= availableLength - numberOfAppendedFeatures + // (where numberOfAppendedFeatures is equal to number of delimiters needed). + if featureLength <= availableLength-numberOfAppendedFeatures { + availableLength -= featureLength + lambdaRuntimeFeatures = append(lambdaRuntimeFeatures, feature) + numberOfAppendedFeatures++ + } + } + // Append valid features to runtime release. + if len(lambdaRuntimeFeatures) > 0 { + if runtimeRelease == "" { + runtimeRelease = "Unknown" + } + runtimeRelease += " (" + strings.Join(lambdaRuntimeFeatures, " ") + ")" + } + + return runtimeRelease +} + +// UpdateAppCtxWithRuntimeRelease extracts runtime release info from user agent & lambda runtime features +// headers and update it into appCtx. // Sample UA: // Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:47.0) Gecko/20100101 Firefox/47.0 func UpdateAppCtxWithRuntimeRelease(request *http.Request, appCtx ApplicationContext) bool { - // If appCtx has runtime release value already, skip updating for consistency. - if len(GetRuntimeRelease(appCtx)) > 0 { - return false - } - - userAgent := request.Header.Get("User-Agent") - if len(userAgent) == 0 { + // If appCtx has runtime release value already, just append the runtime features. + if appCtxRuntimeRelease := GetRuntimeRelease(appCtx); len(appCtxRuntimeRelease) > 0 { + // if the runtime features are not appended before append them, otherwise ignore + if runtimeReleaseWithFeatures := CreateRuntimeReleaseFromRequest(request, appCtxRuntimeRelease); len(runtimeReleaseWithFeatures) > len(appCtxRuntimeRelease) && + appCtxRuntimeRelease[len(appCtxRuntimeRelease)-1] != ')' { + appCtx.Store(AppCtxRuntimeReleaseKey, runtimeReleaseWithFeatures) + return true + } return false } - - // Split around spaces and use only the first token. - if fields := strings.Fields(userAgent); len(fields) > 0 && len(fields[0]) > 0 { - appCtx.Store(AppCtxRuntimeReleaseKey, - fields[0]) + // If appCtx doesn't have runtime release value, update it with user agent and runtime features. + if runtimeReleaseWithFeatures := CreateRuntimeReleaseFromRequest(request, + GetUserAgentFromRequest(request)); runtimeReleaseWithFeatures != "" { + appCtx.Store(AppCtxRuntimeReleaseKey, runtimeReleaseWithFeatures) return true } return false diff --git a/lambda/appctx/appctxutil_test.go b/lambda/appctx/appctxutil_test.go index 50a48ac..a8a4761 100644 --- a/lambda/appctx/appctxutil_test.go +++ b/lambda/appctx/appctxutil_test.go @@ -5,6 +5,7 @@ package appctx import ( "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -30,6 +31,63 @@ func runTestRequestWithUserAgent(t *testing.T, userAgent string, expectedRuntime assert.Equal(t, expectedRuntimeRelease, ctxRuntimeRelease, "failed to extract runtime_release token") } +func TestCreateRuntimeReleaseFromRequest(t *testing.T) { + tests := map[string]struct { + userAgentHeader string + lambdaRuntimeFeaturesHeader string + expectedRuntimeRelease string + }{ + "No User-Agent header": { + userAgentHeader: "", + lambdaRuntimeFeaturesHeader: "httpcl/2.0 execwr", + expectedRuntimeRelease: "Unknown (httpcl/2.0 execwr)", + }, + "No Lambda-Runtime-Features header": { + userAgentHeader: "Node.js/14.16.0", + lambdaRuntimeFeaturesHeader: "", + expectedRuntimeRelease: "Node.js/14.16.0", + }, + "Lambda-Runtime-Features header with additional spaces": { + userAgentHeader: "Node.js/14.16.0", + lambdaRuntimeFeaturesHeader: "httpcl/2.0 execwr", + expectedRuntimeRelease: "Node.js/14.16.0 (httpcl/2.0 execwr)", + }, + "Lambda-Runtime-Features header with special characters": { + userAgentHeader: "Node.js/14.16.0", + lambdaRuntimeFeaturesHeader: "httpcl/2.0@execwr-1 abcd?efg nodewr/(4.33)) nodewr/4.3", + expectedRuntimeRelease: "Node.js/14.16.0 (httpcl/2.0@execwr-1 abcd?efg nodewr/4.33 nodewr/4.3)", + }, + "Lambda-Runtime-Features header with long Lambda-Runtime-Features header": { + userAgentHeader: "Node.js/14.16.0", + lambdaRuntimeFeaturesHeader: strings.Repeat("abcdef ", MaxRuntimeReleaseLength/7), + expectedRuntimeRelease: "Node.js/14.16.0 (" + strings.Repeat("abcdef ", (MaxRuntimeReleaseLength-18-6)/7) + "abcdef)", + }, + "Lambda-Runtime-Features header with long Lambda-Runtime-Features header with UTF-8 characters": { + userAgentHeader: "Node.js/14.16.0", + lambdaRuntimeFeaturesHeader: strings.Repeat("我爱亚马逊 ", MaxRuntimeReleaseLength/16), + expectedRuntimeRelease: "Node.js/14.16.0 (" + strings.Repeat("我爱亚马逊 ", (MaxRuntimeReleaseLength-18-15)/16) + "我爱亚马逊)", + }, + } + + for _, tc := range tests { + req := httptest.NewRequest("", "/", nil) + if tc.userAgentHeader != "" { + req.Header.Set("User-Agent", tc.userAgentHeader) + } + if tc.lambdaRuntimeFeaturesHeader != "" { + req.Header.Set("Lambda-Runtime-Features", tc.lambdaRuntimeFeaturesHeader) + } + appCtx := NewApplicationContext() + request := RequestWithAppCtx(req, appCtx) + + UpdateAppCtxWithRuntimeRelease(request, appCtx) + runtimeRelease := GetRuntimeRelease(appCtx) + + assert.LessOrEqual(t, len(runtimeRelease), MaxRuntimeReleaseLength) + assert.Equal(t, tc.expectedRuntimeRelease, runtimeRelease) + } +} + func TestUpdateAppCtxWithRuntimeRelease(t *testing.T) { type pair struct { in, wanted string @@ -74,6 +132,25 @@ func TestUpdateAppCtxWithRuntimeReleaseWithBlankUserAgent(t *testing.T) { assert.False(t, ok) } +func TestUpdateAppCtxWithRuntimeReleaseWithLambdaRuntimeFeatures(t *testing.T) { + // GIVEN + // Simple LambdaRuntimeFeatures passed. + req := httptest.NewRequest("", "/", nil) + req.Header.Set("User-Agent", "Node.js/14.16.0") + req.Header.Set("Lambda-Runtime-Features", "httpcl/2.0 execwr nodewr/4.3") + request := RequestWithAppCtx(req, NewApplicationContext()) + appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) + + // DO + ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) + + //ASSERT + assert.True(t, ok, "runtime_release updated based only on User-Agent and valid features") + ctxRuntimeRelease, ok := appCtx.Load(AppCtxRuntimeReleaseKey) + assert.True(t, ok) + assert.Equal(t, "Node.js/14.16.0 (httpcl/2.0 execwr nodewr/4.3)", ctxRuntimeRelease) +} + // Test that RAPID allows updating runtime_release only once func TestUpdateAppCtxWithRuntimeReleaseMultipleTimes(t *testing.T) { // GIVEN diff --git a/lambda/core/credentials.go b/lambda/core/credentials.go new file mode 100644 index 0000000..7b1bf14 --- /dev/null +++ b/lambda/core/credentials.go @@ -0,0 +1,119 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "fmt" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + UNBLOCKED = iota + BLOCKED +) + +var ErrCredentialsNotFound = fmt.Errorf("credentials not found for the provided token") + +type Credentials struct { + AwsKey string `json:"AccessKeyId"` + AwsSecret string `json:"SecretAccessKey"` + AwsSession string `json:"Token"` + Expiration time.Time `json:"Expiration"` +} + +type CredentialsService interface { + SetCredentials(token, awsKey, awsSecret, awsSession string) + GetCredentials(token string) (*Credentials, error) + UpdateCredentials(awsKey, awsSecret, awsSession string) error + BlockService() + UnblockService() +} + +type credentialsServiceImpl struct { + credentials map[string]Credentials + contentMutex *sync.Mutex + serviceMutex *sync.Mutex + currentState int +} + +func NewCredentialsService() CredentialsService { + credentialsService := &credentialsServiceImpl{ + credentials: make(map[string]Credentials), + contentMutex: &sync.Mutex{}, + serviceMutex: &sync.Mutex{}, + currentState: UNBLOCKED, + } + + return credentialsService +} + +func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSession string) { + c.contentMutex.Lock() + defer c.contentMutex.Unlock() + + c.credentials[token] = Credentials{ + AwsKey: awsKey, + AwsSecret: awsSecret, + AwsSession: awsSession, + Expiration: time.Now().Add(16 * time.Minute), + } +} + +func (c *credentialsServiceImpl) GetCredentials(token string) (*Credentials, error) { + c.serviceMutex.Lock() + defer c.serviceMutex.Unlock() + + c.contentMutex.Lock() + defer c.contentMutex.Unlock() + + if credentials, ok := c.credentials[token]; ok { + return &credentials, nil + } + + return nil, ErrCredentialsNotFound +} + +func (c *credentialsServiceImpl) BlockService() { + if c.currentState == BLOCKED { + return + } + log.Info("blocking the credentials service") + c.serviceMutex.Lock() + + c.contentMutex.Lock() + defer c.contentMutex.Unlock() + + c.currentState = BLOCKED +} + +func (c *credentialsServiceImpl) UnblockService() { + if c.currentState == UNBLOCKED { + return + } + log.Info("unblocking the credentials service") + + c.contentMutex.Lock() + defer c.contentMutex.Unlock() + + c.currentState = UNBLOCKED + c.serviceMutex.Unlock() +} + +func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession string) error { + mapSize := len(c.credentials) + if mapSize != 1 { + return fmt.Errorf("there are %d set of credentials", mapSize) + } + + var token string + for key := range c.credentials { + token = key + } + + c.SetCredentials(token, awsKey, awsSecret, awsSession) + return nil +} diff --git a/lambda/core/credentials_test.go b/lambda/core/credentials_test.go new file mode 100644 index 0000000..ab0b247 --- /dev/null +++ b/lambda/core/credentials_test.go @@ -0,0 +1,98 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +const ( + Token string = "sampleToken" + AwsKey string = "sampleKey" + AwsSecret string = "sampleSecret" + AwsSession string = "sampleSession" +) + +func TestGetSetCredentialsHappy(t *testing.T) { + credentialsService := NewCredentialsService() + + credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) + + credentials, err := credentialsService.GetCredentials(Token) + + assert.NoError(t, err) + assert.Equal(t, AwsKey, credentials.AwsKey) + assert.Equal(t, AwsSecret, credentials.AwsSecret) + assert.Equal(t, AwsSession, credentials.AwsSession) +} + +func TestGetCredentialsFail(t *testing.T) { + credentialsService := NewCredentialsService() + + _, err := credentialsService.GetCredentials("unknownToken") + + assert.Error(t, err) +} + +func TestUpdateCredentialsHappy(t *testing.T) { + credentialsService := NewCredentialsService() + + credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) + err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1") + assert.NoError(t, err) + + credentials, err := credentialsService.GetCredentials(Token) + + assert.NoError(t, err) + assert.Equal(t, "sampleKey1", credentials.AwsKey) + assert.Equal(t, "sampleSecret1", credentials.AwsSecret) + assert.Equal(t, "sampleSession1", credentials.AwsSession) +} + +func TestUpdateCredentialsFail(t *testing.T) { + credentialsService := NewCredentialsService() + + err := credentialsService.UpdateCredentials("unknownKey", "unknownSecret", "unknownSession") + + assert.Error(t, err) +} + +func TestUpdateCredentialsOfBlockedService(t *testing.T) { + credentialsService := NewCredentialsService() + credentialsService.BlockService() + credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) + err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1") + assert.NoError(t, err) +} + +func TestConsecutiveBlockService(t *testing.T) { + credentialsService := NewCredentialsService() + + timeout := time.After(1 * time.Second) + done := make(chan bool) + + go func() { + for i := 0; i < 10; i++ { + credentialsService.BlockService() + } + done <- true + }() + + select { + case <-timeout: + t.Fatal("BlockService should not block the calling thread.") + case <-done: + } +} + +// unlocking a mutex twice causes panic +// the assertion here is basically not having panic +func TestConsecutiveUnblockService(t *testing.T) { + credentialsService := NewCredentialsService() + + credentialsService.UnblockService() + credentialsService.UnblockService() +} diff --git a/lambda/core/directinvoke/directinvoke.go b/lambda/core/directinvoke/directinvoke.go index ab1075d..1699121 100644 --- a/lambda/core/directinvoke/directinvoke.go +++ b/lambda/core/directinvoke/directinvoke.go @@ -109,6 +109,12 @@ func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Re w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) } else if n == MaxDirectResponseSize+1 { w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) + err = &interop.ErrorResponseTooLargeDI{ + ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ + ResponseSize: int(n), + MaxResponseSize: int(MaxDirectResponseSize), + }, + } } else { w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } diff --git a/lambda/core/runtime_state_names.go b/lambda/core/runtime_state_names.go index a30ed96..b20b9f8 100644 --- a/lambda/core/runtime_state_names.go +++ b/lambda/core/runtime_state_names.go @@ -8,6 +8,7 @@ const ( RuntimeStartedStateName = "Started" RuntimeInitErrorStateName = "InitError" RuntimeReadyStateName = "Ready" + RuntimeRunningStateName = "Running" RuntimeInvocationResponseStateName = "InvocationResponse" RuntimeInvocationErrorResponseStateName = "InvocationErrorResponse" RuntimeResponseSentStateName = "RuntimeResponseSentState" diff --git a/lambda/core/states.go b/lambda/core/states.go index e6df068..bc7359d 100644 --- a/lambda/core/states.go +++ b/lambda/core/states.go @@ -78,6 +78,14 @@ type RuntimeState interface { Name() string } +type disallowEveryTransitionByDefault struct{} + +func (s *disallowEveryTransitionByDefault) InitError() error { return ErrNotAllowed } +func (s *disallowEveryTransitionByDefault) Ready() error { return ErrNotAllowed } +func (s *disallowEveryTransitionByDefault) InvocationResponse() error { return ErrNotAllowed } +func (s *disallowEveryTransitionByDefault) InvocationErrorResponse() error { return ErrNotAllowed } +func (s *disallowEveryTransitionByDefault) ResponseSent() error { return ErrNotAllowed } + // Runtime is runtime object. type Runtime struct { ManagedThread Suspendable @@ -90,6 +98,7 @@ type Runtime struct { RuntimeStartedState RuntimeState RuntimeInitErrorState RuntimeState RuntimeReadyState RuntimeState + RuntimeRunningState RuntimeState RuntimeInvocationResponseState RuntimeState RuntimeInvocationErrorResponseState RuntimeState RuntimeResponseSentState RuntimeState @@ -182,7 +191,8 @@ func NewRuntime(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchroni runtime.RuntimeStartedState = &RuntimeStartedState{runtime: runtime, initFlow: initFlow} runtime.RuntimeInitErrorState = &RuntimeInitErrorState{runtime: runtime, initFlow: initFlow} - runtime.RuntimeReadyState = &RuntimeReadyState{runtime: runtime, invokeFlow: invokeFlow} + runtime.RuntimeReadyState = &RuntimeReadyState{runtime: runtime} + runtime.RuntimeRunningState = &RuntimeRunningState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeInvocationResponseState = &RuntimeInvocationResponseState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeInvocationErrorResponseState = &RuntimeInvocationErrorResponseState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeResponseSentState = &RuntimeResponseSentState{runtime: runtime, invokeFlow: invokeFlow} @@ -193,38 +203,28 @@ func NewRuntime(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchroni // RuntimeStartedState runtime started state. type RuntimeStartedState struct { + disallowEveryTransitionByDefault runtime *Runtime initFlow InitFlowSynchronization } // Ready call when runtime init done. func (s *RuntimeStartedState) Ready() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) err := s.initFlow.RuntimeReady() if err != nil { return err } s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { + return ErrConcurrentStateModification + } - s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) + s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) return nil } -// InvocationResponse not allowed in this state. -func (s *RuntimeStartedState) InvocationResponse() error { - return ErrNotAllowed -} - -// InvocationErrorResponse not allowed in this state. -func (s *RuntimeStartedState) InvocationErrorResponse() error { - return ErrNotAllowed -} - -// InvocationErrorResponse not allowed in this state. -func (s *RuntimeStartedState) ResponseSent() error { - return ErrNotAllowed -} - // InitError move runtime to init error state. func (s *RuntimeStartedState) InitError() error { s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) @@ -238,111 +238,79 @@ func (s *RuntimeStartedState) Name() string { // RuntimeInitErrorState runtime started state. type RuntimeInitErrorState struct { + disallowEveryTransitionByDefault runtime *Runtime initFlow InitFlowSynchronization } -// Ready not allowed -func (s *RuntimeInitErrorState) Ready() error { - return ErrNotAllowed -} - -// InvocationResponse not allowed -func (s *RuntimeInitErrorState) InvocationResponse() error { - return ErrNotAllowed +// Name ... +func (s *RuntimeInitErrorState) Name() string { + return RuntimeInitErrorStateName } -// InvocationErrorResponse not allowed -func (s *RuntimeInitErrorState) InvocationErrorResponse() error { - return ErrNotAllowed +// RuntimeReadyState runtime ready state. +type RuntimeReadyState struct { + disallowEveryTransitionByDefault + runtime *Runtime } -// InvocationErrorResponse not allowed -func (s *RuntimeInitErrorState) ResponseSent() error { - return ErrNotAllowed -} +func (s *RuntimeReadyState) Ready() error { + s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { + return ErrConcurrentStateModification + } -// InitError not allowed -func (s *RuntimeInitErrorState) InitError() error { - return ErrNotAllowed + s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) + return nil } // Name ... -func (s *RuntimeInitErrorState) Name() string { - return RuntimeInitErrorStateName +func (s *RuntimeReadyState) Name() string { + return RuntimeReadyStateName } -// RuntimeReadyState runtime ready state. -type RuntimeReadyState struct { +// RuntimeRunningState runtime ready state. +type RuntimeRunningState struct { + disallowEveryTransitionByDefault runtime *Runtime invokeFlow InvokeFlowSynchronization } -func (s *RuntimeReadyState) Ready() error { +func (s *RuntimeRunningState) Ready() error { return nil } // InvocationResponse call when runtime response is available. -func (s *RuntimeReadyState) InvocationResponse() error { +func (s *RuntimeRunningState) InvocationResponse() error { s.runtime.setStateUnsafe(s.runtime.RuntimeInvocationResponseState) return nil } // InvocationErrorResponse call when runtime error response is available. -func (s *RuntimeReadyState) InvocationErrorResponse() error { +func (s *RuntimeRunningState) InvocationErrorResponse() error { s.runtime.setStateUnsafe(s.runtime.RuntimeInvocationErrorResponseState) return nil } -// ResponseSent is a closing state for InvocationResponseState and InvocationErrorResponseState. -func (s *RuntimeReadyState) ResponseSent() error { - return ErrNotAllowed -} - -// InitError not allowed in this state. -func (s *RuntimeReadyState) InitError() error { - return ErrNotAllowed -} - // Name ... -func (s *RuntimeReadyState) Name() string { - return RuntimeReadyStateName +func (s *RuntimeRunningState) Name() string { + return RuntimeRunningStateName } // RuntimeInvocationResponseState runtime response is available. // Start state for runtime response submission. type RuntimeInvocationResponseState struct { + disallowEveryTransitionByDefault runtime *Runtime invokeFlow InvokeFlowSynchronization } -// Ready call when runtime ready. -func (s *RuntimeInvocationResponseState) Ready() error { - return ErrNotAllowed -} - -// InvocationResponse not allowed in this state. -func (s *RuntimeInvocationResponseState) InvocationResponse() error { - return ErrNotAllowed -} - -// InvocationErrorResponse not allowed in this state. -func (s *RuntimeInvocationResponseState) InvocationErrorResponse() error { - return ErrNotAllowed -} - // ResponseSent completes RuntimeInvocationResponseState. func (s *RuntimeInvocationResponseState) ResponseSent() error { s.runtime.setStateUnsafe(s.runtime.RuntimeResponseSentState) return s.invokeFlow.RuntimeResponse(s.runtime) } -// InitError not allowed in this state. -func (s *RuntimeInvocationResponseState) InitError() error { - // TODO log - return ErrNotAllowed -} - // Name ... func (s *RuntimeInvocationResponseState) Name() string { return RuntimeInvocationResponseStateName @@ -351,36 +319,17 @@ func (s *RuntimeInvocationResponseState) Name() string { // RuntimeInvocationErrorResponseState runtime response is available. // Start state for runtime error response submission. type RuntimeInvocationErrorResponseState struct { + disallowEveryTransitionByDefault runtime *Runtime invokeFlow InvokeFlowSynchronization } -// Ready call when runtime ready. -func (s *RuntimeInvocationErrorResponseState) Ready() error { - return ErrNotAllowed -} - -// InvocationResponse not allowed in this state. -func (s *RuntimeInvocationErrorResponseState) InvocationResponse() error { - return ErrNotAllowed -} - -// InvocationErrorResponse not allowed in this state. -func (s *RuntimeInvocationErrorResponseState) InvocationErrorResponse() error { - return ErrNotAllowed -} - // ResponseSent completes RuntimeInvocationErrorResponseState. func (s *RuntimeInvocationErrorResponseState) ResponseSent() error { s.runtime.setStateUnsafe(s.runtime.RuntimeResponseSentState) return s.invokeFlow.RuntimeResponse(s.runtime) } -// InitError not allowed in this state. -func (s *RuntimeInvocationErrorResponseState) InitError() error { - return ErrNotAllowed -} - // Name ... func (s *RuntimeInvocationErrorResponseState) Name() string { return RuntimeInvocationErrorResponseStateName @@ -388,43 +337,27 @@ func (s *RuntimeInvocationErrorResponseState) Name() string { // RuntimeResponseSentState ends started runtime response or runtime error response submission. type RuntimeResponseSentState struct { + disallowEveryTransitionByDefault runtime *Runtime invokeFlow InvokeFlowSynchronization } // Ready call when runtime ready. func (s *RuntimeResponseSentState) Ready() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) if err := s.invokeFlow.RuntimeReady(s.runtime); err != nil { return err } s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { + return ErrConcurrentStateModification + } - s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) + s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) return nil } -// InvocationResponse not allowed in this state. -func (s *RuntimeResponseSentState) InvocationResponse() error { - return ErrNotAllowed -} - -// InvocationErrorResponse not allowed in this state. -func (s *RuntimeResponseSentState) InvocationErrorResponse() error { - return ErrNotAllowed -} - -// ResponseSent completes RuntimeInvocationErrorResponseState. -func (s *RuntimeResponseSentState) ResponseSent() error { - return ErrNotAllowed -} - -// InitError not allowed in this state. -func (s *RuntimeResponseSentState) InitError() error { - // TODO log - return ErrNotAllowed -} - // Name ... func (s *RuntimeResponseSentState) Name() string { return RuntimeResponseSentStateName diff --git a/lambda/core/states_test.go b/lambda/core/states_test.go index 1b6a62e..4b01838 100644 --- a/lambda/core/states_test.go +++ b/lambda/core/states_test.go @@ -7,9 +7,37 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "go.amzn.com/lambda/testdata/mockthread" + "sync" "testing" ) +func TestRuntimeInitErrorAfterReady(t *testing.T) { + initFlow := &mockInitFlowSynchronization{} + initFlow.ReadyCond = sync.NewCond(&sync.Mutex{}) + invokeFlow := &mockInvokeFlowSynchronization{} + runtime := NewRuntime(initFlow, invokeFlow) + + readyChan := make(chan struct{}) + runtime.SetState(runtime.RuntimeStartedState) + go func() { + assert.NoError(t, runtime.Ready()) + readyChan <- struct{}{} + }() + + initFlow.ReadyCond.L.Lock() + for !initFlow.ReadyCalled { + initFlow.ReadyCond.Wait() + } + initFlow.ReadyCond.L.Unlock() + assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) + + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + runtime.Release() + <-readyChan + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) +} + func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { initFlow := &mockInitFlowSynchronization{} invokeFlow := &mockInvokeFlowSynchronization{} @@ -24,7 +52,7 @@ func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { // Started -> Ready runtime.SetState(runtime.RuntimeStartedState) assert.NoError(t, runtime.Ready()) - assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) // Started -> ResponseSent runtime.SetState(runtime.RuntimeStartedState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -78,17 +106,44 @@ func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { // Ready -> Ready runtime.SetState(runtime.RuntimeReadyState) assert.NoError(t, runtime.Ready()) - assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) // Ready -> ResponseSent runtime.SetState(runtime.RuntimeReadyState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) // Ready -> InvocationResponse runtime.SetState(runtime.RuntimeReadyState) - assert.NoError(t, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) + assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) + assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) // Ready -> InvocationErrorResponse runtime.SetState(runtime.RuntimeReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) + assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) +} + +func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { + initFlow := &mockInitFlowSynchronization{} + invokeFlow := &mockInvokeFlowSynchronization{} + runtime := NewRuntime(initFlow, invokeFlow) + runtime.ManagedThread = &mockthread.MockManagedThread{} + // Running -> InitError + runtime.SetState(runtime.RuntimeRunningState) + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Running -> Ready + runtime.SetState(runtime.RuntimeRunningState) + assert.NoError(t, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Running -> ResponseSent + runtime.SetState(runtime.RuntimeRunningState) + assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Running -> InvocationResponse + runtime.SetState(runtime.RuntimeRunningState) + assert.NoError(t, runtime.InvocationResponse()) + assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) + // Running -> InvocationErrorResponse + runtime.SetState(runtime.RuntimeRunningState) assert.NoError(t, runtime.InvocationErrorResponse()) assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) } @@ -160,7 +215,7 @@ func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { // ResponseSent -> Ready runtime.SetState(runtime.RuntimeResponseSentState) assert.NoError(t, runtime.Ready()) - assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) // ResponseSent -> ResponseSent runtime.SetState(runtime.RuntimeResponseSentState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -175,7 +230,11 @@ func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) } -type mockInitFlowSynchronization struct{ mock.Mock } +type mockInitFlowSynchronization struct { + mock.Mock + ReadyCond *sync.Cond + ReadyCalled bool +} func (s *mockInitFlowSynchronization) SetExternalAgentsRegisterCount(agentCount uint16) error { return nil @@ -198,6 +257,12 @@ func (s *mockInitFlowSynchronization) AwaitAgentsReady() error { return nil } func (s *mockInitFlowSynchronization) RuntimeReady() error { + if s.ReadyCond != nil { + s.ReadyCond.L.Lock() + defer s.ReadyCond.L.Unlock() + s.ReadyCalled = true + s.ReadyCond.Signal() + } return nil } func (s *mockInitFlowSynchronization) AgentReady() error { diff --git a/lambda/interop/model.go b/lambda/interop/model.go index 6735a8b..5cdf63f 100644 --- a/lambda/interop/model.go +++ b/lambda/interop/model.go @@ -41,6 +41,7 @@ type Invoke struct { ReservationToken string VersionID string InvokeReceivedTime int64 + ResyncState Resync } type Token struct { @@ -53,11 +54,21 @@ type Token struct { LambdaSegmentID string InvokeMetadata string NeedDebugLogs bool + ResyncState Resync +} + +type Resync struct { + IsResyncReceived bool + AwsKey string + AwsSecret string + AwsSession string + ReceivedTime time.Time } type ErrorResponse struct { // Payload sent via shared memory. - Payload []byte `json:"Payload,omitempty"` + Payload []byte `json:"Payload,omitempty"` + ContentType string `json:"-"` // When error response body (Payload) is not provided, e.g. // not retrievable, error type and error message will be @@ -89,9 +100,9 @@ type Start struct { AwsSecret string AwsSession string SuppressInit bool - XRayDaemonAddress string // only in standalone; not used by slicer - FunctionName string // only in standalone; not used by slicer - FunctionVersion string // only in standalone; not used by slicer + XRayDaemonAddress string // only in standalone + FunctionName string // only in standalone + FunctionVersion string // only in standalone CorrelationID string // internal use only // TODO: define new Init type that has the Start fields as well as env vars below. // In standalone mode, these env vars come from test/init but from environment otherwise. @@ -127,6 +138,7 @@ type LogsAPIMetrics map[string]int type DoneMetadata struct { NumActiveExtensions int ExtensionsResetMs int64 + ExtensionNames string RuntimeRelease string // Metrics for response status of LogsAPI `/subscribe` calls LogsAPIMetrics LogsAPIMetrics @@ -134,6 +146,7 @@ type DoneMetadata struct { InvokeRequestSizeBytes int64 InvokeCompletionTimeNs int64 InvokeReceivedTime int64 + RuntimeReadyTime int64 } type Done struct { @@ -173,6 +186,11 @@ type ErrorResponseTooLarge struct { ResponseSize int } +// ErrorResponseTooLargeDI is used to reproduce ErrorResponseTooLarge behavior for Direct Invoke mode +type ErrorResponseTooLargeDI struct { + ErrorResponseTooLarge +} + // ErrorResponseTooLarge is returned when response provided by Runtime does not fit into shared memory buffer func (s *ErrorResponseTooLarge) Error() string { return fmt.Sprintf("Response payload size (%d bytes) exceeded maximum allowed payload size (%d bytes).", s.ResponseSize, s.MaxResponseSize) @@ -189,6 +207,7 @@ func (s *ErrorResponseTooLarge) AsInteropError() *ErrorResponse { panic("Failed to marshal interop.ErrorResponse") } resp.Payload = respJSON + resp.ContentType = "application/json" return &resp } @@ -202,7 +221,7 @@ type Server interface { // ErrInvalidInvokeID - validation error indicating that provided invokeID doesn't match current invokeID // ErrResponseSent - validation error indicating that response with given invokeID was already sent // Non-nil error - non-nil error indicating transport failure - SendResponse(invokeID string, response io.Reader) error + SendResponse(invokeID string, contentType string, response io.Reader) error // SendErrorResponse sends error response. // Errors returned: diff --git a/lambda/logging/internal_log.go b/lambda/logging/internal_log.go index 265ae8e..018b2c7 100644 --- a/lambda/logging/internal_log.go +++ b/lambda/logging/internal_log.go @@ -4,9 +4,12 @@ package logging import ( + "bytes" + "fmt" "github.com/sirupsen/logrus" "io" "log" + "strings" ) // SetOutput configures logging output for standard loggers. @@ -14,3 +17,33 @@ func SetOutput(w io.Writer) { log.SetOutput(w) logrus.SetOutput(w) } + +type InternalFormatter struct{} + +// format RAPID's internal log like the rest of the sandbox log +func (f *InternalFormatter) Format(entry *logrus.Entry) ([]byte, error) { + b := &bytes.Buffer{} + + // time with comma separator for fraction of second + time := entry.Time.Format("02 Jan 2006 15:04:05.000") + time = strings.Replace(time, ".", ",", 1) + fmt.Fprint(b, time) + + // level + level := strings.ToUpper(entry.Level.String()) + fmt.Fprintf(b, " [%s]", level) + + // label + fmt.Fprint(b, " (rapid)") + + // message + fmt.Fprintf(b, " %s", entry.Message) + + // from WithField and WithError + for field, value := range entry.Data { + fmt.Fprintf(b, " %s=%s", field, value) + } + + fmt.Fprintf(b, "\n") + return b.Bytes(), nil +} diff --git a/lambda/logging/internal_log_test.go b/lambda/logging/internal_log_test.go index 9c3e598..b94ac88 100644 --- a/lambda/logging/internal_log_test.go +++ b/lambda/logging/internal_log_test.go @@ -5,6 +5,7 @@ package logging import ( "bytes" + "fmt" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "io/ioutil" @@ -26,6 +27,45 @@ func TestLogrusPrint(t *testing.T) { assert.Contains(t, buf.String(), "hello logrus") } +func TestInternalFormatter(t *testing.T) { + pattern := `^([0-9]{2}\s[A-Za-z]{3}\s[0-9]{4}\s[0-9]{2}:[0-9]{2}:[0-9]{2}(?:,[0-9]{3})?)\s(?:\s\{sandbox:([0-9]+)\}\s)?\[([A-Za-z]+)\]\s(\(([^\)]+)\)(?:\s\[Logging Metrics\]\sSBLOG:([a-zA-Z:]+) ([0-9]+))?\s?.*)` + + buf := new(bytes.Buffer) + SetOutput(buf) + logrus.SetFormatter(&InternalFormatter{}) + + logrus.Print("hello logrus") + assert.Regexp(t, pattern, buf.String()) + + buf.Reset() + err := fmt.Errorf("error message") + logrus.WithError(err).Warning("hello logrus") + assert.Regexp(t, pattern, buf.String()) + + buf.Reset() + logrus.WithFields(logrus.Fields{ + "field1": "val1", + "field2": "val2", + "field3": "val3", + }).Info("hello logrus") + assert.Regexp(t, pattern, buf.String()) + + // no caller logged + buf.Reset() + logrus.WithFields(logrus.Fields{ + "field1": "val1", + "field2": "val2", + "field3": "val3", + }).Info("hello logrus") + assert.Regexp(t, pattern, buf.String()) + + // invalid format without InternalFormatter + buf.Reset() + logrus.SetFormatter(&logrus.TextFormatter{}) + logrus.Print("hello logrus") + assert.NotRegexp(t, pattern, buf.String()) +} + func BenchmarkLogPrint(b *testing.B) { SetOutput(ioutil.Discard) for n := 0; n < b.N; n++ { @@ -40,6 +80,15 @@ func BenchmarkLogrusPrint(b *testing.B) { } } +func BenchmarkLogrusPrintInternalFormatter(b *testing.B) { + var l = logrus.New() + l.SetFormatter(&InternalFormatter{}) + l.SetOutput(ioutil.Discard) + for n := 0; n < b.N; n++ { + l.Print(1, "two", true) + } +} + func BenchmarkLogPrintf(b *testing.B) { SetOutput(ioutil.Discard) for n := 0; n < b.N; n++ { @@ -54,6 +103,15 @@ func BenchmarkLogrusPrintf(b *testing.B) { } } +func BenchmarkLogrusPrintfInternalFormatter(b *testing.B) { + var l = logrus.New() + l.SetFormatter(&InternalFormatter{}) + l.SetOutput(ioutil.Discard) + for n := 0; n < b.N; n++ { + l.Printf("field:%v,field:%v,field:%v", 1, "two", true) + } +} + func BenchmarkLogrusDebugLogLevelDisabled(b *testing.B) { SetOutput(ioutil.Discard) logrus.SetLevel(logrus.InfoLevel) @@ -62,6 +120,15 @@ func BenchmarkLogrusDebugLogLevelDisabled(b *testing.B) { } } +func BenchmarkLogrusDebugLogLevelDisabledInternalFormatter(b *testing.B) { + var l = logrus.New() + l.SetOutput(ioutil.Discard) + l.SetLevel(logrus.InfoLevel) + for n := 0; n < b.N; n++ { + l.Debug(1, "two", true) + } +} + func BenchmarkLogrusDebugLogLevelEnabled(b *testing.B) { SetOutput(ioutil.Discard) logrus.SetLevel(logrus.DebugLevel) @@ -70,6 +137,16 @@ func BenchmarkLogrusDebugLogLevelEnabled(b *testing.B) { } } +func BenchmarkLogrusDebugLogLevelEnabledInternalFormatter(b *testing.B) { + var l = logrus.New() + l.SetFormatter(&InternalFormatter{}) + l.SetOutput(ioutil.Discard) + l.SetLevel(logrus.DebugLevel) + for n := 0; n < b.N; n++ { + l.Debug(1, "two", true) + } +} + func BenchmarkLogrusDebugWithFieldLogLevelDisabled(b *testing.B) { SetOutput(ioutil.Discard) logrus.SetLevel(logrus.InfoLevel) @@ -77,3 +154,14 @@ func BenchmarkLogrusDebugWithFieldLogLevelDisabled(b *testing.B) { logrus.WithField("field", "value").Debug(1, "two", true) } } + +func BenchmarkLogrusDebugWithFieldLogLevelDisabledInternalFormatter(b *testing.B) { + var l = logrus.New() + l.SetFormatter(&InternalFormatter{}) + l.SetOutput(ioutil.Discard) + l.SetLevel(logrus.InfoLevel) + for n := 0; n < b.N; n++ { + l.WithField("field", "value").Debug(1, "two", true) + } +} + diff --git a/lambda/metering/time.go b/lambda/metering/time.go index 1f5f047..cf3ad1d 100644 --- a/lambda/metering/time.go +++ b/lambda/metering/time.go @@ -12,15 +12,10 @@ import ( //go:linkname Monotime runtime.nanotime func Monotime() int64 -//go:linkname walltime runtime.walltime -func walltime() (sec int64, nsec int32) - // MonoToEpoch converts monotonic time nanos to epoch time nanos. func MonoToEpoch(t int64) int64 { monoNsec := Monotime() - - wallSec, wallNsec32 := walltime() - wallNsec := wallSec*1e9 + int64(wallNsec32) + wallNsec := time.Now().UnixNano() clockOffset := wallNsec - monoNsec return t + clockOffset diff --git a/lambda/rapi/handler/agentregister_test.go b/lambda/rapi/handler/agentregister_test.go index c860f9a..185f249 100644 --- a/lambda/rapi/handler/agentregister_test.go +++ b/lambda/rapi/handler/agentregister_test.go @@ -143,8 +143,7 @@ func TestInternalAgentShutdownSubscription(t *testing.T) { _, found := registrationService.FindInternalAgentByName(agentName) require.False(t, found) - subscribers := registrationService.GetSubscribedInternalAgents(core.ShutdownEvent) - require.Equal(t, 0, len(subscribers)) + require.Equal(t, 0, registrationService.CountAgents()) } func TestInternalAgentInvalidEventType(t *testing.T) { @@ -170,8 +169,7 @@ func TestInternalAgentInvalidEventType(t *testing.T) { _, found := registrationService.FindInternalAgentByName(agentName) require.False(t, found) - subscribers := registrationService.GetSubscribedInternalAgents(core.ShutdownEvent) - require.Equal(t, 0, len(subscribers)) + require.Equal(t, 0, registrationService.CountAgents()) } } @@ -199,8 +197,13 @@ func TestExternalAgentInvalidEventType(t *testing.T) { _, found := registrationService.FindExternalAgentByName(agentName) require.True(t, found) - subscribers := registrationService.GetSubscribedExternalAgents(core.ShutdownEvent) - require.Equal(t, 0, len(subscribers)) + shutdownSubscribers := registrationService.GetSubscribedExternalAgents(core.ShutdownEvent) + require.Equal(t, 0, len(shutdownSubscribers)) + + invokeSubscribers := registrationService.GetSubscribedExternalAgents(core.InvokeEvent) + require.Equal(t, 0, len(invokeSubscribers)) + + require.Equal(t, 1, registrationService.CountAgents()) } } diff --git a/lambda/rapi/handler/credentials.go b/lambda/rapi/handler/credentials.go new file mode 100644 index 0000000..f1536c4 --- /dev/null +++ b/lambda/rapi/handler/credentials.go @@ -0,0 +1,40 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "encoding/json" + "fmt" + "net/http" + + log "github.com/sirupsen/logrus" + + "go.amzn.com/lambda/core" +) + +type credentialsHandler struct { + credentialsService core.CredentialsService +} + +func (h *credentialsHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + token := request.Header.Get("Authorization") + + credentials, err := h.credentialsService.GetCredentials(token) + + if err != nil { + errorMsg := "cannot get credentials for the provided token" + log.WithError(err).Error(errorMsg) + http.Error(writer, errorMsg, http.StatusNotFound) + return + } + + jsonResponse, _ := json.Marshal(*credentials) + fmt.Fprint(writer, string(jsonResponse)) +} + +func NewCredentialsHandler(credentialsService core.CredentialsService) http.Handler { + return &credentialsHandler{ + credentialsService: credentialsService, + } +} diff --git a/lambda/rapi/handler/credentials_test.go b/lambda/rapi/handler/credentials_test.go new file mode 100644 index 0000000..fa4a2bd --- /dev/null +++ b/lambda/rapi/handler/credentials_test.go @@ -0,0 +1,91 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "encoding/json" + "log" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/testdata" +) + +const InitCachingToken = "sampleInitCachingToken" +const InitCachingAwsKey = "sampleAwsKey" +const InitCachingAwsSecret = "sampleAwsSecret" +const InitCachingAwsSessionToken = "sampleAwsSessionToken" + +func getRequestContext(isServiceBlocked bool) (http.Handler, *http.Request, *httptest.ResponseRecorder) { + flowTest := testdata.NewFlowTest() + if isServiceBlocked { + flowTest.ConfigureForBlockedInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) + } else { + flowTest.ConfigureForInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) + } + handler := NewCredentialsHandler(flowTest.CredentialsService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + + return handler, request, responseRecorder +} + +func TestEmptyAuthorizationHeader(t *testing.T) { + handler, request, responseRecorder := getRequestContext(false) + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) +} + +func TestArbitraryAuthorizationHeader(t *testing.T) { + handler, request, responseRecorder := getRequestContext(false) + request.Header.Set("Authorization", "randomAuthToken") + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) +} + +func TestSuccessfulGet(t *testing.T) { + handler, request, responseRecorder := getRequestContext(false) + request.Header.Set("Authorization", InitCachingToken) + + handler.ServeHTTP(responseRecorder, request) + + var responseMap map[string]string + json.Unmarshal(responseRecorder.Body.Bytes(), &responseMap) + assert.Equal(t, InitCachingAwsKey, responseMap["AccessKeyId"]) + assert.Equal(t, InitCachingAwsSecret, responseMap["SecretAccessKey"]) + assert.Equal(t, InitCachingAwsSessionToken, responseMap["Token"]) + + expirationTime, err := time.Parse(time.RFC3339, responseMap["Expiration"]) + assert.NoError(t, err) + durationUntilExpiration := time.Until(expirationTime) + assert.True(t, durationUntilExpiration.Minutes() <= 16 && durationUntilExpiration.Minutes() > 15 && durationUntilExpiration.Hours() < 1) + log.Println(responseRecorder.Body.String()) +} + +func TestBlockedGet(t *testing.T) { + handler, request, responseRecorder := getRequestContext(true) + request.Header.Set("Authorization", InitCachingToken) + + timeout := time.After(1 * time.Second) + done := make(chan bool) + + go func() { + handler.ServeHTTP(responseRecorder, request) + done <- true + }() + + select { + case <-done: + t.Fatal("Endpoint should be blocked!") + case <-timeout: + } +} diff --git a/lambda/rapi/handler/initerror.go b/lambda/rapi/handler/initerror.go index 9cc407d..4015a11 100644 --- a/lambda/rapi/handler/initerror.go +++ b/lambda/rapi/handler/initerror.go @@ -4,6 +4,7 @@ package handler import ( + "encoding/json" "io/ioutil" "net/http" @@ -44,8 +45,9 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R } response := &interop.ErrorResponse{ - ErrorType: errorType, - Payload: errorBody, + ErrorType: errorType, + Payload: errorBody, + ContentType: determineJSONContentType(errorBody), } if err := server.SendErrorResponse(server.GetCurrentInvokeID(), response); err != nil { @@ -65,3 +67,10 @@ func NewInitErrorHandler(registrationService core.RegistrationService) http.Hand registrationService: registrationService, } } + +func determineJSONContentType(body []byte) string { + if json.Valid(body) { + return "application/json" + } + return "application/octet-stream" +} diff --git a/lambda/rapi/handler/invocationerror.go b/lambda/rapi/handler/invocationerror.go index 38f1ff9..d60b5d6 100644 --- a/lambda/rapi/handler/invocationerror.go +++ b/lambda/rapi/handler/invocationerror.go @@ -49,15 +49,20 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * var errorCause json.RawMessage var errorBody []byte + var contentType string var err error switch request.Header.Get("Content-Type") { case errorWithCauseContentType: errorBody, errorCause, err = h.getErrorBodyForErrorCauseContentType(request) - + contentType = "application/json" + if err != nil { + contentType = "application/octet-stream" + } default: errorBody, err = h.getErrorBody(request) errorCause = h.getValidatedErrorCause(request.Header) + contentType = request.Header.Get("Content-Type") } if err != nil { @@ -65,9 +70,10 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * } response := &interop.ErrorResponse{ - ErrorType: errorType, - Payload: errorBody, - ErrorCause: errorCause, + ErrorType: errorType, + Payload: errorBody, + ErrorCause: errorCause, + ContentType: contentType, } if err := server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), response); err != nil { diff --git a/lambda/rapi/handler/invocationerror_test.go b/lambda/rapi/handler/invocationerror_test.go index b1e96c8..6defa14 100644 --- a/lambda/rapi/handler/invocationerror_test.go +++ b/lambda/rapi/handler/invocationerror_test.go @@ -279,6 +279,7 @@ func TestInvocationResponsePayloadIsDefaultErrorMessageWhenRequestParsingFailsFo errorResponse := flowTest.InteropServer.ErrorResponse assert.NotNil(t, errorResponse) assert.Nil(t, flowTest.InteropServer.Response) + assert.Equal(t, "application/octet-stream", flowTest.InteropServer.ResponseContentType) invokeResponsePayload := errorResponse.Payload diff --git a/lambda/rapi/handler/invocationresponse.go b/lambda/rapi/handler/invocationresponse.go index 50c575c..7c15342 100644 --- a/lambda/rapi/handler/invocationresponse.go +++ b/lambda/rapi/handler/invocationresponse.go @@ -15,6 +15,8 @@ import ( log "github.com/sirupsen/logrus" ) +const contentTypeOverrideHeaderName = "Content-Type" + type invocationResponseHandler struct { registrationService core.RegistrationService } @@ -37,7 +39,9 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques invokeID := chi.URLParam(request, "awsrequestid") - if err := server.SendResponse(invokeID, request.Body); err != nil { + responseContentType := request.Header.Get(contentTypeOverrideHeaderName) + + if err := server.SendResponse(invokeID, responseContentType, request.Body); err != nil { switch err := err.(type) { case *interop.ErrorResponseTooLarge: if server.SendErrorResponse(invokeID, err.AsInteropError()) != nil { @@ -51,6 +55,15 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques log.Panic(err) } + rendering.RenderRequestEntityTooLarge(writer, request) + return + + case *interop.ErrorResponseTooLargeDI: + // in DirectInvoke case, the (truncated) response is already sent back to the caller + if err := runtime.ResponseSent(); err != nil { + log.Panic(err) + } + rendering.RenderRequestEntityTooLarge(writer, request) return default: diff --git a/lambda/rapi/handler/invocationresponse_test.go b/lambda/rapi/handler/invocationresponse_test.go index e3ede59..e40a5bf 100644 --- a/lambda/rapi/handler/invocationresponse_test.go +++ b/lambda/rapi/handler/invocationresponse_test.go @@ -94,10 +94,11 @@ func TestResponseAccepted(t *testing.T) { flowTest.ConfigureForInvoke(context.Background(), invoke) // Invocation response submitted by runtime. - var responseBody = make([]byte, interop.MaxPayloadSize) + var responseBody = []byte("{'foo': 'bar'}") request := httptest.NewRequest("", "/", bytes.NewReader(responseBody)) request = addInvocationID(request, invoke.ID) + request.Header.Set(contentTypeOverrideHeaderName, "application/json") handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) // Assertions @@ -112,6 +113,7 @@ func TestResponseAccepted(t *testing.T) { response := flowTest.InteropServer.Response assert.NotNil(t, response) assert.Nil(t, flowTest.InteropServer.ErrorResponse) + assert.Equal(t, "application/json", flowTest.InteropServer.ResponseContentType) assert.Equal(t, responseBody, response, "Persisted response data in app context must match the submitted.") } diff --git a/lambda/rapi/handler/runtimelogs.go b/lambda/rapi/handler/runtimelogs.go index 7185ce6..9b4e406 100644 --- a/lambda/rapi/handler/runtimelogs.go +++ b/lambda/rapi/handler/runtimelogs.go @@ -21,7 +21,7 @@ import ( type runtimeLogsHandler struct { registrationService core.RegistrationService - telemetryService telemetry.LogsAPIService + logsSubscriptionAPI telemetry.LogsSubscriptionAPI } func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -31,10 +31,10 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http switch err := err.(type) { case *ErrAgentIdentifierUnknown: rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension "+err.agentID.String()) - h.telemetryService.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.telemetryService.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } return } @@ -45,21 +45,21 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http if err != nil { log.Error(err) rendering.RenderInternalServerError(writer, request) - h.telemetryService.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) return } - respBody, status, headers, err := h.telemetryService.Subscribe(agentName, bytes.NewReader(body), request.Header) + respBody, status, headers, err := h.logsSubscriptionAPI.Subscribe(agentName, bytes.NewReader(body), request.Header) if err != nil { log.Errorf("Telemetry API error: %s", err) switch err { case logsapi.ErrTelemetryServiceOff: rendering.RenderForbiddenWithTypeMsg(writer, request, errLogsSubscriptionClosed, "Logs API subscription is closed already") - h.telemetryService.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.telemetryService.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } return } @@ -67,11 +67,11 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http rendering.RenderRuntimeLogsResponse(writer, respBody, status, headers) switch status / 100 { case 2: // 2xx - h.telemetryService.RecordCounterMetric(logsapi.SubscribeSuccess, 1) + h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeSuccess, 1) case 4: // 4xx - h.telemetryService.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) case 5: // 5xx - h.telemetryService.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } } @@ -124,9 +124,9 @@ func (h *runtimeLogsHandler) getBody(writer http.ResponseWriter, request *http.R // NewRuntimeLogsHandler returns a new instance of http handler // for serving /runtime/logs -func NewRuntimeLogsHandler(registrationService core.RegistrationService, telemetryService telemetry.LogsAPIService) http.Handler { +func NewRuntimeLogsHandler(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.LogsSubscriptionAPI) http.Handler { return &runtimeLogsHandler{ registrationService: registrationService, - telemetryService: telemetryService, + logsSubscriptionAPI: logsSubscriptionAPI, } } diff --git a/lambda/rapi/handler/runtimelogs_test.go b/lambda/rapi/handler/runtimelogs_test.go index b87e230..b7db6df 100644 --- a/lambda/rapi/handler/runtimelogs_test.go +++ b/lambda/rapi/handler/runtimelogs_test.go @@ -23,27 +23,27 @@ import ( "go.amzn.com/lambda/rapidcore/telemetry/logsapi" ) -type mockTelemetryService struct{ mock.Mock } +type mockLogsSubscriptionAPI struct{ mock.Mock } -func (s *mockTelemetryService) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { +func (s *mockLogsSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { args := s.Called(agentName, body, headers) return args.Get(0).([]byte), args.Int(1), args.Get(2).(map[string][]string), args.Error(3) } -func (s *mockTelemetryService) RecordCounterMetric(metricName string, count int) { +func (s *mockLogsSubscriptionAPI) RecordCounterMetric(metricName string, count int) { s.Called(metricName, count) } -func (s *mockTelemetryService) FlushMetrics() interop.LogsAPIMetrics { +func (s *mockLogsSubscriptionAPI) FlushMetrics() interop.LogsAPIMetrics { args := s.Called() return args.Get(0).(interop.LogsAPIMetrics) } -func (s *mockTelemetryService) Clear() { +func (s *mockLogsSubscriptionAPI) Clear() { s.Called() } -func (s *mockTelemetryService) TurnOff() { +func (s *mockLogsSubscriptionAPI) TurnOff() { s.Called() } @@ -60,11 +60,11 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - telemetryService := &mockTelemetryService{} - telemetryService.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return(respBody, respStatus, respHeaders, nil) - telemetryService.On("RecordCounterMetric", clientErrMetric, 1) + logsSubscriptionAPI := &mockLogsSubscriptionAPI{} + logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return(respBody, respStatus, respHeaders, nil) + logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, telemetryService) + handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -77,8 +77,8 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - telemetryService.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders) - telemetryService.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + logsSubscriptionAPI.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders) + logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) recordedBody, err := ioutil.ReadAll(responseRecorder.Body) assert.NoError(t, err) @@ -98,10 +98,10 @@ func TestErrorUnregisteredAgentID(t *testing.T) { core.NewInvokeFlowSynchronization(), ) - telemetryService := &mockTelemetryService{} - telemetryService.On("RecordCounterMetric", clientErrMetric, 1) + logsSubscriptionAPI := &mockLogsSubscriptionAPI{} + logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, telemetryService) + handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -123,7 +123,7 @@ func TestErrorUnregisteredAgentID(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - telemetryService.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) } func TestErrorTelemetryAPICallFailure(t *testing.T) { @@ -139,11 +139,11 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - telemetryService := &mockTelemetryService{} - telemetryService.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - telemetryService.On("RecordCounterMetric", serverErrMetric, 1) + logsSubscriptionAPI := &mockLogsSubscriptionAPI{} + logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + logsSubscriptionAPI.On("RecordCounterMetric", serverErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, telemetryService) + handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -165,7 +165,7 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - telemetryService.AssertCalled(t, "RecordCounterMetric", serverErrMetric, 1) + logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", serverErrMetric, 1) } func TestRenderLogsSubscriptionClosed(t *testing.T) { @@ -181,11 +181,11 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - telemetryService := &mockTelemetryService{} - telemetryService.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - telemetryService.On("RecordCounterMetric", clientErrMetric, 1) + logsSubscriptionAPI := &mockLogsSubscriptionAPI{} + logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, telemetryService) + handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -207,5 +207,5 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - telemetryService.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) } diff --git a/lambda/rapi/router.go b/lambda/rapi/router.go index b8559af..1d2766a 100644 --- a/lambda/rapi/router.go +++ b/lambda/rapi/router.go @@ -78,27 +78,35 @@ func ExtensionsRouter(appCtx appctx.ApplicationContext, registrationService core return router } -// TelemetryAPIRouter returns a new instance of chi router implementing +// LogsAPIRouter returns a new instance of chi router implementing // Logs API specification. -func TelemetryAPIRouter(registrationService core.RegistrationService, telemetryService telemetry.LogsAPIService) http.Handler { +func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.LogsSubscriptionAPI) http.Handler { router := chi.NewRouter() router.Use(middleware.AccessLogMiddleware()) router.Use(middleware.AllowIfExtensionsEnabled) router.Put("/logs", middleware.AgentUniqueIdentifierHeaderValidator( - handler.NewRuntimeLogsHandler(registrationService, telemetryService)).ServeHTTP) + handler.NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI)).ServeHTTP) return router } -// TelemetryAPIStubRouter returns a new instance of chi router implementing +// LogsAPIStubRouter returns a new instance of chi router implementing // a stub of Logs API that always returns a non-committal response to // prevent customer code from crashing when Logs API is disabled locally -func TelemetryAPIStubRouter() http.Handler { +func LogsAPIStubRouter() http.Handler { router := chi.NewRouter() router.Put("/logs", handler.NewRuntimeLogsStubHandler().ServeHTTP) return router } + +func CredentialsAPIRouter(credentialsService core.CredentialsService) http.Handler { + router := chi.NewRouter() + + router.Get("/credentials", handler.NewCredentialsHandler(credentialsService).ServeHTTP) + + return router +} diff --git a/lambda/rapi/server.go b/lambda/rapi/server.go index 16d955a..e2c6ad4 100644 --- a/lambda/rapi/server.go +++ b/lambda/rapi/server.go @@ -22,6 +22,7 @@ import ( const version20180601 = "/2018-06-01" const version20200101 = "/2020-01-01" const version20200815 = "/2020-08-15" +const version20210423 = "/2021-04-23" // Server is a Runtime API server type Server struct { @@ -43,7 +44,7 @@ func NewServer(host string, port int, appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, telemetryAPIEnabled bool, - telemetryService telemetry.LogsAPIService) *Server { + logsSubscriptionAPI telemetry.LogsSubscriptionAPI, initCachingEnabled bool, credentialsService core.CredentialsService) *Server { exitErrors := make(chan error, 1) @@ -52,9 +53,13 @@ func NewServer(host string, port int, appCtx appctx.ApplicationContext, router.Mount(version20200101, ExtensionsRouter(appCtx, registrationService, renderingService)) if telemetryAPIEnabled { - router.Mount(version20200815, TelemetryAPIRouter(registrationService, telemetryService)) + router.Mount(version20200815, LogsAPIRouter(registrationService, logsSubscriptionAPI)) } else { - router.Mount(version20200815, TelemetryAPIStubRouter()) + router.Mount(version20200815, LogsAPIStubRouter()) + } + + if initCachingEnabled { + router.Mount(version20210423, CredentialsAPIRouter(credentialsService)) } return &Server{ diff --git a/lambda/rapid/sandbox.go b/lambda/rapid/sandbox.go index a88a97a..a5614b0 100644 --- a/lambda/rapid/sandbox.go +++ b/lambda/rapid/sandbox.go @@ -25,23 +25,27 @@ type EnvironmentVariables interface { StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) StoreEnvironmentVariablesFromInit(customerEnv map[string]string, handler, awsKey, awsSecret, awsSession, funcName, funcVer string) + StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) } type Sandbox struct { - EnableTelemetryAPI bool - StandaloneMode bool - Bootstrap Bootstrap - InteropServer interop.Server - Tracer telemetry.Tracer - TelemetryService telemetry.LogsAPIService - Environment EnvironmentVariables - DebugTailLogger *logging.TailLogWriter - PlatformLogger logging.PlatformLogger - ExtensionLogWriter io.Writer - RuntimeLogWriter io.Writer - PreLoadTimeNs int64 - Handler string - SignalCtx context.Context + EnableTelemetryAPI bool + StandaloneMode bool + Bootstrap Bootstrap + InteropServer interop.Server + Tracer telemetry.Tracer + LogsSubscriptionAPI telemetry.LogsSubscriptionAPI + LogsEgressAPI telemetry.LogsEgressAPI + Environment EnvironmentVariables + DebugTailLogger *logging.TailLogWriter + PlatformLogger logging.PlatformLogger + RuntimeStdoutWriter io.Writer + RuntimeStderrWriter io.Writer + PreLoadTimeNs int64 + Handler string + SignalCtx context.Context + EventsAPI telemetry.EventsAPI + InitCachingEnabled bool } // Start is a public version of start() that exports only configurable parameters @@ -51,11 +55,12 @@ func Start(s *Sandbox) { invokeFlow := core.NewInvokeFlowSynchronization() registrationService := core.NewRegistrationService(initFlow, invokeFlow) renderingService := rendering.NewRenderingService() + credentialsService := core.NewCredentialsService() if s.StandaloneMode { s.InteropServer.SetInternalStateGetter(registrationService.GetInternalStateDescriptor(appCtx)) } - server := rapi.NewServer(RuntimeAPIHost, RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.TelemetryService) + server := rapi.NewServer(RuntimeAPIHost, RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.InitCachingEnabled, credentialsService) postLoadTimeNs := metering.Monotime() @@ -75,9 +80,11 @@ func Start(s *Sandbox) { renderingService: renderingService, exitPidChan: make(chan int), resetChan: make(chan *interop.Reset), + credentialsService: credentialsService, telemetryAPIEnabled: s.EnableTelemetryAPI, - telemetryService: s.TelemetryService, + logsSubscriptionAPI: s.LogsSubscriptionAPI, + logsEgressAPI: s.LogsEgressAPI, bootstrap: s.Bootstrap, interopServer: s.InteropServer, xray: s.Tracer, @@ -85,8 +92,10 @@ func Start(s *Sandbox) { standaloneMode: s.StandaloneMode, debugTailLogger: s.DebugTailLogger, platformLogger: s.PlatformLogger, - extensionLogWriter: s.ExtensionLogWriter, - runtimeLogWriter: s.RuntimeLogWriter, + runtimeStdoutWriter: s.RuntimeStdoutWriter, + runtimeStderrWriter: s.RuntimeStderrWriter, preLoadTimeNs: s.PreLoadTimeNs, + eventsAPI: s.EventsAPI, + initCachingEnabled: s.InitCachingEnabled, }) } diff --git a/lambda/rapid/start.go b/lambda/rapid/start.go index 711f122..087ef13 100644 --- a/lambda/rapid/start.go +++ b/lambda/rapid/start.go @@ -9,6 +9,7 @@ import ( "errors" "io" "os" + "strings" "time" "go.amzn.com/lambda/agents" @@ -24,6 +25,8 @@ import ( "go.amzn.com/lambda/runtimecmd" "go.amzn.com/lambda/telemetry" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" ) @@ -36,6 +39,8 @@ const ( ) const ( + // Same value as defined in LambdaSandbox minus 1. + maxExtensionNamesLength = 127 standaloneShutdownReason = "spindown" ) @@ -55,7 +60,8 @@ type rapidContext struct { registrationService core.RegistrationService renderingService *rendering.EventRenderingService telemetryAPIEnabled bool - telemetryService telemetry.LogsAPIService + logsSubscriptionAPI telemetry.LogsSubscriptionAPI + logsEgressAPI telemetry.LogsEgressAPI xray telemetry.Tracer exitPidChan chan int resetChan chan *interop.Reset @@ -63,14 +69,37 @@ type rapidContext struct { standaloneMode bool debugTailLogger *logging.TailLogWriter platformLogger logging.PlatformLogger - extensionLogWriter io.Writer - runtimeLogWriter io.Writer + runtimeStdoutWriter io.Writer + runtimeStderrWriter io.Writer + eventsAPI telemetry.EventsAPI + initCachingEnabled bool + credentialsService core.CredentialsService +} + +type invokeMetrics struct { + rendererMetrics rendering.InvokeRendererMetrics + runtimeReadyTime int64 } func (c *rapidContext) HasActiveExtensions() bool { return extensions.AreEnabled() && c.registrationService.CountAgents() > 0 } +func (c *rapidContext) GetExtensionNames() string { + var extensionNamesList []string + for _, agent := range c.registrationService.AgentsInfo() { + extensionNamesList = append(extensionNamesList, agent.Name) + } + extensionNames := strings.Join(extensionNamesList, ";") + if len(extensionNames) > maxExtensionNamesLength { + if idx := strings.LastIndex(extensionNames[:maxExtensionNamesLength], ";"); idx != -1 { + return extensionNames[:idx] + } + return "" + } + return extensionNames +} + func logAgentsInitStatus(execCtx *rapidContext) { for _, agent := range execCtx.registrationService.AgentsInfo() { execCtx.platformLogger.LogExtensionInitEvent(agent.Name, agent.State, agent.ErrorType, agent.Subscriptions) @@ -95,10 +124,22 @@ func doInitExtensions(execCtx *rapidContext, watchdog *core.Watchdog) error { for _, agentPath := range agentPaths { env := execCtx.environment.AgentExecEnv() - agentLogSinks := execCtx.extensionLogWriter - agentProc := agents.NewExternalAgentProcess(agentPath, env, agentLogSinks) + + agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() + + if err != nil { + return err + } + + // Compose debug log writer with all log sinks. Debug log writer w + // will not write logs when disabled by invoke parameter + agentStdoutWriter = io.MultiWriter(execCtx.debugTailLogger, agentStdoutWriter) + agentStderrWriter = io.MultiWriter(execCtx.debugTailLogger, agentStderrWriter) + + agentProc := agents.NewExternalAgentProcess(agentPath, env, agentStdoutWriter, agentStderrWriter) agent, err := execCtx.registrationService.CreateExternalAgent(agentProc.Name()) + if err != nil { return err } @@ -176,7 +217,7 @@ func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) } bootstrapExtraFiles := bootstrap.ExtraFiles() - runtimeCmd := runtimecmd.NewCustomRuntimeCmd(ctx, bootstrapCmd, bootstrapCwd, bootstrapEnv, execCtx.runtimeLogWriter, bootstrapExtraFiles) + runtimeCmd := runtimecmd.NewCustomRuntimeCmd(ctx, bootstrapCmd, bootstrapCwd, bootstrapEnv, execCtx.runtimeStdoutWriter, execCtx.runtimeStderrWriter, bootstrapExtraFiles) log.Debug("Start runtime") err = runtimeCmd.Start() @@ -211,14 +252,15 @@ func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) // Logs API subscription phase finished for agents - no more agents can be subscribed to the Logs API if execCtx.telemetryAPIEnabled { - execCtx.telemetryService.TurnOff() + execCtx.logsSubscriptionAPI.TurnOff() } execCtx.initDone = true return nil } -func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke, mx *rendering.InvokeRendererMetrics) error { +func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke, mx *invokeMetrics) error { + execCtx.eventsAPI.SetCurrentRequestID(invokeRequest.ID) appCtx := execCtx.appCtx appctx.StoreErrorResponse(appCtx, nil) @@ -268,7 +310,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo log.Debug("Set renderer for invoke") renderer := rendering.NewInvokeRenderer(ctx, invokeRequest, xray.TracingHeaderParser()) defer func() { - *mx = renderer.GetMetrics() + mx.rendererMetrics = renderer.GetMetrics() }() execCtx.renderingService.SetRenderer(renderer) @@ -297,6 +339,10 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo }); err != nil { return err } + mx.runtimeReadyTime = metering.Monotime() + if err := execCtx.eventsAPI.SendRuntimeDone("success"); err != nil { + log.Errorf("Failed to send RUNDONE: %s", err) + } // Extensions overhead if execCtx.HasActiveExtensions() { @@ -341,8 +387,53 @@ func (c *rapidContext) acceptStartRequest(startRequest *interop.Start) { } } +func (c *rapidContext) acceptStartRequestForInitCaching(startRequest *interop.Start) error { + log.Info("Configure environment for Init Caching.") + c.startRequest = startRequest + randomUUID, err := uuid.NewRandom() + + if err != nil { + return err + } + + initCachingToken := randomUUID.String() + + c.environment.StoreEnvironmentVariablesFromInitForInitCaching( + RuntimeAPIHost, + RuntimeAPIPort, + startRequest.CustomerEnvironmentVariables, + startRequest.Handler, + startRequest.FunctionName, + startRequest.FunctionVersion, + initCachingToken) + + c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ + FunctionName: startRequest.FunctionName, + FunctionVersion: startRequest.FunctionVersion, + Handler: startRequest.Handler, + }) + + c.credentialsService.SetCredentials(initCachingToken, startRequest.AwsKey, startRequest.AwsSecret, startRequest.AwsSession) + + if extensionsDisabledByLayer() { + extensions.Disable() + } + + return nil +} + func handleStart(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, startRequest *interop.Start) { - execCtx.acceptStartRequest(startRequest) + if execCtx.initCachingEnabled { + if err := execCtx.acceptStartRequestForInitCaching(startRequest); err != nil { + handleStartError(execCtx, startRequest.InvokeID, startRequest.CorrelationID, err) + return + } + + execCtx.credentialsService.UnblockService() + defer execCtx.credentialsService.BlockService() + } else { + execCtx.acceptStartRequest(startRequest) + } interopServer, appCtx := execCtx.interopServer, execCtx.appCtx @@ -358,9 +449,7 @@ func handleStart(ctx context.Context, execCtx *rapidContext, watchdog *core.Watc if !startRequest.SuppressInit { if err := doInit(ctx, execCtx, watchdog); err != nil { - log.WithError(err).WithField("InvokeID", startRequest.InvokeID).Error("Init failed") - doneFailMsg := generateDoneFail(execCtx, startRequest.CorrelationID, nil, 0) - handleInitError(doneFailMsg, execCtx, startRequest.InvokeID, interopServer, err) + handleStartError(execCtx, startRequest.InvokeID, startRequest.CorrelationID, err) return } } @@ -370,10 +459,11 @@ func handleStart(ctx context.Context, execCtx *rapidContext, watchdog *core.Watc Meta: interop.DoneMetadata{ RuntimeRelease: appctx.GetRuntimeRelease(appCtx), NumActiveExtensions: execCtx.registrationService.CountAgents(), + ExtensionNames: execCtx.GetExtensionNames(), }, } if execCtx.telemetryAPIEnabled { - doneMsg.Meta.LogsAPIMetrics = execCtx.telemetryService.FlushMetrics() + doneMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() } if err := interopServer.SendDone(doneMsg); err != nil { log.Panic(err) @@ -384,7 +474,13 @@ func handleStart(ctx context.Context, execCtx *rapidContext, watchdog *core.Watc } } -func generateDoneFail(execCtx *rapidContext, correlationID string, invokeMx *rendering.InvokeRendererMetrics, invokeReceivedTime int64) *interop.DoneFail { +func handleStartError(execCtx *rapidContext, invokeID string, correlationID string, err error) { + log.WithError(err).WithField("InvokeID", invokeID).Error("Init failed") + doneFailMsg := generateDoneFail(execCtx, correlationID, nil, 0) + handleInitError(doneFailMsg, execCtx, invokeID, execCtx.interopServer, err) +} + +func generateDoneFail(execCtx *rapidContext, correlationID string, invokeMx *invokeMetrics, invokeReceivedTime int64) *interop.DoneFail { errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) if !found { errorType = fatalerror.Unknown @@ -401,12 +497,14 @@ func generateDoneFail(execCtx *rapidContext, correlationID string, invokeMx *ren } if invokeMx != nil { - doneFailMsg.Meta.InvokeRequestReadTimeNs = invokeMx.ReadTime.Nanoseconds() - doneFailMsg.Meta.InvokeRequestSizeBytes = int64(invokeMx.SizeBytes) + doneFailMsg.Meta.InvokeRequestReadTimeNs = invokeMx.rendererMetrics.ReadTime.Nanoseconds() + doneFailMsg.Meta.InvokeRequestSizeBytes = int64(invokeMx.rendererMetrics.SizeBytes) + doneFailMsg.Meta.RuntimeReadyTime = int64(invokeMx.runtimeReadyTime) + doneFailMsg.Meta.ExtensionNames = execCtx.GetExtensionNames() } if execCtx.telemetryAPIEnabled { - doneFailMsg.Meta.LogsAPIMetrics = execCtx.telemetryService.FlushMetrics() + doneFailMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() } return doneFailMsg @@ -415,7 +513,18 @@ func generateDoneFail(execCtx *rapidContext, correlationID string, invokeMx *ren func handleInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke) { interopServer, appCtx := execCtx.interopServer, execCtx.appCtx - invokeMx := rendering.InvokeRendererMetrics{} + invokeMx := invokeMetrics{} + + if invokeRequest.ResyncState.IsResyncReceived { + err := execCtx.credentialsService.UpdateCredentials(invokeRequest.ResyncState.AwsKey, invokeRequest.ResyncState.AwsSecret, invokeRequest.ResyncState.AwsSession) + execCtx.credentialsService.UnblockService() + + if err != nil { + log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Resync for Invoke failed") + doneFailMsg := generateDoneFail(execCtx, invokeRequest.CorrelationID, &invokeMx, invokeRequest.InvokeReceivedTime) + handleInvokeError(doneFailMsg, execCtx, invokeRequest.ID, interopServer, err) + } + } if err := doInvoke(ctx, execCtx, watchdog, invokeRequest, &invokeMx); err != nil { log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Invoke failed") @@ -438,14 +547,16 @@ func handleInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Wat Meta: interop.DoneMetadata{ RuntimeRelease: appctx.GetRuntimeRelease(appCtx), NumActiveExtensions: execCtx.registrationService.CountAgents(), - InvokeRequestReadTimeNs: invokeMx.ReadTime.Nanoseconds(), - InvokeRequestSizeBytes: int64(invokeMx.SizeBytes), + ExtensionNames: execCtx.GetExtensionNames(), + InvokeRequestReadTimeNs: invokeMx.rendererMetrics.ReadTime.Nanoseconds(), + InvokeRequestSizeBytes: int64(invokeMx.rendererMetrics.SizeBytes), InvokeCompletionTimeNs: invokeCompletionTimeNs, InvokeReceivedTime: invokeRequest.InvokeReceivedTime, + RuntimeReadyTime: invokeMx.runtimeReadyTime, }, } if execCtx.telemetryAPIEnabled { - doneMsg.Meta.LogsAPIMetrics = execCtx.telemetryService.FlushMetrics() + doneMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() } if err := interopServer.SendDone(doneMsg); err != nil { @@ -464,7 +575,7 @@ func reinitialize(execCtx *rapidContext, watchdog *core.Watchdog) { execCtx.initFlow.Clear() execCtx.invokeFlow.Clear() if execCtx.telemetryAPIEnabled { - execCtx.telemetryService.Clear() + execCtx.logsSubscriptionAPI.Clear() } watchdog.Clear() } @@ -476,6 +587,13 @@ func blockForever() { // handle notification of reset func handleReset(execCtx *rapidContext, watchdog *core.Watchdog, reset *interop.Reset) { log.Warnf("Reset initiated: %s", reset.Reason) + if execCtx.initCachingEnabled { + execCtx.credentialsService.UnblockService() + } + + if err := execCtx.eventsAPI.SendRuntimeDone(reset.Reason); err != nil { + log.Errorf("Failed to send RUNDONE: %s", err) + } profiler := metering.ExtensionsResetDurationProfiler{} gracefulShutdown(execCtx, watchdog, &profiler, reset.DeadlineNs, execCtx.standaloneMode, reset.Reason) diff --git a/lambda/rapid/start_test.go b/lambda/rapid/start_test.go index 3210b68..2363705 100644 --- a/lambda/rapid/start_test.go +++ b/lambda/rapid/start_test.go @@ -6,8 +6,11 @@ package rapid import ( "context" "fmt" + "go.amzn.com/lambda/core" "io/ioutil" "net/http" + "regexp" + "strconv" "strings" "testing" "time" @@ -75,6 +78,61 @@ func BenchmarkChannelsSelect2(b *testing.B) { } } +func TestGetExtensionNamesWithNoExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + + c := &rapidContext{ + registrationService: rs, + } + + assert.Equal(t, "", c.GetExtensionNames()) +} + +func TestGetExtensionNamesWithMultipleExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + _, _ = rs.CreateExternalAgent("Example1") + _, _ = rs.CreateInternalAgent("Example2") + _, _ = rs.CreateExternalAgent("Example3") + _, _ = rs.CreateInternalAgent("Example4") + + c := &rapidContext{ + registrationService: rs, + } + + r := regexp.MustCompile(`^(Example\d;){3}(Example\d)$`) + assert.True(t, r.MatchString(c.GetExtensionNames())) +} + +func TestGetExtensionNamesWithTooManyExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + for i := 10; i < 60; i++ { + _, _ = rs.CreateExternalAgent("E" + strconv.Itoa(i)) + } + + c := &rapidContext{ + registrationService: rs, + } + + output := c.GetExtensionNames() + + r := regexp.MustCompile(`^(E\d\d;){30}(E\d\d)$`) + assert.LessOrEqual(t, len(output), maxExtensionNamesLength) + assert.True(t, r.MatchString(output)) +} + +func TestGetExtensionNamesWithTooLongExtensionName(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + for i := 10; i < 60; i++ { + _, _ = rs.CreateExternalAgent(strings.Repeat("E", 130)) + } + + c := &rapidContext{ + registrationService: rs, + } + + assert.Equal(t, "", c.GetExtensionNames()) +} + // This test confirms our assumption that http client can establish a tcp connection // to a listening server. func TestListen(t *testing.T) { @@ -84,7 +142,7 @@ func TestListen(t *testing.T) { ctx := context.Background() telemetryAPIEnabled := true - server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.TelemetryService) + server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.LogsSubscriptionAPI, false, flowTest.CredentialsService) err := server.Listen() assert.NoError(t, err) diff --git a/lambda/rapidcore/bootstrap.go b/lambda/rapidcore/bootstrap.go index e39f7e9..9faf518 100644 --- a/lambda/rapidcore/bootstrap.go +++ b/lambda/rapidcore/bootstrap.go @@ -52,6 +52,11 @@ func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string) *Bootstrap } func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string) *Bootstrap { + if currentWorkingDir == "" { + // use the root directory as the default working directory + currentWorkingDir = "/" + } + // a single candidate command makes it automatically valid return &Bootstrap{ validCmd: cmd, diff --git a/lambda/rapidcore/bootstrap_test.go b/lambda/rapidcore/bootstrap_test.go index a0e466e..4700130 100644 --- a/lambda/rapidcore/bootstrap_test.go +++ b/lambda/rapidcore/bootstrap_test.go @@ -216,3 +216,10 @@ func TestDefaultWorkeringDirectory(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "/", cwd) } + +func TestBootstrapSingleCmdDefaultWorkingDir(t *testing.T) { + b := NewBootstrapSingleCmd([]string{}, "") + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, "/", bCwd) +} diff --git a/lambda/rapidcore/env/environment.go b/lambda/rapidcore/env/environment.go index 5c80229..699abda 100644 --- a/lambda/rapidcore/env/environment.go +++ b/lambda/rapidcore/env/environment.go @@ -4,6 +4,7 @@ package env import ( + "fmt" "os" "strconv" "strings" @@ -103,10 +104,22 @@ func (e *Environment) SetExecutionEnv(executionEnv string) { // StoreEnvironmentVariablesFromInit sets the environment variables // for credentials & _HANDLER which are received in the START message func (e *Environment) StoreEnvironmentVariablesFromInit(customerEnv map[string]string, handler, awsKey, awsSecret, awsSession, funcName, funcVer string) { + e.Credentials["AWS_ACCESS_KEY_ID"] = awsKey e.Credentials["AWS_SECRET_ACCESS_KEY"] = awsSecret e.Credentials["AWS_SESSION_TOKEN"] = awsSession + e.storeNonCredentialEnvironmentVariablesFromInit(customerEnv, handler, funcName, funcVer) +} + +func (e *Environment) StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) { + e.Credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"] = fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port) + e.Credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"] = token + + e.storeNonCredentialEnvironmentVariablesFromInit(customerEnv, handler, funcName, funcVer) +} + +func (e *Environment) storeNonCredentialEnvironmentVariablesFromInit(customerEnv map[string]string, handler, funcName, funcVer string) { if handler != "" { e.SetHandler(handler) } diff --git a/lambda/rapidcore/env/environment_test.go b/lambda/rapidcore/env/environment_test.go index 53a2f1e..cdfef24 100644 --- a/lambda/rapidcore/env/environment_test.go +++ b/lambda/rapidcore/env/environment_test.go @@ -4,6 +4,7 @@ package env import ( + "fmt" "os" "strings" "testing" @@ -277,6 +278,25 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { assert.Contains(t, agentEnvVars, runtimeAPIAddressKey+"="+env.Platform[runtimeAPIAddressKey]) } +func TestStoreEnvironmentVariablesFromInitCaching(t *testing.T) { + host := "samplehost" + port := 1234 + handler := "samplehandler" + funcName := "samplefunctionname" + funcVer := "samplefunctionver" + token := "sampletoken" + env := NewEnvironment() + customerEnv := CustomerEnvironmentVariables() + + env.StoreEnvironmentVariablesFromInitForInitCaching("samplehost", 1234, customerEnv, handler, funcName, funcVer, token) + + assert.Equal(t, fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port), env.Credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"]) + assert.Equal(t, token, env.Credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"]) + assert.Equal(t, funcName, env.Platform["AWS_LAMBDA_FUNCTION_NAME"]) + assert.Equal(t, funcVer, env.Platform["AWS_LAMBDA_FUNCTION_VERSION"]) + assert.Equal(t, handler, env.Runtime["_HANDLER"]) +} + func setAll(keys map[string]bool, value string) { for key := range keys { os.Setenv(key, value) diff --git a/lambda/rapidcore/sandbox.go b/lambda/rapidcore/sandbox.go index fee0e73..7d5a8a9 100644 --- a/lambda/rapidcore/sandbox.go +++ b/lambda/rapidcore/sandbox.go @@ -67,17 +67,23 @@ const ( func NewSandboxBuilder(bootstrap *Bootstrap) *SandboxBuilder { defaultInteropServer := NewServer(context.Background()) signalCtx, cancelSignalCtx := context.WithCancel(context.Background()) + logsEgressAPI := &telemetry.NoOpLogsEgressAPI{} + runtimeStdoutWriter, runtimeStderrWriter, _ := logsEgressAPI.GetRuntimeSockets() + b := &SandboxBuilder{ sandbox: &rapid.Sandbox{ - Bootstrap: bootstrap, - PreLoadTimeNs: 0, // TODO - StandaloneMode: true, - ExtensionLogWriter: os.Stdout, - RuntimeLogWriter: os.Stdout, - EnableTelemetryAPI: false, - Environment: env.NewEnvironment(), - Tracer: telemetry.NewNoOpTracer(), - SignalCtx: signalCtx, + Bootstrap: bootstrap, + PreLoadTimeNs: 0, // TODO + StandaloneMode: true, + RuntimeStdoutWriter: runtimeStdoutWriter, + RuntimeStderrWriter: runtimeStderrWriter, + LogsEgressAPI: logsEgressAPI, + EnableTelemetryAPI: false, + Environment: env.NewEnvironment(), + Tracer: telemetry.NewNoOpTracer(), + SignalCtx: signalCtx, + EventsAPI: &telemetry.NoOpEventsAPI{}, + InitCachingEnabled: false, }, defaultInteropServer: defaultInteropServer, shutdownFuncs: []context.CancelFunc{}, @@ -100,6 +106,11 @@ func (b *SandboxBuilder) SetInteropServer(interopServer interop.Server) *Sandbox return b } +func (b *SandboxBuilder) SetEventsAPI(eventsAPI telemetry.EventsAPI) *SandboxBuilder { + b.sandbox.EventsAPI = eventsAPI + return b +} + func (b *SandboxBuilder) SetTracer(tracer telemetry.Tracer) *SandboxBuilder { b.sandbox.Tracer = tracer return b @@ -119,6 +130,11 @@ func (b *SandboxBuilder) SetExtensionsFlag(extensionsEnabled bool) *SandboxBuild return b } +func (b *SandboxBuilder) SetInitCachingFlag(initCachingEnabled bool) *SandboxBuilder { + b.sandbox.InitCachingEnabled = initCachingEnabled + return b +} + func (b *SandboxBuilder) SetPreLoadTimeNs(preLoadTimeNs int64) *SandboxBuilder { b.sandbox.PreLoadTimeNs = preLoadTimeNs return b @@ -139,19 +155,22 @@ func (b *SandboxBuilder) SetTailLogOutput(w io.Writer) *SandboxBuilder { return b } -func (b *SandboxBuilder) SetLogWriter(logSink logSink, w io.Writer) *SandboxBuilder { - switch logSink { - case RuntimeLogSink: - b.sandbox.RuntimeLogWriter = w - case ExtensionLogSink: - b.sandbox.ExtensionLogWriter = w - } +func (b *SandboxBuilder) SetLogsSubscriptionAPI(logsSubscriptionAPI telemetry.LogsSubscriptionAPI) *SandboxBuilder { + b.sandbox.EnableTelemetryAPI = true + b.sandbox.LogsSubscriptionAPI = logsSubscriptionAPI return b } -func (b *SandboxBuilder) SetTelemetryService(telemetryService telemetry.LogsAPIService) *SandboxBuilder { - b.sandbox.EnableTelemetryAPI = true - b.sandbox.TelemetryService = telemetryService +func (b *SandboxBuilder) SetLogsEgressAPI(logsEgressAPI telemetry.LogsEgressAPI) *SandboxBuilder { + runtimeStdoutWriter, runtimeStderrWriter, err := logsEgressAPI.GetRuntimeSockets() + + if err != nil { + log.WithError(err).Fatal("failed to get the Runtime sockets from the logs egress API") + } + + b.sandbox.LogsEgressAPI = logsEgressAPI + b.sandbox.RuntimeStdoutWriter = runtimeStdoutWriter + b.sandbox.RuntimeStderrWriter = runtimeStderrWriter return b } @@ -170,8 +189,8 @@ func (b *SandboxBuilder) setupLoggingWithDebugLogs() { // will not write logs when disabled by invoke parameter b.sandbox.DebugTailLogger = logging.NewTailLogWriter(b.debugTailLogWriter) b.sandbox.PlatformLogger = logging.NewPlatformLogger(b.platformLogWriter, b.sandbox.DebugTailLogger) - b.sandbox.RuntimeLogWriter = io.MultiWriter(b.sandbox.DebugTailLogger, b.sandbox.RuntimeLogWriter) - b.sandbox.ExtensionLogWriter = io.MultiWriter(b.sandbox.ExtensionLogWriter, b.sandbox.DebugTailLogger) + b.sandbox.RuntimeStdoutWriter = io.MultiWriter(b.sandbox.DebugTailLogger, b.sandbox.RuntimeStdoutWriter) + b.sandbox.RuntimeStderrWriter = io.MultiWriter(b.sandbox.DebugTailLogger, b.sandbox.RuntimeStderrWriter) } func (b *SandboxBuilder) Create() { @@ -221,9 +240,7 @@ func SetLogLevel(logLevel string) { } log.SetLevel(level) - Formatter := new(log.TextFormatter) - Formatter.TimestampFormat = "2006-01-02T15:04:05.999" - log.SetFormatter(Formatter) + log.SetFormatter(&logging.InternalFormatter{}) } func SetInternalLogOutput(w io.Writer) { diff --git a/lambda/rapidcore/server.go b/lambda/rapidcore/server.go index ee5b243..e3e01b6 100644 --- a/lambda/rapidcore/server.go +++ b/lambda/rapidcore/server.go @@ -33,8 +33,6 @@ const ( resetDefaultTimeoutMs = 2000 ) -// rapidPhase tracks the state machine in the go.amzn.com/lambda/rapid receive loop. See -// a state diagram of how the events and states of rapid package and this interop server type rapidPhase int const ( @@ -311,7 +309,7 @@ func (s *Server) TransportErrorChan() <-chan error { return s.errorChanOut } -func (s *Server) sendResponseUnsafe(invokeID string, status int, payload io.Reader) error { +func (s *Server) sendResponseUnsafe(invokeID string, contentType string, status int, payload io.Reader) error { if s.invokeCtx == nil || invokeID != s.invokeCtx.Token.InvokeID { return interop.ErrInvalidInvokeID } @@ -330,12 +328,17 @@ func (s *Server) sendResponseUnsafe(invokeID string, status int, payload io.Read // s.invokeCtx.ReplyStream.WriteHeader(status) + var reportedErr error if s.invokeCtx.Direct { - if err := directinvoke.SendDirectInvokeResponse(nil, payload, s.invokeCtx.ReplyStream); err != nil { - // we intentionally do not return an error here: - // even if error happened, the response has already been initiated (and might be partially written into the socket) - // so there is no other option except to consider response to be sent. + if err := directinvoke.SendDirectInvokeResponse(map[string]string{"Content-Type": contentType}, payload, s.invokeCtx.ReplyStream); err != nil { + // TODO: Do we need to drain the reader in case of a large payload and connection reuse? log.Errorf("Failed to write response to %s: %s", invokeID, err) + flusher, ok := s.invokeCtx.ReplyStream.(http.Flusher) + if !ok { + log.Error("Failed to flush response") + } + flusher.Flush() + reportedErr = err } } else { data, err := ioutil.ReadAll(payload) @@ -348,6 +351,8 @@ func (s *Server) sendResponseUnsafe(invokeID string, status int, payload io.Read MaxResponseSize: interop.MaxPayloadSize, } } + + s.invokeCtx.ReplyStream.Header().Add("Content-Type", contentType) if _, err := s.invokeCtx.ReplyStream.Write(data); err != nil { return fmt.Errorf("Failed to write response to %s: %s", invokeID, err) } @@ -356,14 +361,14 @@ func (s *Server) sendResponseUnsafe(invokeID string, status int, payload io.Read s.sendResponseChan <- struct{}{} s.invokeCtx.ReplySent = true s.invokeCtx.Direct = false - return nil + return reportedErr } -func (s *Server) SendResponse(invokeID string, reader io.Reader) error { +func (s *Server) SendResponse(invokeID string, contentType string, reader io.Reader) error { s.setRuntimeState(runtimeInvokeResponseSent) s.mutex.Lock() defer s.mutex.Unlock() - return s.sendResponseUnsafe(invokeID, http.StatusOK, reader) + return s.sendResponseUnsafe(invokeID, contentType, http.StatusOK, reader) } func (s *Server) CommitResponse() error { return nil } @@ -384,7 +389,7 @@ func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorResponse) s.setRuntimeState(runtimeInvokeError) s.mutex.Lock() defer s.mutex.Unlock() - return s.sendResponseUnsafe(invokeID, http.StatusInternalServerError, bytes.NewReader(resp.Payload)) + return s.sendResponseUnsafe(invokeID, resp.ContentType, http.StatusInternalServerError, bytes.NewReader(resp.Payload)) default: panic("received unexpected error response outside invoke or init phases") } @@ -641,6 +646,7 @@ func (s *Server) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invo s.Reset(autoresetReasonReleaseFail, resetDefaultTimeoutMs) } } + return err } diff --git a/lambda/rapidcore/server_test.go b/lambda/rapidcore/server_test.go index 85530f1..416304c 100644 --- a/lambda/rapidcore/server_test.go +++ b/lambda/rapidcore/server_test.go @@ -129,7 +129,7 @@ func TestInvokeSuccess(t *testing.T) { require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) <-srv.InvokeChan() - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), bytes.NewReader([]byte("response")))) + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "application/json", bytes.NewReader([]byte("response")))) require.NoError(t, srv.SendRuntimeReady()) require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) }() @@ -146,6 +146,7 @@ func TestInvokeSuccess(t *testing.T) { invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) require.NoError(t, invokeErr) require.Equal(t, "response", responseRecorder.Body.String()) + require.Equal(t, "application/json", responseRecorder.Result().Header.Get("Content-Type")) _, err = srv.AwaitRelease() require.NoError(t, err) @@ -163,7 +164,7 @@ func TestInvokeError(t *testing.T) { <-srv.InvokeChan() - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }"), ContentType: "application/json"})) require.NoError(t, srv.SendRuntimeReady()) require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) }() @@ -180,6 +181,7 @@ func TestInvokeError(t *testing.T) { invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) require.NoError(t, invokeErr) require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) + require.Equal(t, "application/json", responseRecorder.Result().Header.Get("Content-Type")) _, err = srv.AwaitRelease() require.NoError(t, err) @@ -212,7 +214,7 @@ func TestInvokeWithSuppressedInitSuccess(t *testing.T) { require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"})) <-srv.InvokeChan() // run only after FastInvoke is called - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), bytes.NewReader([]byte("response")))) + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "", bytes.NewReader([]byte("response")))) require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) }() @@ -276,7 +278,6 @@ func TestInvokeWithSuppressedInitErrorDueToInitError(t *testing.T) { }() srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) _, err := srv.Reserve("", "", "") require.EqualError(t, err, ErrInitError.Error()) @@ -366,7 +367,7 @@ func TestMultipleInvokeSuccess(t *testing.T) { invokeFunc := func(i int) { <-srv.InvokeChan() - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), bytes.NewReader([]byte("response-"+fmt.Sprint(i))))) + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "", bytes.NewReader([]byte("response-"+fmt.Sprint(i))))) require.NoError(t, srv.SendRuntimeReady()) require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) } diff --git a/lambda/rapidcore/standalone/executeHandler.go b/lambda/rapidcore/standalone/executeHandler.go index 0b89322..36c257a 100644 --- a/lambda/rapidcore/standalone/executeHandler.go +++ b/lambda/rapidcore/standalone/executeHandler.go @@ -43,6 +43,7 @@ func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) // DONE failures: case rapidcore.ErrTerminated, rapidcore.ErrInitDoneFailed, rapidcore.ErrInvokeDoneFailed: + copyHeaders(invokeResp, w) w.WriteHeader(DoneFailedHTTPCode) w.Write(invokeResp.Body) return @@ -54,8 +55,15 @@ func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) return } + copyHeaders(invokeResp, w) if invokeResp.StatusCode != 0 { w.WriteHeader(invokeResp.StatusCode) } w.Write(invokeResp.Body) } + +func copyHeaders(proxyWriter, writer http.ResponseWriter) { + for key, val := range proxyWriter.Header() { + writer.Header().Set(key, val[0]) + } +} diff --git a/lambda/rapidcore/standalone/invokeHandler.go b/lambda/rapidcore/standalone/invokeHandler.go index 25819e3..0d89f1c 100644 --- a/lambda/rapidcore/standalone/invokeHandler.go +++ b/lambda/rapidcore/standalone/invokeHandler.go @@ -22,12 +22,28 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropSe return } + isResyncReceivedFlag := false + + awsKey := r.Header.Get("ResyncAwsKey") + awsSecret := r.Header.Get("ResyncAwsSecret") + awsSession := r.Header.Get("ResyncAwsSession") + + if len(awsKey) > 0 && len(awsSecret) > 0 && len(awsSession) > 0 { + isResyncReceivedFlag = true + } + invokePayload := &interop.Invoke{ TraceID: r.Header.Get("X-Amzn-Trace-Id"), LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), Payload: r.Body, CorrelationID: "invokeCorrelationID", DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), + ResyncState: interop.Resync{ + IsResyncReceived: isResyncReceivedFlag, + AwsKey: awsKey, + AwsSecret: awsSecret, + AwsSession: awsSession, + }, } if err := s.FastInvoke(w, invokePayload, false); err != nil { diff --git a/lambda/rapidcore/standalone/util.go b/lambda/rapidcore/standalone/util.go index 84ecc22..21ee08f 100644 --- a/lambda/rapidcore/standalone/util.go +++ b/lambda/rapidcore/standalone/util.go @@ -34,10 +34,14 @@ func (t ErrorType) String() string { type ResponseWriterProxy struct { Body []byte StatusCode int + header http.Header } func (w *ResponseWriterProxy) Header() http.Header { - return http.Header{} + if w.header == nil { + w.header = http.Header{} + } + return w.header } func (w *ResponseWriterProxy) Write(b []byte) (int, error) { diff --git a/lambda/runtimecmd/runtime_command.go b/lambda/runtimecmd/runtime_command.go index 3a7a051..adf7886 100644 --- a/lambda/runtimecmd/runtime_command.go +++ b/lambda/runtimecmd/runtime_command.go @@ -19,12 +19,12 @@ type CustomRuntimeCmd struct { } // NewCustomRuntimeCmd returns a new CustomRuntimeCmd -func NewCustomRuntimeCmd(ctx context.Context, bootstrapCmd []string, dir string, env []string, runtimeLogWriter io.Writer, extraFiles []*os.File) *CustomRuntimeCmd { +func NewCustomRuntimeCmd(ctx context.Context, bootstrapCmd []string, dir string, env []string, stdoutWriter io.Writer, stderrWriter io.Writer, extraFiles []*os.File) *CustomRuntimeCmd { cmd := exec.CommandContext(ctx, bootstrapCmd[0], bootstrapCmd[1:]...) cmd.Dir = dir - cmd.Stdout = runtimeLogWriter - cmd.Stderr = runtimeLogWriter + cmd.Stdout = stdoutWriter + cmd.Stderr = stderrWriter cmd.Env = env diff --git a/lambda/runtimecmd/runtime_command_test.go b/lambda/runtimecmd/runtime_command_test.go index 0e8f170..f99599d 100644 --- a/lambda/runtimecmd/runtime_command_test.go +++ b/lambda/runtimecmd/runtime_command_test.go @@ -20,7 +20,7 @@ func TestRuntimeCommandSetsEnvironmentVariables(t *testing.T) { assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) execCmdArgs := []string{"foobar"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, nil) + runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) assert.ElementsMatch(t, envVars, runtimeCmd.Env) assert.Equal(t, execCmdArgs, runtimeCmd.Args) @@ -33,7 +33,7 @@ func TestRuntimeCommandSetsCurrentWorkingDir(t *testing.T) { assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) execCmdArgs := []string{"foobar"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, nil) + runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) assert.Equal(t, currentDir, runtimeCmd.Dir) } @@ -45,7 +45,7 @@ func TestRuntimeCommandSetsMultipleArgs(t *testing.T) { assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) execCmdArgs := []string{"foobar", "--baz", "22"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, nil) + runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) assert.Equal(t, execCmdArgs, runtimeCmd.Args) } diff --git a/lambda/telemetry/events_api.go b/lambda/telemetry/events_api.go new file mode 100644 index 0000000..132977e --- /dev/null +++ b/lambda/telemetry/events_api.go @@ -0,0 +1,14 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +type EventsAPI interface { + SetCurrentRequestID(requestID string) + SendRuntimeDone(status string) error +} + +type NoOpEventsAPI struct{} + +func (s *NoOpEventsAPI) SetCurrentRequestID(requestID string) {} +func (s *NoOpEventsAPI) SendRuntimeDone(status string) error { return nil } diff --git a/lambda/telemetry/logs_api.go b/lambda/telemetry/logs_api.go deleted file mode 100644 index c0dfe0d..0000000 --- a/lambda/telemetry/logs_api.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "io" - - "go.amzn.com/lambda/interop" -) - -// LogsAPIService represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible -type LogsAPIService interface { - Subscribe(agentName string, body io.Reader, headers map[string][]string) (resp []byte, status int, respHeaders map[string][]string, err error) - RecordCounterMetric(metricName string, count int) - FlushMetrics() interop.LogsAPIMetrics - Clear() - TurnOff() -} diff --git a/lambda/telemetry/logs_egress_api.go b/lambda/telemetry/logs_egress_api.go new file mode 100644 index 0000000..ac9a754 --- /dev/null +++ b/lambda/telemetry/logs_egress_api.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "io" + "os" +) + +type LogsEgressAPI interface { + GetExtensionSockets() (io.Writer, io.Writer, error) + GetRuntimeSockets() (io.Writer, io.Writer, error) +} + +type NoOpLogsEgressAPI struct{} + +func (s *NoOpLogsEgressAPI) GetExtensionSockets() (io.Writer, io.Writer, error) { + // os.Stderr can not be used for the stderrWriter because stderr is for internal logging (not customer visible). + return os.Stdout, os.Stdout, nil +} + +func (s *NoOpLogsEgressAPI) GetRuntimeSockets() (io.Writer, io.Writer, error) { + // os.Stderr can not be used for the stderrWriter because stderr is for internal logging (not customer visible). + return os.Stdout, os.Stdout, nil +} diff --git a/lambda/telemetry/logs_subscription_api.go b/lambda/telemetry/logs_subscription_api.go new file mode 100644 index 0000000..3ea7a20 --- /dev/null +++ b/lambda/telemetry/logs_subscription_api.go @@ -0,0 +1,37 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "io" + "net/http" + + "go.amzn.com/lambda/interop" +) + +// LogsSubscriptionAPI represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible +type LogsSubscriptionAPI interface { + Subscribe(agentName string, body io.Reader, headers map[string][]string) (resp []byte, status int, respHeaders map[string][]string, err error) + RecordCounterMetric(metricName string, count int) + FlushMetrics() interop.LogsAPIMetrics + Clear() + TurnOff() +} + +type NoOpLogsSubscriptionAPI struct{} + +// Subscribe writes response to a shared memory +func (m *NoOpLogsSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { + return []byte(`{}`), http.StatusOK, map[string][]string{}, nil +} + +func (m *NoOpLogsSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} + +func (m *NoOpLogsSubscriptionAPI) FlushMetrics() interop.LogsAPIMetrics { + return interop.LogsAPIMetrics(map[string]int{}) +} + +func (m *NoOpLogsSubscriptionAPI) Clear() {} + +func (m *NoOpLogsSubscriptionAPI) TurnOff() {} diff --git a/lambda/testdata/agents/bash_echo.sh b/lambda/testdata/agents/bash_echo.sh deleted file mode 100755 index ceb4228..0000000 --- a/lambda/testdata/agents/bash_echo.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env bash - -printf "hello world\nbarbaz\n" -printf "hello world\nbarbaz2" \ No newline at end of file diff --git a/lambda/testdata/agents/bash_stderr.sh b/lambda/testdata/agents/bash_stderr.sh new file mode 100755 index 0000000..65c0ff1 --- /dev/null +++ b/lambda/testdata/agents/bash_stderr.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +printf "stderr line 1\n" >&2 +printf "stderr line 2\n" >&2 +printf "stderr line 3\n" >&2 diff --git a/lambda/testdata/agents/bash_stdout.sh b/lambda/testdata/agents/bash_stdout.sh new file mode 100755 index 0000000..d0cb893 --- /dev/null +++ b/lambda/testdata/agents/bash_stdout.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +printf "stdout line 1\n" +printf "stdout line 2\n" +printf "stdout line 3\n" diff --git a/lambda/testdata/agents/bash_stdout_and_stderr.sh b/lambda/testdata/agents/bash_stdout_and_stderr.sh new file mode 100755 index 0000000..cf87e60 --- /dev/null +++ b/lambda/testdata/agents/bash_stdout_and_stderr.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +printf "stdout line 1\n" +printf "stderr line 1\n" >&2 +printf "stdout line 2\n" +printf "stderr line 2\n" >&2 +printf "stdout line 3\n" +printf "stderr line 3\n" >&2 diff --git a/lambda/testdata/flowtesting.go b/lambda/testdata/flowtesting.go index f729632..ee163bb 100644 --- a/lambda/testdata/flowtesting.go +++ b/lambda/testdata/flowtesting.go @@ -19,16 +19,17 @@ import ( ) type MockInteropServer struct { - Response []byte - ErrorResponse *interop.ErrorResponse - ActiveInvokeID string + Response []byte + ErrorResponse *interop.ErrorResponse + ResponseContentType string + ActiveInvokeID string } // StartAcceptingDirectInvokes func (i *MockInteropServer) StartAcceptingDirectInvokes() error { return nil } // SendResponse writes response to a shared memory. -func (i *MockInteropServer) SendResponse(invokeID string, reader io.Reader) error { +func (i *MockInteropServer) SendResponse(invokeID string, contentType string, reader io.Reader) error { bytes, err := ioutil.ReadAll(reader) if err != nil { return err @@ -40,12 +41,14 @@ func (i *MockInteropServer) SendResponse(invokeID string, reader io.Reader) erro } } i.Response = bytes + i.ResponseContentType = contentType return nil } // SendErrorResponse writes error response to a shared memory and sends GIRD FAULT. func (i *MockInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error { i.ErrorResponse = response + i.ResponseContentType = response.ContentType return nil } @@ -93,7 +96,9 @@ func (m *MockInteropServer) Init(i *interop.Start, invokeTimeoutMs int64) {} func (m *MockInteropServer) Invoke(w http.ResponseWriter, i *interop.Invoke) error { return nil } -func (m *MockInteropServer) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { return nil } +func (m *MockInteropServer) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { + return nil +} // FlowTest provides configuration for tests that involve synchronization flows. type FlowTest struct { @@ -104,7 +109,8 @@ type FlowTest struct { RenderingService *rendering.EventRenderingService Runtime *core.Runtime InteropServer *MockInteropServer - TelemetryService *MockNoOpTelemetryService + LogsSubscriptionAPI *telemetry.NoOpLogsSubscriptionAPI + CredentialsService core.CredentialsService } // ConfigureForInit initialize synchronization gates and states for init. @@ -119,28 +125,13 @@ func (s *FlowTest) ConfigureForInvoke(ctx context.Context, invoke *interop.Invok s.RenderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invoke, telemetry.GetCustomerTracingHeader)) } -// MockNoOpTelemetryService is a no-op telemetry API used in tests where it does not matter -type MockNoOpTelemetryService struct{} - -// Subscribe writes response to a shared memory -func (m *MockNoOpTelemetryService) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { - return []byte(`{}`), http.StatusOK, map[string][]string{}, nil -} - -func (s *MockNoOpTelemetryService) RecordCounterMetric(metricName string, count int) { - // NOOP -} - -func (s *MockNoOpTelemetryService) FlushMetrics() interop.LogsAPIMetrics { - return interop.LogsAPIMetrics(map[string]int{}) -} - -func (m *MockNoOpTelemetryService) Clear() { - // NOOP +func (s *FlowTest) ConfigureForInitCaching(token, awsKey, awsSecret, awsSession string) { + s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession) } -func (m *MockNoOpTelemetryService) TurnOff() { - // NOOP +func (s *FlowTest) ConfigureForBlockedInitCaching(token, awsKey, awsSecret, awsSession string) { + s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession) + s.CredentialsService.BlockService() } // NewFlowTest returns new FlowTest configuration. @@ -150,6 +141,7 @@ func NewFlowTest() *FlowTest { invokeFlow := core.NewInvokeFlowSynchronization() registrationService := core.NewRegistrationService(initFlow, invokeFlow) renderingService := rendering.NewRenderingService() + credentialsService := core.NewCredentialsService() runtime := core.NewRuntime(initFlow, invokeFlow) runtime.ManagedThread = &mockthread.MockManagedThread{} interopServer := &MockInteropServer{} @@ -160,8 +152,9 @@ func NewFlowTest() *FlowTest { InvokeFlow: invokeFlow, RegistrationService: registrationService, RenderingService: renderingService, - TelemetryService: &MockNoOpTelemetryService{}, + LogsSubscriptionAPI: &telemetry.NoOpLogsSubscriptionAPI{}, Runtime: runtime, InteropServer: interopServer, + CredentialsService: credentialsService, } }