diff --git a/Makefile b/Makefile index 9ff6c1a..80ccb89 100644 --- a/Makefile +++ b/Makefile @@ -21,10 +21,10 @@ compile-lambda-linux-all: make ARCH=old compile-lambda-linux compile-with-docker: - docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.19 make ARCH=${ARCH} compile-lambda-linux + docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.20 make ARCH=${ARCH} compile-lambda-linux compile-lambda-linux: - CGO_ENABLED=0 GOOS=linux GOARCH=${GO_ARCH_${ARCH}} go build -ldflags "${RELEASE_BUILD_LINKER_FLAGS}" -o ${DESTINATION_${ARCH}} ./cmd/aws-lambda-rie + CGO_ENABLED=0 GOOS=linux GOARCH=${GO_ARCH_${ARCH}} go build -buildvcs=false -ldflags "${RELEASE_BUILD_LINKER_FLAGS}" -o ${DESTINATION_${ARCH}} ./cmd/aws-lambda-rie tests: go test ./... diff --git a/cmd/aws-lambda-rie/main.go b/cmd/aws-lambda-rie/main.go index 65879c0..bd15402 100644 --- a/cmd/aws-lambda-rie/main.go +++ b/cmd/aws-lambda-rie/main.go @@ -11,6 +11,7 @@ import ( "runtime/debug" "github.com/jessevdk/go-flags" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" log "github.com/sirupsen/logrus" @@ -103,7 +104,7 @@ func isBootstrapFileExist(filePath string) bool { return !os.IsNotExist(err) && !file.IsDir() } -func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) { +func getBootstrap(args []string, opts options) (interop.Bootstrap, string) { var bootstrapLookupCmd []string var handler string currentWorkingDir := "/var/task" // default value @@ -149,5 +150,5 @@ func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) { log.Panic("insufficient arguments: bootstrap not provided") } - return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir, ""), handler + return NewSimpleBootstrap(bootstrapLookupCmd, currentWorkingDir), handler } diff --git a/cmd/aws-lambda-rie/simple_bootstrap.go b/cmd/aws-lambda-rie/simple_bootstrap.go new file mode 100644 index 0000000..c9111a2 --- /dev/null +++ b/cmd/aws-lambda-rie/simple_bootstrap.go @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "fmt" + "os" + "path/filepath" + + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore/env" +) + +// the type implement a simpler version of the Bootstrap +// this is useful in the Standalone Core implementation. +type simpleBootstrap struct { + cmd []string + workingDir string +} + +func NewSimpleBootstrap(cmd []string, currentWorkingDir string) interop.Bootstrap { + if currentWorkingDir == "" { + // use the root directory as the default working directory + currentWorkingDir = "/" + } + + // a single candidate command makes it automatically valid + return &simpleBootstrap{ + cmd: cmd, + workingDir: currentWorkingDir, + } +} + +func (b *simpleBootstrap) Cmd() ([]string, error) { + return b.cmd, nil +} + +// Cwd returns the working directory of the bootstrap process +// The path is validated against the chroot identified by `root` +func (b *simpleBootstrap) Cwd() (string, error) { + if !filepath.IsAbs(b.workingDir) { + return "", fmt.Errorf("the working directory '%s' is invalid, it needs to be an absolute path", b.workingDir) + } + + // evaluate the path relatively to the domain's mnt namespace root + if _, err := os.Stat(b.workingDir); os.IsNotExist(err) { + return "", fmt.Errorf("the working directory doesn't exist: %s", b.workingDir) + } + + return b.workingDir, nil +} + +// Env returns the environment variables available to +// the bootstrap process +func (b *simpleBootstrap) Env(e *env.Environment) map[string]string { + return e.RuntimeExecEnv() +} + +// ExtraFiles returns the extra file descriptors apart from 1 & 2 to be passed to runtime +func (b *simpleBootstrap) ExtraFiles() []*os.File { + return make([]*os.File, 0) +} + +func (b *simpleBootstrap) CachedFatalError(err error) (fatalerror.ErrorType, string, bool) { + // not implemented as it is not needed in Core but we need to fullfil the interface anyway + return fatalerror.ErrorType(""), "", false +} diff --git a/cmd/aws-lambda-rie/simple_bootstrap_test.go b/cmd/aws-lambda-rie/simple_bootstrap_test.go new file mode 100644 index 0000000..de00ee2 --- /dev/null +++ b/cmd/aws-lambda-rie/simple_bootstrap_test.go @@ -0,0 +1,78 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "os" + "reflect" + "testing" + + "go.amzn.com/lambda/rapidcore/env" + + "github.com/stretchr/testify/assert" +) + +func TestSimpleBootstrap(t *testing.T) { + tmpFile, err := os.CreateTemp("", "oci-test-bootstrap") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // Setup single cmd candidate + file := []string{tmpFile.Name(), "--arg1 s", "foo"} + cmdCandidate := file + + // Setup working dir + cwd, err := os.Getwd() + assert.NoError(t, err) + + // Setup environment + environment := env.NewEnvironment() + environment.StoreRuntimeAPIEnvironmentVariable("host:port") + environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") + + // Test + b := NewSimpleBootstrap(cmdCandidate, cwd) + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, cwd, bCwd) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) + + cmd, err := b.Cmd() + assert.NoError(t, err) + assert.Equal(t, file, cmd) +} + +func TestSimpleBootstrapCmdNonExistingCandidate(t *testing.T) { + // Setup inexistent single cmd candidate + file := []string{"/foo/bar", "--arg1 s", "foo"} + cmdCandidate := file + + // Setup working dir + cwd, err := os.Getwd() + assert.NoError(t, err) + + // Setup environment + environment := env.NewEnvironment() + environment.StoreRuntimeAPIEnvironmentVariable("host:port") + environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") + + // Test + b := NewSimpleBootstrap(cmdCandidate, cwd) + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, cwd, bCwd) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) + + // No validations run against single candidates + cmd, err := b.Cmd() + assert.NoError(t, err) + assert.Equal(t, file, cmd) +} + +func TestSimpleBootstrapCmdDefaultWorkingDir(t *testing.T) { + b := NewSimpleBootstrap([]string{}, "") + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, "/", bCwd) +} diff --git a/go.mod b/go.mod index 053c7e0..990a7dd 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module go.amzn.com -go 1.19 +go 1.20 require ( github.com/aws/aws-lambda-go v1.41.0 @@ -16,7 +16,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.0 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect + golang.org/x/net v0.18.0 // indirect + golang.org/x/sys v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d8fb9e9..0ea11d6 100644 --- a/go.sum +++ b/go.sum @@ -22,15 +22,15 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= +golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lambda/agents/agent.go b/lambda/agents/agent.go index b1f8563..cabe1fa 100644 --- a/lambda/agents/agent.go +++ b/lambda/agents/agent.go @@ -20,10 +20,18 @@ func ListExternalAgentPaths(dir string, root string) []string { } fullDir := path.Join(root, dir) files, err := os.ReadDir(fullDir) + if err != nil { - log.WithError(err).Warning("Cannot list external agents") + if os.IsNotExist(err) { + log.Infof("The extension's directory %q does not exist, assuming no extensions to be loaded.", fullDir) + } else { + // TODO - Should this return an error rather than ignore failing to load? + log.WithError(err).Error("Cannot list external agents") + } + return agentPaths } + for _, file := range files { if !file.IsDir() { // The returned path is absolute wrt to `root`. This allows diff --git a/lambda/appctx/appctx.go b/lambda/appctx/appctx.go index 6c81653..931a2ec 100644 --- a/lambda/appctx/appctx.go +++ b/lambda/appctx/appctx.go @@ -13,9 +13,9 @@ type Key int type InitType int const ( - // AppCtxInvokeErrorResponseKey is used for storing deferred invoke error response. + // AppCtxInvokeErrorTraceDataKey is used for storing deferred invoke error cause header value. // Only used by xray. TODO refactor xray interface so it doesn't use appctx - AppCtxInvokeErrorResponseKey Key = iota + AppCtxInvokeErrorTraceDataKey Key = iota // AppCtxRuntimeReleaseKey is used for storing runtime release information (parsed from User_Agent Http header string). AppCtxRuntimeReleaseKey @@ -23,6 +23,9 @@ const ( // AppCtxInteropServerKey is used to store a reference to the interop server. AppCtxInteropServerKey + // AppCtxResponseSenderKey is used to store a reference to the response sender + AppCtxResponseSenderKey + // AppCtxFirstFatalErrorKey is used to store first unrecoverable error message encountered to propagate it to slicer with DONE(errortype) or DONEFAIL(errortype) AppCtxFirstFatalErrorKey diff --git a/lambda/appctx/appctxutil.go b/lambda/appctx/appctxutil.go index a30677f..cd6e6d3 100644 --- a/lambda/appctx/appctxutil.go +++ b/lambda/appctx/appctxutil.go @@ -119,16 +119,16 @@ func UpdateAppCtxWithRuntimeRelease(request *http.Request, appCtx ApplicationCon return false } -// StoreErrorResponse stores response in the applicaton context. -func StoreErrorResponse(appCtx ApplicationContext, errorResponse *interop.ErrorResponse) { - appCtx.Store(AppCtxInvokeErrorResponseKey, errorResponse) +// StoreInvokeErrorTraceData stores invocation error x-ray cause header in the applicaton context. +func StoreInvokeErrorTraceData(appCtx ApplicationContext, invokeError *interop.InvokeErrorTraceData) { + appCtx.Store(AppCtxInvokeErrorTraceDataKey, invokeError) } -// LoadErrorResponse retrieves response from the application context. -func LoadErrorResponse(appCtx ApplicationContext) *interop.ErrorResponse { - v, ok := appCtx.Load(AppCtxInvokeErrorResponseKey) +// LoadInvokeErrorTraceData retrieves invocation error x-ray cause header from the application context. +func LoadInvokeErrorTraceData(appCtx ApplicationContext) *interop.InvokeErrorTraceData { + v, ok := appCtx.Load(AppCtxInvokeErrorTraceDataKey) if ok { - return v.(*interop.ErrorResponse) + return v.(*interop.InvokeErrorTraceData) } return nil } @@ -147,6 +147,20 @@ func LoadInteropServer(appCtx ApplicationContext) interop.Server { return nil } +// StoreResponseSender stores a reference to the response sender +func StoreResponseSender(appCtx ApplicationContext, server interop.InvokeResponseSender) { + appCtx.Store(AppCtxResponseSenderKey, server) +} + +// LoadResponseSender retrieves the response sender +func LoadResponseSender(appCtx ApplicationContext) interop.InvokeResponseSender { + v, ok := appCtx.Load(AppCtxResponseSenderKey) + if ok { + return v.(interop.InvokeResponseSender) + } + return nil +} + // StoreFirstFatalError stores unrecoverable error code in appctx once. This error is considered to be the rootcause of failure func StoreFirstFatalError(appCtx ApplicationContext, err fatalerror.ErrorType) { if existing := appCtx.StoreIfNotExists(AppCtxFirstFatalErrorKey, err); existing != nil { diff --git a/lambda/core/directinvoke/directinvoke.go b/lambda/core/directinvoke/directinvoke.go index 8ef59ae..3510132 100644 --- a/lambda/core/directinvoke/directinvoke.go +++ b/lambda/core/directinvoke/directinvoke.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strconv" + "strings" "github.com/go-chi/chi" "go.amzn.com/lambda/core/bandwidthlimiter" @@ -27,6 +28,7 @@ const ( CustomerHeadersHeader = "Customer-Headers" ContentTypeHeader = "Content-Type" MaxPayloadSizeHeader = "MaxPayloadSize" + InvokeResponseModeHeader = "InvokeResponseMode" ResponseBandwidthRateHeader = "ResponseBandwidthRate" ResponseBandwidthBurstSizeHeader = "ResponseBandwidthBurstSize" FunctionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" @@ -53,6 +55,10 @@ var MaxDirectResponseSize int64 = interop.MaxPayloadSize // this is intentionall var ResponseBandwidthRate int64 = interop.ResponseBandwidthRate var ResponseBandwidthBurstSize int64 = interop.ResponseBandwidthBurstSize +// InvokeResponseMode controls the context in which the invoke is. Since this was introduced +// in Streaming invokes, we default it to Buffered. +var InvokeResponseMode interop.InvokeResponseMode = interop.InvokeResponseModeBuffered + func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) { w.Header().Set(ErrorTypeHeader, errorType) w.WriteHeader(http.StatusBadRequest) @@ -65,9 +71,29 @@ func renderInternalServerError(w http.ResponseWriter, errorType string) { w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } +// convertToInvokeResponseMode converts the given string to a InvokeResponseMode +// It is case insensitive and if there is no match, an error is thrown. +func convertToInvokeResponseMode(value string) (interop.InvokeResponseMode, error) { + // buffered + if strings.EqualFold(value, string(interop.InvokeResponseModeBuffered)) { + return interop.InvokeResponseModeBuffered, nil + } + + // streaming + if strings.EqualFold(value, string(interop.InvokeResponseModeStreaming)) { + return interop.InvokeResponseModeStreaming, nil + } + + // unknown + allowedValues := strings.Join(interop.AllInvokeResponseModes, ", ") + log.Errorf("Unable to map %s to %s.", value, allowedValues) + return "", interop.ErrInvalidInvokeResponseMode +} + // ReceiveDirectInvoke parses invoke and verifies it against Token message. Uses deadline provided by Token // Renders BadRequest in case of error func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.Token) (*interop.Invoke, error) { + log.Infof("Received Invoke(invokeID: %s) Request", token.InvokeID) w.Header().Set("Trailer", EndOfResponseTrailer) custHeaders := CustomerHeaders{} @@ -89,10 +115,30 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T } } - if MaxDirectResponseSize == -1 { + if valueFromHeader := r.Header.Get(InvokeResponseModeHeader); valueFromHeader != "" { + invokeResponseMode, err := convertToInvokeResponseMode(valueFromHeader) + if err != nil { + log.Errorf( + "InvokeResponseMode header is not a valid string. Was: %#v, Allowed: %#v.", + valueFromHeader, + strings.Join(interop.AllInvokeResponseModes, ", "), + ) + renderBadRequest(w, r, err.Error()) + return nil, err + } + InvokeResponseMode = invokeResponseMode + } + + // TODO: stop using `MaxDirectResponseSize` + if isStreamingInvoke(int(MaxDirectResponseSize), InvokeResponseMode) { w.Header().Add("Trailer", FunctionErrorTypeTrailer) w.Header().Add("Trailer", FunctionErrorBodyTrailer) + // FIXME + // Until WorkerProxy stops sending MaxDirectResponseSize == -1 to identify streaming + // invokes, we need to override InvokeResponseMode to avoid setting InvokeResponseMode to buffered (default) for a streaming invoke (MaxDirectResponseSize == -1). + InvokeResponseMode = interop.InvokeResponseModeStreaming + ResponseBandwidthRate = interop.ResponseBandwidthRate if responseBandwidthRate := r.Header.Get(ResponseBandwidthRateHeader); responseBandwidthRate != "" { if n, err := strconv.ParseInt(responseBandwidthRate, 10, 64); err == nil && @@ -119,20 +165,23 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T } inv := &interop.Invoke{ - ID: r.Header.Get(InvokeIDHeader), - ReservationToken: chi.URLParam(r, "reservationtoken"), - InvokedFunctionArn: r.Header.Get(InvokedFunctionArnHeader), - VersionID: r.Header.Get(VersionIDHeader), - ContentType: r.Header.Get(ContentTypeHeader), - CognitoIdentityID: custHeaders.CognitoIdentityID, - CognitoIdentityPoolID: custHeaders.CognitoIdentityPoolID, - TraceID: token.TraceID, - LambdaSegmentID: token.LambdaSegmentID, - ClientContext: custHeaders.ClientContext, - Payload: r.Body, - DeadlineNs: fmt.Sprintf("%d", now+token.FunctionTimeout.Nanoseconds()), - NeedDebugLogs: token.NeedDebugLogs, - InvokeReceivedTime: now, + ID: r.Header.Get(InvokeIDHeader), + ReservationToken: chi.URLParam(r, "reservationtoken"), + InvokedFunctionArn: r.Header.Get(InvokedFunctionArnHeader), + VersionID: r.Header.Get(VersionIDHeader), + ContentType: r.Header.Get(ContentTypeHeader), + CognitoIdentityID: custHeaders.CognitoIdentityID, + CognitoIdentityPoolID: custHeaders.CognitoIdentityPoolID, + TraceID: token.TraceID, + LambdaSegmentID: token.LambdaSegmentID, + ClientContext: custHeaders.ClientContext, + Payload: r.Body, + DeadlineNs: fmt.Sprintf("%d", now+token.FunctionTimeout.Nanoseconds()), + NeedDebugLogs: token.NeedDebugLogs, + InvokeReceivedTime: now, + InvokeResponseMode: InvokeResponseMode, + RestoreDurationNs: token.RestoreDurationNs, + RestoreStartTimeMonotime: token.RestoreStartTimeMonotime, } if inv.ID != token.InvokeID { @@ -170,7 +219,7 @@ type CopyDoneResult struct { func getErrorTypeFromResetReason(resetReason string) fatalerror.ErrorType { errorTypeTrailer, ok := ResetReasonMap[resetReason] if !ok { - errorTypeTrailer = fatalerror.Unknown + errorTypeTrailer = fatalerror.SandboxFailure } return errorTypeTrailer } @@ -180,8 +229,11 @@ func isErrorResponse(additionalHeaders map[string]string) (isErrorResponse bool) return } -func isStreamingInvoke() bool { - return MaxDirectResponseSize == -1 +// isStreamingInvoke checks whether the invoke mode is streaming or not. +// `maxDirectResponseSize == -1` is used as it was the first check we did when we released +// streaming invokes. +func isStreamingInvoke(maxDirectResponseSize int, invokeResponseMode interop.InvokeResponseMode) bool { + return maxDirectResponseSize == -1 || invokeResponseMode == interop.InvokeResponseModeStreaming } func asyncPayloadCopy(w http.ResponseWriter, payload io.Reader) (copyDone chan CopyDoneResult, cancel context.CancelFunc, err error) { @@ -190,10 +242,34 @@ func asyncPayloadCopy(w http.ResponseWriter, payload io.Reader) (copyDone chan C if err != nil { return nil, nil, &interop.ErrInternalPlatformError{} } + go func() { // copy payload in a separate go routine - _, copyError := bandwidthlimiter.BandwidthLimitingCopy(streamedResponseWriter, payload) + // -1 size indicates the payload size is unlimited. + isPayloadsSizeRestricted := MaxDirectResponseSize != -1 + + if isPayloadsSizeRestricted { + // Setting the limit to MaxDirectResponseSize + 1 so we can do + // readBytes > MaxDirectResponseSize to check if the response is oversized. + // As the response is allowed to be of the size MaxDirectResponseSize but not larger than it. + payload = io.LimitReader(payload, MaxDirectResponseSize+1) + } + + // FIXME: inject bandwidthlimiter as a dependency, so that we can mock it in tests + copiedBytes, copyError := bandwidthlimiter.BandwidthLimitingCopy(streamedResponseWriter, payload) + + isPayloadsSizeOversized := copiedBytes > MaxDirectResponseSize + if copyError != nil { w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) + copyError = &interop.ErrTruncatedResponse{} + } else if isPayloadsSizeRestricted && isPayloadsSizeOversized { + w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) + copyError = &interop.ErrorResponseTooLargeDI{ + ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ + ResponseSize: int(copiedBytes), + MaxResponseSize: int(MaxDirectResponseSize), + }, + } } else { w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } @@ -227,8 +303,8 @@ func sendStreamingInvokeResponse(payload io.Reader, trailers http.Header, w http case copyDoneResult = <-copyDone: // copy finished errorTypeTrailer = trailers.Get(FunctionErrorTypeTrailer) errorBodyTrailer = trailers.Get(FunctionErrorBodyTrailer) - if copyDoneResult.Error != nil && errorTypeTrailer == "" { // truncated payload, error type not known - errorTypeTrailer = string(fatalerror.TruncatedResponse) + if copyDoneResult.Error != nil && errorTypeTrailer == "" { + errorTypeTrailer = string(mapCopyDoneResultErrorToErrorType(copyDoneResult.Error)) } case reset := <-interruptedResponseChan: // reset initiated cancel() @@ -247,6 +323,7 @@ func sendStreamingInvokeResponse(payload io.Reader, trailers http.Header, w http } copyDoneResult = <-copyDone reset.InvokeResponseMetrics = copyDoneResult.Metrics + reset.InvokeResponseMode = InvokeResponseMode interruptedResponseChan <- nil errorTypeTrailer = string(getErrorTypeFromResetReason(reset.Reason)) } @@ -258,11 +335,23 @@ func sendStreamingInvokeResponse(payload io.Reader, trailers http.Header, w http if copyDoneResult.Error != nil { log.Errorf("Error while streaming response payload: %s", copyDoneResult.Error) - err = &interop.ErrTruncatedResponse{} + err = copyDoneResult.Error } return } +// mapCopyDoneResultErrorToErrorType map a copyDoneResult error into a fatalerror +func mapCopyDoneResultErrorToErrorType(err interface{}) fatalerror.ErrorType { + switch err.(type) { + case *interop.ErrTruncatedResponse: + return fatalerror.TruncatedResponse + case *interop.ErrorResponseTooLargeDI: + return fatalerror.FunctionOversizedResponse + default: + return fatalerror.SandboxFailure + } +} + func sendStreamingInvokeErrorResponse(payload io.Reader, w http.ResponseWriter, interruptedResponseChan chan *interop.Reset, sendResponseChan chan *interop.InvokeResponseMetrics, runtimeCalledResponse bool) (err error) { @@ -279,6 +368,7 @@ func sendStreamingInvokeErrorResponse(payload io.Reader, w http.ResponseWriter, cancel() copyDoneResult = <-copyDone reset.InvokeResponseMetrics = copyDoneResult.Metrics + reset.InvokeResponseMode = InvokeResponseMode interruptedResponseChan <- nil } @@ -287,8 +377,9 @@ func sendStreamingInvokeErrorResponse(payload io.Reader, w http.ResponseWriter, if copyDoneResult.Error != nil { log.Errorf("Error while streaming error response payload: %s", copyDoneResult.Error) - err = &interop.ErrTruncatedResponse{} + err = copyDoneResult.Error } + return } @@ -317,7 +408,10 @@ func sendPayloadLimitedResponse(payload io.Reader, trailers http.Header, w http. } startReadingResponseMonoTimeMs := metering.Monotime() - written, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // +1 because we do allow 10MB but not 10MB + 1 byte + // Setting the limit to MaxDirectResponseSize + 1 so we can do + // readBytes > MaxDirectResponseSize to check if the response is oversized. + // As the response is allowed to be of the size MaxDirectResponseSize but not larger than it. + written, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // non-streaming invoke request but runtime is streaming: set response trailers if functionResponseMode == interop.FunctionResponseModeStreaming { @@ -325,10 +419,12 @@ func sendPayloadLimitedResponse(payload io.Reader, trailers http.Header, w http. w.Header().Set(FunctionErrorBodyTrailer, trailers.Get(FunctionErrorBodyTrailer)) } + isNotStreamingInvoke := InvokeResponseMode != interop.InvokeResponseModeStreaming + if err != nil { w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) err = &interop.ErrTruncatedResponse{} - } else if MaxDirectResponseSize != -1 && written == MaxDirectResponseSize+1 { + } else if isNotStreamingInvoke && written == MaxDirectResponseSize+1 { w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) err = &interop.ErrorResponseTooLargeDI{ ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ @@ -358,19 +454,33 @@ func sendPayloadLimitedResponse(payload io.Reader, trailers http.Header, w http. func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, trailers http.Header, w http.ResponseWriter, interruptedResponseChan chan *interop.Reset, - sendResponseChan chan *interop.InvokeResponseMetrics, request *interop.CancellableRequest, runtimeCalledResponse bool) error { + sendResponseChan chan *interop.InvokeResponseMetrics, request *interop.CancellableRequest, runtimeCalledResponse bool, invokeID string) error { for k, v := range additionalHeaders { w.Header().Add(k, v) } - if isStreamingInvoke() { // unlimited payload; response streaming mode - if isErrorResponse(additionalHeaders) { // send streamed error response when runtime called /error - return sendStreamingInvokeErrorResponse(payload, w, interruptedResponseChan, sendResponseChan, runtimeCalledResponse) + var err error + log.Infof("Started sending response (mode: %s, requestID: %s)", InvokeResponseMode, invokeID) + if InvokeResponseMode == interop.InvokeResponseModeStreaming { + // send streamed error response when runtime called /error + if isErrorResponse(additionalHeaders) { + err = sendStreamingInvokeErrorResponse(payload, w, interruptedResponseChan, sendResponseChan, runtimeCalledResponse) + if err != nil { + log.Infof("Error in sending error response (mode: %s, requestID: %s, error: %v)", InvokeResponseMode, invokeID, err) + } + return err } // send streamed response when runtime called /response - return sendStreamingInvokeResponse(payload, trailers, w, interruptedResponseChan, sendResponseChan, request, runtimeCalledResponse) + err = sendStreamingInvokeResponse(payload, trailers, w, interruptedResponseChan, sendResponseChan, request, runtimeCalledResponse) + } else { + err = sendPayloadLimitedResponse(payload, trailers, w, sendResponseChan, runtimeCalledResponse) } - return sendPayloadLimitedResponse(payload, trailers, w, sendResponseChan, runtimeCalledResponse) + if err != nil { + log.Infof("Error in sending response (mode: %s, requestID: %s, error: %v)", InvokeResponseMode, invokeID, err) + } else { + log.Infof("Completed sending response (mode: %s, requestID: %s)", InvokeResponseMode, invokeID) + } + return err } diff --git a/lambda/core/directinvoke/directinvoke_test.go b/lambda/core/directinvoke/directinvoke_test.go index 4e26161..94b6323 100644 --- a/lambda/core/directinvoke/directinvoke_test.go +++ b/lambda/core/directinvoke/directinvoke_test.go @@ -5,14 +5,24 @@ package directinvoke import ( "bytes" + "context" + "errors" + "fmt" "io" + "math" "net/http" + "net/http/httptest" + "strconv" "strings" "testing" "time" + "github.com/go-chi/chi" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" ) func NewResponseWriterWithoutFlushMethod() *ResponseWriterWithoutFlushMethod { @@ -93,24 +103,87 @@ func (r *Reader) Read(b []byte) (n int, err error) { return } -func TestSendDirectInvokeWithIncompatibleResponseWriter(t *testing.T) { - MaxDirectResponseSize = -1 - err := SendDirectInvokeResponse(nil, nil, nil, NewResponseWriterWithoutFlushMethod(), nil, nil, nil, false) - require.Error(t, err) - require.Equal(t, "ErrInternalPlatformError", err.Error()) +func TestAsyncPayloadCopyWhenPayloadSizeBelowMaxAllowed(t *testing.T) { + MaxDirectResponseSize = 2 + payloadSize := int(MaxDirectResponseSize - 1) + payloadString := strings.Repeat("a", payloadSize) + writer := NewSimpleResponseWriter() + + copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + copyDoneResult := <-copyDone + require.Nil(t, copyDoneResult.Error) + + require.Equal(t, payloadString, writer.buffer.String()) + require.Equal(t, EndOfResponseComplete, writer.Header().Get(EndOfResponseTrailer)) + + // reset it to its original value + MaxDirectResponseSize = interop.MaxPayloadSize } -func TestAsyncPayloadCopySuccess(t *testing.T) { - payloadString := strings.Repeat("a", 10*1024*1024) +func TestAsyncPayloadCopyWhenPayloadSizeEqualMaxAllowed(t *testing.T) { + MaxDirectResponseSize = 2 + payloadSize := int(MaxDirectResponseSize) + payloadString := strings.Repeat("a", payloadSize) writer := NewSimpleResponseWriter() - expectedPayloadString := payloadString + copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + copyDoneResult := <-copyDone + require.Nil(t, copyDoneResult.Error) + + require.Equal(t, payloadString, writer.buffer.String()) + require.Equal(t, EndOfResponseComplete, writer.Header().Get(EndOfResponseTrailer)) + + // reset it to its original value + MaxDirectResponseSize = interop.MaxPayloadSize +} + +func TestAsyncPayloadCopyWhenPayloadSizeAboveMaxAllowed(t *testing.T) { + MaxDirectResponseSize = 2 + payloadSize := int(MaxDirectResponseSize) + 1 + payloadString := strings.Repeat("a", payloadSize) + writer := NewSimpleResponseWriter() + expectedCopyDoneResultError := &interop.ErrorResponseTooLargeDI{ + ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ + ResponseSize: payloadSize, + MaxResponseSize: int(MaxDirectResponseSize), + }, + } copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) require.Nil(t, err) - <-copyDone - require.Equal(t, expectedPayloadString, writer.buffer.String()) + copyDoneResult := <-copyDone + require.Equal(t, expectedCopyDoneResultError, copyDoneResult.Error) + + require.Equal(t, payloadString, writer.buffer.String()) + require.Equal(t, EndOfResponseOversized, writer.Header().Get(EndOfResponseTrailer)) + + // reset it to its original value + MaxDirectResponseSize = interop.MaxPayloadSize +} + +// This is only allowed in streaming mode, currently. +func TestAsyncPayloadCopyWhenUnlimitedPayloadSizeAllowed(t *testing.T) { + MaxDirectResponseSize = -1 + payloadSize := int(interop.MaxPayloadSize + 1) + payloadString := strings.Repeat("a", payloadSize) + writer := NewSimpleResponseWriter() + + copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + copyDoneResult := <-copyDone + require.Nil(t, copyDoneResult.Error) + + require.Equal(t, payloadString, writer.buffer.String()) + require.Equal(t, EndOfResponseComplete, writer.Header().Get(EndOfResponseTrailer)) + + // reset it to its original value + MaxDirectResponseSize = interop.MaxPayloadSize } // We use an interruptable response writer which informs on a channel that it's ready to be interrupted after @@ -135,7 +208,6 @@ func TestAsyncPayloadCopySuccessAfterCancel(t *testing.T) { <-copyDone require.Equal(t, expectedPayloadString, writer.buffer.String()) } - func TestAsyncPayloadCopyWithIncompatibleResponseWriter(t *testing.T) { copyDone, cancel, err := asyncPayloadCopy(&ResponseWriterWithoutFlushMethod{}, nil) require.Nil(t, copyDone) @@ -144,6 +216,13 @@ func TestAsyncPayloadCopyWithIncompatibleResponseWriter(t *testing.T) { require.Equal(t, "ErrInternalPlatformError", err.Error()) } +// TODO: in order to implement this test we need bandwidthlimiter to be received by asyncPayloadCopy +// as an argument. Otherwise, this test will need to know how to force bandwidthlimiter to fail, +// which isn't a good practice. +func TestAsyncPayloadCopyWhenResponseIsTruncated(t *testing.T) { + t.Skip("Pending injection of bandwidthlimiter as a dependency of asyncPayloadCopy.") +} + func TestSendStreamingInvokeResponseSuccess(t *testing.T) { payloadString := strings.Repeat("a", 128*1024) // 128 KiB payload := NewReader(payloadString) @@ -289,6 +368,7 @@ func TestSendStreamingInvokeResponseReset(t *testing.T) { // Reset initiated aft interruptedTestWriterChan <- struct{}{} // inform test writer about interruption <-interruptedResponseChan // wait for copy done after interruption require.NotNil(t, reset.InvokeResponseMetrics) + require.Equal(t, interop.InvokeResponseMode("Buffered"), reset.InvokeResponseMode) <-sendResponseChan require.Equal(t, expectedPayloadString, writer.buffer.String()) @@ -298,6 +378,60 @@ func TestSendStreamingInvokeResponseReset(t *testing.T) { // Reset initiated aft <-testFinished } +// TODO: mock asyncPayloadCopy and force it to return Oversized in copyDone +func TestSendStreamingInvokeResponseOversizedRuntimesWithTrailers(t *testing.T) { + oversizedPayloadString := strings.Repeat("a", int(MaxDirectResponseSize)+1) + payload := NewReader(oversizedPayloadString) + trailers := http.Header{ + FunctionErrorTypeTrailer: []string{"RuntimesErrorType"}, + FunctionErrorBodyTrailer: []string{"RuntimesBody"}, + } + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) + require.Error(t, err) + require.IsType(t, &interop.ErrorResponseTooLargeDI{}, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, trailers.Get(FunctionErrorTypeTrailer), writer.Header().Get(FunctionErrorTypeTrailer)) + require.Equal(t, trailers.Get(FunctionErrorBodyTrailer), writer.Header().Get(FunctionErrorBodyTrailer)) + require.Equal(t, EndOfResponseOversized, writer.Header().Get(EndOfResponseTrailer)) + <-testFinished +} + +// TODO: mock asyncPayloadCopy and force it to return Oversized in copyDone +func TestSendStreamingInvokeResponseOversizedRuntimesWithoutErrorTypeTrailer(t *testing.T) { + oversizedPayloadString := strings.Repeat("a", int(MaxDirectResponseSize)+1) + payload := NewReader(oversizedPayloadString) + trailers := http.Header{ + FunctionErrorTypeTrailer: []string{""}, + FunctionErrorBodyTrailer: []string{"RuntimesErrorBody"}, + } + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) + require.Error(t, err) + require.IsType(t, &interop.ErrorResponseTooLargeDI{}, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, "Function.ResponseSizeTooLarge", writer.Header().Get(FunctionErrorTypeTrailer)) + require.Equal(t, trailers.Get(FunctionErrorBodyTrailer), writer.Header().Get(FunctionErrorBodyTrailer)) + require.Equal(t, EndOfResponseOversized, writer.Header().Get(EndOfResponseTrailer)) + <-testFinished +} + func TestSendStreamingInvokeErrorResponseSuccess(t *testing.T) { payloadString := strings.Repeat("a", 128*1024) // 128 KiB payload := NewReader(payloadString) @@ -356,3 +490,247 @@ func TestSendStreamingInvokeErrorResponseReset(t *testing.T) { // Reset initiate require.Equal(t, "Truncated", writer.Header().Get("End-Of-Response")) <-testFinished } + +func TestIsStreamingInvokeTrue(t *testing.T) { + fallbackFlag := -1 + reponseForFallback := isStreamingInvoke(fallbackFlag, interop.InvokeResponseModeBuffered) + + require.True(t, reponseForFallback) + + nonFallbackFlag := 1 + reponseForResponseMode := isStreamingInvoke(nonFallbackFlag, interop.InvokeResponseModeStreaming) + + require.True(t, reponseForResponseMode) +} + +func TestIsStreamingInvokeFalse(t *testing.T) { + nonFallbackFlag := 1 + response := isStreamingInvoke(nonFallbackFlag, interop.InvokeResponseModeBuffered) + + require.False(t, response) +} + +func TestMapCopyDoneResultErrorToErrorType(t *testing.T) { + require.Equal(t, fatalerror.TruncatedResponse, mapCopyDoneResultErrorToErrorType(&interop.ErrTruncatedResponse{})) + require.Equal(t, fatalerror.FunctionOversizedResponse, mapCopyDoneResultErrorToErrorType(&interop.ErrorResponseTooLargeDI{})) + require.Equal(t, fatalerror.SandboxFailure, mapCopyDoneResultErrorToErrorType(errors.New(""))) +} + +func TestConvertToInvokeResponseMode(t *testing.T) { + response, err := convertToInvokeResponseMode("buffered") + require.Equal(t, interop.InvokeResponseModeBuffered, response) + require.Nil(t, err) + + response, err = convertToInvokeResponseMode("streaming") + require.Equal(t, interop.InvokeResponseModeStreaming, response) + require.Nil(t, err) + + response, err = convertToInvokeResponseMode("foo-bar") + require.Equal(t, interop.InvokeResponseMode(""), response) + require.Equal(t, interop.ErrInvalidInvokeResponseMode, err) +} + +func FuzzReceiveDirectInvoke(f *testing.F) { + testCustHeaders := CustomerHeaders{ + CognitoIdentityID: "id1", + CognitoIdentityPoolID: "id2", + ClientContext: "clientcontext1", + } + custHeadersJSON := testCustHeaders.Dump() + + f.Add([]byte{'a'}, "res-token", "invokeid", "functionarn", "versionid", "contenttype", + custHeadersJSON, "1000", + "Streaming", fmt.Sprint(interop.MinResponseBandwidthRate), fmt.Sprint(interop.MinResponseBandwidthBurstSize)) + f.Add([]byte{'b'}, "res-token", "invokeid", "functionarn", "versionid", "contenttype", + custHeadersJSON, "2000", "Buffered", + "0", "0") + f.Add([]byte{'0'}, "0", "0", "0", "0", "0", + "", "", "0", + "0", "0") + + f.Fuzz(func( + t *testing.T, + payload []byte, + reservationToken string, + invokeID string, + invokedFunctionArn string, + versionID string, + contentType string, + custHeadersStr string, + maxPayloadSizeStr string, + invokeResponseModeStr string, + responseBandwidthRateStr string, + responseBandwidthBurstSizeStr string, + ) { + request := makeDirectInvokeRequest(payload, reservationToken, invokeID, + invokedFunctionArn, versionID, contentType, custHeadersStr, maxPayloadSizeStr, + invokeResponseModeStr, responseBandwidthRateStr, responseBandwidthBurstSizeStr) + + token := createDummyToken() + responseRecorder := httptest.NewRecorder() + + receivedInvoke, err := ReceiveDirectInvoke(responseRecorder, request, token) + + // default values used if header values are empty + responseMode := interop.InvokeResponseModeBuffered + maxDirectResponseSize := interop.MaxPayloadSize + + custHeaders := CustomerHeaders{} + + if err != nil { + if err = custHeaders.Load(custHeadersStr); err != nil { + assertBadRequestErrorType(t, responseRecorder, interop.ErrMalformedCustomerHeaders) + return + } + + if !isValidMaxPayloadSize(maxPayloadSizeStr) { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidMaxPayloadSize) + return + } + + n, _ := strconv.ParseInt(maxPayloadSizeStr, 10, 64) + maxDirectResponseSize = int(n) + + if invokeResponseModeStr != "" { + if responseMode, err = convertToInvokeResponseMode(invokeResponseModeStr); err != nil { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidInvokeResponseMode) + return + } + } + + if isStreamingInvoke(maxDirectResponseSize, responseMode) { + if !isValidResponseBandwidthRate(responseBandwidthRateStr) { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidResponseBandwidthRate) + return + } + + if !isValidResponseBandwidthBurstSize(responseBandwidthBurstSizeStr) { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidResponseBandwidthBurstSize) + return + } + } + + } else { + if isStreamingInvoke(maxDirectResponseSize, responseMode) { + // FIXME + // Until WorkerProxy stops sending MaxDirectResponseSize == -1 to identify streaming + // invokes, the ReceiveDirectInvoke() implementation overrides InvokeResponseMode + // to avoid setting InvokeResponseMode to buffered (default) for a streaming invoke (MaxDirectResponseSize == -1). + responseMode = interop.InvokeResponseModeStreaming + + assert.Equal(t, responseRecorder.Header().Values("Trailer"), []string{FunctionErrorTypeTrailer, FunctionErrorBodyTrailer}) + } + + if receivedInvoke.ID != token.InvokeID { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidInvokeID) + return + } + + if receivedInvoke.ReservationToken != token.ReservationToken { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidReservationToken) + return + } + + if receivedInvoke.VersionID != token.VersionID { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidFunctionVersion) + return + } + + if now := metering.Monotime(); now > token.InvackDeadlineNs { + assertBadRequestErrorType(t, responseRecorder, interop.ErrReservationExpired) + return + } + + assert.Equal(t, responseRecorder.Header().Get(VersionIDHeader), token.VersionID) + assert.Equal(t, responseRecorder.Header().Get(ReservationTokenHeader), token.ReservationToken) + assert.Equal(t, responseRecorder.Header().Get(InvokeIDHeader), token.InvokeID) + + expectedInvoke := &interop.Invoke{ + ID: invokeID, + ReservationToken: reservationToken, + InvokedFunctionArn: invokedFunctionArn, + VersionID: versionID, + ContentType: contentType, + CognitoIdentityID: custHeaders.CognitoIdentityID, + CognitoIdentityPoolID: custHeaders.CognitoIdentityPoolID, + TraceID: token.TraceID, + LambdaSegmentID: token.LambdaSegmentID, + ClientContext: custHeaders.ClientContext, + Payload: request.Body, + DeadlineNs: receivedInvoke.DeadlineNs, + NeedDebugLogs: token.NeedDebugLogs, + InvokeReceivedTime: receivedInvoke.InvokeReceivedTime, + InvokeResponseMode: responseMode, + RestoreDurationNs: token.RestoreDurationNs, + RestoreStartTimeMonotime: token.RestoreStartTimeMonotime, + } + + assert.Equal(t, expectedInvoke, receivedInvoke) + } + }) +} + +func createDummyToken() interop.Token { + return interop.Token{ + ReservationToken: "reservation_token", + TraceID: "trace_id", + InvokeID: "invoke_id", + InvackDeadlineNs: math.MaxInt64, + VersionID: "version_id", + } +} + +func assertBadRequestErrorType(t *testing.T, responseRecorder *httptest.ResponseRecorder, expectedErrType error) { + assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) + + assert.Equal(t, expectedErrType.Error(), responseRecorder.Header().Get(ErrorTypeHeader)) + assert.Equal(t, EndOfResponseComplete, responseRecorder.Header().Get(EndOfResponseTrailer)) +} + +func isValidResponseBandwidthBurstSize(sizeStr string) bool { + size, err := strconv.ParseInt(sizeStr, 10, 64) + return err == nil && + interop.MinResponseBandwidthBurstSize <= size && size <= interop.MaxResponseBandwidthBurstSize +} + +func isValidResponseBandwidthRate(rateStr string) bool { + rate, err := strconv.ParseInt(rateStr, 10, 64) + return err == nil && + interop.MinResponseBandwidthRate <= rate && rate <= interop.MaxResponseBandwidthRate +} + +func isValidMaxPayloadSize(maxPayloadSizeStr string) bool { + if maxPayloadSizeStr != "" { + maxPayloadSize, err := strconv.ParseInt(maxPayloadSizeStr, 10, 64) + return err == nil && maxPayloadSize >= -1 + } + + return true +} + +func makeDirectInvokeRequest( + payload []byte, reservationToken string, invokeID string, invokedFunctionArn string, + versionID string, contentType string, custHeadersStr string, maxPayloadSize string, + invokeResponseModeStr string, responseBandwidthRate string, responseBandwidthBurstSize string, +) *http.Request { + request := httptest.NewRequest("POST", "http://example.com/", bytes.NewReader(payload)) + request = addReservationToken(request, reservationToken) + + request.Header.Set(InvokeIDHeader, invokeID) + request.Header.Set(InvokedFunctionArnHeader, invokedFunctionArn) + request.Header.Set(VersionIDHeader, versionID) + request.Header.Set(ContentTypeHeader, contentType) + request.Header.Set(CustomerHeadersHeader, custHeadersStr) + request.Header.Set(MaxPayloadSizeHeader, maxPayloadSize) + request.Header.Set(InvokeResponseModeHeader, invokeResponseModeStr) + request.Header.Set(ResponseBandwidthRateHeader, responseBandwidthRate) + request.Header.Set(ResponseBandwidthBurstSizeHeader, responseBandwidthBurstSize) + + return request +} + +func addReservationToken(r *http.Request, reservationToken string) *http.Request { + rctx := chi.NewRouteContext() + rctx.URLParams.Add("reservationtoken", reservationToken) + return r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) +} diff --git a/lambda/core/flow.go b/lambda/core/flow.go index b2cb538..08d5e4b 100644 --- a/lambda/core/flow.go +++ b/lambda/core/flow.go @@ -3,6 +3,12 @@ package core +import ( + "context" + + "go.amzn.com/lambda/interop" +) + // InitFlowSynchronization wraps init flow barriers. type InitFlowSynchronization interface { SetExternalAgentsRegisterCount(uint16) error @@ -13,6 +19,7 @@ type InitFlowSynchronization interface { RuntimeReady() error AwaitRuntimeReady() error + AwaitRuntimeReadyWithDeadline(context.Context) error AgentReady() error AwaitAgentsReady() error @@ -47,6 +54,26 @@ func (s *initFlowSynchronizationImpl) AwaitRuntimeReady() error { return s.runtimeReadyGate.AwaitGateCondition() } +func (s *initFlowSynchronizationImpl) AwaitRuntimeReadyWithDeadline(ctx context.Context) error { + var err error + errorChan := make(chan error) + + go func() { + errorChan <- s.runtimeReadyGate.AwaitGateCondition() + }() + + select { + case err = <-errorChan: + break + case <-ctx.Done(): + err = interop.ErrRestoreHookTimeout + s.CancelWithError(err) + break + } + + return err +} + // AwaitRuntimeRestoreReady awaits runtime restore ready state (/restore/next is called by runtime) func (s *initFlowSynchronizationImpl) AwaitRuntimeRestoreReady() error { return s.runtimeRestoreReadyGate.AwaitGateCondition() diff --git a/lambda/core/registrations.go b/lambda/core/registrations.go index f68612c..26f6f2f 100644 --- a/lambda/core/registrations.go +++ b/lambda/core/registrations.go @@ -70,10 +70,12 @@ type AgentInfo struct { // FunctionMetadata holds static information regarding the function (Name, Version, Handler) type FunctionMetadata struct { - FunctionName string - FunctionVersion string - Handler string - RuntimeInfo interop.RuntimeInfo + AccountID string + FunctionName string + FunctionVersion string + InstanceMaxMemory uint64 + Handler string + RuntimeInfo interop.RuntimeInfo } // RegistrationService keeps track of registered parties, including external agents, threads, and runtime. diff --git a/lambda/core/runtime_state_names.go b/lambda/core/runtime_state_names.go index b04ba5d..4a2184d 100644 --- a/lambda/core/runtime_state_names.go +++ b/lambda/core/runtime_state_names.go @@ -16,4 +16,5 @@ const ( RuntimeInvocationResponseStateName = "InvocationResponse" RuntimeInvocationErrorResponseStateName = "InvocationErrorResponse" RuntimeResponseSentStateName = "RuntimeResponseSentState" + RuntimeRestoreErrorStateName = "RuntimeRestoreErrorState" ) diff --git a/lambda/core/statejson/description.go b/lambda/core/statejson/description.go index eb46946..a614d20 100644 --- a/lambda/core/statejson/description.go +++ b/lambda/core/statejson/description.go @@ -5,9 +5,24 @@ package statejson import ( "encoding/json" + log "github.com/sirupsen/logrus" ) +// ResponseMode are top-level constants used in combination with the various types of +// modes we have for responses, such as invoke's response mode and function's response mode. +// In the future we might have invoke's request mode or similar, so these help set the ground +// for consistency. +type ResponseMode string + +const ResponseModeBuffered = "Buffered" +const ResponseModeStreaming = "Streaming" + +type InvokeResponseMode string + +const InvokeResponseModeBuffered InvokeResponseMode = ResponseModeBuffered +const InvokeResponseModeStreaming InvokeResponseMode = ResponseModeStreaming + // StateDescription ... type StateDescription struct { Name string `json:"name"` @@ -35,9 +50,24 @@ type InternalStateDescription struct { FirstFatalError string `json:"firstFatalError"` } +type ResponseMetricsDimensions struct { + InvokeResponseMode InvokeResponseMode `json:"invokeResponseMode"` +} + +type ResponseMetrics struct { + RuntimeResponseLatencyMs float64 `json:"runtimeResponseLatencyMs"` + Dimensions ResponseMetricsDimensions `json:"dimensions"` +} + +type ReleaseResponse struct { + *InternalStateDescription + ResponseMetrics ResponseMetrics `json:"responseMetrics"` +} + // ResetDescription describes fields of the response to an INVOKE API request type ResetDescription struct { - ExtensionsResetMs int64 `json:"extensionsResetMs"` + ExtensionsResetMs int64 `json:"extensionsResetMs"` + ResponseMetrics ResponseMetrics `json:"responseMetrics"` } func (s *InternalStateDescription) AsJSON() []byte { @@ -55,3 +85,11 @@ func (s *ResetDescription) AsJSON() []byte { } return bytes } + +func (s *ReleaseResponse) AsJSON() []byte { + bytes, err := json.Marshal(s) + if err != nil { + log.Panicf("Failed to marshall release response: %s", err) + } + return bytes +} diff --git a/lambda/core/states.go b/lambda/core/states.go index a5e2010..0de88ec 100644 --- a/lambda/core/states.go +++ b/lambda/core/states.go @@ -9,6 +9,7 @@ import ( "time" "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/interop" ) // Suspendable on operator condition. @@ -76,6 +77,7 @@ type RuntimeState interface { InvocationResponse() error InvocationErrorResponse() error ResponseSent() error + RestoreError(interop.FunctionError) error Name() string } @@ -87,6 +89,9 @@ func (s *disallowEveryTransitionByDefault) RestoreReady() error { ret func (s *disallowEveryTransitionByDefault) InvocationResponse() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) InvocationErrorResponse() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) ResponseSent() error { return ErrNotAllowed } +func (s *disallowEveryTransitionByDefault) RestoreError(interop.FunctionError) error { + return ErrNotAllowed +} // Runtime is runtime object. type Runtime struct { @@ -105,6 +110,7 @@ type Runtime struct { RuntimeInvocationResponseState RuntimeState RuntimeInvocationErrorResponseState RuntimeState RuntimeResponseSentState RuntimeState + RuntimeRestoreErrorState RuntimeState } // Release ... @@ -176,6 +182,12 @@ func (s *Runtime) ResponseSent() error { return err } +func (s *Runtime) RestoreError(UserError interop.FunctionError) error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.RestoreError(UserError) +} + // GetRuntimeDescription returns runtime description object for debugging purposes func (s *Runtime) GetRuntimeDescription() statejson.RuntimeDescription { s.ManagedThread.Lock() @@ -207,6 +219,7 @@ func NewRuntime(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchroni runtime.RuntimeResponseSentState = &RuntimeResponseSentState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeRestoreReadyState = &RuntimeRestoreReadyState{} runtime.RuntimeRestoringState = &RuntimeRestoringState{runtime: runtime, initFlow: initFlow} + runtime.RuntimeRestoreErrorState = &RuntimeRestoreErrorState{runtime: runtime, initFlow: initFlow} runtime.setStateUnsafe(runtime.RuntimeStartedState) return runtime @@ -292,9 +305,9 @@ func (s *RuntimeRestoringState) Ready() error { return nil } -// Runtime has thrown an exception when executing restore hooks and called /init/error -func (s *RuntimeRestoringState) InitError() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) +func (s *RuntimeRestoringState) RestoreError(userError interop.FunctionError) error { + s.runtime.setStateUnsafe(s.runtime.RuntimeRestoreErrorState) + s.initFlow.CancelWithError(interop.ErrRestoreHookUserError{UserError: userError}) return nil } @@ -436,3 +449,13 @@ func (s *RuntimeResponseSentState) Ready() error { func (s *RuntimeResponseSentState) Name() string { return RuntimeResponseSentStateName } + +type RuntimeRestoreErrorState struct { + disallowEveryTransitionByDefault + runtime *Runtime + initFlow InitFlowSynchronization +} + +func (s *RuntimeRestoreErrorState) Name() string { + return RuntimeRestoreErrorStateName +} diff --git a/lambda/core/states_test.go b/lambda/core/states_test.go index 37f38e2..b6d2955 100644 --- a/lambda/core/states_test.go +++ b/lambda/core/states_test.go @@ -4,11 +4,14 @@ package core import ( + "context" + "sync" + "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/testdata/mockthread" - "sync" - "testing" ) func TestRuntimeInitErrorAfterReady(t *testing.T) { @@ -96,6 +99,34 @@ func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) } +func TestRuntimeStateTransitionsFromRestoreErrorState(t *testing.T) { + runtime := newRuntime() + // RestoreError -> InitError + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) + // RestoreError -> Ready + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) + // RestoreError -> RestoreReady + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) + // RestoreError -> ResponseSent + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) + // RestoreError -> InvocationResponse + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) + // RestoreError -> InvocationErrorResponse + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) +} + func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { runtime := newRuntime() // Ready -> InitError @@ -266,11 +297,9 @@ func TestRuntimeStateTransitionsFromRestoreReadyState(t *testing.T) { } func TestRuntimeStateTransitionsFromRestoringState(t *testing.T) { - runtime := newRuntime() - // RestoreRunning -> InitError + runtime, mockInitFlow, _ := newRuntimeGetMockFlows() runtime.SetState(runtime.RuntimeRestoringState) - assert.NoError(t, runtime.InitError()) - assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) + mockInitFlow.On("CancelWithError", interop.ErrRestoreHookUserError{UserError: interop.FunctionError{}}).Return() // RestoreRunning -> Ready runtime.SetState(runtime.RuntimeRestoringState) assert.NoError(t, runtime.Ready()) @@ -291,6 +320,10 @@ func TestRuntimeStateTransitionsFromRestoringState(t *testing.T) { runtime.SetState(runtime.RuntimeRestoringState) assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> RestoreError + runtime.SetState(runtime.RuntimeRestoringState) + assert.NoError(t, runtime.RestoreError(interop.FunctionError{})) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) } func newRuntime() *Runtime { @@ -302,6 +335,15 @@ func newRuntime() *Runtime { return runtime } +func newRuntimeGetMockFlows() (*Runtime, *mockInitFlowSynchronization, *mockInvokeFlowSynchronization) { + initFlow := &mockInitFlowSynchronization{} + invokeFlow := &mockInvokeFlowSynchronization{} + runtime := NewRuntime(initFlow, invokeFlow) + runtime.ManagedThread = &mockthread.MockManagedThread{} + + return runtime, initFlow, invokeFlow +} + type mockInitFlowSynchronization struct { mock.Mock ReadyCond *sync.Cond @@ -325,6 +367,9 @@ func (s *mockInitFlowSynchronization) ExternalAgentRegistered() error { func (s *mockInitFlowSynchronization) AwaitRuntimeReady() error { return nil } +func (s *mockInitFlowSynchronization) AwaitRuntimeReadyWithDeadline(ctx context.Context) error { + return nil +} func (s *mockInitFlowSynchronization) AwaitAgentsReady() error { return nil } diff --git a/lambda/extensions/extensions.go b/lambda/extensions/extensions.go index b55dc51..abe0c87 100644 --- a/lambda/extensions/extensions.go +++ b/lambda/extensions/extensions.go @@ -4,7 +4,14 @@ package extensions import ( + "os" "sync/atomic" + + log "github.com/sirupsen/logrus" +) + +const ( + disableExtensionsFile = "/opt/disable-extensions-jwigqn8j" ) var enabled atomic.Value @@ -27,3 +34,11 @@ func AreEnabled() bool { } return val.(bool) } + +func DisableViaMagicLayer() { + _, err := os.Stat(disableExtensionsFile) + if err == nil { + log.Infof("Extensions disabled by attached layer (%s)", disableExtensionsFile) + Disable() + } +} diff --git a/lambda/fatalerror/fatalerror.go b/lambda/fatalerror/fatalerror.go index bb8a86a..665627d 100644 --- a/lambda/fatalerror/fatalerror.go +++ b/lambda/fatalerror/fatalerror.go @@ -3,23 +3,69 @@ package fatalerror +import ( + "regexp" + "strings" +) + // This package defines constant error types returned to slicer with DONE(failure), and also sandbox errors // Separate package for namespacing // ErrorType is returned to slicer inside DONE type ErrorType string +// TODO: Find another name than "fatalerror" +// TODO: Rename all const so that they always begin with Agent/Runtime/Sandbox/Function +// TODO: Add filtering for extensions as well const ( - AgentInitError ErrorType = "Extension.InitError" // agent exited after calling /extension/init/error - AgentExitError ErrorType = "Extension.ExitError" // agent exited after calling /extension/exit/error - AgentCrash ErrorType = "Extension.Crash" // agent crashed unexpectedly - AgentLaunchError ErrorType = "Extension.LaunchError" // agent could not be launched - RuntimeExit ErrorType = "Runtime.ExitError" - InvalidEntrypoint ErrorType = "Runtime.InvalidEntrypoint" - InvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir" - InvalidTaskConfig ErrorType = "Runtime.InvalidTaskConfig" - TruncatedResponse ErrorType = "Runtime.TruncatedResponse" - SandboxFailure ErrorType = "Sandbox.Failure" - SandboxTimeout ErrorType = "Sandbox.Timeout" - Unknown ErrorType = "Unknown" + // Extension errors + AgentInitError ErrorType = "Extension.InitError" // agent exited after calling /extension/init/error + AgentExitError ErrorType = "Extension.ExitError" // agent exited after calling /extension/exit/error + AgentCrash ErrorType = "Extension.Crash" // agent crashed unexpectedly + AgentLaunchError ErrorType = "Extension.LaunchError" // agent could not be launched + + // Runtime errors + RuntimeExit ErrorType = "Runtime.ExitError" + InvalidEntrypoint ErrorType = "Runtime.InvalidEntrypoint" + InvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir" + InvalidTaskConfig ErrorType = "Runtime.InvalidTaskConfig" + TruncatedResponse ErrorType = "Runtime.TruncatedResponse" + RuntimeInvalidResponseModeHeader ErrorType = "Runtime.InvalidResponseModeHeader" + RuntimeUnknown ErrorType = "Runtime.Unknown" + + // Function errors + FunctionOversizedResponse ErrorType = "Function.ResponseSizeTooLarge" + FunctionUnknown ErrorType = "Function.Unknown" + + // Sandbox errors + SandboxFailure ErrorType = "Sandbox.Failure" + SandboxTimeout ErrorType = "Sandbox.Timeout" ) + +var validRuntimeAndFunctionErrors = map[ErrorType]struct{}{ + // Runtime errors + RuntimeExit: {}, + InvalidEntrypoint: {}, + InvalidWorkingDir: {}, + InvalidTaskConfig: {}, + TruncatedResponse: {}, + RuntimeInvalidResponseModeHeader: {}, + RuntimeUnknown: {}, + + // Function errors + FunctionOversizedResponse: {}, + FunctionUnknown: {}, +} + +func GetValidRuntimeOrFunctionErrorType(errorType string) ErrorType { + match, _ := regexp.MatchString("(Runtime|Function)\\.[A-Z][a-zA-Z]+", errorType) + if match { + return ErrorType(errorType) + } + + if strings.HasPrefix(errorType, "Function.") { + return FunctionUnknown + } + + return RuntimeUnknown +} diff --git a/lambda/fatalerror/fatalerror_test.go b/lambda/fatalerror/fatalerror_test.go new file mode 100644 index 0000000..72c34aa --- /dev/null +++ b/lambda/fatalerror/fatalerror_test.go @@ -0,0 +1,51 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package fatalerror + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidRuntimeAndFunctionErrors(t *testing.T) { + type test struct { + input string + expected ErrorType + } + + var tests = []test{} + for validError := range validRuntimeAndFunctionErrors { + tests = append(tests, test{input: string(validError), expected: validError}) + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, GetValidRuntimeOrFunctionErrorType(tt.input), tt.expected) + }) + } +} + +func TestGetValidRuntimeOrFunctionErrorType(t *testing.T) { + type test struct { + input string + expected ErrorType + } + + var tests = []test{ + {"", RuntimeUnknown}, + {"MyCustomError", RuntimeUnknown}, + {"MyCustomError.Error", RuntimeUnknown}, + {"Runtime.MyCustomErrorTypeHere", ErrorType("Runtime.MyCustomErrorTypeHere")}, + {"Function.MyCustomErrorTypeHere", ErrorType("Function.MyCustomErrorTypeHere")}, + } + + for _, tt := range tests { + testname := fmt.Sprintf("TestGetValidRuntimeOrFunctionErrorType with %s", tt.input) + t.Run(testname, func(t *testing.T) { + assert.Equal(t, GetValidRuntimeOrFunctionErrorType(tt.input), tt.expected) + }) + } +} diff --git a/lambda/interop/bootstrap.go b/lambda/interop/bootstrap.go index 4a9b6af..d3f4500 100644 --- a/lambda/interop/bootstrap.go +++ b/lambda/interop/bootstrap.go @@ -7,12 +7,13 @@ import ( "os" "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/rapidcore/env" ) type Bootstrap interface { - Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable - Env(e EnvironmentVariables) map[string]string // returns the environment variables to be passed to the bootstrapped process - Cwd() (string, error) // returns the working directory of the bootstrap process - ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime + Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable + Env(e *env.Environment) map[string]string // returns the environment variables to be passed to the bootstrapped process + Cwd() (string, error) // returns the working directory of the bootstrap process + ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime CachedFatalError(err error) (fatalerror.ErrorType, string, bool) } diff --git a/lambda/interop/environment_variables.go b/lambda/interop/environment_variables.go deleted file mode 100644 index 46bdf8b..0000000 --- a/lambda/interop/environment_variables.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package interop - -type EnvironmentVariables interface { - AgentExecEnv() map[string]string - RuntimeExecEnv() map[string]string - SetHandler(handler string) - StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) - StoreEnvironmentVariablesFromInit(customerEnv map[string]string, - handler, awsKey, awsSecret, awsSession, funcName, funcVer string) - StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) -} diff --git a/lambda/interop/events_api.go b/lambda/interop/events_api.go new file mode 100644 index 0000000..a0e9967 --- /dev/null +++ b/lambda/interop/events_api.go @@ -0,0 +1,193 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "fmt" + + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/rapi/model" +) + +type InitPhase string + +// InitializationType describes possible types of INIT phase +type InitType string + +type InitStartData struct { + InitializationType InitType `json:"initializationType"` + RuntimeVersion string `json:"runtimeVersion"` + RuntimeVersionArn string `json:"runtimeVersionArn"` + FunctionName string `json:"functionName"` + FunctionArn string `json:"functionArn"` + FunctionVersion string `json:"functionVersion"` + InstanceID string `json:"instanceId"` + InstanceMaxMemory uint64 `json:"instanceMaxMemory"` + Phase InitPhase `json:"phase"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InitStartData) String() string { + return fmt.Sprintf("INIT START(type: %s, phase: %s)", d.InitializationType, d.Phase) +} + +type InitRuntimeDoneData struct { + InitializationType InitType `json:"initializationType"` + Status string `json:"status"` + Phase InitPhase `json:"phase"` + ErrorType *string `json:"errorType,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InitRuntimeDoneData) String() string { + return fmt.Sprintf("INIT RTDONE(status: %s)", d.Status) +} + +type InitReportMetrics struct { + DurationMs float64 `json:"durationMs"` +} + +type InitReportData struct { + InitializationType InitType `json:"initializationType"` + Metrics InitReportMetrics `json:"metrics"` + Phase InitPhase `json:"phase"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InitReportData) String() string { + return fmt.Sprintf("INIT REPORT(durationMs: %f)", d.Metrics.DurationMs) +} + +type RestoreRuntimeDoneData struct { + Status string `json:"status"` + ErrorType *string `json:"errorType,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *RestoreRuntimeDoneData) String() string { + return fmt.Sprintf("RESTORE RTDONE(status: %s)", d.Status) +} + +type TracingCtx struct { + SpanID string `json:"spanId,omitempty"` + Type model.TracingType `json:"type"` + Value string `json:"value"` +} + +type InvokeStartData struct { + RequestID string `json:"requestId"` + Version string `json:"version,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InvokeStartData) String() string { + return fmt.Sprintf("INVOKE START(requestId: %s)", d.RequestID) +} + +type RuntimeDoneInvokeMetrics struct { + ProducedBytes int64 `json:"producedBytes"` + DurationMs float64 `json:"durationMs"` +} + +type Span struct { + Name string `json:"name"` + Start string `json:"start"` + DurationMs float64 `json:"durationMs"` +} + +func (s *Span) String() string { + return fmt.Sprintf("SPAN(name: %s)", s.Name) +} + +type InvokeRuntimeDoneData struct { + RequestID RequestID `json:"requestId"` + Status string `json:"status"` + Metrics *RuntimeDoneInvokeMetrics `json:"metrics,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` + Spans []Span `json:"spans,omitempty"` + ErrorType *string `json:"errorType,omitempty"` + InternalMetrics *InvokeResponseMetrics `json:"-"` +} + +func (d *InvokeRuntimeDoneData) String() string { + return fmt.Sprintf("INVOKE RTDONE(status: %s, produced bytes: %d, duration: %fms)", d.Status, d.Metrics.ProducedBytes, d.Metrics.DurationMs) +} + +type ExtensionInitData struct { + AgentName string `json:"name"` + State string `json:"state"` + Subscriptions []string `json:"events"` + ErrorType string `json:"errorType,omitempty"` +} + +func (d *ExtensionInitData) String() string { + return fmt.Sprintf("EXTENSION INIT(agent name: %s, state: %s, error type: %s)", d.AgentName, d.State, d.ErrorType) +} + +type ReportMetrics struct { + DurationMs float64 `json:"durationMs"` + BilledDurationMs float64 `json:"billedDurationMs"` + MemorySizeMB uint64 `json:"memorySizeMB"` + MaxMemoryUsedMB uint64 `json:"maxMemoryUsedMB"` + InitDurationMs float64 `json:"initDurationMs,omitempty"` +} + +type ReportData struct { + RequestID RequestID `json:"requestId"` + Status string `json:"status"` + Metrics ReportMetrics `json:"metrics"` + Tracing *TracingCtx `json:"tracing,omitempty"` + Spans []Span `json:"spans,omitempty"` + ErrorType *string `json:"errorType,omitempty"` +} + +func (d *ReportData) String() string { + return fmt.Sprintf("REPORT(status: %s, durationMs: %f)", d.Status, d.Metrics.DurationMs) +} + +type EndData struct { + RequestID RequestID `json:"requestId"` +} + +func (d *EndData) String() string { + return "END" +} + +type RequestID string + +type FaultData struct { + RequestID RequestID + ErrorMessage error + ErrorType fatalerror.ErrorType +} + +func (d *FaultData) String() string { + return fmt.Sprintf("RequestId: %s Error: %s\n%s\n", d.RequestID, d.ErrorMessage, d.ErrorType) +} + +type ImageErrorLogData string + +type EventsAPI interface { + SetCurrentRequestID(RequestID) + SendInitStart(InitStartData) error + SendInitRuntimeDone(InitRuntimeDoneData) error + SendInitReport(InitReportData) error + SendRestoreRuntimeDone(RestoreRuntimeDoneData) error + SendInvokeStart(InvokeStartData) error + SendInvokeRuntimeDone(InvokeRuntimeDoneData) error + SendExtensionInit(ExtensionInitData) error + SendReportSpan(Span) error + SendReport(ReportData) error + SendEnd(EndData) error + SendFault(FaultData) error + SendImageErrorLog(ImageErrorLogData) + + FetchTailLogs(string) (string, error) + GetRuntimeDoneSpans( + runtimeStartedTime int64, + invokeResponseMetrics *InvokeResponseMetrics, + runtimeOverheadStartedTime int64, + runtimeReadyTime int64, + ) []Span +} diff --git a/lambda/interop/events_api_test.go b/lambda/interop/events_api_test.go new file mode 100644 index 0000000..d3a7dc1 --- /dev/null +++ b/lambda/interop/events_api_test.go @@ -0,0 +1,656 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/rapi/model" +) + +const requestID RequestID = "REQUEST_ID" + +func TestJsonMarshalInvokeRuntimeDone(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(52.56), + }, + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ], + "metrics": { + "producedBytes": 100, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneNoTracing(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(52.56), + }, + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ], + "metrics": { + "producedBytes": 100, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneNoMetrics(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "success", + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ] + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneWithProducedBytesEqualToZero(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(52.56), + }, + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ], + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneWithNoSpans(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 100, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneTimeout(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "timeout", + Metrics: &RuntimeDoneInvokeMetrics{ + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "timeout", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneFailure(t *testing.T) { + errorType := "Runtime.ExitError" + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "failure", + ErrorType: &errorType, + Metrics: &RuntimeDoneInvokeMetrics{ + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "failure", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + }, + "errorType": "Runtime.ExitError" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneWithEmptyErrorType(t *testing.T) { + errorType := "" + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "failure", + ErrorType: &errorType, + Metrics: &RuntimeDoneInvokeMetrics{ + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "failure", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + }, + "errorType": "" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitRuntimeDoneSuccess(t *testing.T) { + var errorType *string + data := InitRuntimeDoneData{ + InitializationType: "snap-start", + Phase: "init", + Status: "success", + ErrorType: errorType, + } + + expected := ` + { + "initializationType": "snap-start", + "phase": "init", + "status": "success" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitRuntimeDoneError(t *testing.T) { + errorType := "Runtime.ExitError" + data := InitRuntimeDoneData{ + InitializationType: "snap-start", + Phase: "init", + Status: "error", + ErrorType: &errorType, + } + + expected := ` + { + "initializationType": "snap-start", + "phase": "init", + "status": "error", + "errorType": "Runtime.ExitError" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitRuntimeDoneFailureWithEmptyErrorType(t *testing.T) { + errorType := "" + data := InitRuntimeDoneData{ + InitializationType: "snap-start", + Phase: "init", + Status: "error", + ErrorType: &errorType, + } + + expected := ` + { + "initializationType": "snap-start", + "phase": "init", + "status": "error", + "errorType": "" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalRestoreRuntimeDoneSuccess(t *testing.T) { + var errorType *string + data := RestoreRuntimeDoneData{ + Status: "success", + ErrorType: errorType, + } + + expected := ` + { + "status": "success" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalRestoreRuntimeDoneError(t *testing.T) { + errorType := "Runtime.ExitError" + data := RestoreRuntimeDoneData{ + Status: "error", + ErrorType: &errorType, + } + + expected := ` + { + "status": "error", + "errorType": "Runtime.ExitError" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalRestoreRuntimeDoneErrorWithEmptyErrorType(t *testing.T) { + errorType := "" + data := RestoreRuntimeDoneData{ + Status: "error", + ErrorType: &errorType, + } + + expected := ` + { + "status": "error", + "errorType": "" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalExtensionInit(t *testing.T) { + data := ExtensionInitData{ + AgentName: "agentName", + State: "Registered", + ErrorType: "", + Subscriptions: []string{"INVOKE", "SHUTDOWN"}, + } + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, `{"name":"agentName","state":"Registered","events":["INVOKE","SHUTDOWN"]}`, string(actual)) +} + +func TestJsonMarshalExtensionInitWithError(t *testing.T) { + data := ExtensionInitData{ + AgentName: "agentName", + State: "Registered", + ErrorType: "Extension.FooBar", + Subscriptions: []string{"INVOKE", "SHUTDOWN"}, + } + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, `{"name":"agentName","state":"Registered","events":["INVOKE","SHUTDOWN"],"errorType":"Extension.FooBar"}`, string(actual)) +} + +func TestJsonMarshalExtensionInitEmptyEvents(t *testing.T) { + data := ExtensionInitData{ + AgentName: "agentName", + State: "Registered", + ErrorType: "Extension.FooBar", + Subscriptions: []string{}, + } + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, `{"name":"agentName","state":"Registered","events":[],"errorType":"Extension.FooBar"}`, string(actual)) +} + +func TestJsonMarshalReportWithTracing(t *testing.T) { + errorType := "Runtime.ExitError" + data := ReportData{ + RequestID: requestID, + Status: "error", + ErrorType: &errorType, + Metrics: ReportMetrics{ + DurationMs: float64(52.56), + BilledDurationMs: float64(52.40), + MemorySizeMB: uint64(1024), + MaxMemoryUsedMB: uint64(512), + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "error", + "errorType": "Runtime.ExitError", + "metrics": { + "durationMs": 52.56, + "billedDurationMs": 52.40, + "memorySizeMB": 1024, + "maxMemoryUsedMB": 512 + }, + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + } + } + ` + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalReportWithoutErrorSpansAndTracing(t *testing.T) { + data := ReportData{ + RequestID: requestID, + Status: "timeout", + Metrics: ReportMetrics{ + DurationMs: float64(52.56), + BilledDurationMs: float64(52.40), + MemorySizeMB: uint64(1024), + MaxMemoryUsedMB: uint64(512), + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "timeout", + "metrics": { + "durationMs": 52.56, + "billedDurationMs": 52.40, + "memorySizeMB": 1024, + "maxMemoryUsedMB": 512 + } + } + ` + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalReportWithInit(t *testing.T) { + data := ReportData{ + RequestID: requestID, + Status: "success", + Metrics: ReportMetrics{ + DurationMs: float64(52.56), + BilledDurationMs: float64(52.40), + MemorySizeMB: uint64(1024), + MaxMemoryUsedMB: uint64(512), + InitDurationMs: float64(3.15), + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "metrics": { + "durationMs": 52.56, + "billedDurationMs": 52.40, + "memorySizeMB": 1024, + "maxMemoryUsedMB": 512, + "initDurationMs": 3.15 + } + } + ` + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, expected, string(actual)) +} diff --git a/lambda/interop/messages.go b/lambda/interop/messages.go new file mode 100644 index 0000000..ee1c783 --- /dev/null +++ b/lambda/interop/messages.go @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +// conversion from internal data structure into well defined messages + +func DoneFromInvokeSuccess(successMsg InvokeSuccess) *Done { + return &Done{ + Meta: DoneMetadata{ + RuntimeRelease: successMsg.RuntimeRelease, + NumActiveExtensions: successMsg.NumActiveExtensions, + ExtensionNames: successMsg.ExtensionNames, + InvokeRequestReadTimeNs: successMsg.InvokeMetrics.InvokeRequestReadTimeNs, + InvokeRequestSizeBytes: successMsg.InvokeMetrics.InvokeRequestSizeBytes, + RuntimeReadyTime: successMsg.InvokeMetrics.RuntimeReadyTime, + + InvokeCompletionTimeNs: successMsg.InvokeCompletionTimeNs, + InvokeReceivedTime: successMsg.InvokeReceivedTime, + RuntimeResponseLatencyMs: successMsg.ResponseMetrics.RuntimeResponseLatencyMs, + RuntimeTimeThrottledMs: successMsg.ResponseMetrics.RuntimeTimeThrottledMs, + RuntimeProducedBytes: successMsg.ResponseMetrics.RuntimeProducedBytes, + RuntimeOutboundThroughputBps: successMsg.ResponseMetrics.RuntimeOutboundThroughputBps, + LogsAPIMetrics: successMsg.LogsAPIMetrics, + MetricsDimensions: DoneMetadataMetricsDimensions{ + InvokeResponseMode: successMsg.InvokeResponseMode, + }, + }, + } +} + +func DoneFailFromInvokeFailure(failureMsg *InvokeFailure) *DoneFail { + return &DoneFail{ + ErrorType: failureMsg.ErrorType, + Meta: DoneMetadata{ + RuntimeRelease: failureMsg.RuntimeRelease, + NumActiveExtensions: failureMsg.NumActiveExtensions, + InvokeReceivedTime: failureMsg.InvokeReceivedTime, + + RuntimeResponseLatencyMs: failureMsg.ResponseMetrics.RuntimeResponseLatencyMs, + RuntimeTimeThrottledMs: failureMsg.ResponseMetrics.RuntimeTimeThrottledMs, + RuntimeProducedBytes: failureMsg.ResponseMetrics.RuntimeProducedBytes, + RuntimeOutboundThroughputBps: failureMsg.ResponseMetrics.RuntimeOutboundThroughputBps, + + InvokeRequestReadTimeNs: failureMsg.InvokeMetrics.InvokeRequestReadTimeNs, + InvokeRequestSizeBytes: failureMsg.InvokeMetrics.InvokeRequestSizeBytes, + RuntimeReadyTime: failureMsg.InvokeMetrics.RuntimeReadyTime, + + ExtensionNames: failureMsg.ExtensionNames, + LogsAPIMetrics: failureMsg.LogsAPIMetrics, + + MetricsDimensions: DoneMetadataMetricsDimensions{ + InvokeResponseMode: failureMsg.InvokeResponseMode, + }, + }, + } +} + +func DoneFailFromInitFailure(initFailure *InitFailure) *DoneFail { + return &DoneFail{ + ErrorType: initFailure.ErrorType, + Meta: DoneMetadata{ + RuntimeRelease: initFailure.RuntimeRelease, + NumActiveExtensions: initFailure.NumActiveExtensions, + LogsAPIMetrics: initFailure.LogsAPIMetrics, + }, + } +} diff --git a/lambda/interop/model.go b/lambda/interop/model.go index cc9c7d0..a4bdbf4 100644 --- a/lambda/interop/model.go +++ b/lambda/interop/model.go @@ -5,9 +5,9 @@ package interop import ( "encoding/json" + "errors" "fmt" "io" - "net/http" "strings" "time" @@ -32,8 +32,6 @@ const ( MaxResponseBandwidthBurstSize = 64 * 1024 * 1024 // 64 MiB ) -const functionResponseSizeTooLargeType = "Function.ResponseSizeTooLarge" - // ResponseMode are top-level constants used in combination with the various types of // modes we have for responses, such as invoke's response mode and function's response mode. // In the future we might have invoke's request mode or similar, so these help set the ground @@ -52,25 +50,6 @@ var AllInvokeResponseModes = []string{ string(InvokeResponseModeBuffered), string(InvokeResponseModeStreaming), } -// ConvertToInvokeResponseMode converts the given string to a InvokeResponseMode -// It is case insensitive and if there is no match, an error is thrown. -func ConvertToInvokeResponseMode(value string) (InvokeResponseMode, error) { - // buffered - if strings.EqualFold(value, string(InvokeResponseModeBuffered)) { - return InvokeResponseModeBuffered, nil - } - - // streaming - if strings.EqualFold(value, string(InvokeResponseModeStreaming)) { - return InvokeResponseModeStreaming, nil - } - - // unknown - allowedValues := strings.Join(AllInvokeResponseModes, ", ") - log.Errorf("Unlable to map %s to %s.", value, allowedValues) - return "", ErrInvalidInvokeResponseMode -} - // FunctionResponseMode is passed by Runtime to tell whether the response should be // streamed or not. type FunctionResponseMode string @@ -82,6 +61,7 @@ var AllFunctionResponseModes = []string{ string(FunctionResponseModeBuffered), string(FunctionResponseModeStreaming), } +// TODO: move to directinvoke.go as we're trying to deprecate interop.* package // ConvertToFunctionResponseMode converts the given string to a FunctionResponseMode // It is case insensitive and if there is no match, an error is thrown. func ConvertToFunctionResponseMode(value string) (FunctionResponseMode, error) { @@ -108,57 +88,79 @@ type Message interface{} type Invoke struct { // Tracing header. // https://docs.aws.amazon.com/xray/latest/devguide/xray-concepts.html#xray-concepts-tracingheader - TraceID string - LambdaSegmentID string - ID string - InvokedFunctionArn string - CognitoIdentityID string - CognitoIdentityPoolID string - DeadlineNs string - ClientContext string - ContentType string - Payload io.Reader - NeedDebugLogs bool - ReservationToken string - VersionID string - InvokeReceivedTime int64 - InvokeResponseMetrics *InvokeResponseMetrics + TraceID string + LambdaSegmentID string + ID string + InvokedFunctionArn string + CognitoIdentityID string + CognitoIdentityPoolID string + DeadlineNs string + ClientContext string + ContentType string + Payload io.Reader + NeedDebugLogs bool + ReservationToken string + VersionID string + InvokeReceivedTime int64 + InvokeResponseMetrics *InvokeResponseMetrics + InvokeResponseMode InvokeResponseMode + RestoreDurationNs int64 // equals 0 for non-snapstart functions + RestoreStartTimeMonotime int64 // equals 0 for non-snapstart functions } type Token struct { - ReservationToken string - InvokeID string - VersionID string - FunctionTimeout time.Duration - InvackDeadlineNs int64 - TraceID string - LambdaSegmentID string - InvokeMetadata string - NeedDebugLogs bool + ReservationToken string + InvokeID string + VersionID string + FunctionTimeout time.Duration + InvackDeadlineNs int64 + TraceID string + LambdaSegmentID string + InvokeMetadata string + NeedDebugLogs bool + RestoreDurationNs int64 + RestoreStartTimeMonotime int64 } -type ErrorResponse struct { - // Payload sent via shared memory. - Payload []byte `json:"Payload,omitempty"` - ContentType string `json:"-"` - FunctionResponseMode string `json:"-"` - - // When error response body (Payload) is not provided, e.g. - // not retrievable, error type and error message will be - // used by the Slicer to construct a response json, e.g: - // - // default error response produced by the Slicer: - // '{"errorMessage":"Unknown application error occurred"}', - // - // when error type is provided, error response becomes: - // '{"errorMessage":"Unknown application error occurred","errorType":"ErrorType"}' - ErrorType string `json:"errorType,omitempty"` - ErrorMessage string `json:"errorMessage,omitempty"` - +// InvokeErrorTraceData is used by the tracer to mark segments as being invocation error +type InvokeErrorTraceData struct { // Attached to invoke segment ErrorCause json.RawMessage `json:"ErrorCause,omitempty"` } +func GetErrorResponseWithFormattedErrorMessage(errorType fatalerror.ErrorType, err error, invokeRequestID string) *ErrorInvokeResponse { + var errorMessage string + if invokeRequestID != "" { + errorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeRequestID, err) + } else { + errorMessage = fmt.Sprintf("Error: %v", err) + } + + jsonPayload, err := json.Marshal(FunctionError{ + Type: errorType, + Message: errorMessage, + }) + + if err != nil { + return &ErrorInvokeResponse{ + Headers: InvokeResponseHeaders{}, + FunctionError: FunctionError{ + Type: fatalerror.SandboxFailure, + Message: errorMessage, + }, + Payload: []byte{}, + } + } + + headers := InvokeResponseHeaders{} + functionError := FunctionError{ + Type: errorType, + Message: errorMessage, + } + + return &ErrorInvokeResponse{Headers: headers, FunctionError: functionError, Payload: jsonPayload} +} + // SandboxType identifies sandbox type (PreWarmed vs Classic) type SandboxType string @@ -178,7 +180,7 @@ type DynamicDomainConfig struct { // extra hooks to execute at domain start. Currently used for filesystem and network hooks. // It can be empty. AdditionalStartHooks []model.Hook - Mounts []model.DriveMount + Mounts []model.Mount //TODO: other dynamic configurations for the domain go here } @@ -189,14 +191,17 @@ type Reset struct { InvokeResponseMetrics *InvokeResponseMetrics TraceID string LambdaSegmentID string + InvokeResponseMode InvokeResponseMode } // Restore message is sent to rapid to restore runtime to make it ready for consecutive invokes type Restore struct { - AwsKey string - AwsSecret string - AwsSession string - CredentialsExpiry time.Time + AwsKey string + AwsSecret string + AwsSession string + CredentialsExpiry time.Time + RestoreHookTimeoutMs int64 + LogStreamName string } type Resync struct { @@ -224,7 +229,10 @@ func MergeSubscriptionMetrics(logsAPIMetrics TelemetrySubscriptionMetrics, telem // InvokeResponseMetrics are produced while sending streaming invoke response to WP type InvokeResponseMetrics struct { - StartReadingResponseMonoTimeMs int64 + // FIXME: this assumes a value in nanoseconds, let's rename it + // to StartReadingResponseMonoTimeNs + StartReadingResponseMonoTimeMs int64 + // Same as the one above FinishReadingResponseMonoTimeMs int64 TimeShapedNs int64 ProducedBytes int64 @@ -240,6 +248,22 @@ func IsResponseStreamingMetrics(metrics *InvokeResponseMetrics) bool { return metrics.FunctionResponseMode == FunctionResponseModeStreaming } +type DoneMetadataMetricsDimensions struct { + InvokeResponseMode InvokeResponseMode +} + +func (dimensions DoneMetadataMetricsDimensions) String() string { + var stringDimensions []string + + if dimensions.InvokeResponseMode != "" { + dimension := string("invoke_response_mode=" + dimensions.InvokeResponseMode) + stringDimensions = append(stringDimensions, dimension) + } + return strings.ToLower( + strings.Join(stringDimensions, ","), + ) +} + type DoneMetadata struct { NumActiveExtensions int ExtensionsResetMs int64 @@ -252,9 +276,11 @@ type DoneMetadata struct { InvokeCompletionTimeNs int64 InvokeReceivedTime int64 RuntimeReadyTime int64 + RuntimeResponseLatencyMs float64 RuntimeTimeThrottledMs int64 RuntimeProducedBytes int64 RuntimeOutboundThroughputBps int64 + MetricsDimensions DoneMetadataMetricsDimensions } type Done struct { @@ -332,19 +358,18 @@ func (s *ErrorResponseTooLarge) Error() string { return fmt.Sprintf("Response payload size (%d bytes) exceeded maximum allowed payload size (%d bytes).", s.ResponseSize, s.MaxResponseSize) } -// AsErrorResponse generates ErrorResponse from ErrorResponseTooLarge -func (s *ErrorResponseTooLarge) AsInteropError() *ErrorResponse { - resp := ErrorResponse{ - ErrorType: functionResponseSizeTooLargeType, - ErrorMessage: s.Error(), +// AsErrorResponse generates ErrorInvokeResponse from ErrorResponseTooLarge +func (s *ErrorResponseTooLarge) AsErrorResponse() *ErrorInvokeResponse { + functionError := FunctionError{ + Type: fatalerror.FunctionOversizedResponse, + Message: s.Error(), } - respJSON, err := json.Marshal(resp) + jsonPayload, err := json.Marshal(functionError) if err != nil { - panic("Failed to marshal interop.ErrorResponse") + panic("Failed to marshal interop.FunctionError") } - resp.Payload = respJSON - resp.ContentType = "application/json" - return &resp + headers := InvokeResponseHeaders{ContentType: "application/json"} + return &ErrorInvokeResponse{Headers: headers, FunctionError: functionError, Payload: jsonPayload} } // Server used for sending messages and sharing data between the Runtime API handlers and the @@ -356,21 +381,6 @@ func (s *ErrorResponseTooLarge) AsInteropError() *ErrorResponse { // protocol used by the specific implementation // TODO: rename this to InvokeResponseContext, used to send responses from handlers to platform-facing server type Server interface { - // SendResponse sends response. - // Errors returned: - // ErrInvalidInvokeID - validation error indicating that provided invokeID doesn't match current invokeID - // ErrResponseSent - validation error indicating that response with given invokeID was already sent - // Non-nil error - non-nil error indicating transport failure - SendResponse(invokeID string, headers map[string]string, response io.Reader, trailers http.Header, request *CancellableRequest) error - - // SendErrorResponse sends error response. - // Errors returned: - // ErrInvalidInvokeID - validation error indicating that provided invokeID doesn't match current invokeID - // ErrResponseSent - validation error indicating that response with given invokeID was already sent - // Non-nil error - non-nil error indicating transport failure - SendErrorResponse(invokeID string, response *ErrorResponse) error - SendInitErrorResponse(invokeID string, response *ErrorResponse) error - // GetCurrentInvokeID returns current invokeID. // NOTE, in case of INIT, when invokeID is not known in advance (e.g. provisioned concurrency), // returned invokeID will contain empty value. @@ -381,24 +391,40 @@ type Server interface { // from the time when all extensions have called /next. // TODO: this method is a lifecycle event used only for metrics, and doesn't belong here SendRuntimeReady() error + + // SendInitErrorResponse does two separate things when init/error is called: + // a) sends the init error response if called during invoke, and + // b) notifies platform of a user fault if called, during both init or invoke + // TODO: + // separate the two concerns & unify with SendErrorResponse in response sender + SendInitErrorResponse(response *ErrorInvokeResponse) error } type InternalStateGetter func() statejson.InternalStateDescription -const OnDemandInitTelemetrySource string = "on-demand" -const ProvisionedConcurrencyInitTelemetrySource string = "provisioned-concurrency" -const InitCachingInitTelemetrySource string = "snap-start" +// ErrRestoreHookTimeout is returned as a response to `RESTORE` message +// when function's restore hook takes more time to execute thatn +// the timeout value. +var ErrRestoreHookTimeout = errors.New("Runtime.RestoreHookUserTimeout") -func InferTelemetryInitSource(initCachingEnabled bool, sandboxType SandboxType) string { - initSource := OnDemandInitTelemetrySource - - // ToDo: Unify this selection of SandboxType by using the START message - // after having a roadmap on the combination of INIT modes - if initCachingEnabled { - initSource = InitCachingInitTelemetrySource - } else if sandboxType == SandboxPreWarmed { - initSource = ProvisionedConcurrencyInitTelemetrySource - } +// ErrRestoreHookUserError is returned as a response to `RESTORE` message +// when function's restore hook faces with an error on throws an exception. +// UserError contains the error type that the runtime encountered. +type ErrRestoreHookUserError struct { + UserError FunctionError +} - return initSource +func (err ErrRestoreHookUserError) Error() string { + return "errRestoreHookUserError" } + +// ErrRestoreUpdateCredentials is returned as a response to `RESTORE` message +// if RAPID cannot update the credentials served by credentials API +// during the RESTORE phase. +var ErrRestoreUpdateCredentials = errors.New("errRestoreUpdateCredentials") + +var ErrCannotParseCredentialsExpiry = errors.New("errCannotParseCredentialsExpiry") + +var ErrCannotParseRestoreHookTimeoutMs = errors.New("errCannotParseRestoreHookTimeoutMs") + +var ErrMissingRestoreCredentials = errors.New("errMissingRestoreCredentials") diff --git a/lambda/interop/model_test.go b/lambda/interop/model_test.go index 9ad4d17..d9ba36a 100644 --- a/lambda/interop/model_test.go +++ b/lambda/interop/model_test.go @@ -4,8 +4,11 @@ package interop import ( + "fmt" "testing" + "go.amzn.com/lambda/fatalerror" + "github.com/stretchr/testify/assert" ) @@ -25,3 +28,39 @@ func TestMergeSubscriptionMetrics(t *testing.T) { assert.Equal(t, 2, metrics["server_error"]) assert.Equal(t, 2, metrics["client_error"]) } + +func TestGetErrorResponseWithFormattedErrorMessageWithoutInvokeRequestId(t *testing.T) { + errorType := fatalerror.RuntimeExit + errorMessage := fmt.Errorf("Divided by 0") + expectedMsg := fmt.Sprintf(`Error: %s`, errorMessage) + expectedJSON := fmt.Sprintf(`{"errorType": "%s", "errorMessage": "%s"}`, string(errorType), expectedMsg) + + actual := GetErrorResponseWithFormattedErrorMessage(errorType, errorMessage, "") + assert.Equal(t, errorType, actual.FunctionError.Type) + assert.Equal(t, expectedMsg, actual.FunctionError.Message) + assert.JSONEq(t, expectedJSON, string(actual.Payload)) +} + +func TestGetErrorResponseWithFormattedErrorMessageWithInvokeRequestId(t *testing.T) { + errorType := fatalerror.RuntimeExit + errorMessage := fmt.Errorf("Divided by 0") + invokeID := "invoke-id" + expectedMsg := fmt.Sprintf(`RequestId: %s Error: %s`, invokeID, errorMessage) + expectedJSON := fmt.Sprintf(`{"errorType": "%s", "errorMessage": "%s"}`, string(errorType), expectedMsg) + + actual := GetErrorResponseWithFormattedErrorMessage(errorType, errorMessage, invokeID) + assert.Equal(t, errorType, actual.FunctionError.Type) + assert.Equal(t, expectedMsg, actual.FunctionError.Message) + assert.JSONEq(t, expectedJSON, string(actual.Payload)) +} + +func TestDoneMetadataMetricsDimensionsStringWhenInvokeResponseModeIsPresent(t *testing.T) { + dimensions := DoneMetadataMetricsDimensions{ + InvokeResponseMode: InvokeResponseModeStreaming, + } + assert.Equal(t, "invoke_response_mode=streaming", dimensions.String()) +} +func TestDoneMetadataMetricsDimensionsStringWhenEmpty(t *testing.T) { + dimensions := DoneMetadataMetricsDimensions{} + assert.Equal(t, "", dimensions.String()) +} diff --git a/lambda/interop/sandbox_model.go b/lambda/interop/sandbox_model.go index b5d15b0..3011c48 100644 --- a/lambda/interop/sandbox_model.go +++ b/lambda/interop/sandbox_model.go @@ -4,9 +4,13 @@ package interop import ( + "bytes" + "io" + "net/http" "time" "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/rapidcore/env" ) // Init represents an init message @@ -15,6 +19,7 @@ import ( type Init struct { InvokeID string Handler string + AccountID string AwsKey string AwsSecret string AwsSession string @@ -28,23 +33,13 @@ type Init struct { // In standalone mode, these env vars come from test/init but from environment otherwise. CustomerEnvironmentVariables map[string]string SandboxType SandboxType - // there is no dynamic config at the moment for the runtime domain - OperatorDomainExtraConfig DynamicDomainConfig - RuntimeInfo RuntimeInfo - Bootstrap Bootstrap - EnvironmentVariables EnvironmentVariables // contains env vars for agents and runtime procs -} - -// InitStarted contains metadata about the initialized sandbox -// In Rapid Shim, this translates to a RUNNING GirD message to Slicer -// In Rapid Daemon, this is followed by a SANDBOX GirP message to MM -type InitStarted struct { - WaitStartTimeNs int64 - WaitEndTimeNs int64 - PreLoadTimeNs int64 - PostLoadTimeNs int64 - ExtensionsEnabled bool - Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent + LogStreamName string + InstanceMaxMemory uint64 + OperatorDomainExtraConfig DynamicDomainConfig + RuntimeDomainExtraConfig DynamicDomainConfig + RuntimeInfo RuntimeInfo + Bootstrap Bootstrap + EnvironmentVariables *env.Environment // contains env vars for agents and runtime procs } // InitSuccess indicates that runtime/extensions initialization completed successfully @@ -72,11 +67,61 @@ type InitFailure struct { Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent } +// ErrorInvokeResponse represents a buffered response received via Runtime API +// for error responses. When body (Payload) is not provided, e.g. +// not retrievable, error type and error message headers will be +// used by the platform to construct a response json, e.g: +// +// default error response produced by the Slicer: +// '{"errorMessage":"Unknown application error occurred"}', +// +// when error type is provided, error response becomes: +// '{"errorMessage":"Unknown application error occurred","errorType":"ErrorType"}' +type ErrorInvokeResponse struct { + Headers InvokeResponseHeaders + Payload []byte + FunctionError FunctionError +} + +// StreamableInvokeResponse represents a response received via Runtime API that can be streamed +type StreamableInvokeResponse struct { + Headers map[string]string + Payload io.Reader + Trailers http.Header + Request *CancellableRequest // streaming request may need to gracefully terminate request streams +} + +// InvokeResponseHeaders contains the headers received via Runtime API /invocation/response +type InvokeResponseHeaders struct { + ContentType string + FunctionResponseMode string +} + +// FunctionError represents information about function errors or 'user errors' +// These are not platform errors and hence are returned as 200 by Lambda +// In the absence of a response payload, the Function Error is serialized and sent +type FunctionError struct { + // Type of error is derived from the Lambda-Runtime-Function-Error-Type set by the Runtime + // This is customer data, so RAPID scrubs this error type to contain only allowlisted values + Type fatalerror.ErrorType `json:"errorType,omitempty"` + // ErrorMessage is generated by RAPID and can never be specified by runtime + Message string `json:"errorMessage,omitempty"` +} + +type InvokeResponseSender interface { + // SendResponse sends invocation response received from Runtime to platform + // This is response may be streamed based on function and invoke response mode + SendResponse(invokeID string, response *StreamableInvokeResponse) error + // SendErrorResponse sends error response in the case of function errors, which are always buffered + SendErrorResponse(invokeID string, response *ErrorInvokeResponse) error +} + // ResponseMetrics groups metrics related to the response stream type ResponseMetrics struct { - RuntimeTimeThrottledMs int64 - RuntimeProducedBytes int64 RuntimeOutboundThroughputBps int64 + RuntimeProducedBytes int64 + RuntimeResponseLatencyMs float64 + RuntimeTimeThrottledMs int64 } // InvokeMetrics groups metrics related to the invoke phase @@ -96,6 +141,7 @@ type InvokeSuccess struct { LogsAPIMetrics TelemetrySubscriptionMetrics ResponseMetrics ResponseMetrics InvokeMetrics InvokeMetrics + InvokeResponseMode InvokeResponseMode } // InvokeFailure is the failure response to invoke phase end @@ -111,21 +157,24 @@ type InvokeFailure struct { ResponseMetrics ResponseMetrics InvokeMetrics InvokeMetrics ExtensionNames string - DefaultErrorResponse *ErrorResponse // error resp constructed by platform during fn errors + DefaultErrorResponse *ErrorInvokeResponse // error resp constructed by platform during fn errors + InvokeResponseMode InvokeResponseMode } // ResetSuccess is the success response to reset request type ResetSuccess struct { - ExtensionsResetMs int64 - ErrorType fatalerror.ErrorType - ResponseMetrics ResponseMetrics + ExtensionsResetMs int64 + ErrorType fatalerror.ErrorType + ResponseMetrics ResponseMetrics + InvokeResponseMode InvokeResponseMode } // ResetFailure is the failure response to reset request type ResetFailure struct { - ExtensionsResetMs int64 - ErrorType fatalerror.ErrorType - ResponseMetrics ResponseMetrics + ExtensionsResetMs int64 + ErrorType fatalerror.ErrorType + ResponseMetrics ResponseMetrics + InvokeResponseMode InvokeResponseMode } // ShutdownSuccess is the response to a shutdown request @@ -136,35 +185,46 @@ type ShutdownSuccess struct { // SandboxInfoFromInit captures data from init request that // is required during invoke (e.g. for suppressed init) type SandboxInfoFromInit struct { - EnvironmentVariables EnvironmentVariables // contains agent env vars (creds, customer, platform) - SandboxType SandboxType // indicating Pre-Warmed, On-Demand etc - RuntimeBootstrap Bootstrap // contains the runtime bootstrap binary path, Cwd, Args, Env, Cmd + EnvironmentVariables *env.Environment // contains agent env vars (creds, customer, platform) + SandboxType SandboxType // indicating Pre-Warmed, On-Demand etc + RuntimeBootstrap Bootstrap // contains the runtime bootstrap binary path, Cwd, Args, Env, Cmd +} + +// RestoreResult represents the result of `HandleRestore` function +// in RapidCore +type RestoreResult struct { + RestoreMs int64 } // RapidContext expose methods for functionality of the Rapid Core library type RapidContext interface { - HandleInit(i *Init, started chan<- InitStarted, success chan<- InitSuccess, failure chan<- InitFailure) - HandleInvoke(i *Invoke, sbMetadata SandboxInfoFromInit) (InvokeSuccess, *InvokeFailure) - HandleReset(reset *Reset, invokeReceivedTime int64, InvokeResponseMetrics *InvokeResponseMetrics) (ResetSuccess, *ResetFailure) + HandleInit(i *Init, success chan<- InitSuccess, failure chan<- InitFailure) + HandleInvoke(i *Invoke, sbMetadata SandboxInfoFromInit, requestBuf *bytes.Buffer, responseSender InvokeResponseSender) (InvokeSuccess, *InvokeFailure) + HandleReset(reset *Reset) (ResetSuccess, *ResetFailure) HandleShutdown(shutdown *Shutdown) ShutdownSuccess - HandleRestore(restore *Restore) error + HandleRestore(restore *Restore) (RestoreResult, error) Clear() + + SetRuntimeStartedTime(runtimeStartedTime int64) + SetInvokeResponseMetrics(metrics *InvokeResponseMetrics) + + SetEventsAPI(eventsAPI EventsAPI) } // SandboxContext represents the sandbox lifecycle context type SandboxContext interface { - Init(i *Init, timeoutMs int64) (InitStarted, InitContext) + Init(i *Init, timeoutMs int64) InitContext Reset(reset *Reset) (ResetSuccess, *ResetFailure) Shutdown(shutdown *Shutdown) ShutdownSuccess - Restore(restore *Restore) error + Restore(restore *Restore) (RestoreResult, error) // TODO: refactor this - // invokeReceivedTime and InvokeResponseMetrics are needed to compute the runtimeDone metrics + // runtimeStartedTime and InvokeResponseMetrics are needed to compute the runtimeDone metrics // in case of a Reset during an invoke (reset.reason=failure or reset.reason=timeout). // Ideally: - // - the InvokeContext will have a Reset method to deal with Reset during an invoke and will hold invokeReceivedTime and InvokeResponseMetrics + // - the InvokeContext will have a Reset method to deal with Reset during an invoke and will hold runtimeStartedTime and InvokeResponseMetrics // - the SandboxContext will have its own Reset/Spindown method - SetInvokeReceivedTime(invokeReceivedTime int64) + SetRuntimeStartedTime(invokeReceivedTime int64) SetInvokeResponseMetrics(metrics *InvokeResponseMetrics) } @@ -176,10 +236,14 @@ type InitContext interface { // InvokeContext represents the lifecycle of a sandbox reservation type InvokeContext interface { - SendRequest(i *Invoke) + SendRequest(i *Invoke, r InvokeResponseSender) Wait() (InvokeSuccess, *InvokeFailure) } -// Restored message is sent to Slicer to inform Runtime Restore Hook execution was successful -type Restored struct { -} +// LifecyclePhase represents enum for possible Sandbox lifecycle phases, like init, invoke, etc. +type LifecyclePhase int + +const ( + LifecyclePhaseInit LifecyclePhase = iota + 1 + LifecyclePhaseInvoke +) diff --git a/lambda/metering/time.go b/lambda/metering/time.go index cf3ad1d..9e0fa01 100644 --- a/lambda/metering/time.go +++ b/lambda/metering/time.go @@ -12,15 +12,19 @@ import ( //go:linkname Monotime runtime.nanotime func Monotime() int64 -// MonoToEpoch converts monotonic time nanos to epoch time nanos. +// MonoToEpoch converts monotonic time nanos to unix epoch time nanos. func MonoToEpoch(t int64) int64 { monoNsec := Monotime() wallNsec := time.Now().UnixNano() - clockOffset := wallNsec - monoNsec return t + clockOffset } +func TimeToMono(t time.Time) int64 { + durNs := time.Since(t).Nanoseconds() + return Monotime() - durNs +} + type ExtensionsResetDurationProfiler struct { NumAgentsRegisteredForShutdown int AvailableNs int64 diff --git a/lambda/metering/time_test.go b/lambda/metering/time_test.go index 0088f9f..5c37a87 100644 --- a/lambda/metering/time_test.go +++ b/lambda/metering/time_test.go @@ -19,6 +19,14 @@ func TestMonoToEpochPrecision(t *testing.T) { assert.True(t, math.Abs(float64(a-b)) < float64(time.Millisecond)) } +func TestEpochToMonoPrecision(t *testing.T) { + a := Monotime() + b := TimeToMono(time.Now()) + + // Conversion error is less than a millisecond. + assert.Less(t, math.Abs(float64(b-a)), float64(1*time.Millisecond)) +} + func TestExtensionsResetDurationProfilerForExtensionsResetWithNoExtensions(t *testing.T) { mono := Monotime() profiler := ExtensionsResetDurationProfiler{} diff --git a/lambda/rapi/extensions_fuzz_test.go b/lambda/rapi/extensions_fuzz_test.go new file mode 100644 index 0000000..c223859 --- /dev/null +++ b/lambda/rapi/extensions_fuzz_test.go @@ -0,0 +1,344 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/extensions" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/rapi/handler" + "go.amzn.com/lambda/rapi/model" + "go.amzn.com/lambda/rapi/rendering" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" +) + +func FuzzAgentRegisterHandler(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + registerReq := handler.RegisterRequest{ + Events: []core.Event{core.InvokeEvent, core.ShutdownEvent}, + } + regReqBytes, err := json.Marshal(®isterReq) + if err != nil { + f.Errorf("failed to marshal register request: %v", err) + } + f.Add("agent", "accountId", true, regReqBytes) + f.Add("agent", "accountId", false, regReqBytes) + + f.Fuzz(func(t *testing.T, + agentName string, + featuresHeader string, + external bool, + payload []byte, + ) { + flowTest := testdata.NewFlowTest() + + if external { + flowTest.RegistrationService.CreateExternalAgent(agentName) + } + + functionMetadata := createDummyFunctionMetadata() + flowTest.RegistrationService.SetFunctionMetadata(functionMetadata) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL("/extension/register", version20200101) + request := httptest.NewRequest("POST", target, bytes.NewReader(payload)) + request.Header.Add(handler.LambdaAgentName, agentName) + request.Header.Add("Lambda-Extension-Accept-Feature", featuresHeader) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentName == "" { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionName") + return + } + + regReqStruct := struct { + handler.RegisterRequest + ConfigurationKeys []string `json:"configurationKeys"` + }{} + if err := json.Unmarshal(payload, ®ReqStruct); err != nil { + assertForbiddenErrorType(t, responseRecorder, "InvalidRequestFormat") + return + } + + if containsInvalidEvent(external, regReqStruct.Events) { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidEventType") + return + } + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + respBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + expectedResponse := map[string]interface{}{ + "functionName": functionMetadata.FunctionName, + "functionVersion": functionMetadata.FunctionVersion, + "handler": functionMetadata.Handler, + } + if featuresHeader == "accountId" && functionMetadata.AccountID != "" { + expectedResponse["accountId"] = functionMetadata.AccountID + } + + expectedRespBytes, err := json.Marshal(expectedResponse) + assert.NoError(t, err) + assert.JSONEq(t, string(expectedRespBytes), string(respBody)) + + if external { + agent, found := flowTest.RegistrationService.FindExternalAgentByName(agentName) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) + } else { + agent, found := flowTest.RegistrationService.FindInternalAgentByName(agentName) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) + } + }) +} + +func FuzzAgentNextHandler(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + regService := core.NewRegistrationService(core.NewInitFlowSynchronization(), core.NewInvokeFlowSynchronization()) + testAgent := makeExternalAgent(regService) + f.Add(testAgent.ID.String(), true, true) + f.Add(testAgent.ID.String(), true, false) + + f.Fuzz(func(t *testing.T, + agentIdentifierHeader string, + registered bool, + isInvokeEvent bool, + ) { + flowTest := testdata.NewFlowTest() + agent := makeExternalAgent(flowTest.RegistrationService) + + if registered { + agent.SetState(agent.RegisteredState) + agent.Release() + } + + configureRendererForEvent(flowTest, isInvokeEvent) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL("/extension/event/next", version20200101) + request := httptest.NewRequest("GET", target, nil) + request.Header.Set(handler.LambdaAgentIdentifier, agentIdentifierHeader) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentIdentifierHeader == "" { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierMissing) + return + } + if _, err := uuid.Parse(agentIdentifierHeader); err != nil { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierInvalid) + return + } + if agentIdentifierHeader != agent.ID.String() { + assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") + return + } + if !registered { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionState") + return + } + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + assertResponseEventType(t, isInvokeEvent, responseRecorder) + + assert.Equal(t, agent.RunningState, agent.GetState()) + }) +} + +func FuzzAgentInitErrorHandler(f *testing.F) { + fuzzErrorHandler(f, "/extension/init/error", fatalerror.AgentInitError) +} + +func FuzzAgentExitErrorHandler(f *testing.F) { + fuzzErrorHandler(f, "/extension/exit/error", fatalerror.AgentExitError) +} + +func fuzzErrorHandler(f *testing.F, handlerPath string, fatalErrorType fatalerror.ErrorType) { + extensions.Enable() + defer extensions.Disable() + + regService := core.NewRegistrationService(core.NewInitFlowSynchronization(), core.NewInvokeFlowSynchronization()) + testAgent := makeExternalAgent(regService) + f.Add(true, testAgent.ID.String(), "Extension.SomeError") + f.Add(false, testAgent.ID.String(), "Extension.SomeError") + + f.Fuzz(func(t *testing.T, + agentRegistered bool, + agentIdentifierHeader string, + errorType string, + ) { + flowTest := testdata.NewFlowTest() + + agent := makeExternalAgent(flowTest.RegistrationService) + + if agentRegistered { + agent.SetState(agent.RegisteredState) + } + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL(handlerPath, version20200101) + + request := httptest.NewRequest("POST", target, nil) + request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) + request.Header.Set(handler.LambdaAgentIdentifier, agentIdentifierHeader) + request.Header.Set(handler.LambdaAgentFunctionErrorType, errorType) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentIdentifierHeader == "" { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierMissing) + return + } + + if _, e := uuid.Parse(agentIdentifierHeader); e != nil { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierInvalid) + return + } + + if errorType == "" { + assertForbiddenErrorType(t, responseRecorder, "Extension.MissingHeader") + return + } + if agentIdentifierHeader != agent.ID.String() { + assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") + return + } + if !agentRegistered { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionState") + } else { + assertErrorAgentRegistered(t, responseRecorder, flowTest, fatalErrorType) + } + }) +} + +func assertErrorAgentRegistered(t *testing.T, responseRecorder *httptest.ResponseRecorder, flowTest *testdata.FlowTest, expectedErrType fatalerror.ErrorType) { + var response model.StatusResponse + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &response) + assert.NoError(t, err) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.Equal(t, "OK", response.Status) + + v, found := appctx.LoadFirstFatalError(flowTest.AppCtx) + assert.True(t, found) + assert.Equal(t, expectedErrType, v) +} + +func assertForbiddenErrorType(t *testing.T, responseRecorder *httptest.ResponseRecorder, errType string) { + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &errorResponse) + assert.NoError(t, err) + + assert.Equal(t, errType, errorResponse.ErrorType) +} + +func createDummyFunctionMetadata() core.FunctionMetadata { + return core.FunctionMetadata{ + AccountID: "accID", + FunctionName: "myFunc", + FunctionVersion: "1.0", + Handler: "myHandler", + } +} + +func makeExternalAgent(registrationService core.RegistrationService) *core.ExternalAgent { + agent, err := registrationService.CreateExternalAgent("agent") + if err != nil { + log.Fatalf("failed to create external agent: %v", err) + return nil + } + + return agent +} + +func configureRendererForEvent(flowTest *testdata.FlowTest, isInvokeEvent bool) { + if isInvokeEvent { + invoke := createDummyInvoke() + + var buf bytes.Buffer + flowTest.RenderingService.SetRenderer( + rendering.NewInvokeRenderer( + context.Background(), + invoke, + &buf, + telemetry.NewNoOpTracer().BuildTracingHeader(), + )) + } else { + flowTest.RenderingService.SetRenderer( + &rendering.ShutdownRenderer{ + AgentEvent: model.AgentShutdownEvent{ + AgentEvent: &model.AgentEvent{ + EventType: "SHUTDOWN", + DeadlineMs: int64(10000), + }, + ShutdownReason: "spindown", + }, + }) + } +} + +func assertResponseEventType(t *testing.T, isInvokeEvent bool, responseRecorder *httptest.ResponseRecorder) { + if isInvokeEvent { + var response model.AgentInvokeEvent + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &response) + assert.NoError(t, err) + + assert.Equal(t, "INVOKE", response.AgentEvent.EventType) + } else { + var response model.AgentShutdownEvent + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &response) + assert.NoError(t, err) + + assert.Equal(t, "SHUTDOWN", response.AgentEvent.EventType) + } +} + +func containsInvalidEvent(external bool, events []core.Event) bool { + for _, e := range events { + if external { + if err := core.ValidateExternalAgentEvent(e); err != nil { + return true + } + } else if err := core.ValidateInternalAgentEvent(e); err != nil { + return true + } + } + + return false +} diff --git a/lambda/rapi/handler/agentnext_test.go b/lambda/rapi/handler/agentnext_test.go index 003c4b6..417633e 100644 --- a/lambda/rapi/handler/agentnext_test.go +++ b/lambda/rapi/handler/agentnext_test.go @@ -4,6 +4,7 @@ package handler import ( + "bytes" "context" "encoding/json" "fmt" @@ -108,7 +109,8 @@ func TestRenderAgentInvokeNextHappy(t *testing.T) { } renderingService := rendering.NewRenderingService() - renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, telemetry.GetCustomerTracingHeader)) + var buf bytes.Buffer + renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, telemetry.NewNoOpTracer().BuildTracingHeader())) handler := NewAgentNextHandler(registrationService, renderingService) request := httptest.NewRequest("GET", "/", nil) @@ -157,7 +159,8 @@ func TestRenderAgentInternalInvokeNextHappy(t *testing.T) { } renderingService := rendering.NewRenderingService() - renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, telemetry.GetCustomerTracingHeader)) + var buf bytes.Buffer + renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, telemetry.NewNoOpTracer().BuildTracingHeader())) handler := NewAgentNextHandler(registrationService, renderingService) request := httptest.NewRequest("GET", "/", nil) @@ -287,7 +290,8 @@ func TestRenderAgentInvokeNextHappyEmptyTraceID(t *testing.T) { } renderingService := rendering.NewRenderingService() - renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, telemetry.GetCustomerTracingHeader)) + var buf bytes.Buffer + renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, telemetry.NewNoOpTracer().BuildTracingHeader())) handler := NewAgentNextHandler(registrationService, renderingService) request := httptest.NewRequest("GET", "/", nil) diff --git a/lambda/rapi/handler/agentregister.go b/lambda/rapi/handler/agentregister.go index 8882965..8da9e4c 100644 --- a/lambda/rapi/handler/agentregister.go +++ b/lambda/rapi/handler/agentregister.go @@ -8,6 +8,7 @@ import ( "errors" "io" "net/http" + "strings" log "github.com/sirupsen/logrus" "go.amzn.com/lambda/core" @@ -24,6 +25,20 @@ type RegisterRequest struct { Events []core.Event `json:"events"` } +const featuresHeader = "Lambda-Extension-Accept-Feature" + +type registrationFeature int + +const ( + accountFeature registrationFeature = iota + 1 +) + +var allowedFeatures = map[string]registrationFeature{ + "accountId": accountFeature, +} + +type responseModifier func(*model.ExtensionRegisterResponse) + func parseRegister(request *http.Request) (*RegisterRequest, error) { body, err := io.ReadAll(request.Body) if err != nil { @@ -53,6 +68,13 @@ func (h *agentRegisterHandler) ServeHTTP(writer http.ResponseWriter, request *ht return } + var responseModifiers []responseModifier + for _, f := range parseRegistrationFeatures(request) { + if f == accountFeature { + responseModifiers = append(responseModifiers, h.respondWithAccountID()) + } + } + registerRequest, err := parseRegister(request) if err != nil { rendering.RenderForbiddenWithTypeMsg(writer, request, errInvalidRequestFormat, err.Error()) @@ -60,32 +82,65 @@ func (h *agentRegisterHandler) ServeHTTP(writer http.ResponseWriter, request *ht } agent, found := h.registrationService.FindExternalAgentByName(agentName) - if found { - h.registerExternalAgent(agent, registerRequest, writer, request) + h.registerExternalAgent(agent, registerRequest, writer, request, responseModifiers...) } else { - h.registerInternalAgent(agentName, registerRequest, writer, request) + h.registerInternalAgent(agentName, registerRequest, writer, request, responseModifiers...) } } -func (h *agentRegisterHandler) renderResponse(agentID string, writer http.ResponseWriter, request *http.Request) { +func (h *agentRegisterHandler) respondWithAccountID() responseModifier { + return func(resp *model.ExtensionRegisterResponse) { + resp.AccountID = h.registrationService.GetFunctionMetadata().AccountID + } +} + +func parseRegistrationFeatures(request *http.Request) []registrationFeature { + rawFeatures := strings.Split(request.Header.Get(featuresHeader), ",") + + var features []registrationFeature + for _, feature := range rawFeatures { + feature = strings.TrimSpace(feature) + if v, found := allowedFeatures[feature]; found { + features = append(features, v) + } + } + + return features +} + +func (h *agentRegisterHandler) renderResponse( + agentID string, + writer http.ResponseWriter, + request *http.Request, + respModifiers ...responseModifier, +) { writer.Header().Set(LambdaAgentIdentifier, agentID) metadata := h.registrationService.GetFunctionMetadata() - resp := &model.ExtensionRegisterResponse{ FunctionVersion: metadata.FunctionVersion, FunctionName: metadata.FunctionName, Handler: metadata.Handler, } + for _, mod := range respModifiers { + mod(resp) + } + if err := rendering.RenderJSON(http.StatusOK, writer, request, resp); err != nil { log.WithError(err).Warn("Error while rendering response") http.Error(writer, err.Error(), http.StatusInternalServerError) } } -func (h *agentRegisterHandler) registerExternalAgent(agent *core.ExternalAgent, registerRequest *RegisterRequest, writer http.ResponseWriter, request *http.Request) { +func (h *agentRegisterHandler) registerExternalAgent( + agent *core.ExternalAgent, + registerRequest *RegisterRequest, + writer http.ResponseWriter, + request *http.Request, + respModifiers ...responseModifier, +) { for _, e := range registerRequest.Events { if err := core.ValidateExternalAgentEvent(e); err != nil { log.Warnf("Failed to register %s: event %s: %s", agent.Name, e, err) @@ -101,11 +156,17 @@ func (h *agentRegisterHandler) registerExternalAgent(agent *core.ExternalAgent, return } - h.renderResponse(agent.ID.String(), writer, request) + h.renderResponse(agent.ID.String(), writer, request, respModifiers...) log.Infof("External agent %s registered, subscribed to %v", agent.String(), registerRequest.Events) } -func (h *agentRegisterHandler) registerInternalAgent(agentName string, registerRequest *RegisterRequest, writer http.ResponseWriter, request *http.Request) { +func (h *agentRegisterHandler) registerInternalAgent( + agentName string, + registerRequest *RegisterRequest, + writer http.ResponseWriter, + request *http.Request, + respModifiers ...responseModifier, +) { for _, e := range registerRequest.Events { if err := core.ValidateInternalAgentEvent(e); err != nil { log.Warnf("Failed to register %s: event %s: %s", agentName, e, err) @@ -142,7 +203,7 @@ func (h *agentRegisterHandler) registerInternalAgent(agentName string, registerR return } - h.renderResponse(agent.ID.String(), writer, request) + h.renderResponse(agent.ID.String(), writer, request, respModifiers...) log.Infof("Internal agent %s registered, subscribed to %v", agent.String(), registerRequest.Events) } diff --git a/lambda/rapi/handler/agentregister_test.go b/lambda/rapi/handler/agentregister_test.go index 35456ee..7370c42 100644 --- a/lambda/rapi/handler/agentregister_test.go +++ b/lambda/rapi/handler/agentregister_test.go @@ -230,102 +230,167 @@ type ExtensionRegisterResponseWithConfig struct { Configuration map[string]string `json:"configuration"` } -var happyPathTests = []struct { - testName string - agentName string - external bool - registrationRequest RegisterRequest - functionMetadata *core.FunctionMetadata - expectedRegistrationResponse ExtensionRegisterResponseWithConfig -}{ - { - testName: "no-config-internal", - agentName: "internal", - external: false, - registrationRequest: RegisterRequest{}, - expectedRegistrationResponse: ExtensionRegisterResponseWithConfig{ - ExtensionRegisterResponse: model.ExtensionRegisterResponse{ - FunctionName: "my-func", - FunctionVersion: "$LATEST", - Handler: "lambda_handler", +func TestRenderAgentResponse(t *testing.T) { + defaultFunctionMetadata := core.FunctionMetadata{ + FunctionVersion: "$LATEST", + FunctionName: "my-func", + Handler: "lambda_handler", + } + + happyPathTests := map[string]struct { + agentName string + external bool + registrationRequest RegisterRequest + featuresHeader string + functionMetadata core.FunctionMetadata + expectedResponse string + }{ + "no-config-internal": { + agentName: "internal", + external: false, + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, + }, + "no-config-external": { + agentName: "external", + external: true, + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, + }, + "function-md-override": { + agentName: "external", + external: true, + functionMetadata: core.FunctionMetadata{FunctionName: "function-name", FunctionVersion: "1", Handler: "myHandler"}, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler" + }`, + }, + "internal with account id feature": { + agentName: "internal", + external: false, + functionMetadata: core.FunctionMetadata{ + FunctionName: "function-name", + FunctionVersion: "1", + Handler: "myHandler", + AccountID: "0123", }, + featuresHeader: "accountId", + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler", + "accountId": "0123" + }`, }, - }, - { - testName: "no-config-external", - agentName: "external", - external: true, - registrationRequest: RegisterRequest{}, - expectedRegistrationResponse: ExtensionRegisterResponseWithConfig{ - ExtensionRegisterResponse: model.ExtensionRegisterResponse{ - FunctionName: "my-func", - FunctionVersion: "$LATEST", - Handler: "lambda_handler", + "external with account id feature": { + agentName: "external", + external: true, + functionMetadata: core.FunctionMetadata{ + FunctionName: "function-name", + FunctionVersion: "1", + Handler: "myHandler", + AccountID: "0123", }, + featuresHeader: "accountId", + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler", + "accountId": "0123" + }`, + }, + "with non-existing accept feature": { + agentName: "external", + external: true, + featuresHeader: "some_non_existing_feature,", + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, }, - }, - { - testName: "function-md-override", - agentName: "external", - external: true, - functionMetadata: &core.FunctionMetadata{FunctionName: "function-name", FunctionVersion: "1", Handler: "myHandler"}, - registrationRequest: RegisterRequest{}, - expectedRegistrationResponse: ExtensionRegisterResponseWithConfig{ - ExtensionRegisterResponse: model.ExtensionRegisterResponse{ + "account id feature and some non-existing feature": { + agentName: "external", + external: true, + featuresHeader: "some_non_existing_feature,accountId,", + functionMetadata: core.FunctionMetadata{ FunctionName: "function-name", FunctionVersion: "1", Handler: "myHandler", + AccountID: "0123", }, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler", + "accountId": "0123" + }`, + }, + "with empty account id data": { + agentName: "external", + external: true, + featuresHeader: "accountId", + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, }, - }, -} - -func TestRenderAgentResponse(t *testing.T) { - defaultFunctionMetadata := core.FunctionMetadata{ - FunctionVersion: "$LATEST", - FunctionName: "my-func", - Handler: "lambda_handler", } - for _, tt := range happyPathTests { - t.Run(tt.testName, func(t *testing.T) { + for name, tt := range happyPathTests { + t.Run(name, func(t *testing.T) { registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), core.NewInvokeFlowSynchronization(), ) registrationService.CreateExternalAgent("external") // external agent has to be pre-registered - if tt.functionMetadata != nil { - registrationService.SetFunctionMetadata(*tt.functionMetadata) - } else { - registrationService.SetFunctionMetadata(defaultFunctionMetadata) - } + registrationService.SetFunctionMetadata(tt.functionMetadata) handler := NewAgentRegisterHandler(registrationService) request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(tt.registrationRequest)) request.Header.Add(LambdaAgentName, tt.agentName) + if tt.featuresHeader != "" { + request.Header.Add(featuresHeader, tt.featuresHeader) + } responseRecorder := httptest.NewRecorder() handler.ServeHTTP(responseRecorder, request) - require.Equal(t, http.StatusOK, responseRecorder.Code) - - registerResponse := ExtensionRegisterResponseWithConfig{} - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, ®isterResponse) - assert.Equal(t, tt.expectedRegistrationResponse.FunctionName, registerResponse.FunctionName) - assert.Equal(t, tt.expectedRegistrationResponse.FunctionVersion, registerResponse.FunctionVersion) - assert.Equal(t, tt.expectedRegistrationResponse.Handler, registerResponse.Handler) + assert.Equal(t, http.StatusOK, responseRecorder.Code) - require.Len(t, registerResponse.Configuration, 0) + respBody, err := io.ReadAll(responseRecorder.Body) + require.NoError(t, err) + assert.JSONEq(t, tt.expectedResponse, string(respBody)) if tt.external { agent, found := registrationService.FindExternalAgentByName(tt.agentName) - require.True(t, found) - require.Equal(t, agent.RegisteredState, agent.GetState()) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) } else { agent, found := registrationService.FindInternalAgentByName(tt.agentName) - require.True(t, found) - require.Equal(t, agent.RegisteredState, agent.GetState()) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) } }) } diff --git a/lambda/rapi/handler/initerror.go b/lambda/rapi/handler/initerror.go index d28e2d4..79daa1f 100644 --- a/lambda/rapi/handler/initerror.go +++ b/lambda/rapi/handler/initerror.go @@ -9,8 +9,8 @@ import ( "net/http" "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/telemetry" "go.amzn.com/lambda/core" "go.amzn.com/lambda/rapi/rendering" @@ -20,21 +20,40 @@ import ( type initErrorHandler struct { registrationService core.RegistrationService - eventsAPI telemetry.EventsAPI } func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { appCtx := appctx.FromRequest(request) - - server := appctx.LoadInteropServer(appCtx) - if server == nil { + interopServer := appctx.LoadInteropServer(appCtx) + if interopServer == nil { log.Panic("Invalid state, cannot access interop server") } + errorType := fatalerror.GetValidRuntimeOrFunctionErrorType(request.Header.Get("Lambda-Runtime-Function-Error-Type")) + fnError := interop.FunctionError{Type: errorType} + errorBody, err := io.ReadAll(request.Body) + if err != nil { + log.WithError(err).Warn("Failed to read error body") + } + headers := interop.InvokeResponseHeaders{ContentType: determineJSONContentType(errorBody)} + response := &interop.ErrorInvokeResponse{Headers: headers, FunctionError: fnError, Payload: errorBody} + runtime := h.registrationService.GetRuntime() - // the previousStateName is needed to define if the init/error is called for INIT or RESTORE - previousStateName := runtime.GetState().Name() + // remove once Languages team change the endpoint to /restore/error + // when an exception is throw while executing the restore hooks + if runtime.GetState() == runtime.RuntimeRestoringState { + if err := runtime.RestoreError(fnError); err != nil { + log.Warn(err) + rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, + runtime.GetState().Name(), core.RuntimeRestoreErrorStateName, err) + return + } + + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) + rendering.RenderAccepted(writer, request) + return + } if err := runtime.InitError(); err != nil { log.Warn(err) @@ -43,42 +62,19 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R return } - errorType := request.Header.Get("Lambda-Runtime-Function-Error-Type") - - errorBody, err := io.ReadAll(request.Body) - if err != nil { - log.WithError(err).Warn("Failed to read error body") - } - - if previousStateName == core.RuntimeRestoringStateName { - h.sendRestoreRuntimeDoneLogEvent() - } else { - h.sendInitRuntimeDoneLogEvent(appCtx) - } - - response := &interop.ErrorResponse{ - ErrorType: errorType, - Payload: errorBody, - ContentType: determineJSONContentType(errorBody), - } - - if err := server.SendInitErrorResponse(server.GetCurrentInvokeID(), response); err != nil { + if err := interopServer.SendInitErrorResponse(response); err != nil { rendering.RenderInteropError(writer, request, err) return } - appctx.StoreErrorResponse(appCtx, response) - + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) rendering.RenderAccepted(writer, request) } // NewInitErrorHandler returns a new instance of http handler // for serving /runtime/init/error. -func NewInitErrorHandler(registrationService core.RegistrationService, eventsAPI telemetry.EventsAPI) http.Handler { - return &initErrorHandler{ - registrationService: registrationService, - eventsAPI: eventsAPI, - } +func NewInitErrorHandler(registrationService core.RegistrationService) http.Handler { + return &initErrorHandler{registrationService: registrationService} } func determineJSONContentType(body []byte) string { @@ -87,24 +83,3 @@ func determineJSONContentType(body []byte) string { } return "application/octet-stream" } - -func (h *initErrorHandler) sendInitRuntimeDoneLogEvent(appCtx appctx.ApplicationContext) { - // ToDo: Convert this to an enum for the whole package to increase readability. - initCachingEnabled := appctx.LoadInitType(appCtx) == appctx.InitCaching - - initSource := interop.InferTelemetryInitSource(initCachingEnabled, appctx.LoadSandboxType(appCtx)) - runtimeDoneData := &telemetry.InitRuntimeDoneData{ - InitSource: initSource, - Status: telemetry.RuntimeDoneFailure, - } - - if err := h.eventsAPI.SendInitRuntimeDone(runtimeDoneData); err != nil { - log.Errorf("Failed to send INITRD: %s", err) - } -} - -func (h *initErrorHandler) sendRestoreRuntimeDoneLogEvent() { - if err := h.eventsAPI.SendRestoreRuntimeDone(telemetry.RuntimeDoneFailure); err != nil { - log.Errorf("Failed to send RESTRD: %s", err) - } -} diff --git a/lambda/rapi/handler/initerror_test.go b/lambda/rapi/handler/initerror_test.go index c9a5a83..a9c4b94 100644 --- a/lambda/rapi/handler/initerror_test.go +++ b/lambda/rapi/handler/initerror_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" "go.amzn.com/lambda/appctx" - + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/testdata" ) @@ -27,7 +27,7 @@ func runTestInitErrorHandler(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - handler := NewInitErrorHandler(flowTest.RegistrationService, flowTest.EventsAPI) + handler := NewInitErrorHandler(flowTest.RegistrationService) responseRecorder := httptest.NewRecorder() appCtx := flowTest.AppCtx @@ -60,12 +60,12 @@ func runTestInitErrorHandler(t *testing.T) { // payload is not provided. This fallback is not part // of the RAPID API spec and is not available to // customers. - require.Equal(t, "", errorResponse.ErrorMessage) + require.Equal(t, "", errorResponse.FunctionError.Message) // Slicer falls back to using ErrorType when error // payload is not provided. Customers can set error // type via header to use this fallback. - require.Equal(t, errorType, errorResponse.ErrorType) + require.Equal(t, fatalerror.RuntimeUnknown, errorResponse.FunctionError.Type) // Payload is arbitrary data that customers submit - it's error response body. require.Equal(t, errorBody, errorResponse.Payload) diff --git a/lambda/rapi/handler/invocationerror.go b/lambda/rapi/handler/invocationerror.go index 170c0cb..d434461 100644 --- a/lambda/rapi/handler/invocationerror.go +++ b/lambda/rapi/handler/invocationerror.go @@ -9,6 +9,7 @@ import ( "io" "net/http" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/model" @@ -37,7 +38,7 @@ type invocationErrorHandler struct { func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { appCtx := appctx.FromRequest(request) - server := appctx.LoadInteropServer(appCtx) + server := appctx.LoadResponseSender(appCtx) if server == nil { log.Panic("Invalid state, cannot access interop server") } @@ -50,7 +51,7 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * return } - errorType := h.getErrorType(request.Header) + errorType := fatalerror.GetValidRuntimeOrFunctionErrorType(h.getErrorType(request.Header)) var errorCause json.RawMessage var errorBody []byte @@ -75,20 +76,23 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * log.WithError(err).Warn("Failed to parse error body") } - response := &interop.ErrorResponse{ - ErrorType: errorType, - Payload: errorBody, - ErrorCause: errorCause, + headers := interop.InvokeResponseHeaders{ ContentType: contentType, FunctionResponseMode: functionResponseMode, } + response := &interop.ErrorInvokeResponse{ + Headers: headers, + FunctionError: interop.FunctionError{Type: errorType}, + Payload: errorBody, + } + if err := server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), response); err != nil { rendering.RenderInteropError(writer, request, err) return } - appctx.StoreErrorResponse(appCtx, response) + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{ErrorCause: errorCause}) if err := runtime.ResponseSent(); err != nil { log.Panic(err) diff --git a/lambda/rapi/handler/invocationerror_test.go b/lambda/rapi/handler/invocationerror_test.go index 2f177fe..72e6719 100644 --- a/lambda/rapi/handler/invocationerror_test.go +++ b/lambda/rapi/handler/invocationerror_test.go @@ -14,6 +14,7 @@ import ( "testing" "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/model" "go.amzn.com/lambda/testdata" @@ -87,12 +88,12 @@ func runTestInvocationErrorHandler(t *testing.T) { // payload is not provided. This fallback is not part // of the RAPID API spec and is not available to // customers. - assert.Equal(t, "", errorResponse.ErrorMessage) + assert.Equal(t, "", errorResponse.FunctionError.Message) // Slicer falls back to using ErrorType when error // payload is not provided. Customers can set error // type header to use this fallback. - assert.Equal(t, errorType, errorResponse.ErrorType) + assert.Equal(t, fatalerror.RuntimeUnknown, errorResponse.FunctionError.Type) // Payload is arbitrary data that customers submit - it's error response body. assert.Equal(t, errorBody, errorResponse.Payload) @@ -176,10 +177,10 @@ func TestInvocationErrorHandlerSendsErrorCauseToXRayForContentTypeErrorCause(t * handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) // Assert error response contains error cause - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.JSONEq(t, string(errorCause), string(errorResponse.ErrorCause)) + assert.JSONEq(t, string(errorCause), string(invokeErrorTraceData.ErrorCause)) } func TestInvocationErrorHandlerSendsNullErrorCauseWhenErrorCauseFormatIsInvalidOrEmptyForContentTypeErrorCause(t *testing.T) { @@ -213,10 +214,10 @@ func TestInvocationErrorHandlerSendsNullErrorCauseWhenErrorCauseFormatIsInvalidO // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, json.RawMessage(nil), errorResponse.ErrorCause) + assert.Equal(t, json.RawMessage(nil), invokeErrorTraceData.ErrorCause) } } @@ -248,11 +249,11 @@ func TestInvocationErrorHandlerSendsCompactedErrorCauseWhenErrorCauseIsTooLargeF // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - errorCauseJSON, err := model.ValidatedErrorCauseJSON(errorResponse.ErrorCause) + errorCauseJSON, err := model.ValidatedErrorCauseJSON(invokeErrorTraceData.ErrorCause) assert.NoError(t, err, "expected cause sent x-ray to be valid") assert.True(t, len(errorCauseJSON) < model.MaxErrorCauseSizeBytes, "expected cause to be compacted to size") } @@ -277,12 +278,13 @@ func TestInvocationResponsePayloadIsDefaultErrorMessageWhenRequestParsingFailsFo // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) assert.Equal(t, "application/octet-stream", flowTest.InteropServer.ResponseContentType) assert.Equal(t, "function-response-mode", flowTest.InteropServer.FunctionResponseMode) + errorResponse := flowTest.InteropServer.ErrorResponse invokeResponsePayload := errorResponse.Payload expectedResponse, _ := json.Marshal(invalidErrorBodyMessage) @@ -311,10 +313,10 @@ func TestInvocationErrorHandlerSendsErrorCauseToXRayWhenXRayErrorCauseHeaderIsSe // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.JSONEq(t, string(errorCause), string(errorResponse.ErrorCause)) + assert.JSONEq(t, string(errorCause), string(invokeErrorTraceData.ErrorCause)) } func TestInvocationErrorHandlerSendsNilCauseToXRayWhenXRayErrorCauseHeaderContainsInvalidCause(t *testing.T) { @@ -340,10 +342,10 @@ func TestInvocationErrorHandlerSendsNilCauseToXRayWhenXRayErrorCauseHeaderContai // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, json.RawMessage(nil), errorResponse.ErrorCause) + assert.Equal(t, json.RawMessage(nil), invokeErrorTraceData.ErrorCause) } } @@ -366,11 +368,11 @@ func TestInvocationErrorHandlerSendsCompactedErrorCauseToXRayWhenXRayErrorCauseI // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - errorCauseJSON, err := model.ValidatedErrorCauseJSON(errorResponse.ErrorCause) + errorCauseJSON, err := model.ValidatedErrorCauseJSON(invokeErrorTraceData.ErrorCause) assert.NoError(t, err, "expected cause sent x-ray to be valid") assert.True(t, len(errorCauseJSON) < model.MaxErrorCauseSizeBytes, "expected cause to be compacted to size") } @@ -391,10 +393,10 @@ func TestInvocationErrorHandlerSendsNilToXRayWhenXRayErrorCauseHeaderIsNotSet(t // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.Nil(t, errorResponse.ErrorCause) + assert.Nil(t, invokeErrorTraceData.ErrorCause) } func TestInvocationErrorHandlerSendsErrorCauseToXRayWhenXRayErrorCauseContainsUTF8Characters(t *testing.T) { @@ -416,8 +418,8 @@ func TestInvocationErrorHandlerSendsErrorCauseToXRayWhenXRayErrorCauseContainsUT // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.JSONEq(t, string(errorCause), string(errorResponse.ErrorCause)) + assert.JSONEq(t, string(errorCause), string(invokeErrorTraceData.ErrorCause)) } diff --git a/lambda/rapi/handler/invocationnext_test.go b/lambda/rapi/handler/invocationnext_test.go index 5bddb86..64ae057 100644 --- a/lambda/rapi/handler/invocationnext_test.go +++ b/lambda/rapi/handler/invocationnext_test.go @@ -4,6 +4,7 @@ package handler import ( + "bytes" "context" "errors" "fmt" @@ -19,6 +20,8 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" @@ -45,57 +48,65 @@ func TestRenderInvokeEmptyHeaders(t *testing.T) { assert.Equal(t, http.StatusOK, responseRecorder.Code) } -func TestRenderInvoke(t *testing.T) { +func TestRenderInvokeHappy(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := httptest.NewRecorder() appCtx := flowTest.AppCtx deadlineNs := 12345 - invokePayload := "Payload" invoke := &interop.Invoke{ TraceID: "Root=RootID;Parent=LambdaFrontend;Sampled=1", - ID: "ID", + ID: "", // updated in loop InvokedFunctionArn: "InvokedFunctionArn", CognitoIdentityID: "CognitoIdentityId1", CognitoIdentityPoolID: "CognitoIdentityPoolId1", ClientContext: "ClientContext", DeadlineNs: strconv.Itoa(deadlineNs), ContentType: "image/png", - Payload: strings.NewReader(invokePayload), + Payload: strings.NewReader(""), // updated in loop } ctx := telemetry.NewTraceContext(context.Background(), "RootID", "InvocationSubegmentID") - flowTest.ConfigureForInvoke(ctx, invoke) - - request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - handler.ServeHTTP(responseRecorder, request) + var requestBuffer bytes.Buffer + for i := 0; i < 6; i++ { + handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + invoke.ID = fmt.Sprintf("ID-%d", i) + invokePayload := string(bytes.Repeat([]byte("a"), (i%3)*128*1024)) // vary payload size up and down across invokes + invoke.Payload = strings.NewReader(invokePayload) + + flowTest.ConfigureForInvoke(ctx, invoke) + flowTest.ConfigureInvokeRenderer(ctx, invoke, &requestBuffer) // reuse request buffer on each invoke + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) - headers := responseRecorder.Header() - assert.Equal(t, invoke.InvokedFunctionArn, headers.Get("Lambda-Runtime-Invoked-Function-Arn")) - assert.Equal(t, invoke.ID, headers.Get("Lambda-Runtime-Aws-Request-Id")) - assert.Equal(t, invoke.ClientContext, headers.Get("Lambda-Runtime-Client-Context")) - expectedCognitoIdentityHeader := fmt.Sprintf("{\"cognitoIdentityId\":\"%s\",\"cognitoIdentityPoolId\":\"%s\"}", invoke.CognitoIdentityID, invoke.CognitoIdentityPoolID) - assert.JSONEq(t, expectedCognitoIdentityHeader, headers.Get("Lambda-Runtime-Cognito-Identity")) - assert.Equal(t, "Root=RootID;Parent=InvocationSubegmentID;Sampled=1", headers.Get("Lambda-Runtime-Trace-Id")) - - // Assert deadline precision. E.g. 1999 ns and 2001 ns having diff of 2 ns - // would result in 1ms and 2ms deadline correspondingly. - expectedDeadline := metering.MonoToEpoch(int64(deadlineNs)) / int64(time.Millisecond) - receivedDeadline, _ := strconv.ParseInt(headers.Get("Lambda-Runtime-Deadline-Ms"), 10, 64) - assert.True(t, math.Abs(float64(expectedDeadline-receivedDeadline)) <= float64(1), - fmt.Sprintf("Expected: %v, received: %v", expectedDeadline, receivedDeadline)) - - assert.Equal(t, "image/png", headers.Get("Content-Type")) - assert.Len(t, headers, 7) - assert.Equal(t, invokePayload, responseRecorder.Body.String()) + headers := responseRecorder.Header() + assert.Equal(t, invoke.InvokedFunctionArn, headers.Get("Lambda-Runtime-Invoked-Function-Arn")) + assert.Equal(t, invoke.ID, headers.Get("Lambda-Runtime-Aws-Request-Id")) + assert.Equal(t, invoke.ClientContext, headers.Get("Lambda-Runtime-Client-Context")) + expectedCognitoIdentityHeader := fmt.Sprintf("{\"cognitoIdentityId\":\"%s\",\"cognitoIdentityPoolId\":\"%s\"}", invoke.CognitoIdentityID, invoke.CognitoIdentityPoolID) + assert.JSONEq(t, expectedCognitoIdentityHeader, headers.Get("Lambda-Runtime-Cognito-Identity")) + assert.Equal(t, "Root=RootID;Parent=InvocationSubegmentID;Sampled=1", headers.Get("Lambda-Runtime-Trace-Id")) + + // Assert deadline precision. E.g. 1999 ns and 2001 ns having diff of 2 ns + // would result in 1ms and 2ms deadline correspondingly. + expectedDeadline := metering.MonoToEpoch(int64(deadlineNs)) / int64(time.Millisecond) + receivedDeadline, _ := strconv.ParseInt(headers.Get("Lambda-Runtime-Deadline-Ms"), 10, 64) + assert.True(t, math.Abs(float64(expectedDeadline-receivedDeadline)) <= float64(1), + fmt.Sprintf("Expected: %v, received: %v", expectedDeadline, receivedDeadline)) + + assert.Equal(t, "image/png", headers.Get("Content-Type")) + assert.Len(t, headers, 7) + responsePayload := responseRecorder.Body.String() + require.Equalf(t, len(invokePayload), len(responsePayload), "Unexpected payload for request %d", i) + assert.Equal(t, invokePayload, responsePayload) + } } // Cgo calls removed due to crashes while spawning threads under memory pressure. func TestRenderInvokeDoesNotCallCgo(t *testing.T) { cgoCallsBefore := runtime.NumCgoCall() - TestRenderInvoke(t) + TestRenderInvokeHappy(t) cgoCallsAfter := runtime.NumCgoCall() assert.Equal(t, cgoCallsBefore, cgoCallsAfter) } diff --git a/lambda/rapi/handler/invocationresponse.go b/lambda/rapi/handler/invocationresponse.go index 7e47d2e..d267775 100644 --- a/lambda/rapi/handler/invocationresponse.go +++ b/lambda/rapi/handler/invocationresponse.go @@ -8,6 +8,7 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/rendering" @@ -17,7 +18,6 @@ import ( const ( StreamingFunctionResponseMode = "streaming" - ErrInvalidResponseModeHeader = "Runtime.InvalidResponseModeHeader" ) type invocationResponseHandler struct { @@ -27,7 +27,7 @@ type invocationResponseHandler struct { func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { appCtx := appctx.FromRequest(request) - server := appctx.LoadInteropServer(appCtx) + server := appctx.LoadResponseSender(appCtx) if server == nil { log.Panic("Invalid state, cannot access interop server") } @@ -48,25 +48,38 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques case StreamingFunctionResponseMode: headers[functionResponseModeHeader] = functionResponseMode default: - errorResponse := &interop.ErrorResponse{ - ErrorType: ErrInvalidResponseModeHeader, + errHeaders := interop.InvokeResponseHeaders{ ContentType: request.Header.Get(contentTypeHeader), } - _ = server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), errorResponse) + fnError := interop.FunctionError{Type: fatalerror.RuntimeInvalidResponseModeHeader} + response := &interop.ErrorInvokeResponse{ + Headers: errHeaders, + FunctionError: fnError, + Payload: []byte{}, + } + + _ = server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), response) rendering.RenderInvalidFunctionResponseMode(writer, request) return } } - if err := server.SendResponse(invokeID, headers, request.Body, request.Trailer, &interop.CancellableRequest{Request: request}); err != nil { + response := &interop.StreamableInvokeResponse{ + Headers: headers, + Payload: request.Body, + Trailers: request.Trailer, + Request: &interop.CancellableRequest{Request: request}, + } + + if err := server.SendResponse(invokeID, response); err != nil { switch err := err.(type) { case *interop.ErrorResponseTooLarge: - if server.SendErrorResponse(invokeID, err.AsInteropError()) != nil { + if server.SendErrorResponse(invokeID, err.AsErrorResponse()) != nil { rendering.RenderInteropError(writer, request, err) return } - appctx.StoreErrorResponse(appCtx, err.AsInteropError()) + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) if err := runtime.ResponseSent(); err != nil { log.Panic(err) diff --git a/lambda/rapi/handler/invocationresponse_test.go b/lambda/rapi/handler/invocationresponse_test.go index 7c0b220..dc29c10 100644 --- a/lambda/rapi/handler/invocationresponse_test.go +++ b/lambda/rapi/handler/invocationresponse_test.go @@ -17,6 +17,7 @@ import ( "github.com/aws/aws-lambda-go/events/test" "github.com/stretchr/testify/assert" "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/testdata" ) @@ -62,13 +63,13 @@ func TestResponseTooLarge(t *testing.T) { errorResponse := flowTest.InteropServer.ErrorResponse assert.NotNil(t, errorResponse) assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, "Function.ResponseSizeTooLarge", errorResponse.ErrorType) - assert.Equal(t, "Response payload size (6291557 bytes) exceeded maximum allowed payload size (6291556 bytes).", errorResponse.ErrorMessage) + assert.Equal(t, fatalerror.FunctionOversizedResponse, errorResponse.FunctionError.Type) + assert.Equal(t, "Response payload size (6291557 bytes) exceeded maximum allowed payload size (6291556 bytes).", errorResponse.FunctionError.Message) var errorPayload map[string]interface{} assert.NoError(t, json.Unmarshal(errorResponse.Payload, &errorPayload)) - assert.Equal(t, errorResponse.ErrorType, errorPayload["errorType"]) - assert.Equal(t, errorResponse.ErrorMessage, errorPayload["errorMessage"]) + assert.Equal(t, string(errorResponse.FunctionError.Type), errorPayload["errorType"]) + assert.Equal(t, errorResponse.FunctionError.Message, errorPayload["errorMessage"]) } func TestResponseAccepted(t *testing.T) { @@ -193,7 +194,7 @@ func TestResponseWithDifferentFunctionResponseModes(t *testing.T) { if testCase.expectedErrorResponse { assert.NotNil(t, flowTest.InteropServer.ErrorResponse) assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, "Runtime.InvalidResponseModeHeader", flowTest.InteropServer.ErrorResponse.ErrorType) + assert.Equal(t, fatalerror.RuntimeInvalidResponseModeHeader, flowTest.InteropServer.ErrorResponse.FunctionError.Type) } else { assert.NotNil(t, flowTest.InteropServer.Response) assert.Nil(t, flowTest.InteropServer.ErrorResponse) diff --git a/lambda/rapi/handler/restoreerror.go b/lambda/rapi/handler/restoreerror.go new file mode 100644 index 0000000..eed97b2 --- /dev/null +++ b/lambda/rapi/handler/restoreerror.go @@ -0,0 +1,47 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "net/http" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi/rendering" +) + +type restoreErrorHandler struct { + registrationService core.RegistrationService +} + +func (h *restoreErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + appCtx := appctx.FromRequest(request) + server := appctx.LoadInteropServer(appCtx) + if server == nil { + log.Panic("Invalid state, cannot access interop server") + } + + errorType := fatalerror.GetValidRuntimeOrFunctionErrorType(request.Header.Get("Lambda-Runtime-Function-Error-Type")) + fnError := interop.FunctionError{Type: errorType} + + runtime := h.registrationService.GetRuntime() + + if err := runtime.RestoreError(fnError); err != nil { + log.Warn(err) + rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, + runtime.GetState().Name(), core.RuntimeRestoreErrorStateName, err) + return + } + + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) + + rendering.RenderAccepted(writer, request) +} + +func NewRestoreErrorHandler(registrationService core.RegistrationService) http.Handler { + return &restoreErrorHandler{registrationService: registrationService} +} diff --git a/lambda/rapi/handler/restoreerror_test.go b/lambda/rapi/handler/restoreerror_test.go new file mode 100644 index 0000000..57226fa --- /dev/null +++ b/lambda/rapi/handler/restoreerror_test.go @@ -0,0 +1,44 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/testdata" +) + +func TestRestoreErrorHandler(t *testing.T) { + t.Run("GA", func(t *testing.T) { runTestRestoreErrorHandler(t) }) +} + +func runTestRestoreErrorHandler(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForRestoring() + + handler := NewRestoreErrorHandler(flowTest.RegistrationService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + errorBody := []byte("My byte array is yours") + errorType := "ErrorType" + errorContentType := "application/MyBinaryType" + + request := appctx.RequestWithAppCtx(httptest.NewRequest("POST", "/", bytes.NewReader(errorBody)), appCtx) + + request.Header.Set("Content-Type", errorContentType) + request.Header.Set("Lambda-Runtime-Function-Error-Type", errorType) + + handler.ServeHTTP(responseRecorder, request) + + require.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", responseRecorder.Code, http.StatusAccepted) + require.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) + require.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) +} diff --git a/lambda/rapi/handler/runtimelogs.go b/lambda/rapi/handler/runtimelogs.go index 99941b0..6b8a67e 100644 --- a/lambda/rapi/handler/runtimelogs.go +++ b/lambda/rapi/handler/runtimelogs.go @@ -9,10 +9,10 @@ import ( "fmt" "io" "net/http" + "strings" "go.amzn.com/lambda/core" "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/rapidcore/telemetry/logsapi" "go.amzn.com/lambda/telemetry" "github.com/google/uuid" @@ -31,10 +31,10 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http switch err := err.(type) { case *ErrAgentIdentifierUnknown: rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension "+err.agentID.String()) - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) } return } @@ -45,21 +45,21 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http if err != nil { log.Error(err) rendering.RenderInternalServerError(writer, request) - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) return } - respBody, status, headers, err := h.telemetrySubscription.Subscribe(agentName, bytes.NewReader(body), request.Header) + respBody, status, headers, err := h.telemetrySubscription.Subscribe(agentName, bytes.NewReader(body), request.Header, request.RemoteAddr) if err != nil { log.Errorf("Telemetry API error: %s", err) switch err { - case logsapi.ErrTelemetryServiceOff: + case telemetry.ErrTelemetryServiceOff: rendering.RenderForbiddenWithTypeMsg(writer, request, h.telemetrySubscription.GetServiceClosedErrorType(), h.telemetrySubscription.GetServiceClosedErrorMessage()) - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) } return } @@ -67,11 +67,14 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http rendering.RenderRuntimeLogsResponse(writer, respBody, status, headers) switch status / 100 { case 2: // 2xx - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeSuccess, 1) + if strings.Contains(string(respBody), "OK") { + h.telemetrySubscription.RecordCounterMetric(telemetry.NumSubscribers, 1) + } + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeSuccess, 1) case 4: // 4xx - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) case 5: // 5xx - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) } } diff --git a/lambda/rapi/handler/runtimelogs_test.go b/lambda/rapi/handler/runtimelogs_test.go index 892d61e..cbb8b0b 100644 --- a/lambda/rapi/handler/runtimelogs_test.go +++ b/lambda/rapi/handler/runtimelogs_test.go @@ -9,23 +9,24 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "testing" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore/telemetry/logsapi" ) type mockSubscriptionAPI struct{ mock.Mock } -func (s *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { - args := s.Called(agentName, body, headers) +func (s *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { + args := s.Called(agentName, body, headers, remoteAddr) return args.Get(0).([]byte), args.Int(1), args.Get(2).(map[string][]string), args.Error(3) } @@ -61,10 +62,15 @@ func (s *mockSubscriptionAPI) GetServiceClosedErrorType() string { return args.Get(0).(string) } +func validIPPort(addr string) bool { + ip, _, err := net.SplitHostPort(addr) + return err == nil && net.ParseIP(ip) != nil +} + func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} respBody, respStatus, respHeaders := []byte(`barbaz`), http.StatusNotFound, map[string][]string{"K": []string{"V1", "V2"}} - clientErrMetric := logsapi.SubscribeClientErr + clientErrMetric := telemetry.SubscribeClientErr registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), @@ -75,7 +81,7 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { assert.NoError(t, err) telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) @@ -91,7 +97,7 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders) + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) recordedBody, err := io.ReadAll(responseRecorder.Body) @@ -102,10 +108,97 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) } +func TestSuccessfulTelemetryAPIPutRequest(t *testing.T) { + agentName, reqBody, reqHeaders := "extensionName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} + respBody, respStatus, respHeaders := []byte(`"OK"`), http.StatusOK, map[string][]string{"K": []string{"V1", "V2"}} + numSubscribersMetric := telemetry.NumSubscribers + subscribeSuccessMetric := telemetry.SubscribeSuccess + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization(), + core.NewInvokeFlowSynchronization(), + ) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("RecordCounterMetric", numSubscribersMetric, 1) + telemetrySubscription.On("RecordCounterMetric", subscribeSuccessMetric, 1) + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", numSubscribersMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", subscribeSuccessMetric, 1) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + assert.Equal(t, respStatus, responseRecorder.Code) + assert.Equal(t, respBody, recordedBody) + assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) +} + +func TestNumberOfSubscribersWhenAnExtensionIsAlreadySubscribed(t *testing.T) { + agentName, reqBody, reqHeaders := "extensionName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} + respBody, respStatus, respHeaders := []byte(`"AlreadySubcribed"`), http.StatusOK, map[string][]string{"K": []string{"V1", "V2"}} + numSubscribersMetric := telemetry.NumSubscribers + subscribeSuccessMetric := telemetry.SubscribeSuccess + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization(), + core.NewInvokeFlowSynchronization(), + ) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("RecordCounterMetric", subscribeSuccessMetric, 1) + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", subscribeSuccessMetric, 1) + telemetrySubscription.AssertNotCalled(t, "RecordCounterMetric", numSubscribersMetric, mock.Anything) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + assert.Equal(t, respStatus, responseRecorder.Code) + assert.Equal(t, respBody, recordedBody) + assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) +} + func TestErrorUnregisteredAgentID(t *testing.T) { invalidAgentID := uuid.New() reqBody, reqHeaders := []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - clientErrMetric := logsapi.SubscribeClientErr + clientErrMetric := telemetry.SubscribeClientErr registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), @@ -143,7 +236,7 @@ func TestErrorUnregisteredAgentID(t *testing.T) { func TestErrorTelemetryAPICallFailure(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} apiError := errors.New("Error calling Telemetry API: connection refused") - serverErrMetric := logsapi.SubscribeServerErr + serverErrMetric := telemetry.SubscribeServerErr registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), @@ -154,7 +247,7 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { assert.NoError(t, err) telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) telemetrySubscription.On("RecordCounterMetric", serverErrMetric, 1) handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) @@ -184,8 +277,8 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { func TestRenderLogsSubscriptionClosed(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - apiError := logsapi.ErrTelemetryServiceOff - clientErrMetric := logsapi.SubscribeClientErr + apiError := telemetry.ErrTelemetryServiceOff + clientErrMetric := telemetry.SubscribeClientErr registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), @@ -196,7 +289,7 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { assert.NoError(t, err) telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Logs API subscription is closed already") telemetrySubscription.On("GetServiceClosedErrorType").Return("Logs.SubscriptionClosed") @@ -228,8 +321,8 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { func TestRenderTelemetrySubscriptionClosed(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - apiError := logsapi.ErrTelemetryServiceOff - clientErrMetric := logsapi.SubscribeClientErr + apiError := telemetry.ErrTelemetryServiceOff + clientErrMetric := telemetry.SubscribeClientErr registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), @@ -240,7 +333,7 @@ func TestRenderTelemetrySubscriptionClosed(t *testing.T) { assert.NoError(t, err) telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Telemetry API subscription is closed already") telemetrySubscription.On("GetServiceClosedErrorType").Return("Telemetry.SubscriptionClosed") diff --git a/lambda/rapi/model/agentregisterresponse.go b/lambda/rapi/model/agentregisterresponse.go index 7e2eb86..fb9cacc 100644 --- a/lambda/rapi/model/agentregisterresponse.go +++ b/lambda/rapi/model/agentregisterresponse.go @@ -5,6 +5,7 @@ package model // ExtensionRegisterResponse is a response returned by the API server on extension/register post request type ExtensionRegisterResponse struct { + AccountID string `json:"accountId,omitempty"` FunctionName string `json:"functionName"` FunctionVersion string `json:"functionVersion"` Handler string `json:"handler"` diff --git a/lambda/rapi/model/errorresponse.go b/lambda/rapi/model/errorresponse.go index 621811c..4c95e6c 100644 --- a/lambda/rapi/model/errorresponse.go +++ b/lambda/rapi/model/errorresponse.go @@ -3,12 +3,6 @@ package model -import ( - "encoding/json" - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/interop" -) - // ErrorResponse is a standard invoke error response, // providing information about the error. type ErrorResponse struct { @@ -16,16 +10,3 @@ type ErrorResponse struct { ErrorType string `json:"errorType"` StackTrace []string `json:"stackTrace,omitempty"` } - -func (s *ErrorResponse) AsInteropError() *interop.ErrorResponse { - respJSON, err := json.Marshal(s) - if err != nil { - log.Panicf("Failed to marshal %#v: %s", *s, err) - } - - return &interop.ErrorResponse{ - ErrorType: s.ErrorType, - ErrorMessage: s.ErrorMessage, - Payload: respJSON, - } -} diff --git a/lambda/rapi/rapi_fuzz_test.go b/lambda/rapi/rapi_fuzz_test.go new file mode 100644 index 0000000..f1df47f --- /dev/null +++ b/lambda/rapi/rapi_fuzz_test.go @@ -0,0 +1,391 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/url" + "os" + "regexp" + "strings" + "testing" + "unicode" + + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/extensions" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" +) + +type runtimeFunctionErrStruct struct { + ErrorMessage string + ErrorType string + StackTrace []string +} + +func FuzzRuntimeAPIRouter(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + addSeedCorpusURLTargets(f) + + f.Fuzz(func(t *testing.T, rawPath string, payload []byte, isGetMethod bool) { + u, err := parseToURLStruct(rawPath) + if err != nil { + t.Skipf("error parsing url: %v. Skipping test.", err) + } + + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + + invoke := createDummyInvoke() + flowTest.ConfigureForInvoke(context.Background(), invoke) + + appctx.StoreInitType(flowTest.AppCtx, true) + + rapiServer := makeRapiServer(flowTest) + + method := "GET" + if !isGetMethod { + method = "POST" + } + + request := httptest.NewRequest(method, rawPath, bytes.NewReader(payload)) + responseRecorder := serveTestRequest(rapiServer, request) + + if isExpectedPath(u.Path, invoke.ID, isGetMethod) { + assertExpectedPathResponseCode(t, responseRecorder.Code, rawPath) + } else { + assertUnexpectedPathResponseCode(t, responseRecorder.Code, rawPath) + } + }) +} + +func FuzzInitErrorHandler(f *testing.F) { + addRuntimeFunctionErrorJSONCorpus(f) + + f.Fuzz(func(t *testing.T, errorBody []byte, errTypeHeader []byte) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL("/runtime/init/error", version20180601) + request := httptest.NewRequest("POST", target, bytes.NewReader(errorBody)) + request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) + request.Header.Set("Lambda-Runtime-Function-Error-Type", string(errTypeHeader)) + + responseRecorder := serveTestRequest(rapiServer, request) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.JSONEq(t, "{\"status\":\"OK\"}\n", responseRecorder.Body.String()) + assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) + + assertErrorResponsePersists(t, errorBody, errTypeHeader, flowTest) + }) +} + +func FuzzInvocationResponseHandler(f *testing.F) { + f.Add([]byte("SUCCESS"), []byte("application/json"), []byte("streaming")) + f.Add([]byte(strings.Repeat("a", interop.MaxPayloadSize+1)), []byte("application/json"), []byte("streaming")) + + f.Fuzz(func(t *testing.T, responseBody []byte, contentType []byte, responseMode []byte) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + flowTest.Runtime.Ready() + + invoke := createDummyInvoke() + flowTest.ConfigureForInvoke(context.Background(), invoke) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/response", invoke.ID), version20180601) + request := httptest.NewRequest("POST", target, bytes.NewReader(responseBody)) + request.Header.Set("Content-Type", string(contentType)) + request.Header.Set("Lambda-Runtime-Function-Response-Mode", string(responseMode)) + + request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) + + responseRecorder := serveTestRequest(rapiServer, request) + + if !isValidResponseMode(responseMode) { + assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) + return + } + + if len(responseBody) > interop.MaxPayloadSize { + assertInvocationResponseTooLarge(t, responseRecorder, flowTest, responseBody) + } else { + assertInvocationResponseAccepted(t, responseRecorder, flowTest, responseBody, contentType) + } + }) +} + +func FuzzInvocationErrorHandler(f *testing.F) { + addRuntimeFunctionErrorJSONCorpus(f) + + f.Fuzz(func(t *testing.T, errorBody []byte, errTypeHeader []byte) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + flowTest.Runtime.Ready() + appCtx := flowTest.AppCtx + + invoke := createDummyInvoke() + flowTest.ConfigureForInvoke(context.Background(), invoke) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/error", invoke.ID), version20180601) + request := httptest.NewRequest("POST", target, bytes.NewReader(errorBody)) + request = appctx.RequestWithAppCtx(request, appCtx) + + request.Header.Set("Lambda-Runtime-Function-Error-Type", string(errTypeHeader)) + + responseRecorder := serveTestRequest(rapiServer, request) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.JSONEq(t, "{\"status\":\"OK\"}\n", responseRecorder.Body.String()) + assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) + + assertErrorResponsePersists(t, errorBody, errTypeHeader, flowTest) + }) +} + +func FuzzRestoreErrorHandler(f *testing.F) { + f.Fuzz(func(t *testing.T, errorBody []byte, errTypeHeader []byte) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForRestoring() + + appctx.StoreInitType(flowTest.AppCtx, true) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL("/runtime/restore/error", version20180601) + request := httptest.NewRequest("POST", target, bytes.NewReader(errorBody)) + request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) + + request.Header.Set("Lambda-Runtime-Function-Error-Type", string(errTypeHeader)) + + responseRecorder := serveTestRequest(rapiServer, request) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.JSONEq(t, "{\"status\":\"OK\"}\n", responseRecorder.Body.String()) + assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) + }) +} + +func makeRapiServer(flowTest *testdata.FlowTest) *Server { + return NewServer( + "127.0.0.1", + 0, + flowTest.AppCtx, + flowTest.RegistrationService, + flowTest.RenderingService, + true, + &telemetry.NoOpSubscriptionAPI{}, + flowTest.TelemetrySubscription, + flowTest.CredentialsService, + ) +} + +func createDummyInvoke() *interop.Invoke { + return &interop.Invoke{ + ID: "InvocationID1", + Payload: strings.NewReader("Payload1"), + } +} + +func makeTargetURL(path string, apiVersion string) string { + protocol := "http" + endpoint := os.Getenv("AWS_LAMBDA_RUNTIME_API") + baseurl := fmt.Sprintf("%s://%s%s", protocol, endpoint, apiVersion) + + return fmt.Sprintf("%s%s", baseurl, path) +} + +func serveTestRequest(rapiServer *Server, request *http.Request) *httptest.ResponseRecorder { + responseRecorder := httptest.NewRecorder() + rapiServer.server.Handler.ServeHTTP(responseRecorder, request) + log.Printf("test(%v) = %v", request.URL, responseRecorder.Code) + + return responseRecorder +} + +func addSeedCorpusURLTargets(f *testing.F) { + invoke := createDummyInvoke() + errStruct := runtimeFunctionErrStruct{ + ErrorMessage: "error occurred", + ErrorType: "Runtime.UnknownReason", + StackTrace: []string{}, + } + errJSON, _ := json.Marshal(errStruct) + f.Add(makeTargetURL("/runtime/init/error", version20180601), errJSON, false) + f.Add(makeTargetURL("/runtime/invocation/next", version20180601), []byte{}, true) + f.Add(makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/response", invoke.ID), version20180601), []byte("SUCCESS"), false) + f.Add(makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/error", invoke.ID), version20180601), errJSON, false) + f.Add(makeTargetURL("/runtime/restore/next", version20180601), []byte{}, true) + f.Add(makeTargetURL("/runtime/restore/error", version20180601), errJSON, false) + + f.Add(makeTargetURL("/extension/register", version20200101), []byte("register"), false) + f.Add(makeTargetURL("/extension/event/next", version20200101), []byte("next"), true) + f.Add(makeTargetURL("/extension/init/error", version20200101), []byte("init error"), false) + f.Add(makeTargetURL("/extension/exit/error", version20200101), []byte("exit error"), false) +} + +func addRuntimeFunctionErrorJSONCorpus(f *testing.F) { + runtimeFuncErr := runtimeFunctionErrStruct{ + ErrorMessage: "error", + ErrorType: "Runtime.Unknown", + StackTrace: []string{}, + } + data, _ := json.Marshal(runtimeFuncErr) + + f.Add(data, []byte("Runtime.Unknown")) +} + +func isExpectedPath(path string, invokeID string, isGetMethod bool) bool { + expectedPaths := make(map[string]bool) + + expectedPaths[fmt.Sprintf("%s/runtime/init/error", version20180601)] = false + expectedPaths[fmt.Sprintf("%s/runtime/invocation/next", version20180601)] = true + expectedPaths[fmt.Sprintf("%s/runtime/invocation/%s/response", version20180601, invokeID)] = false + expectedPaths[fmt.Sprintf("%s/runtime/invocation/%s/error", version20180601, invokeID)] = false + expectedPaths[fmt.Sprintf("%s/runtime/restore/next", version20180601)] = true + expectedPaths[fmt.Sprintf("%s/runtime/restore/error", version20180601)] = false + + expectedPaths[fmt.Sprintf("%s/extension/register", version20200101)] = false + expectedPaths[fmt.Sprintf("%s/extension/event/next", version20200101)] = true + expectedPaths[fmt.Sprintf("%s/extension/init/error", version20200101)] = false + expectedPaths[fmt.Sprintf("%s/extension/exit/error", version20200101)] = false + + val, found := expectedPaths[path] + return found && (val == isGetMethod) +} + +func parseToURLStruct(rawPath string) (*url.URL, error) { + invalidChars := regexp.MustCompile(`[ %]+`) + if invalidChars.MatchString(rawPath) { + return nil, errors.New("url must not contain spaces or %") + } + + for _, r := range rawPath { + if !unicode.IsGraphic(r) { + return nil, errors.New("url contains non-graphic runes") + } + } + + if _, err := url.ParseRequestURI(rawPath); err != nil { + return nil, err + } + + u, err := url.Parse(rawPath) + if err != nil { + return nil, err + } + + if u.Scheme == "" { + return nil, errors.New("blank url scheme") + } + + return u, nil +} + +func assertInvocationResponseAccepted(t *testing.T, responseRecorder *httptest.ResponseRecorder, + flowTest *testdata.FlowTest, responseBody []byte, contentType []byte) { + assert.Equal(t, http.StatusAccepted, responseRecorder.Code, + "Handler returned wrong status code: got %v expected %v", + responseRecorder.Code, http.StatusAccepted) + + expectedAPIResponse := "{\"status\":\"OK\"}\n" + body, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + assert.JSONEq(t, expectedAPIResponse, string(body)) + + response := flowTest.InteropServer.Response + assert.NotNil(t, response) + assert.Nil(t, flowTest.InteropServer.ErrorResponse) + + assert.Equal(t, string(contentType), flowTest.InteropServer.ResponseContentType) + + assert.Equal(t, responseBody, response, + "Persisted response data in app context must match the submitted.") +} + +func assertInvocationResponseTooLarge(t *testing.T, responseRecorder *httptest.ResponseRecorder, flowTest *testdata.FlowTest, responseBody []byte) { + assert.Equal(t, http.StatusRequestEntityTooLarge, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", + responseRecorder.Code, http.StatusRequestEntityTooLarge) + + expectedAPIResponse := fmt.Sprintf("{\"errorMessage\":\"Exceeded maximum allowed payload size (%d bytes).\",\"errorType\":\"RequestEntityTooLarge\"}\n", interop.MaxPayloadSize) + body, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + assert.JSONEq(t, expectedAPIResponse, string(body)) + + errorResponse := flowTest.InteropServer.ErrorResponse + assert.NotNil(t, errorResponse) + assert.Nil(t, flowTest.InteropServer.Response) + assert.Equal(t, fatalerror.FunctionOversizedResponse, errorResponse.FunctionError.Type) + assert.Equal(t, fmt.Sprintf("Response payload size (%v bytes) exceeded maximum allowed payload size (6291556 bytes).", len(responseBody)), errorResponse.FunctionError.Message) + + var errorPayload map[string]interface{} + assert.NoError(t, json.Unmarshal(errorResponse.Payload, &errorPayload)) + assert.Equal(t, string(errorResponse.FunctionError.Type), errorPayload["errorType"]) + assert.Equal(t, errorResponse.FunctionError.Message, errorPayload["errorMessage"]) +} + +func assertErrorResponsePersists(t *testing.T, errorBody []byte, errTypeHeader []byte, flowTest *testdata.FlowTest) { + errorResponse := flowTest.InteropServer.ErrorResponse + assert.NotNil(t, errorResponse) + assert.Nil(t, flowTest.InteropServer.Response) + + var runtimeFunctionErr runtimeFunctionErrStruct + var expectedErrMsg string + + // If input payload is a valid function error json object, + // assert that the error message persisted in the response + err := json.Unmarshal(errorBody, &runtimeFunctionErr) + if err != nil { + expectedErrMsg = runtimeFunctionErr.ErrorMessage + } + assert.Equal(t, expectedErrMsg, errorResponse.FunctionError.Message) + + // If input error type is valid (within the allow-listed value, + // assert that the error type persisted in the response + expectedErrType := fatalerror.GetValidRuntimeOrFunctionErrorType(string(errTypeHeader)) + assert.Equal(t, expectedErrType, errorResponse.FunctionError.Type) + + assert.Equal(t, errorBody, errorResponse.Payload) +} + +func isValidResponseMode(responseMode []byte) bool { + responseModeStr := string(responseMode) + return responseModeStr == "streaming" || + responseModeStr == "" +} + +func assertExpectedPathResponseCode(t *testing.T, code int, target string) { + if !(code == http.StatusOK || + code == http.StatusAccepted || + code == http.StatusForbidden) { + t.Errorf("Unexpected status code (%v) for target (%v)", code, target) + } +} + +func assertUnexpectedPathResponseCode(t *testing.T, code int, target string) { + if !(code == http.StatusNotFound || + code == http.StatusMethodNotAllowed || + code == http.StatusBadRequest) { + t.Errorf("Unexpected status code (%v) for target (%v)", code, target) + } +} diff --git a/lambda/rapi/rendering/render_error.go b/lambda/rapi/rendering/render_error.go new file mode 100644 index 0000000..151e606 --- /dev/null +++ b/lambda/rapi/rendering/render_error.go @@ -0,0 +1,88 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rendering + +import ( + "fmt" + "net/http" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi/model" +) + +// RenderForbiddenWithTypeMsg method for rendering error response +func RenderForbiddenWithTypeMsg(w http.ResponseWriter, r *http.Request, errorType string, format string, args ...interface{}) { + if err := RenderJSON(http.StatusForbidden, w, r, &model.ErrorResponse{ + ErrorType: errorType, + ErrorMessage: fmt.Sprintf(format, args...), + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInternalServerError method for rendering error response +func RenderInternalServerError(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusInternalServerError, w, r, &model.ErrorResponse{ + ErrorMessage: "Internal Server Error", + ErrorType: ErrorTypeInternalServerError, + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderRequestEntityTooLarge method for rendering error response +func RenderRequestEntityTooLarge(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusRequestEntityTooLarge, w, r, &model.ErrorResponse{ + ErrorMessage: fmt.Sprintf("Exceeded maximum allowed payload size (%d bytes).", interop.MaxPayloadSize), + ErrorType: ErrorTypeRequestEntityTooLarge, + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderTruncatedHTTPRequestError method for rendering error response +func RenderTruncatedHTTPRequestError(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "HTTP request detected as truncated", + ErrorType: ErrorTypeTruncatedHTTPRequest, + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInvalidRequestID renders invalid request ID error response +func RenderInvalidRequestID(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "Invalid request ID", + ErrorType: "InvalidRequestID", + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInvalidFunctionResponseMode renders invalid function response mode response +func RenderInvalidFunctionResponseMode(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "Invalid function response mode", + ErrorType: "InvalidFunctionResponseMode", + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInteropError is a convenience method for interpreting interop errors +func RenderInteropError(writer http.ResponseWriter, request *http.Request, err error) { + if err == interop.ErrInvalidInvokeID || err == interop.ErrResponseSent { + RenderInvalidRequestID(writer, request) + } else { + log.Panic(err) + } +} diff --git a/lambda/rapi/rendering/render_json.go b/lambda/rapi/rendering/render_json.go index 8cea816..1afbfe8 100644 --- a/lambda/rapi/rendering/render_json.go +++ b/lambda/rapi/rendering/render_json.go @@ -6,8 +6,9 @@ package rendering import ( "bytes" "encoding/json" - log "github.com/sirupsen/logrus" "net/http" + + log "github.com/sirupsen/logrus" ) // RenderJSON: @@ -15,6 +16,7 @@ import ( // - sets the Content-Type as application/json // - sets the HTTP response status code // - returns an error if it occurred before writing to response +// TODO: r *http.Request is not used, remove it func RenderJSON(status int, w http.ResponseWriter, r *http.Request, v interface{}) error { buf := &bytes.Buffer{} enc := json.NewEncoder(buf) diff --git a/lambda/rapi/rendering/rendering.go b/lambda/rapi/rendering/rendering.go index 0edfb68..9a9d77b 100644 --- a/lambda/rapi/rendering/rendering.go +++ b/lambda/rapi/rendering/rendering.go @@ -4,10 +4,10 @@ package rendering import ( + "bytes" "context" "encoding/json" "errors" - "fmt" "io" "net/http" "strconv" @@ -50,6 +50,13 @@ type EventRenderingService struct { currentState RendererState } +// NewRenderingService returns new EventRenderingService. +func NewRenderingService() *EventRenderingService { + return &EventRenderingService{ + mutex: &sync.RWMutex{}, + } +} + // SetRenderer set current state func (s *EventRenderingService) SetRenderer(state RendererState) { s.mutex.Lock() @@ -77,11 +84,19 @@ func (s *EventRenderingService) RenderRuntimeEvent(w http.ResponseWriter, r *htt return s.currentState.RenderRuntimeEvent(w, r) } -// NewRenderingService returns new EventRenderingService. -func NewRenderingService() *EventRenderingService { - return &EventRenderingService{ - mutex: &sync.RWMutex{}, - } +type RestoreRenderer struct{} + +func NewRestoreRenderer() *RestoreRenderer { + return &RestoreRenderer{} +} + +func (s *RestoreRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { + writer.WriteHeader(http.StatusOK) + return nil +} + +func (s *RestoreRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { + return nil } // InvokeRendererMetrics contains metrics of invoke request @@ -94,17 +109,26 @@ type InvokeRendererMetrics struct { type InvokeRenderer struct { ctx context.Context invoke *interop.Invoke - tracingHeaderParser func(context.Context, *interop.Invoke) string - requestBuffer []byte + tracingHeaderParser func(context.Context) string + requestBuffer *bytes.Buffer requestMutex sync.Mutex metrics InvokeRendererMetrics } -type RestoreRenderer struct { +// NewInvokeRenderer returns new invoke event renderer +func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, requestBuffer *bytes.Buffer, traceParser func(context.Context) string) *InvokeRenderer { + requestBuffer.Reset() // clear request buffer, since this can be reused across invokes + return &InvokeRenderer{ + invoke: invoke, + ctx: ctx, + tracingHeaderParser: traceParser, + requestBuffer: requestBuffer, + requestMutex: sync.Mutex{}, + } } -// NewAgentInvokeEvent forms a new AgentInvokeEvent from INVOKE request -func NewAgentInvokeEvent(req *interop.Invoke) (*model.AgentInvokeEvent, error) { +// newAgentInvokeEvent forms a new AgentInvokeEvent from INVOKE request +func newAgentInvokeEvent(req *interop.Invoke) (*model.AgentInvokeEvent, error) { deadlineMono, err := strconv.ParseInt(req.DeadlineNs, 10, 64) if err != nil { return nil, err @@ -123,7 +147,7 @@ func NewAgentInvokeEvent(req *interop.Invoke) (*model.AgentInvokeEvent, error) { // RenderAgentEvent renders invoke event json for agent. func (s *InvokeRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { - event, err := NewAgentInvokeEvent(s.invoke) + event, err := newAgentInvokeEvent(s.invoke) if err != nil { return err } @@ -133,7 +157,11 @@ func (s *InvokeRenderer) RenderAgentEvent(writer http.ResponseWriter, request *h return err } - renderAgentInvokeHeaders(writer, uuid.New()) // TODO: check this thing + eventID := uuid.New() + headers := writer.Header() + headers.Set("Lambda-Extension-Event-Identifier", eventID.String()) + headers.Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) if _, err := writer.Write(bytes); err != nil { return err @@ -145,13 +173,13 @@ func (s *InvokeRenderer) bufferInvokeRequest() error { s.requestMutex.Lock() defer s.requestMutex.Unlock() var err error = nil - if nil == s.requestBuffer { + if s.requestBuffer.Len() == 0 { reader := io.LimitReader(s.invoke.Payload, interop.MaxPayloadSize) start := time.Now() - s.requestBuffer, err = io.ReadAll(reader) + _, err = s.requestBuffer.ReadFrom(reader) s.metrics = InvokeRendererMetrics{ ReadTime: time.Since(start), - SizeBytes: len(s.requestBuffer), + SizeBytes: s.requestBuffer.Len(), } } return err @@ -160,7 +188,7 @@ func (s *InvokeRenderer) bufferInvokeRequest() error { // RenderRuntimeEvent renders invoke payload for runtime. func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { invoke := s.invoke - customerTraceID := s.tracingHeaderParser(s.ctx, s.invoke) + customerTraceID := s.tracingHeaderParser(s.ctx) cognitoIdentityJSON := "" if len(invoke.CognitoIdentityID) != 0 || len(invoke.CognitoIdentityPoolID) != 0 { @@ -189,37 +217,13 @@ func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request if err := s.bufferInvokeRequest(); err != nil { return err } - _, err := writer.Write(s.requestBuffer) + _, err := writer.Write(s.requestBuffer.Bytes()) return err } return nil } -func (s *RestoreRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { - writer.WriteHeader(http.StatusOK) - return nil -} - -func (s *RestoreRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { - return nil -} - -// NewInvokeRenderer returns new invoke event renderer -func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, traceParser func(context.Context, *interop.Invoke) string) *InvokeRenderer { - return &InvokeRenderer{ - invoke: invoke, - ctx: ctx, - tracingHeaderParser: traceParser, - requestBuffer: nil, - requestMutex: sync.Mutex{}, - } -} - -func NewRestoreRenderer() *RestoreRenderer { - return &RestoreRenderer{} -} - func (s *InvokeRenderer) GetMetrics() InvokeRendererMetrics { s.requestMutex.Lock() defer s.requestMutex.Unlock() @@ -248,22 +252,15 @@ func (s *ShutdownRenderer) RenderRuntimeEvent(w http.ResponseWriter, r *http.Req panic("We should SIGTERM runtime") } -func setHeaderIfNotEmpty(headers http.Header, key string, value string) { - if len(value) != 0 { - headers.Set(key, value) - } -} +func renderInvokeHeaders(writer http.ResponseWriter, invokeID string, customerTraceID string, clientContext string, + cognitoIdentity string, invokedFunctionArn string, deadlineMs string, contentType string) { -func setHeaderOrDefault(headers http.Header, key, val, defaultVal string) { - if val == "" { - headers.Set(key, defaultVal) - return + setHeaderIfNotEmpty := func(headers http.Header, key string, value string) { + if value != "" { + headers.Set(key, value) + } } - headers.Set(key, val) -} -func renderInvokeHeaders(writer http.ResponseWriter, invokeID string, customerTraceID string, clientContext string, - cognitoIdentity string, invokedFunctionArn string, deadlineMs string, contentType string) { headers := writer.Header() setHeaderIfNotEmpty(headers, "Lambda-Runtime-Aws-Request-Id", invokeID) setHeaderIfNotEmpty(headers, "Lambda-Runtime-Trace-Id", customerTraceID) @@ -271,7 +268,10 @@ func renderInvokeHeaders(writer http.ResponseWriter, invokeID string, customerTr setHeaderIfNotEmpty(headers, "Lambda-Runtime-Cognito-Identity", cognitoIdentity) setHeaderIfNotEmpty(headers, "Lambda-Runtime-Invoked-Function-Arn", invokedFunctionArn) setHeaderIfNotEmpty(headers, "Lambda-Runtime-Deadline-Ms", deadlineMs) - setHeaderOrDefault(headers, "Content-Type", contentType, "application/json") + if contentType == "" { + contentType = "application/json" + } + headers.Set("Content-Type", contentType) writer.WriteHeader(http.StatusOK) } @@ -290,79 +290,6 @@ func RenderRuntimeLogsResponse(w http.ResponseWriter, respBody []byte, status in return err } -func renderAgentInvokeHeaders(writer http.ResponseWriter, eventID uuid.UUID) { - headers := writer.Header() - headers.Set("Lambda-Extension-Event-Identifier", eventID.String()) - headers.Set("Content-Type", "application/json") - writer.WriteHeader(http.StatusOK) -} - -// RenderForbiddenWithTypeMsg method for rendering error response -func RenderForbiddenWithTypeMsg(w http.ResponseWriter, r *http.Request, errorType string, format string, args ...interface{}) { - if err := RenderJSON(http.StatusForbidden, w, r, &model.ErrorResponse{ - ErrorType: errorType, - ErrorMessage: fmt.Sprintf(format, args...), - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderInternalServerError method for rendering error response -func RenderInternalServerError(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusInternalServerError, w, r, &model.ErrorResponse{ - ErrorMessage: "Internal Server Error", - ErrorType: ErrorTypeInternalServerError, - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderRequestEntityTooLarge method for rendering error response -func RenderRequestEntityTooLarge(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusRequestEntityTooLarge, w, r, &model.ErrorResponse{ - ErrorMessage: fmt.Sprintf("Exceeded maximum allowed payload size (%d bytes).", interop.MaxPayloadSize), - ErrorType: ErrorTypeRequestEntityTooLarge, - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderTruncatedHTTPRequestError method for rendering error response -func RenderTruncatedHTTPRequestError(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ - ErrorMessage: "HTTP request detected as truncated", - ErrorType: ErrorTypeTruncatedHTTPRequest, - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderInvalidRequestID renders invalid request ID error response -func RenderInvalidRequestID(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ - ErrorMessage: "Invalid request ID", - ErrorType: "InvalidRequestID", - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderInvalidFunctionResponseMode renders invalid function response mode response -func RenderInvalidFunctionResponseMode(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ - ErrorMessage: "Invalid function response mode", - ErrorType: "InvalidFunctionResponseMode", - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - // RenderAccepted method for rendering accepted status response func RenderAccepted(w http.ResponseWriter, r *http.Request) { if err := RenderJSON(http.StatusAccepted, w, r, &model.StatusResponse{ @@ -372,12 +299,3 @@ func RenderAccepted(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) } } - -// RenderInteropError is a convenience method for interpreting interop errors -func RenderInteropError(writer http.ResponseWriter, request *http.Request, err error) { - if err == interop.ErrInvalidInvokeID || err == interop.ErrResponseSent { - RenderInvalidRequestID(writer, request) - } else { - log.Panic(err) - } -} diff --git a/lambda/rapi/router.go b/lambda/rapi/router.go index 5c2a56d..dc036bc 100644 --- a/lambda/rapi/router.go +++ b/lambda/rapi/router.go @@ -19,7 +19,7 @@ import ( // NewRouter returns a new instance of chi router implementing // Runtime API specification. -func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, eventsAPI telemetry.EventsAPI) http.Handler { +func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { router := chi.NewRouter() router.Use(middleware.AppCtxMiddleware(appCtx)) @@ -45,11 +45,11 @@ func NewRouter(appCtx appctx.ApplicationContext, registrationService core.Regist middleware.AwsRequestIDValidator( handler.NewInvocationErrorHandler(registrationService)).ServeHTTP) - router.Post("/runtime/init/error", - handler.NewInitErrorHandler(registrationService, eventsAPI).ServeHTTP) + router.Post("/runtime/init/error", handler.NewInitErrorHandler(registrationService).ServeHTTP) if appctx.LoadInitType(appCtx) == appctx.InitCaching { router.Get("/runtime/restore/next", handler.NewRestoreNextHandler(registrationService, renderingService).ServeHTTP) + router.Post("/runtime/restore/error", handler.NewRestoreErrorHandler(registrationService).ServeHTTP) } return router diff --git a/lambda/rapi/router_test.go b/lambda/rapi/router_test.go index 73cbde1..276fa53 100644 --- a/lambda/rapi/router_test.go +++ b/lambda/rapi/router_test.go @@ -69,7 +69,7 @@ func assertResponseErrorType(t *testing.T, expectedErrorType string, response *h func TestAcceptXML(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) responseRecorder := httptest.NewRecorder() request := httptest.NewRequest("POST", "/runtime/invocation/x-y-z/error", bytes.NewReader([]byte(""))) // Tell server that client side accepts "application/xml". @@ -90,7 +90,7 @@ func TestAcceptXML(t *testing.T) { func Test404PageNotFound(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/unsupported", bytes.NewReader([]byte("")))) assert.Equal(t, http.StatusNotFound, responseRecorder.Code) assert.Equal(t, "404 page not found\n", responseRecorder.Body.String()) @@ -99,7 +99,7 @@ func Test404PageNotFound(t *testing.T) { func Test405MethodNotAllowed(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("DELETE", "/runtime/invocation/ABC/error", bytes.NewReader([]byte("")))) assert.Equal(t, http.StatusMethodNotAllowed, responseRecorder.Code) } @@ -107,7 +107,7 @@ func Test405MethodNotAllowed(t *testing.T) { func TestInitErrorAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/init/error", bytes.NewReader([]byte("{}")))) assert.Equal(t, http.StatusAccepted, responseRecorder.Code) } @@ -115,7 +115,7 @@ func TestInitErrorAccepted(t *testing.T) { func TestInitErrorForbidden(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -126,7 +126,7 @@ func TestInitErrorForbidden(t *testing.T) { func TestInvokeResponseAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -137,7 +137,7 @@ func TestInvokeResponseAccepted(t *testing.T) { func TestInvokeErrorResponseAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -148,7 +148,7 @@ func TestInvokeErrorResponseAccepted(t *testing.T) { func TestInvokeNextTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -159,7 +159,7 @@ func TestInvokeNextTwice(t *testing.T) { func TestInvokeResponseInvalidRequestID(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -171,7 +171,7 @@ func TestInvokeResponseInvalidRequestID(t *testing.T) { func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -183,7 +183,7 @@ func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { func TestInvokeResponseTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -197,7 +197,7 @@ func TestInvokeResponseTwice(t *testing.T) { func TestInvokeErrorResponseTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -211,7 +211,7 @@ func TestInvokeErrorResponseTwice(t *testing.T) { func TestInvokeResponseAfterErrorResponse(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -225,7 +225,7 @@ func TestInvokeResponseAfterErrorResponse(t *testing.T) { func TestInvokeErrorResponseAfterResponse(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -239,7 +239,7 @@ func TestInvokeErrorResponseAfterResponse(t *testing.T) { func TestMoreThanOneInvoke(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) var responseRecorder *httptest.ResponseRecorder for _, id := range []string{"A", "B", "C"} { flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) @@ -253,7 +253,7 @@ func TestMoreThanOneInvoke(t *testing.T) { func TestInitCachingAPIDisabledForPlainInit(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) var responseRecorder *httptest.ResponseRecorder responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/restore/next", nil)) @@ -263,12 +263,13 @@ func TestInitCachingAPIDisabledForPlainInit(t *testing.T) { assert.Equal(t, http.StatusNotFound, responseRecorder.Code) } -func benchmarkInvoke(b *testing.B, payload []byte) { +func benchmarkInvokeResponse(b *testing.B, payload []byte) { b.StopTimer() + b.ResetTimer() // does not restart timer, only resets state b.ReportAllocs() flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) for i := 0; i < b.N; i++ { id := uuid.New().String() flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) @@ -277,30 +278,76 @@ func benchmarkInvoke(b *testing.B, payload []byte) { } } -func BenchmarkInvokeWithEmptyPayload(b *testing.B) { - benchmarkInvoke(b, []byte("")) +func BenchmarkInvokeResponseWithEmptyPayload(b *testing.B) { + benchmarkInvokeResponse(b, []byte("")) } -func BenchmarkInvokeWith4KBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 4*1024)) +func BenchmarkInvokeResponseWith4KBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 4*1024)) } -func BenchmarkInvokeWith512KBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 512*1024)) +func BenchmarkInvokeResponseWith512KBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 512*1024)) } -func BenchmarkInvokeWith1MBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 1*1024*1024)) +func BenchmarkInvokeResponseWith1MBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 1*1024*1024)) } -func BenchmarkInvokeWith2MBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 2*1024*1024)) +func BenchmarkInvokeResponseWith2MBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 2*1024*1024)) } -func BenchmarkInvokeWith4MBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 4*1024*1024)) +func BenchmarkInvokeResponseWith4MBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 4*1024*1024)) } -func BenchmarkInvokeWith6MBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 6*1024*1024)) +func BenchmarkInvokeResponseWith6MBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 6*1024*1024)) +} + +func benchmarkInvokeRequest(b *testing.B, payload []byte) { + b.StopTimer() + b.ResetTimer() // does not restart timer, only resets state + b.ReportAllocs() + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + var requestBuf bytes.Buffer + for i := 0; i < b.N; i++ { + id := uuid.New().String() + ctx, invoke := context.Background(), createInvoke(id) + flowTest.ConfigureForInvoke(ctx, invoke) // set invoke ID and initialize barriers + flowTest.ConfigureInvokeRenderer(ctx, invoke, &requestBuf) // override invoke renderer to reuse buffer + makeBenchRequest(b, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) + makeBenchRequest(b, router, httptest.NewRequest("POST", fmt.Sprintf("/runtime/invocation/%s/response", id), bytes.NewReader(payload))) + } +} + +func BenchmarkInvokeRequestWithEmptyPayload(b *testing.B) { + benchmarkInvokeRequest(b, []byte("")) +} + +func BenchmarkInvokeRequestWith4KBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 4*1024)) +} + +func BenchmarkInvokeRequestWith512KBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 512*1024)) +} + +func BenchmarkInvokeRequestWith1MBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 1*1024*1024)) +} + +func BenchmarkInvokeRequestWith2MBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 2*1024*1024)) +} + +func BenchmarkInvokeRequestWith4MBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 4*1024*1024)) +} + +func BenchmarkInvokeRequestWith6MBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 6*1024*1024)) } diff --git a/lambda/rapi/security_test.go b/lambda/rapi/security_test.go index 5312b43..3f869d5 100644 --- a/lambda/rapi/security_test.go +++ b/lambda/rapi/security_test.go @@ -20,7 +20,7 @@ func TestInvokeValidId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) @@ -53,7 +53,7 @@ func TestSecurityInvokeResponseBadRequestId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) @@ -100,7 +100,7 @@ func TestSecurityInvokeErrorBadRequestId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) diff --git a/lambda/rapi/server.go b/lambda/rapi/server.go index dd027f4..d17270a 100644 --- a/lambda/rapi/server.go +++ b/lambda/rapi/server.go @@ -46,16 +46,22 @@ func SaveConnInContext(ctx context.Context, c net.Conn) context.Context { // should happen before provided runtime is started. // // When port is 0, OS will dynamically allocate the listening port. -func NewServer(host string, port int, appCtx appctx.ApplicationContext, +func NewServer( + host string, + port int, + appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, telemetryAPIEnabled bool, - logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI, credentialsService core.CredentialsService, eventsAPI telemetry.EventsAPI) *Server { + logsSubscriptionAPI telemetry.SubscriptionAPI, + telemetrySubscriptionAPI telemetry.SubscriptionAPI, + credentialsService core.CredentialsService, +) *Server { exitErrors := make(chan error, 1) router := chi.NewRouter() - router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService, eventsAPI)) + router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService)) router.Mount(version20200101, ExtensionsRouter(appCtx, registrationService, renderingService)) if telemetryAPIEnabled { diff --git a/lambda/rapi/telemetry_logs_fuzz_test.go b/lambda/rapi/telemetry_logs_fuzz_test.go new file mode 100644 index 0000000..89adbd1 --- /dev/null +++ b/lambda/rapi/telemetry_logs_fuzz_test.go @@ -0,0 +1,185 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/extensions" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi/handler" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" +) + +const ( + logsHandlerPath = "/logs" + telemetryHandlerPath = "/telemetry" + + samplePayload = `{"foo" : "bar"}` +) + +func FuzzTelemetryLogRouters(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + f.Add(makeTargetURL(logsHandlerPath, version20200815), []byte(samplePayload)) + f.Add(makeTargetURL(telemetryHandlerPath, version20220701), []byte(samplePayload)) + + logsPath := fmt.Sprintf("%s%s", version20200815, logsHandlerPath) + telemetryPath := fmt.Sprintf("%s%s", version20220701, telemetryHandlerPath) + + f.Fuzz(func(t *testing.T, rawPath string, payload []byte) { + u, err := parseToURLStruct(rawPath) + if err != nil { + t.Skipf("error parsing url: %v. Skipping test.", err) + } + + flowTest := testdata.NewFlowTest() + + rapiServer := makeRapiServerWithMockSubscriptionAPI(flowTest, newMockSubscriptionAPI(true), newMockSubscriptionAPI(true)) + + request := httptest.NewRequest("PUT", rawPath, bytes.NewReader(payload)) + responseRecorder := serveTestRequest(rapiServer, request) + + if u.Path == logsPath || u.Path == telemetryPath { + assertExpectedPathResponseCode(t, responseRecorder.Code, rawPath) + } else { + assertUnexpectedPathResponseCode(t, responseRecorder.Code, rawPath) + } + }) +} + +func FuzzLogsHandler(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + fuzzSubscriptionAPIHandler(f, logsHandlerPath, version20200815) +} + +func FuzzTelemetryHandler(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + fuzzSubscriptionAPIHandler(f, telemetryHandlerPath, version20220701) +} + +func fuzzSubscriptionAPIHandler(f *testing.F, handlerPath string, apiVersion string) { + flowTest := testdata.NewFlowTest() + agent := makeExternalAgent(flowTest.RegistrationService) + f.Add([]byte(samplePayload), agent.ID.String(), true) + f.Add([]byte(samplePayload), agent.ID.String(), false) + + f.Fuzz(func(t *testing.T, payload []byte, agentIdentifierHeader string, serviceOn bool) { + telemetrySubscriptionAPI := newMockSubscriptionAPI(serviceOn) + logsSubscriptionAPI := newMockSubscriptionAPI(serviceOn) + rapiServer := makeRapiServerWithMockSubscriptionAPI(flowTest, logsSubscriptionAPI, telemetrySubscriptionAPI) + + apiUnderTest := telemetrySubscriptionAPI + if handlerPath == logsHandlerPath { + apiUnderTest = logsSubscriptionAPI + } + + target := makeTargetURL(handlerPath, apiVersion) + request := httptest.NewRequest("PUT", target, bytes.NewReader(payload)) + request.Header.Set(handler.LambdaAgentIdentifier, agentIdentifierHeader) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentIdentifierHeader == "" { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierMissing) + return + } + + if _, err := uuid.Parse(agentIdentifierHeader); err != nil { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierInvalid) + return + } + + if agentIdentifierHeader != agent.ID.String() { + assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") + return + } + + if !serviceOn { + assertForbiddenErrorType(t, responseRecorder, apiUnderTest.GetServiceClosedErrorType()) + return + } + + // assert that payload has been stored in the mock subscription api after the handler calls Subscribe() + assert.Equal(t, payload, apiUnderTest.receivedPayload) + }) +} + +func makeRapiServerWithMockSubscriptionAPI( + flowTest *testdata.FlowTest, + logsSubscription telemetry.SubscriptionAPI, + telemetrySubscription telemetry.SubscriptionAPI) *Server { + return NewServer( + "127.0.0.1", + 0, + flowTest.AppCtx, + flowTest.RegistrationService, + flowTest.RenderingService, + true, + logsSubscription, + telemetrySubscription, + flowTest.CredentialsService, + ) +} + +type mockSubscriptionAPI struct { + serviceOn bool + receivedPayload []byte +} + +func newMockSubscriptionAPI(serviceOn bool) *mockSubscriptionAPI { + return &mockSubscriptionAPI{ + serviceOn: serviceOn, + } +} + +func (m *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { + if !m.serviceOn { + return nil, 0, map[string][]string{}, telemetry.ErrTelemetryServiceOff + } + + bodyBytes, err := io.ReadAll(body) + if err != nil { + return nil, 0, map[string][]string{}, fmt.Errorf("error Reading the body of subscription request: %s", err) + } + + m.receivedPayload = bodyBytes + + return []byte("OK"), http.StatusOK, map[string][]string{}, nil +} + +func (m *mockSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} + +func (m *mockSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { + return nil +} + +func (m *mockSubscriptionAPI) Clear() {} + +func (m *mockSubscriptionAPI) TurnOff() {} + +func (m *mockSubscriptionAPI) GetEndpointURL() string { + return "/subscribe" +} + +func (m *mockSubscriptionAPI) GetServiceClosedErrorMessage() string { + return "Subscription API is closed" +} + +func (m *mockSubscriptionAPI) GetServiceClosedErrorType() string { + return "SubscriptionClosed" +} diff --git a/lambda/rapid/exit.go b/lambda/rapid/exit.go index e45f3a4..a601efc 100644 --- a/lambda/rapid/exit.go +++ b/lambda/rapid/exit.go @@ -4,31 +4,22 @@ package rapid import ( - "fmt" "time" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/model" + "go.amzn.com/lambda/telemetry" log "github.com/sirupsen/logrus" ) func handleInvokeError(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { invokeFailure := newInvokeFailureMsg(execCtx, invokeRequest, invokeMx, err) - resp := model.ErrorResponse{ - ErrorType: string(invokeFailure.ErrorType), - ErrorMessage: fmt.Sprintf("Error: %v", invokeFailure.ErrorMessage), - } - - if invokeRequest.ID != "" { - resp.ErrorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeRequest.ID, invokeFailure.ErrorMessage) - } // This is the default error response that gets sent back as the function response in failure cases - invokeFailure.DefaultErrorResponse = resp.AsInteropError() + invokeFailure.DefaultErrorResponse = interop.GetErrorResponseWithFormattedErrorMessage(invokeFailure.ErrorType, invokeFailure.ErrorMessage, invokeRequest.ID) // Invoke with extensions disabled maintains behaviour parity with pre-extensions rapid if !extensions.AreEnabled() { @@ -50,7 +41,7 @@ func handleInvokeError(execCtx *rapidContext, invokeRequest *interop.Invoke, inv func newInvokeFailureMsg(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) if !found { - errorType = fatalerror.Unknown + errorType = fatalerror.SandboxFailure } invokeFailure := &interop.InvokeFailure{ @@ -64,6 +55,7 @@ func newInvokeFailureMsg(execCtx *rapidContext, invokeRequest *interop.Invoke, i } if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeFailure.ResponseMetrics.RuntimeResponseLatencyMs = telemetry.CalculateDuration(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics.StartReadingResponseMonoTimeMs) invokeFailure.ResponseMetrics.RuntimeTimeThrottledMs = invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) invokeFailure.ResponseMetrics.RuntimeProducedBytes = invokeRequest.InvokeResponseMetrics.ProducedBytes invokeFailure.ResponseMetrics.RuntimeOutboundThroughputBps = invokeRequest.InvokeResponseMetrics.OutboundThroughputBps @@ -80,13 +72,15 @@ func newInvokeFailureMsg(execCtx *rapidContext, invokeRequest *interop.Invoke, i invokeFailure.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } + invokeFailure.InvokeResponseMode = invokeRequest.InvokeResponseMode + return invokeFailure } func generateInitFailureMsg(execCtx *rapidContext, err error) interop.InitFailure { errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) if !found { - errorType = fatalerror.Unknown + errorType = fatalerror.SandboxFailure } initFailureMsg := interop.InitFailure{ diff --git a/lambda/rapid/start.go b/lambda/rapid/handlers.go similarity index 58% rename from lambda/rapid/start.go rename to lambda/rapid/handlers.go index 76337af..f379c4c 100644 --- a/lambda/rapid/start.go +++ b/lambda/rapid/handlers.go @@ -5,6 +5,7 @@ package rapid import ( + "bytes" "context" "errors" "fmt" @@ -22,22 +23,20 @@ import ( "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi" - "go.amzn.com/lambda/rapi/model" "go.amzn.com/lambda/rapi/rendering" + "go.amzn.com/lambda/rapidcore/env" supvmodel "go.amzn.com/lambda/supervisor/model" "go.amzn.com/lambda/telemetry" "github.com/google/uuid" - log "github.com/sirupsen/logrus" ) const ( - RuntimeDomain = "runtime" - OperatorDomain = "operator" - defaultAgentLocation = "/opt/extensions" - disableExtensionsFile = "/opt/disable-extensions-jwigqn8j" - runtimeProcessName = "runtime" + RuntimeDomain = "runtime" + OperatorDomain = "operator" + defaultAgentLocation = "/opt/extensions" + runtimeProcessName = "runtime" ) const ( @@ -48,14 +47,17 @@ const ( var errResetReceived = errors.New("errResetReceived") +type processSupervisor struct { + supvmodel.ProcessSupervisor + RootPath string +} + type rapidContext struct { interopServer interop.Server server *rapi.Server appCtx appctx.ApplicationContext - preLoadTimeNs int64 - postLoadTimeNs int64 initDone bool - supervisor supvmodel.Supervisor + supervisor processSupervisor runtimeDomainGeneration uint32 initFlow core.InitFlowSynchronization invokeFlow core.InvokeFlowSynchronization @@ -67,12 +69,16 @@ type rapidContext struct { logsEgressAPI telemetry.StdLogsEgressAPI xray telemetry.Tracer standaloneMode bool - eventsAPI telemetry.EventsAPI + eventsAPI interop.EventsAPI initCachingEnabled bool credentialsService core.CredentialsService - signalCtx context.Context - executionMutex sync.Mutex + handlerExecutionMutex sync.Mutex shutdownContext *shutdownContext + logStreamName string + + RuntimeStartedTime int64 + RuntimeOverheadStartedTime int64 + InvokeResponseMetrics *interop.InvokeResponseMetrics } // Validate interface compliance @@ -105,7 +111,13 @@ func (c *rapidContext) GetExtensionNames() string { func logAgentsInitStatus(execCtx *rapidContext) { for _, agent := range execCtx.registrationService.AgentsInfo() { - execCtx.eventsAPI.SendExtensionInit(agent.Name, agent.State, agent.ErrorType, agent.Subscriptions) + extensionInitData := interop.ExtensionInitData{ + AgentName: agent.Name, + State: agent.State, + ErrorType: agent.ErrorType, + Subscriptions: agent.Subscriptions, + } + execCtx.eventsAPI.SendExtensionInit(extensionInitData) } } @@ -116,7 +128,7 @@ func agentLaunchError(agent *core.ExternalAgent, appCtx appctx.ApplicationContex appctx.StoreFirstFatalError(appCtx, fatalerror.AgentLaunchError) } -func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, env interop.EnvironmentVariables) error { +func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, env *env.Environment) error { initFlow := execCtx.registrationService.InitFlow() // we don't bring it into the loop below because we don't want unnecessary broadcasts on agent gate @@ -127,7 +139,6 @@ func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, for _, agentPath := range agentPaths { // Using path.Base(agentPath) not agentName because the agent name is contact, as standalone can get the internal state. agent, err := execCtx.registrationService.CreateExternalAgent(path.Base(agentPath)) - if err != nil { return err } @@ -140,21 +151,27 @@ func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, env := env.AgentExecEnv() agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() - if err != nil { return err } agentName := fmt.Sprintf("extension-%s-%d", path.Base(agentPath), execCtx.runtimeDomainGeneration) - err = execCtx.supervisor.Exec(&supvmodel.ExecRequest{ - Domain: domain, - Name: agentName, - Path: agentPath, - Env: &env, + err = execCtx.supervisor.Exec(context.Background(), &supvmodel.ExecRequest{ + Domain: domain, + Name: agentName, + Path: agentPath, + Env: &env, + Logging: supvmodel.Logging{ + Managed: supvmodel.ManagedLogging{ + Topic: supvmodel.RtExtensionManagedLoggingTopic, + Formats: []supvmodel.ManagedLoggingFormat{ + supvmodel.LineBasedManagedLogging, + }, + }, + }, StdoutWriter: agentStdoutWriter, StderrWriter: agentStderrWriter, }) - if err != nil { agentLaunchError(agent, execCtx.appCtx, err) return err @@ -177,7 +194,7 @@ func doRuntimeBootstrap(execCtx *rapidContext, sbInfoFromInit interop.SandboxInf if err != nil { if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.eventsAPI.SendImageErrorLog(formattedLog) + execCtx.eventsAPI.SendImageErrorLog(interop.ImageErrorLogData(formattedLog)) } else { appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) } @@ -189,7 +206,7 @@ func doRuntimeBootstrap(execCtx *rapidContext, sbInfoFromInit interop.SandboxInf if err != nil { if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.eventsAPI.SendImageErrorLog(formattedLog) + execCtx.eventsAPI.SendImageErrorLog(interop.ImageErrorLogData(formattedLog)) } else { appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidWorkingDir) } @@ -201,95 +218,69 @@ func doRuntimeBootstrap(execCtx *rapidContext, sbInfoFromInit interop.SandboxInf return bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, nil } -func (c *rapidContext) setupEventsWatcher(events <-chan supvmodel.Event) { - go func() { - for event := range events { - var err error = nil - log.Debugf("The events handler received the event %+v.", event) - if loss := event.Event.EventLoss(); loss != nil { - log.Panicf("Lost %d events from supervisor", *loss) - } - termination := event.Event.ProcessTerminated() - - // If we are not shutting down then we care if an unexpected exit happens. - if !c.shutdownContext.isShuttingDown() { - runtimeProcessName := fmt.Sprintf("%s-%d", runtimeProcessName, c.runtimeDomainGeneration) - - // If event from the runtime. - if *termination.Name == runtimeProcessName { - if termination.Success() { - err = fmt.Errorf("Runtime exited without providing a reason") - } else { - err = fmt.Errorf("Runtime exited with error: %s", termination.String()) - } - appctx.StoreFirstFatalError(c.appCtx, fatalerror.RuntimeExit) - } else { - if termination.Success() { - err = fmt.Errorf("exit code 0") - } else { - err = fmt.Errorf(termination.String()) - } +func (c *rapidContext) watchEvents(events <-chan supvmodel.Event) { + for event := range events { + var err error + log.Debugf("The events handler received the event %+v.", event) + if loss := event.Event.EventLoss(); loss != nil { + log.Panicf("Lost %d events from supervisor", *loss) + } + termination := event.Event.ProcessTerminated() - appctx.StoreFirstFatalError(c.appCtx, fatalerror.AgentCrash) + // If we are not shutting down then we care if an unexpected exit happens. + if !c.shutdownContext.isShuttingDown() { + runtimeProcessName := fmt.Sprintf("%s-%d", runtimeProcessName, c.runtimeDomainGeneration) + + // If event from the runtime. + if *termination.Name == runtimeProcessName { + if termination.Success() { + err = fmt.Errorf("Runtime exited without providing a reason") + } else { + err = fmt.Errorf("Runtime exited with error: %s", termination.String()) + } + appctx.StoreFirstFatalError(c.appCtx, fatalerror.RuntimeExit) + } else { + if termination.Success() { + err = fmt.Errorf("exit code 0") + } else { + err = fmt.Errorf(termination.String()) } - log.Warnf("Process %s exited: %+v", *termination.Name, termination) + appctx.StoreFirstFatalError(c.appCtx, fatalerror.AgentCrash) } - // At the moment we only get termination events. - // When their are other event types then we would need to be selective, - // about what we send to handleShutdownEvent(). - c.shutdownContext.handleProcessExit(*termination) - c.registrationService.CancelFlows(err) + log.Warnf("Process %s exited: %+v", *termination.Name, termination) } - }() -} - -func doOperatorDomainInit(ctx context.Context, execCtx *rapidContext, operatorDomainExtraConfig interop.DynamicDomainConfig) error { - events, err := execCtx.supervisor.Events() - if err != nil { - log.WithError(err).Panic("Could not get events stream from supervsior") - } - execCtx.setupEventsWatcher(events) - - log.Info("Configuring and starting Operator Domain") - conf := operatorDomainExtraConfig - err = execCtx.supervisor.Configure(&supvmodel.ConfigureRequest{ - Domain: OperatorDomain, - AdditionalStartHooks: conf.AdditionalStartHooks, - Mounts: conf.Mounts, - }) - - if err != nil { - log.WithError(err).Error("Failed to configure operator domain") - return err - } - - err = execCtx.supervisor.Start(&supvmodel.StartRequest{ - Domain: OperatorDomain, - }) - if err != nil { - log.WithError(err).Error("Failed to start operator domain") - return err + // At the moment we only get termination events. + // When their are other event types then we would need to be selective, + // about what we send to handleShutdownEvent(). + c.shutdownContext.handleProcessExit(*termination) + c.registrationService.CancelFlows(err) } +} - // we configure the runtime domain only once and not at - // every init phase (e.g., suppressed or reset). - err = execCtx.supervisor.Configure(&supvmodel.ConfigureRequest{ +// subscribe to /events for runtime domain in supervisor +func setupEventsWatcher(execCtx *rapidContext) error { + eventsRequest := supvmodel.EventsRequest{ Domain: RuntimeDomain, - }) + } + events, err := execCtx.supervisor.Events(context.Background(), &eventsRequest) if err != nil { - log.WithError(err).Error("Failed to configure operator domain") + log.Errorf("Could not get events stream from supervisor: %s", err) return err } + go execCtx.watchEvents(events) return nil - } -func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit) error { +func doRuntimeDomainInit(execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit, phase interop.LifecyclePhase) error { + initStartTime := metering.Monotime() + sendInitStartLogEvent(execCtx, sbInfoFromInit.SandboxType, phase) + defer sendInitReportLogEvent(execCtx, sbInfoFromInit.SandboxType, initStartTime, phase) + execCtx.xray.RecordInitStartTime() defer execCtx.xray.RecordInitEndTime() @@ -299,18 +290,11 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI } }() - log.Info("Starting runtime domain") - err := execCtx.supervisor.Start(&supvmodel.StartRequest{ - Domain: RuntimeDomain, - }) - if err != nil { - log.WithError(err).Panic("Failed configuring runtime domain") - } execCtx.runtimeDomainGeneration++ if extensions.AreEnabled() { runtimeExtensions := agents.ListExternalAgentPaths(defaultAgentLocation, - execCtx.supervisor.RuntimeConfig.RootPath) + execCtx.supervisor.RootPath) if err := doInitExtensions(RuntimeDomain, runtimeExtensions, execCtx, sbInfoFromInit.EnvironmentVariables); err != nil { return err } @@ -328,20 +312,17 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI // runtime is implicitly subscribed for certain lifecycle events. log.Debug("Preregister runtime") registrationService := execCtx.registrationService - err = registrationService.PreregisterRuntime(runtime) - + err := registrationService.PreregisterRuntime(runtime) if err != nil { return err } bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, err := doRuntimeBootstrap(execCtx, sbInfoFromInit) - if err != nil { return err } runtimeStdoutWriter, runtimeStderrWriter, err := execCtx.logsEgressAPI.GetRuntimeSockets() - if err != nil { return err } @@ -349,13 +330,23 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI log.Debug("Start runtime") checkCredentials(execCtx, bootstrapEnv) name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) - err = execCtx.supervisor.Exec(&supvmodel.ExecRequest{ - Domain: RuntimeDomain, - Name: name, - Cwd: &bootstrapCwd, - Path: bootstrapCmd[0], - Args: bootstrapCmd[1:], - Env: &bootstrapEnv, + + err = execCtx.supervisor.Exec(context.Background(), &supvmodel.ExecRequest{ + Domain: RuntimeDomain, + Name: name, + Cwd: &bootstrapCwd, + Path: bootstrapCmd[0], + Args: bootstrapCmd[1:], + Env: &bootstrapEnv, + Logging: supvmodel.Logging{ + Managed: supvmodel.ManagedLogging{ + Topic: supvmodel.RuntimeManagedLoggingTopic, + Formats: []supvmodel.ManagedLoggingFormat{ + supvmodel.LineBasedManagedLogging, + supvmodel.MessageBasedManagedLogging, + }, + }, + }, StdoutWriter: runtimeStdoutWriter, StderrWriter: runtimeStderrWriter, ExtraFiles: &bootstrapExtraFiles, @@ -364,25 +355,25 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI runtimeDoneStatus := telemetry.RuntimeDoneSuccess defer func() { - sendInitRuntimeDoneLogEvent(execCtx, sbInfoFromInit.SandboxType, runtimeDoneStatus) + sendInitRuntimeDoneLogEvent(execCtx, sbInfoFromInit.SandboxType, runtimeDoneStatus, phase) }() if err != nil { if fatalError, formattedLog, hasError := sbInfoFromInit.RuntimeBootstrap.CachedFatalError(err); hasError { appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.eventsAPI.SendImageErrorLog(formattedLog) + execCtx.eventsAPI.SendImageErrorLog(interop.ImageErrorLogData(formattedLog)) } else { appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) } - runtimeDoneStatus = telemetry.RuntimeDoneFailure + runtimeDoneStatus = telemetry.RuntimeDoneError return err } execCtx.shutdownContext.createExitedChannel(name) if err := initFlow.AwaitRuntimeRestoreReady(); err != nil { - runtimeDoneStatus = telemetry.RuntimeDoneFailure + runtimeDoneStatus = telemetry.RuntimeDoneError return err } @@ -396,6 +387,7 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI return err } if err := initFlow.AwaitAgentsReady(); err != nil { + runtimeDoneStatus = telemetry.RuntimeDoneError return err } } @@ -411,25 +403,34 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI return nil } -func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop.Invoke, mx *invokeMetrics, sbInfoFromInit interop.SandboxInfoFromInit) error { - execCtx.eventsAPI.SetCurrentRequestID(invokeRequest.ID) +func doInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, mx *invokeMetrics, sbInfoFromInit interop.SandboxInfoFromInit, requestBuffer *bytes.Buffer) error { + execCtx.eventsAPI.SetCurrentRequestID(interop.RequestID(invokeRequest.ID)) appCtx := execCtx.appCtx xray := execCtx.xray xray.Configure(invokeRequest) + ctx := context.Background() + return xray.CaptureInvokeSegment(ctx, xray.WithErrorCause(ctx, appCtx, func(ctx context.Context) error { + telemetryTracingCtx := xray.BuildTracingCtxForStart() + if !execCtx.initDone { // do inline init if err := xray.CaptureInitSubsegment(ctx, func(ctx context.Context) error { - return doRuntimeDomainInit(ctx, execCtx, sbInfoFromInit) + return doRuntimeDomainInit(execCtx, sbInfoFromInit, interop.LifecyclePhaseInvoke) }); err != nil { + sendInvokeStartLogEvent(execCtx, invokeRequest.ID, telemetryTracingCtx) return err } - } else if sbInfoFromInit.SandboxType != interop.SandboxPreWarmed { + } else if sbInfoFromInit.SandboxType != interop.SandboxPreWarmed && !execCtx.initCachingEnabled { xray.SendInitSubsegmentWithRecordedTimesOnce(ctx) } + xray.SendRestoreSubsegmentWithRecordedTimesOnce(ctx) + + sendInvokeStartLogEvent(execCtx, invokeRequest.ID, telemetryTracingCtx) + invokeFlow := execCtx.invokeFlow log.Debug("Initialize invoke flow barriers") err := invokeFlow.InitializeBarriers() @@ -453,7 +454,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop // Invoke if err := xray.CaptureInvokeSubsegment(ctx, xray.WithError(ctx, appCtx, func(ctx context.Context) error { log.Debug("Set renderer for invoke") - renderer := rendering.NewInvokeRenderer(ctx, invokeRequest, xray.TracingHeaderParser()) + renderer := rendering.NewInvokeRenderer(ctx, invokeRequest, requestBuffer, xray.BuildTracingHeader()) defer func() { mx.rendererMetrics = renderer.GetMetrics() }() @@ -473,6 +474,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop log.Debug("Release runtime condition") //TODO handle Supervisors listening channel + execCtx.SetRuntimeStartedTime(metering.Monotime()) runtime.Release() log.Debug("Await runtime response") //TODO handle Supervisors listening channel @@ -484,6 +486,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop // Runtime overhead if err := xray.CaptureOverheadSubsegment(ctx, func(ctx context.Context) error { log.Debug("Await runtime ready") + execCtx.SetRuntimeOverheadStartedTime(metering.Monotime()) //TODO handle Supervisors listening channel return invokeFlow.AwaitRuntimeReady() }); err != nil { @@ -491,19 +494,21 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop } mx.runtimeReadyTime = metering.Monotime() - runtimeDoneEventData := telemetry.InvokeRuntimeDoneData{ + runtimeDoneEventData := interop.InvokeRuntimeDoneData{ Status: telemetry.RuntimeDoneSuccess, - Metrics: telemetry.GetRuntimeDoneInvokeMetrics(invokeRequest.InvokeReceivedTime, invokeRequest.InvokeResponseMetrics, mx.runtimeReadyTime), + Metrics: telemetry.GetRuntimeDoneInvokeMetrics(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics, mx.runtimeReadyTime), InternalMetrics: invokeRequest.InvokeResponseMetrics, - Tracing: telemetry.BuildTracingCtx(model.XRayTracingType, invokeRequest.TraceID, invokeRequest.LambdaSegmentID), - Spans: telemetry.GetRuntimeDoneSpans(invokeRequest.InvokeReceivedTime, invokeRequest.InvokeResponseMetrics), + Tracing: xray.BuildTracingCtxAfterInvokeComplete(), + Spans: execCtx.eventsAPI.GetRuntimeDoneSpans(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics, execCtx.RuntimeOverheadStartedTime, mx.runtimeReadyTime), } - if err := execCtx.eventsAPI.SendRuntimeDone(runtimeDoneEventData); err != nil { - log.Errorf("Failed to send RUNDONE: %s", err) + log.Info(runtimeDoneEventData.String()) + if err := execCtx.eventsAPI.SendInvokeRuntimeDone(runtimeDoneEventData); err != nil { + log.Errorf("Failed to send INVOKE RTDONE: %s", err) } // Extensions overhead if execCtx.HasActiveExtensions() { + extensionOverheadStartTime := metering.Monotime() execCtx.interopServer.SendRuntimeReady() log.Debug("Await agents ready") //TODO handle Supervisors listening channel @@ -511,18 +516,21 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop log.Warnf("AwaitAgentsReady() = %s", err) return err } + extensionOverheadEndTime := metering.Monotime() + extensionOverheadMsSpan := interop.Span{ + Name: "extensionOverhead", + Start: telemetry.GetEpochTimeInISO8601FormatFromMonotime(extensionOverheadStartTime), + DurationMs: telemetry.CalculateDuration(extensionOverheadStartTime, extensionOverheadEndTime), + } + if err := execCtx.eventsAPI.SendReportSpan(extensionOverheadMsSpan); err != nil { + log.WithError(err).Error("Failed to create REPORT Span") + } } return nil })) } -func extensionsDisabledByLayer() bool { - _, err := os.Stat(disableExtensionsFile) - log.Infof("extensionsDisabledByLayer(%s) -> %s", disableExtensionsFile, err) - return err == nil -} - // acceptInitRequest is a second initialization phase, performed after receiving START // initialized entities: _HANDLER, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN func (c *rapidContext) acceptInitRequest(initRequest *interop.Init) *interop.Init { @@ -535,15 +543,14 @@ func (c *rapidContext) acceptInitRequest(initRequest *interop.Init) *interop.Ini initRequest.FunctionName, initRequest.FunctionVersion) c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - FunctionName: initRequest.FunctionName, - FunctionVersion: initRequest.FunctionVersion, - Handler: initRequest.Handler, - RuntimeInfo: initRequest.RuntimeInfo, + AccountID: initRequest.AccountID, + FunctionName: initRequest.FunctionName, + FunctionVersion: initRequest.FunctionVersion, + InstanceMaxMemory: initRequest.InstanceMaxMemory, + Handler: initRequest.Handler, + RuntimeInfo: initRequest.RuntimeInfo, }) - - if extensionsDisabledByLayer() { - extensions.Disable() - } + c.SetLogStreamName(initRequest.LogStreamName) return initRequest } @@ -568,26 +575,21 @@ func (c *rapidContext) acceptInitRequestForInitCaching(initRequest *interop.Init initCachingToken) c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - FunctionName: initRequest.FunctionName, - FunctionVersion: initRequest.FunctionVersion, - Handler: initRequest.Handler, + AccountID: initRequest.AccountID, + FunctionName: initRequest.FunctionName, + FunctionVersion: initRequest.FunctionVersion, + InstanceMaxMemory: initRequest.InstanceMaxMemory, + Handler: initRequest.Handler, + RuntimeInfo: initRequest.RuntimeInfo, }) + c.SetLogStreamName(initRequest.LogStreamName) c.credentialsService.SetCredentials(initCachingToken, initRequest.AwsKey, initRequest.AwsSecret, initRequest.AwsSession, initRequest.CredentialsExpiry) - if extensionsDisabledByLayer() { - extensions.Disable() - } - return initRequest, nil } -func handleInit(execCtx *rapidContext, initRequest *interop.Init, - initStartedResponse chan<- interop.InitStarted, - initSuccessResponse chan<- interop.InitSuccess, - initFailureResponse chan<- interop.InitFailure) { - ctx := execCtx.signalCtx - +func handleInit(execCtx *rapidContext, initRequest *interop.Init, initSuccessResponse chan<- interop.InitSuccess, initFailureResponse chan<- interop.InitFailure) { if execCtx.initCachingEnabled { var err error if initRequest, err = execCtx.acceptInitRequestForInitCaching(initRequest); err != nil { @@ -600,23 +602,7 @@ func handleInit(execCtx *rapidContext, initRequest *interop.Init, initRequest = execCtx.acceptInitRequest(initRequest) } - initStartedMsg := interop.InitStarted{ - PreLoadTimeNs: execCtx.preLoadTimeNs, - PostLoadTimeNs: execCtx.postLoadTimeNs, - WaitStartTimeNs: execCtx.postLoadTimeNs, - WaitEndTimeNs: metering.Monotime(), - ExtensionsEnabled: extensions.AreEnabled(), - Ack: make(chan struct{}), - } - - initStartedResponse <- initStartedMsg - <-initStartedMsg.Ack - - // Operator domain init happens only once, it's never suppressed, - // and it's terminal in case of failures - if err := doOperatorDomainInit(ctx, execCtx, initRequest.OperatorDomainExtraConfig); err != nil { - // TODO: I believe we need to handle this specially, because we want - // to consider any failure here as terminal + if err := setupEventsWatcher(execCtx); err != nil { handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) return } @@ -628,7 +614,7 @@ func handleInit(execCtx *rapidContext, initRequest *interop.Init, SandboxType: initRequest.SandboxType, RuntimeBootstrap: initRequest.Bootstrap, } - if err := doRuntimeDomainInit(ctx, execCtx, sbInfo); err != nil { + if err := doRuntimeDomainInit(execCtx, sbInfo, interop.LifecyclePhaseInit); err != nil { handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) return } @@ -649,16 +635,18 @@ func handleInit(execCtx *rapidContext, initRequest *interop.Init, <-initSuccessMsg.Ack } -func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { - ctx := execCtx.signalCtx +func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit, requestBuffer *bytes.Buffer, responseSender interop.InvokeResponseSender) (interop.InvokeSuccess, *interop.InvokeFailure) { + appctx.StoreResponseSender(execCtx.appCtx, responseSender) invokeMx := invokeMetrics{} - if err := doInvoke(ctx, execCtx, invokeRequest, &invokeMx, sbInfoFromInit); err != nil { + if err := doInvoke(execCtx, invokeRequest, &invokeMx, sbInfoFromInit, requestBuffer); err != nil { log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Invoke failed") invokeFailure := handleInvokeError(execCtx, invokeRequest, &invokeMx, err) + invokeFailure.InvokeResponseMode = invokeRequest.InvokeResponseMode if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { invokeFailure.ResponseMetrics = interop.ResponseMetrics{ + RuntimeResponseLatencyMs: telemetry.CalculateDuration(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics.StartReadingResponseMonoTimeMs), RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, @@ -683,10 +671,12 @@ func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFr }, InvokeCompletionTimeNs: invokeCompletionTimeNs, InvokeReceivedTime: invokeRequest.InvokeReceivedTime, + InvokeResponseMode: invokeRequest.InvokeResponseMode, } if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { invokeSuccessMsg.ResponseMetrics = interop.ResponseMetrics{ + RuntimeResponseLatencyMs: telemetry.CalculateDuration(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics.StartReadingResponseMonoTimeMs), RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, @@ -701,7 +691,7 @@ func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFr } func reinitialize(execCtx *rapidContext) { - execCtx.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) + execCtx.appCtx.Delete(appctx.AppCtxInvokeErrorTraceDataKey) execCtx.appCtx.Delete(appctx.AppCtxRuntimeReleaseKey) execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) execCtx.renderingService.SetRenderer(nil) @@ -716,32 +706,46 @@ func reinitialize(execCtx *rapidContext) { } // handle notification of reset -func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { +func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, runtimeStartedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { log.Warnf("Reset initiated: %s", resetEvent.Reason) // Only send RuntimeDone event if we get a reset during an Invoke if resetEvent.Reason == "failure" || resetEvent.Reason == "timeout" { - runtimeDoneEventData := telemetry.InvokeRuntimeDoneData{ - Status: resetEvent.Reason, + var errorType *string + if resetEvent.Reason == "failure" { + firstFatalError, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + firstFatalError = fatalerror.SandboxFailure + } + stringifiedError := string(firstFatalError) + errorType = &stringifiedError + } + + var status string + if resetEvent.Reason == "timeout" { + status = "timeout" + } else if strings.HasPrefix(*errorType, "Sandbox.") { + status = "failure" + } else { + status = "error" + } + + var runtimeReadyTime int64 = metering.Monotime() + runtimeDoneEventData := interop.InvokeRuntimeDoneData{ + Status: status, InternalMetrics: invokeResponseMetrics, - Metrics: telemetry.GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, metering.Monotime()), - Tracing: telemetry.BuildTracingCtx(model.XRayTracingType, resetEvent.TraceID, resetEvent.LambdaSegmentID), - Spans: telemetry.GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics), + Metrics: telemetry.GetRuntimeDoneInvokeMetrics(runtimeStartedTime, invokeResponseMetrics, runtimeReadyTime), + Tracing: execCtx.xray.BuildTracingCtxAfterInvokeComplete(), + Spans: execCtx.eventsAPI.GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics, execCtx.RuntimeOverheadStartedTime, runtimeReadyTime), + ErrorType: errorType, } - if err := execCtx.eventsAPI.SendRuntimeDone(runtimeDoneEventData); err != nil { - log.Errorf("Failed to send RUNDONE: %s", err) + if err := execCtx.eventsAPI.SendInvokeRuntimeDone(runtimeDoneEventData); err != nil { + log.Errorf("Failed to send INVOKE RTDONE: %s", err) } } extensionsResetMs, resetTimeout, _ := execCtx.shutdownContext.shutdown(execCtx, resetEvent.DeadlineNs, resetEvent.Reason) - log.Info("Starting runtime domain") - err := execCtx.supervisor.Start(&supvmodel.StartRequest{ - Domain: RuntimeDomain, - }) - if err != nil { - log.WithError(err).Panic("Failed booting runtime domain") - } execCtx.runtimeDomainGeneration++ // Only used by standalone for more indepth assertions. @@ -751,8 +755,12 @@ func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, invokeReceive fatalErrorType, _ = appctx.LoadFirstFatalError(execCtx.appCtx) } + // TODO: move interop.ResponseMetrics{} to a factory method and initialize it there. + // Initialization is very similar in handleInvoke's invokeFailure.ResponseMetrics and + // invokeSuccessMsg.ResponseMetrics var responseMetrics interop.ResponseMetrics if resetEvent.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(resetEvent.InvokeResponseMetrics) { + responseMetrics.RuntimeResponseLatencyMs = telemetry.CalculateDuration(execCtx.RuntimeStartedTime, resetEvent.InvokeResponseMetrics.StartReadingResponseMonoTimeMs) responseMetrics.RuntimeTimeThrottledMs = resetEvent.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) responseMetrics.RuntimeProducedBytes = resetEvent.InvokeResponseMetrics.ProducedBytes responseMetrics.RuntimeOutboundThroughputBps = resetEvent.InvokeResponseMetrics.OutboundThroughputBps @@ -760,16 +768,18 @@ func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, invokeReceive if resetTimeout { return interop.ResetSuccess{}, &interop.ResetFailure{ - ExtensionsResetMs: extensionsResetMs, - ErrorType: fatalErrorType, - ResponseMetrics: responseMetrics, + ExtensionsResetMs: extensionsResetMs, + ErrorType: fatalErrorType, + ResponseMetrics: responseMetrics, + InvokeResponseMode: resetEvent.InvokeResponseMode, } } return interop.ResetSuccess{ - ExtensionsResetMs: extensionsResetMs, - ErrorType: fatalErrorType, - ResponseMetrics: responseMetrics, + ExtensionsResetMs: extensionsResetMs, + ErrorType: fatalErrorType, + ResponseMetrics: responseMetrics, + InvokeResponseMode: resetEvent.InvokeResponseMode, }, nil } @@ -789,75 +799,199 @@ func handleShutdown(execCtx *rapidContext, shutdownEvent *interop.Shutdown, reas return interop.ShutdownSuccess{ErrorType: fatalErrorType} } -func handleRestore(execCtx *rapidContext, restore *interop.Restore) error { +func handleRestore(execCtx *rapidContext, restore *interop.Restore) (interop.RestoreResult, error) { err := execCtx.credentialsService.UpdateCredentials(restore.AwsKey, restore.AwsSecret, restore.AwsSession, restore.CredentialsExpiry) restoreStatus := telemetry.RuntimeDoneSuccess + restoreResult := interop.RestoreResult{} + defer func() { sendRestoreRuntimeDoneLogEvent(execCtx, restoreStatus) }() if err != nil { - return fmt.Errorf("error when updating credentials: %s", err) + log.Infof("error when updating credentials: %s", err) + return restoreResult, interop.ErrRestoreUpdateCredentials } + renderer := rendering.NewRestoreRenderer() execCtx.renderingService.SetRenderer(renderer) registrationService := execCtx.registrationService runtime := registrationService.GetRuntime() + execCtx.SetLogStreamName(restore.LogStreamName) + // If runtime has not called /restore/next then just return // instead of releasing the Runtime since there is no need to release. // Then the runtime should be released only during Invoke if runtime.GetState() != runtime.RuntimeRestoreReadyState { restoreStatus = telemetry.RuntimeDoneSuccess log.Infof("Runtime is in state: %s just returning", runtime.GetState().Name()) - return nil + + return restoreResult, nil } + deadlineNs := time.Now().Add(time.Duration(restore.RestoreHookTimeoutMs) * time.Millisecond).UnixNano() + + ctx, ctxCancel := context.WithDeadline(context.Background(), time.Unix(0, deadlineNs)) + + defer ctxCancel() + + startTime := metering.Monotime() + runtime.Release() initFlow := execCtx.initFlow - err = initFlow.AwaitRuntimeReady() + err = initFlow.AwaitRuntimeReadyWithDeadline(ctx) + + fatalErrorType, fatalErrorFound := appctx.LoadFirstFatalError(execCtx.appCtx) + + // If there is an error occured when waiting runtime to complete the restore hook execution, + // check if there is any error stored in appctx to get the root cause error type + // Runtime.ExitError is an example to such a scenario + if fatalErrorFound { + err = fmt.Errorf(string(fatalErrorType)) + } if err != nil { - restoreStatus = telemetry.RuntimeDoneFailure - } else { - restoreStatus = telemetry.RuntimeDoneSuccess + restoreStatus = telemetry.RuntimeDoneError } - return err + endTime := metering.Monotime() + restoreDuration := time.Duration(endTime - startTime) + restoreResult.RestoreMs = restoreDuration.Milliseconds() + + return restoreResult, err } -func start(signalCtx context.Context, execCtx *rapidContext) { +func startRuntimeAPI(ctx context.Context, execCtx *rapidContext) { // Start Runtime API Server err := execCtx.server.Listen() if err != nil { log.WithError(err).Panic("Runtime API Server failed to listen") } - go func() { execCtx.server.Serve(signalCtx) }() + execCtx.server.Serve(ctx) // blocking until server exits // Note, most of initialization code should run before blocking to receive START, // code before START runs in parallel with code downloads. } +func getFirstFatalError(execCtx *rapidContext, status string) *string { + if status == telemetry.RuntimeDoneSuccess { + return nil + } + + firstFatalError, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + // We will set errorType to "Runtime.Unknown" in case of INIT timeout and RESTORE timeout + // This is a trade-off we are willing to make. We will improve this later + firstFatalError = fatalerror.RuntimeUnknown + } + stringifiedError := string(firstFatalError) + return &stringifiedError +} + func sendRestoreRuntimeDoneLogEvent(execCtx *rapidContext, status string) { - if err := execCtx.eventsAPI.SendRestoreRuntimeDone(status); err != nil { - log.Errorf("Failed to send RESTRD: %s", err) + firstFatalError := getFirstFatalError(execCtx, status) + + restoreRuntimeDoneData := interop.RestoreRuntimeDoneData{ + Status: status, + ErrorType: firstFatalError, + } + + if err := execCtx.eventsAPI.SendRestoreRuntimeDone(restoreRuntimeDoneData); err != nil { + log.Errorf("Failed to send RESTORE RTDONE: %s", err) + } +} + +func sendInitStartLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, phase interop.LifecyclePhase) { + initPhase, err := telemetry.InitPhaseFromLifecyclePhase(phase) + if err != nil { + log.Errorf("failed to convert lifecycle phase into init phase: %s", err) + return + } + + functionMetadata := execCtx.registrationService.GetFunctionMetadata() + initStartData := interop.InitStartData{ + InitializationType: telemetry.InferInitType(execCtx.initCachingEnabled, sandboxType), + RuntimeVersion: functionMetadata.RuntimeInfo.Version, + RuntimeVersionArn: functionMetadata.RuntimeInfo.Arn, + FunctionName: functionMetadata.FunctionName, + FunctionVersion: functionMetadata.FunctionVersion, + // based on https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/resource/semantic_conventions/faas.md + // we're sending the logStream as the instance id + InstanceID: execCtx.logStreamName, + InstanceMaxMemory: functionMetadata.InstanceMaxMemory, + Phase: initPhase, + } + log.Info(initStartData.String()) + + if err := execCtx.eventsAPI.SendInitStart(initStartData); err != nil { + log.Errorf("Failed to send INIT START: %s", err) } } -func sendInitRuntimeDoneLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, status string) { - initSource := interop.InferTelemetryInitSource(execCtx.initCachingEnabled, sandboxType) +func sendInitRuntimeDoneLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, status string, phase interop.LifecyclePhase) { + initPhase, err := telemetry.InitPhaseFromLifecyclePhase(phase) + if err != nil { + log.Errorf("failed to convert lifecycle phase into init phase: %s", err) + return + } + + firstFatalError := getFirstFatalError(execCtx, status) + + initRuntimeDoneData := interop.InitRuntimeDoneData{ + InitializationType: telemetry.InferInitType(execCtx.initCachingEnabled, sandboxType), + Status: status, + Phase: initPhase, + ErrorType: firstFatalError, + } + + log.Info(initRuntimeDoneData.String()) + + if err := execCtx.eventsAPI.SendInitRuntimeDone(initRuntimeDoneData); err != nil { + log.Errorf("Failed to send INIT RTDONE: %s", err) + } +} + +func sendInitReportLogEvent( + execCtx *rapidContext, + sandboxType interop.SandboxType, + initStartMonotime int64, + phase interop.LifecyclePhase, +) { + initPhase, err := telemetry.InitPhaseFromLifecyclePhase(phase) + if err != nil { + log.Errorf("failed to convert lifecycle phase into init phase: %s", err) + return + } + + initReportData := interop.InitReportData{ + InitializationType: telemetry.InferInitType(execCtx.initCachingEnabled, sandboxType), + Metrics: interop.InitReportMetrics{ + DurationMs: telemetry.CalculateDuration(initStartMonotime, metering.Monotime()), + }, + Phase: initPhase, + } + log.Info(initReportData.String()) + + if err = execCtx.eventsAPI.SendInitReport(initReportData); err != nil { + log.Errorf("Failed to send INIT REPORT: %s", err) + } +} - runtimeDoneData := &telemetry.InitRuntimeDoneData{ - InitSource: initSource, - Status: status, +func sendInvokeStartLogEvent(execCtx *rapidContext, invokeRequestID string, tracingCtx *interop.TracingCtx) { + invokeStartData := interop.InvokeStartData{ + RequestID: invokeRequestID, + Version: execCtx.registrationService.GetFunctionMetadata().FunctionVersion, + Tracing: tracingCtx, } + log.Info(invokeStartData.String()) - if err := execCtx.eventsAPI.SendInitRuntimeDone(runtimeDoneData); err != nil { - log.Errorf("Failed to send INITRD: %s", err) + if err := execCtx.eventsAPI.SendInvokeStart(invokeStartData); err != nil { + log.Errorf("Failed to send INVOKE START: %s", err) } } diff --git a/lambda/rapid/handlers_test.go b/lambda/rapid/handlers_test.go new file mode 100644 index 0000000..089dbb7 --- /dev/null +++ b/lambda/rapid/handlers_test.go @@ -0,0 +1,341 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "regexp" + "strconv" + "strings" + "sync" + "testing" + "time" + + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi" + "go.amzn.com/lambda/rapi/handler" + "go.amzn.com/lambda/rapi/rendering" + "go.amzn.com/lambda/rapidcore/env" + "go.amzn.com/lambda/supervisor/model" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func BenchmarkChannelsSelect10(b *testing.B) { + c1 := make(chan int) + c2 := make(chan int) + c3 := make(chan int) + c4 := make(chan int) + c5 := make(chan int) + c6 := make(chan int) + c7 := make(chan int) + c8 := make(chan int) + c9 := make(chan int) + c10 := make(chan int) + + for n := 0; n < b.N; n++ { + select { + case <-c1: + case <-c2: + case <-c3: + case <-c4: + case <-c5: + case <-c6: + case <-c7: + case <-c8: + case <-c9: + case <-c10: + default: + } + } +} + +func BenchmarkChannelsSelect2(b *testing.B) { + c1 := make(chan int) + c2 := make(chan int) + + for n := 0; n < b.N; n++ { + select { + case <-c1: + case <-c2: + default: + } + } +} + +func TestGetExtensionNamesWithNoExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + + c := &rapidContext{ + registrationService: rs, + } + + assert.Equal(t, "", c.GetExtensionNames()) +} + +func TestGetExtensionNamesWithMultipleExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + _, _ = rs.CreateExternalAgent("Example1") + _, _ = rs.CreateInternalAgent("Example2") + _, _ = rs.CreateExternalAgent("Example3") + _, _ = rs.CreateInternalAgent("Example4") + + c := &rapidContext{ + registrationService: rs, + } + + r := regexp.MustCompile(`^(Example\d;){3}(Example\d)$`) + assert.True(t, r.MatchString(c.GetExtensionNames())) +} + +func TestGetExtensionNamesWithTooManyExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + for i := 10; i < 60; i++ { + _, _ = rs.CreateExternalAgent("E" + strconv.Itoa(i)) + } + + c := &rapidContext{ + registrationService: rs, + } + + output := c.GetExtensionNames() + + r := regexp.MustCompile(`^(E\d\d;){30}(E\d\d)$`) + assert.LessOrEqual(t, len(output), maxExtensionNamesLength) + assert.True(t, r.MatchString(output)) +} + +func TestGetExtensionNamesWithTooLongExtensionName(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + for i := 10; i < 60; i++ { + _, _ = rs.CreateExternalAgent(strings.Repeat("E", 130)) + } + + c := &rapidContext{ + registrationService: rs, + } + + assert.Equal(t, "", c.GetExtensionNames()) +} + +// This test confirms our assumption that http client can establish a tcp connection +// to a listening server. +func TestListen(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + flowTest.ConfigureForInvoke(context.Background(), &interop.Invoke{ID: "ID", DeadlineNs: "1", Payload: strings.NewReader("MyTest")}) + + ctx := context.Background() + telemetryAPIEnabled := true + server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.TelemetrySubscription, flowTest.TelemetrySubscription, flowTest.CredentialsService) + err := server.Listen() + assert.NoError(t, err) + + defer server.Close() + + go func() { + time.Sleep(time.Second) + fmt.Println("Serving...") + server.Serve(ctx) + }() + + done := make(chan struct{}) + + go func() { + fmt.Println("Connecting...") + resp, err1 := http.Get(fmt.Sprintf("http://%s:%d/2018-06-01/runtime/invocation/next", server.Host(), server.Port())) + assert.Nil(t, err1) + + body, err2 := io.ReadAll(resp.Body) + assert.Nil(t, err2) + + assert.Equal(t, "MyTest", string(body)) + + done <- struct{}{} + }() + + <-done +} + +func makeRapidContext(appCtx appctx.ApplicationContext, initFlow core.InitFlowSynchronization, invokeFlow core.InvokeFlowSynchronization, registrationService core.RegistrationService, supervisor *processSupervisor) *rapidContext { + + appctx.StoreInitType(appCtx, true) + appctx.StoreInteropServer(appCtx, MockInteropServer{}) + + renderingService := rendering.NewRenderingService() + + credentialsService := core.NewCredentialsService() + credentialsService.SetCredentials("token", "key", "secret", "session", time.Now()) + + // Runtime state machine + runtime := core.NewRuntime(initFlow, invokeFlow) + + registrationService.PreregisterRuntime(runtime) + runtime.SetState(runtime.RuntimeRestoreReadyState) + + rapidCtx := &rapidContext{ + // Internally initialized configurations + appCtx: appCtx, + initDone: true, + initFlow: initFlow, + invokeFlow: invokeFlow, + registrationService: registrationService, + renderingService: renderingService, + credentialsService: credentialsService, + handlerExecutionMutex: sync.Mutex{}, + shutdownContext: newShutdownContext(), + eventsAPI: &telemetry.NoOpEventsAPI{}, + } + if supervisor != nil { + rapidCtx.supervisor = *supervisor + } + + return rapidCtx +} + +const hookErrorType = "Runtime.RestoreHookUserErrorType" + +func makeRequest(appCtx appctx.ApplicationContext) *http.Request { + errorBody := []byte("My byte array is yours") + + request := appctx.RequestWithAppCtx(httptest.NewRequest("POST", "/", bytes.NewReader(errorBody)), appCtx) + + request.Header.Set("Content-Type", "application/MyBinaryType") + request.Header.Set("Lambda-Runtime-Function-Error-Type", hookErrorType) + + return request +} + +type MockInteropServer struct{} + +func (server MockInteropServer) GetCurrentInvokeID() string { + return "" +} + +func (server MockInteropServer) SendRuntimeReady() error { + return nil +} + +func (server MockInteropServer) SendInitErrorResponse(response *interop.ErrorInvokeResponse) error { + return nil +} + +func TestRestoreErrorAndAwaitRestoreCompletionRaceCondition(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + invokeFlow := core.NewInvokeFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow, invokeFlow) + + rapidCtx := makeRapidContext(appCtx, initFlow, invokeFlow, registrationService, nil /* don't set process supervisor */) + + // Runtime state machine + runtime := core.NewRuntime(initFlow, invokeFlow) + registrationService.PreregisterRuntime(runtime) + runtime.SetState(runtime.RuntimeRestoreReadyState) + + restore := &interop.Restore{ + AwsKey: "key", + AwsSecret: "secret", + AwsSession: "session", + CredentialsExpiry: time.Now(), + RestoreHookTimeoutMs: 10 * 1000, + } + + var wg sync.WaitGroup + + wg.Add(1) + + go func() { + defer wg.Done() + _, err := rapidCtx.HandleRestore(restore) + assert.Equal(t, err.Error(), "errRestoreHookUserError") + v, ok := err.(interop.ErrRestoreHookUserError) + assert.True(t, ok) + assert.Equal(t, v.UserError.Type, fatalerror.ErrorType(hookErrorType)) + }() + + responseRecorder := httptest.NewRecorder() + + handler := handler.NewRestoreErrorHandler(registrationService) + + request := makeRequest(appCtx) + + wg.Add(1) + + time.Sleep(1 * time.Second) + runtime.SetState(runtime.RuntimeRestoringState) + + go func() { + defer wg.Done() + handler.ServeHTTP(responseRecorder, request) + }() + + wg.Wait() +} + +type MockedProcessSupervisor struct { + mock.Mock +} + +func (supv *MockedProcessSupervisor) Exec(ctx context.Context, req *model.ExecRequest) error { + args := supv.Called(req) + return args.Error(0) +} + +func (supv *MockedProcessSupervisor) Events(ctx context.Context, req *model.EventsRequest) (<-chan model.Event, error) { + args := supv.Called(req) + err := args.Error(1) + if err != nil { + return nil, err + } + return args.Get(0).(<-chan model.Event), nil +} + +func (supv *MockedProcessSupervisor) Terminate(ctx context.Context, req *model.TerminateRequest) error { + args := supv.Called(req) + return args.Error(0) +} + +func (supv *MockedProcessSupervisor) Kill(ctx context.Context, req *model.KillRequest) error { + args := supv.Called(req) + return args.Error(0) +} + +var _ model.ProcessSupervisor = (*MockedProcessSupervisor)(nil) + +func TestSetupEventWatcherErrorHandling(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + invokeFlow := core.NewInvokeFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow, invokeFlow) + mockedProcessSupervisor := &MockedProcessSupervisor{} + mockedProcessSupervisor.On("Events", mock.Anything).Return(nil, fmt.Errorf("events call failed")) + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + + rapidCtx := makeRapidContext(appCtx, initFlow, invokeFlow, registrationService, procSupv) + + initSuccessResponseChan := make(chan interop.InitSuccess) + initFailureResponseChan := make(chan interop.InitFailure) + init := &interop.Init{EnvironmentVariables: env.NewEnvironment()} + + go assert.NotPanics(t, func() { + rapidCtx.HandleInit(init, initSuccessResponseChan, initFailureResponseChan) + }) + + failure := <-initFailureResponseChan + failure.Ack <- struct{}{} + errorType := interop.InitFailure(failure).ErrorType + assert.Equal(t, fatalerror.SandboxFailure, errorType) +} diff --git a/lambda/rapid/sandbox.go b/lambda/rapid/sandbox.go index 9259514..26eaff0 100644 --- a/lambda/rapid/sandbox.go +++ b/lambda/rapid/sandbox.go @@ -4,22 +4,19 @@ package rapid import ( + "bytes" "context" "fmt" "io" "sync" - "time" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi" "go.amzn.com/lambda/rapi/rendering" supvmodel "go.amzn.com/lambda/supervisor/model" "go.amzn.com/lambda/telemetry" - - log "github.com/sirupsen/logrus" ) type Sandbox struct { @@ -32,18 +29,31 @@ type Sandbox struct { LogsEgressAPI telemetry.StdLogsEgressAPI RuntimeStdoutWriter io.Writer RuntimeStderrWriter io.Writer - PreLoadTimeNs int64 Handler string - SignalCtx context.Context - EventsAPI telemetry.EventsAPI + EventsAPI interop.EventsAPI InitCachingEnabled bool - Supervisor supvmodel.Supervisor + Supervisor supvmodel.ProcessSupervisor + RuntimeFsRootPath string // path to the root of the domain within the root mnt namespace. Reqired to find extensions RuntimeAPIHost string RuntimeAPIPort int } -// Start is a public version of start() that exports only configurable parameters -func Start(s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, string) { +// Start pings Supervisor, and starts the Runtime API server. It allows the caller to configure: +// - Supervisor implementation: performs container construction & process management +// - Telemetry API and Logs API implementation: handling /logs and /telemetry of Runtime API +// - Events API implementation: handles platform log events emitted by Rapid (e.g. RuntimeDone, InitStart) +// - Logs Egress implementation: handling stdout/stderr logs from extension & runtime processes (TODO: remove & unify with Supervisor) +// - Tracer implementation: handling trace segments generate by platform (TODO: remove & unify with Events API) +// - InteropServer implementation: legacy interface for sending internal protocol messages, today only RuntimeReady remains (TODO: move RuntimeReady outside Core) +// - Feature flags: +// - StandaloneMode: indicates if being called by Rapid Core's standalone HTTP frontend (TODO: remove after unifying error reporting) +// - InitCachingEnabled: indicates if handlers must run Init Caching specific logic +// - TelemetryAPIEnabled: indicates if /telemetry and /logs endpoint HTTP handlers must be mounted +// +// - Contexts & Data: +// - ctx is used to gracefully terminate Runtime API HTTP Server on exit +func Start(ctx context.Context, s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, string) { + // Initialize internal state objects required by Rapid handlers appCtx := appctx.NewApplicationContext() initFlow := core.NewInitFlowSynchronization() invokeFlow := core.NewInvokeFlowSynchronization() @@ -53,26 +63,27 @@ func Start(s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, strin appctx.StoreInitType(appCtx, s.InitCachingEnabled) - server := rapi.NewServer(s.RuntimeAPIHost, s.RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.TelemetrySubscriptionAPI, credentialsService, s.EventsAPI) + server := rapi.NewServer(s.RuntimeAPIHost, s.RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.TelemetrySubscriptionAPI, credentialsService) runtimeAPIAddr := fmt.Sprintf("%s:%d", server.Host(), server.Port()) - postLoadTimeNs := metering.Monotime() - // TODO: pass this directly down to HTTP servers and handlers, instead of using // global state to share the interop server implementation appctx.StoreInteropServer(appCtx, s.InteropServer) execCtx := &rapidContext{ - server: server, - appCtx: appCtx, - postLoadTimeNs: postLoadTimeNs, - initDone: false, - initFlow: initFlow, - invokeFlow: invokeFlow, - registrationService: registrationService, - renderingService: renderingService, - credentialsService: credentialsService, - + // Internally initialized configurations + server: server, + appCtx: appCtx, + initDone: false, + initFlow: initFlow, + invokeFlow: invokeFlow, + registrationService: registrationService, + renderingService: renderingService, + credentialsService: credentialsService, + handlerExecutionMutex: sync.Mutex{}, + shutdownContext: newShutdownContext(), + + // Externally specified configurations (i.e. via SandboxBuilder) telemetryAPIEnabled: s.EnableTelemetryAPI, logsSubscriptionAPI: s.LogsSubscriptionAPI, telemetrySubscriptionAPI: s.TelemetrySubscriptionAPI, @@ -80,77 +91,84 @@ func Start(s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, strin interopServer: s.InteropServer, xray: s.Tracer, standaloneMode: s.StandaloneMode, - preLoadTimeNs: s.PreLoadTimeNs, eventsAPI: s.EventsAPI, initCachingEnabled: s.InitCachingEnabled, - signalCtx: s.SignalCtx, - supervisor: s.Supervisor, - executionMutex: sync.Mutex{}, - shutdownContext: newShutdownContext(), + supervisor: processSupervisor{ + ProcessSupervisor: s.Supervisor, + RootPath: s.RuntimeFsRootPath, + }, + + RuntimeStartedTime: -1, + RuntimeOverheadStartedTime: -1, + InvokeResponseMetrics: nil, } - // We call /ping on Supervisor before starting Rapid, since Rapid - // depends on Supervisor setting up networking dependencies - var startupErr error - for retries := 1; retries <= 5; retries++ { - if startupErr = s.Supervisor.Ping(); startupErr == nil { - break - } - // Retry timeout: 5s, same order-of-mag as test client PING retries - // TODO: revisit retry timeout, identify appropriate value for prod. - time.Sleep(1000 * time.Millisecond) - } - - if startupErr != nil { - log.Panicf("Application ping to Supervisor failed, terminating Rapid Startup: %s", startupErr) - } - - go start(s.SignalCtx, execCtx) + go startRuntimeAPI(ctx, execCtx) return execCtx, registrationService.GetInternalStateDescriptor(appCtx), runtimeAPIAddr } -func (r *rapidContext) HandleInit(init *interop.Init, initStartedResponseChan chan<- interop.InitStarted, initSuccessResponseChan chan<- interop.InitSuccess, initFailureResponseChan chan<- interop.InitFailure) { - r.executionMutex.Lock() - defer r.executionMutex.Unlock() - handleInit(r, init, initStartedResponseChan, initSuccessResponseChan, initFailureResponseChan) +func (r *rapidContext) HandleInit(init *interop.Init, initSuccessResponseChan chan<- interop.InitSuccess, initFailureResponseChan chan<- interop.InitFailure) { + r.handlerExecutionMutex.Lock() + defer r.handlerExecutionMutex.Unlock() + handleInit(r, init, initSuccessResponseChan, initFailureResponseChan) } -func (r *rapidContext) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { - r.executionMutex.Lock() - defer r.executionMutex.Unlock() - // Clear the context used by the last invok - r.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) - return handleInvoke(r, invoke, sbInfoFromInit) +func (r *rapidContext) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit, requestBuffer *bytes.Buffer, responseSender interop.InvokeResponseSender) (interop.InvokeSuccess, *interop.InvokeFailure) { + r.handlerExecutionMutex.Lock() + defer r.handlerExecutionMutex.Unlock() + // Clear the context used by the last invoke + r.appCtx.Delete(appctx.AppCtxInvokeErrorTraceDataKey) + return handleInvoke(r, invoke, sbInfoFromInit, requestBuffer, responseSender) } -func (r *rapidContext) HandleReset(reset *interop.Reset, invokeReceivedTime int64, InvokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { +func (r *rapidContext) HandleReset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { // In the event of a Reset during init/invoke, CancelFlows cancels execution // flows and return with the errResetReceived err - this error is special-cased // and not handled by the init/invoke (unexpected) error handling functions r.registrationService.CancelFlows(errResetReceived) // Wait until invoke error handling has returned before continuing execution - r.executionMutex.Lock() - defer r.executionMutex.Unlock() + r.handlerExecutionMutex.Lock() + defer r.handlerExecutionMutex.Unlock() - // Clear the context used by the last invoke, i.e. error message etc. - r.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) - return handleReset(r, reset, invokeReceivedTime, InvokeResponseMetrics) + // Clear the context used by the last invoke + r.appCtx.Delete(appctx.AppCtxInvokeErrorTraceDataKey) + return handleReset(r, reset, r.RuntimeStartedTime, r.InvokeResponseMetrics) } func (r *rapidContext) HandleShutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { // Wait until invoke error handling has returned before continuing execution - r.executionMutex.Lock() - defer r.executionMutex.Unlock() + r.handlerExecutionMutex.Lock() + defer r.handlerExecutionMutex.Unlock() // Shutdown doesn't cancel flows, so it can block forever return handleShutdown(r, shutdown, standaloneShutdownReason) } -func (r *rapidContext) HandleRestore(restore *interop.Restore) error { +func (r *rapidContext) HandleRestore(restore *interop.Restore) (interop.RestoreResult, error) { return handleRestore(r, restore) } func (r *rapidContext) Clear() { reinitialize(r) } + +func (r *rapidContext) SetRuntimeStartedTime(runtimeStartedTime int64) { + r.RuntimeStartedTime = runtimeStartedTime +} + +func (r *rapidContext) SetRuntimeOverheadStartedTime(runtimeOverheadStartedTime int64) { + r.RuntimeOverheadStartedTime = runtimeOverheadStartedTime +} + +func (r *rapidContext) SetInvokeResponseMetrics(metrics *interop.InvokeResponseMetrics) { + r.InvokeResponseMetrics = metrics +} + +func (r *rapidContext) SetLogStreamName(logStreamName string) { + r.logStreamName = logStreamName +} + +func (r *rapidContext) SetEventsAPI(eventsAPI interop.EventsAPI) { + r.eventsAPI = eventsAPI +} diff --git a/lambda/rapid/shutdown.go b/lambda/rapid/shutdown.go index fe23a9f..05695e3 100644 --- a/lambda/rapid/shutdown.go +++ b/lambda/rapid/shutdown.go @@ -5,6 +5,8 @@ package rapid import ( + "context" + "errors" "fmt" "sync" "time" @@ -21,18 +23,20 @@ import ( const ( // supervisor shutdown and kill operations block until the exit status of the - // interested process has been collected, or until the specified timeotuw - // expires (in which case the operation fails). - // Note that this timeout is mainly relevant when any of the domain + // interested process has been collected, or until the specified deadline expires + // Note that this deadline is mainly relevant when any of the domain // processes are in uninterruptible sleep state (notable examples: syscall - // to read/write a newtorked driver) + // to read/write a networked driver) // // We set a non nil value for these timeouts so that RAPID doesn't block // forever in one of the cases above. supervisorBlockingMaxMillis = 9000 runtimeDeadlineShare = 0.3 + + maxProcessExitWait = 2 * time.Second ) +// TODO: aggregate struct's methods into an interface, so that we can mock in tests type shutdownContext struct { // Adding a mutex around shuttingDown because there may be concurrent reads/writes. // Because the code in shutdown() and the seperate go routine created in setupEventsWatcher() @@ -130,11 +134,15 @@ func (s *shutdownContext) createExitedChannel(name string) { // Blocks until all the processes in the runtime domain generation have exited. // This helps us have a nice sync point on Shutdown where we know for sure that -// all the processes have exited and the state has been cleared. +// all the processes have exited and the state has been cleared. The exception +// to that rule is that if any of the processes don't exit within +// maxProcessExitWait from the beginning of the waiting period, an error is +// returned, in order to prevent it from waiting forever if any of the processes +// cannot be killed. // // It is OK not to hold the lock because we know that this is called only during // shutdown and nobody will start a new process during shutdown -func (s *shutdownContext) clearExitedChannel() { +func (s *shutdownContext) clearExitedChannel() error { s.runtimeDomainExitedMutex.Lock() mapLen := len(s.runtimeDomainExited) channels := make([]chan struct{}, 0, mapLen) @@ -143,26 +151,32 @@ func (s *shutdownContext) clearExitedChannel() { } s.runtimeDomainExitedMutex.Unlock() + exitTimeout := time.After(maxProcessExitWait) for _, v := range channels { - <-v + select { + case <-v: + case <-exitTimeout: + return errors.New("timed out waiting for runtime processes to exit") + } } s.runtimeDomainExitedMutex.Lock() s.runtimeDomainExited = make(map[string]chan struct{}, mapLen) s.runtimeDomainExitedMutex.Unlock() + return nil } func (s *shutdownContext) shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time) { // If runtime is started: - // 1. SIGTERM and wait until timeout - // 2. SIGKILL on timeout + // 1. SIGTERM and wait until deadline + // 2. SIGKILL on deadline log.Debug("Shutting down the runtime.") name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) exitedChannel, found := s.getExitedChannel(name) if found { - err := execCtx.supervisor.Terminate(&supvmodel.TerminateRequest{ + err := execCtx.supervisor.Terminate(context.Background(), &supvmodel.TerminateRequest{ Domain: RuntimeDomain, Name: name, }) @@ -172,17 +186,17 @@ func (s *shutdownContext) shutdownRuntime(execCtx *rapidContext, start time.Time log.WithError(err).Warn("Failed sending Termination signal to runtime") } - runtimeTimeout := deadline.Sub(start) - log.Tracef("The runtime timeout is %v.", runtimeTimeout) - runtimeTimer := time.NewTimer(runtimeTimeout) + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + select { - case <-runtimeTimer.C: - log.Warnf("Timeout: The runtime did not exit after %d ms; Killing it.", int64(runtimeTimeout/time.Millisecond)) - supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) - err = execCtx.supervisor.Kill(&supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Timeout: &supervisorBlockingMaxMillis, + case <-ctx.Done(): + log.Warnf("Deadline: The runtime did not exit after deadline %s; Killing it.", deadline) + + err = execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), }) if err != nil { @@ -201,8 +215,8 @@ func (s *shutdownContext) shutdownRuntime(execCtx *rapidContext, start time.Time func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, deadline time.Time, reason string) { // For each external agent, if agent is launched: // 1. Send Shutdown event if subscribed for it, else send SIGKILL to process group - // 2. Wait for all Shutdown-subscribed agents to exit with timeout - // 3. Send SIGKILL to process group for Shutdown-subscribed agents on timeout + // 2. Wait for all Shutdown-subscribed agents to exit with deadline + // 3. Send SIGKILL to process group for Shutdown-subscribed agents on deadline log.Debug("Shutting down the agents.") execCtx.renderingService.SetRenderer( @@ -224,7 +238,6 @@ func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, for _, a := range execCtx.registrationService.GetExternalAgents() { name := fmt.Sprintf("extension-%s-%d", a.Name, execCtx.runtimeDomainGeneration) exitedChannel, found := s.getExitedChannel(name) - supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) if !found { log.Warnf("Agent %s failed to launch, therefore skipping shutting it down.", a) @@ -242,24 +255,25 @@ func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, agent.Release() - agentTimeout := deadline.Sub(start) - var agentTimeoutChan <-chan time.Time + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() if execCtx.standaloneMode { - agentTimeoutChan = time.NewTimer(agentTimeout).C + ctx, cancel = context.WithDeadline(ctx, deadline) + defer cancel() } select { - case <-agentTimeoutChan: - log.Warnf("Timeout: the agent %s did not exit after %d ms; Killing it.", name, int64(agentTimeout/time.Millisecond)) - err := execCtx.supervisor.Kill(&supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Timeout: &supervisorBlockingMaxMillis, + case <-ctx.Done(): + log.Warnf("Deadline: the agent %s did not exit after deadline %s; Killing it.", name, deadline) + err := execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), }) if err != nil { // We are not reporting the error upstream because we will anyway // shut the domain out at the end of the shutdown sequence - log.WithError(err).Warn("Failed sending Kill signal to runtime") + log.WithError(err).Warn("Failed sending Kill signal to agent") } case <-exitedChannel: } @@ -270,11 +284,14 @@ func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, go func(name string) { defer wg.Done() - execCtx.supervisor.Kill(&supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Timeout: &supervisorBlockingMaxMillis, + err := execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), }) + if err != nil { + log.WithError(err).Warn("Failed sending Kill signal to agent") + } }(name) } } @@ -295,7 +312,6 @@ func (s *shutdownContext) shutdown(execCtx *rapidContext, deadlineNs int64, reas execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) runtimeDomainProfiler := &metering.ExtensionsResetDurationProfiler{} - supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) // We do not spend any compute time on runtime graceful shutdown if there are no agents if execCtx.registrationService.CountAgents() == 0 { @@ -305,10 +321,10 @@ func (s *shutdownContext) shutdown(execCtx *rapidContext, deadlineNs int64, reas if found { log.Debug("SIGKILLing the runtime as no agents are registered.") - err = execCtx.supervisor.Kill(&supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Timeout: &supervisorBlockingMaxMillis, + err = execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), }) if err != nil { // We are not reporting the error upstream because we will anyway @@ -340,27 +356,13 @@ func (s *shutdownContext) shutdown(execCtx *rapidContext, deadlineNs int64, reas runtimeDomainProfiler.NumAgentsRegisteredForShutdown = len(s.agentsAwaitingExit) } - log.Info("Stopping runtime domain") - err = execCtx.supervisor.Stop(&supvmodel.StopRequest{ - Domain: RuntimeDomain, - Timeout: &supervisorBlockingMaxMillis, - }) - if err != nil { - log.WithError(err).Error("Failed shutting runtime domain down") - } else { - log.Info("Waiting for runtime domain processes termination") - s.clearExitedChannel() - log.Info("Stopping operator domain") - err = execCtx.supervisor.Stop(&supvmodel.StopRequest{ - Domain: OperatorDomain, - Timeout: &supervisorBlockingMaxMillis, - }) - if err != nil { - log.WithError(err).Error("Failed shutting operator domain down") - } + + log.Info("Waiting for runtime domain processes termination") + if err := s.clearExitedChannel(); err != nil { + log.Error(err) } runtimeDomainProfiler.Stop() - extensionsRestMs, timeout := runtimeDomainProfiler.CalculateExtensionsResetMs() - return extensionsRestMs, timeout, err + extensionsResetMs, timeout := runtimeDomainProfiler.CalculateExtensionsResetMs() + return extensionsResetMs, timeout, err } diff --git a/lambda/rapid/start_test.go b/lambda/rapid/start_test.go deleted file mode 100644 index ffb446f..0000000 --- a/lambda/rapid/start_test.go +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "context" - "fmt" - "go.amzn.com/lambda/core" - "io" - "net/http" - "regexp" - "strconv" - "strings" - "testing" - "time" - - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi" - "go.amzn.com/lambda/testdata" - - "github.com/stretchr/testify/assert" -) - -func BenchmarkChannelsSelect10(b *testing.B) { - c1 := make(chan int) - c2 := make(chan int) - c3 := make(chan int) - c4 := make(chan int) - c5 := make(chan int) - c6 := make(chan int) - c7 := make(chan int) - c8 := make(chan int) - c9 := make(chan int) - c10 := make(chan int) - - for n := 0; n < b.N; n++ { - select { - case <-c1: - break - case <-c2: - break - case <-c3: - break - case <-c4: - break - case <-c5: - break - case <-c6: - break - case <-c7: - break - case <-c8: - break - case <-c9: - break - case <-c10: - break - default: - break - } - } -} - -func BenchmarkChannelsSelect2(b *testing.B) { - c1 := make(chan int) - c2 := make(chan int) - - for n := 0; n < b.N; n++ { - select { - case <-c1: - break - case <-c2: - break - default: - break - } - } -} - -func TestGetExtensionNamesWithNoExtensions(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - - c := &rapidContext{ - registrationService: rs, - } - - assert.Equal(t, "", c.GetExtensionNames()) -} - -func TestGetExtensionNamesWithMultipleExtensions(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - _, _ = rs.CreateExternalAgent("Example1") - _, _ = rs.CreateInternalAgent("Example2") - _, _ = rs.CreateExternalAgent("Example3") - _, _ = rs.CreateInternalAgent("Example4") - - c := &rapidContext{ - registrationService: rs, - } - - r := regexp.MustCompile(`^(Example\d;){3}(Example\d)$`) - assert.True(t, r.MatchString(c.GetExtensionNames())) -} - -func TestGetExtensionNamesWithTooManyExtensions(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - for i := 10; i < 60; i++ { - _, _ = rs.CreateExternalAgent("E" + strconv.Itoa(i)) - } - - c := &rapidContext{ - registrationService: rs, - } - - output := c.GetExtensionNames() - - r := regexp.MustCompile(`^(E\d\d;){30}(E\d\d)$`) - assert.LessOrEqual(t, len(output), maxExtensionNamesLength) - assert.True(t, r.MatchString(output)) -} - -func TestGetExtensionNamesWithTooLongExtensionName(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - for i := 10; i < 60; i++ { - _, _ = rs.CreateExternalAgent(strings.Repeat("E", 130)) - } - - c := &rapidContext{ - registrationService: rs, - } - - assert.Equal(t, "", c.GetExtensionNames()) -} - -// This test confirms our assumption that http client can establish a tcp connection -// to a listening server. -func TestListen(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.ConfigureForInvoke(context.Background(), &interop.Invoke{ID: "ID", DeadlineNs: "1", Payload: strings.NewReader("MyTest")}) - - ctx := context.Background() - telemetryAPIEnabled := true - server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.TelemetrySubscription, flowTest.TelemetrySubscription, flowTest.CredentialsService, flowTest.EventsAPI) - err := server.Listen() - assert.NoError(t, err) - - defer server.Close() - - go func() { - time.Sleep(time.Second) - fmt.Println("Serving...") - server.Serve(ctx) - }() - - done := make(chan struct{}) - - go func() { - fmt.Println("Connecting...") - resp, err1 := http.Get(fmt.Sprintf("http://%s:%d/2018-06-01/runtime/invocation/next", server.Host(), server.Port())) - assert.Nil(t, err1) - - body, err2 := io.ReadAll(resp.Body) - assert.Nil(t, err2) - - assert.Equal(t, "MyTest", string(body)) - - done <- struct{}{} - }() - - <-done -} - -func TestInferSandboxInitTypeOnDemand(t *testing.T) { - initCachingEnabled := false - sandboxType := interop.SandboxClassic - initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) - assert.Equal(t, "on-demand", initSource) -} - -func TestInferSandboxInitTypeProvisionedConcurrency(t *testing.T) { - initCachingEnabled := false - sandboxType := interop.SandboxPreWarmed - initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) - assert.Equal(t, "provisioned-concurrency", initSource) -} - -func TestInferSandboxInitTypeInitCaching(t *testing.T) { - initCachingEnabled := true - sandboxType := interop.SandboxClassic - initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) - assert.Equal(t, "snap-start", initSource) -} - -func TestInferSandboxInitTypeInitCachingWithPC(t *testing.T) { - initCachingEnabled := true - sandboxType := interop.SandboxPreWarmed - initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) - assert.Equal(t, "snap-start", initSource) -} diff --git a/lambda/rapidcore/bootstrap.go b/lambda/rapidcore/bootstrap.go deleted file mode 100644 index 165f532..0000000 --- a/lambda/rapidcore/bootstrap.go +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "fmt" - "os" - "path" - "path/filepath" - "strings" - - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - - log "github.com/sirupsen/logrus" -) - -type LogFormatter func(error) string -type BootstrapError func() (fatalerror.ErrorType, LogFormatter) - -// Bootstrap represents a list of executable bootstrap -// candidates in order of priority and exec metadata -type Bootstrap struct { - runtimeDomainRoot string - orderedLookupPaths []string - validCmd []string - workingDir string - cmdCandidates [][]string - extraFiles []*os.File - bootstrapError BootstrapError -} - -// Validate interface compliance -var _ interop.Bootstrap = (*Bootstrap)(nil) - -// NewBootstrap returns an instance of bootstrap defined by given params -func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string, runtimeDomainRoot string) *Bootstrap { - var orderedLookupBootstrapPaths []string - for _, args := range cmdCandidates { - // Empty args is an error, but we want to detect it later (in Cmd() call) when we are able to report a descriptive error - if len(args) != 0 { - orderedLookupBootstrapPaths = append(orderedLookupBootstrapPaths, args[0]) - } - } - - if currentWorkingDir == "" { - // use the root directory as the default working directory - currentWorkingDir = "/" - } - - if runtimeDomainRoot == "" { - runtimeDomainRoot = "/" - } - - return &Bootstrap{ - orderedLookupPaths: orderedLookupBootstrapPaths, - workingDir: currentWorkingDir, - cmdCandidates: cmdCandidates, - runtimeDomainRoot: runtimeDomainRoot, - } -} - -func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string, runtimeDomainRoot string) *Bootstrap { - if currentWorkingDir == "" { - // use the root directory as the default working directory - currentWorkingDir = "/" - } - if runtimeDomainRoot == "" { - runtimeDomainRoot = "/" - } - - // a single candidate command makes it automatically valid - return &Bootstrap{ - validCmd: cmd, - workingDir: currentWorkingDir, - runtimeDomainRoot: runtimeDomainRoot, - } -} - -// locateBootstrap sets the first occurrence of an -// actual bootstrap, given a list of possible files -func (b *Bootstrap) locateBootstrap() error { - for i, bootstrapCandidate := range b.orderedLookupPaths { - // validate path relatively to the domain's root - candidatPath := path.Join(b.runtimeDomainRoot, bootstrapCandidate) - file, err := os.Stat(candidatPath) - if err != nil { - if !os.IsNotExist(err) { - log.WithError(err).Warnf("Could not validate %s. Ignoring it.", bootstrapCandidate) - } - continue - } - if file.IsDir() { - log.Warnf("%s is a directory. Ignoring it", bootstrapCandidate) - continue - } - b.validCmd = b.cmdCandidates[i] - return nil - } - log.WithField("bootstrapPathsChecked", b.orderedLookupPaths).Warn("Couldn't find valid bootstrap(s)") - return fmt.Errorf("Couldn't find valid bootstrap(s): %s", b.orderedLookupPaths) -} - -// Cmd returns the args of bootstrap, relative to the -// chroot idenfied by `root`, where args[0] -// is the path to executable -func (b *Bootstrap) Cmd() ([]string, error) { - if len(b.validCmd) > 0 { - return b.validCmd, nil - } - - if err := b.locateBootstrap(); err != nil { - return []string{}, err - } - - log.Debug("Located runtime bootstrap", b.validCmd[0]) - return b.validCmd, nil -} - -// Env returns the environment variables available to -// the bootstrap process -func (b *Bootstrap) Env(e interop.EnvironmentVariables) map[string]string { - return e.RuntimeExecEnv() -} - -// Cwd returns the working directory of the bootstrap process -// The path is validated against the chroot identified by `root` -func (b *Bootstrap) Cwd() (string, error) { - if !filepath.IsAbs(b.workingDir) { - return "", fmt.Errorf("the working directory '%s' is invalid, it needs to be an absolute path", b.workingDir) - } - - // evaluate the path relatively to the domain's mnt namespace root - domainPath := path.Join(b.runtimeDomainRoot, b.workingDir) - if _, err := os.Stat(domainPath); os.IsNotExist(err) { - return "", fmt.Errorf("the working directory doesn't exist: %s", domainPath) - } - - return b.workingDir, nil -} - -// SetExtraFiles sets the extra file descriptors apart from 1 & 2 to be passed to runtime -func (b *Bootstrap) SetExtraFiles(extraFiles []*os.File) { - b.extraFiles = extraFiles -} - -// ExtraFiles returns the extra file descriptors apart from 1 & 2 to be passed to runtime -func (b *Bootstrap) ExtraFiles() []*os.File { - return b.extraFiles -} - -// CachedFatalError returns a bootstrap error that occurred during startup and before init -// so that it can be reported back to the customer in a later phase -func (b *Bootstrap) CachedFatalError(err error) (fatalerror.ErrorType, string, bool) { - if b.bootstrapError == nil { - return fatalerror.ErrorType(""), "", false - } - - fatalError, logFunc := b.bootstrapError() - - return fatalError, logFunc(err), true -} - -// SetCachedFatalError sets a cached fatal error that occurred during startup and before init -// so that it can be reported back to the customer in a later phase -func (b *Bootstrap) SetCachedFatalError(bootstrapErrFn BootstrapError) { - b.bootstrapError = bootstrapErrFn -} - -// BootstrapErrInvalidLCISTaskConfig represents an error while parsing LCIS task config -func BootstrapErrInvalidLCISTaskConfig(err error) BootstrapError { - return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidTaskConfig, SupernovaInvalidTaskConfigRepr(err) - } -} - -// BootstrapErrInvalidLCISEntrypoint represents an invalid LCIS entrypoint error -func BootstrapErrInvalidLCISEntrypoint(entrypoint []string, cmd []string, workingdir string) BootstrapError { - return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidEntrypoint, SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) - } -} - -func BootstrapErrInvalidLCISWorkingDir(entrypoint []string, cmd []string, workingdir string) BootstrapError { - return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidWorkingDir, SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) - } -} - -func SupernovaInvalidTaskConfigRepr(err error) func(error) string { - return func(unused error) string { - return fmt.Sprintf("IMAGE\tInvalid task config: %s", err) - } -} - -func SupernovaLaunchErrorRepr(entrypoint []string, cmd []string, workingDir string) func(error) string { - return func(err error) string { - return fmt.Sprintf("IMAGE\tLaunch error: %s\tEntrypoint: [%s]\tCmd: [%s]\tWorkingDir: [%s]", - err, - strings.Join(entrypoint, ","), - strings.Join(cmd, ","), - workingDir) - } -} diff --git a/lambda/rapidcore/bootstrap_test.go b/lambda/rapidcore/bootstrap_test.go deleted file mode 100644 index b43520d..0000000 --- a/lambda/rapidcore/bootstrap_test.go +++ /dev/null @@ -1,280 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "os" - "path" - "path/filepath" - "reflect" - "testing" - - "go.amzn.com/lambda/rapidcore/env" - - "github.com/stretchr/testify/assert" -) - -func TestBootstrap(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - tmpFile, err := os.CreateTemp("", "lcis-test-bootstrap") - assert.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - // Setup cmd candidates - nonExistent := []string{"/foo/bar/baz"} - dir := []string{tmpDir, "--arg1", "foo"} - file := []string{tmpFile.Name(), "--arg1 s", "foo"} - cmdCandidates := [][]string{nonExistent, dir, file} - - // Setup working dir - cwd, err := os.Getwd() - assert.NoError(t, err) - - // Setup environment - environment := env.NewEnvironment() - environment.StoreRuntimeAPIEnvironmentVariable("host:port") - environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") - - // Test - b := NewBootstrap(cmdCandidates, cwd, "") - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, cwd, bCwd) - assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) - - cmd, err := b.Cmd() - assert.NoError(t, err) - assert.Equal(t, file, cmd) -} - -// When running bootstraps in separate mount namespaces -// we want to verify and discover paths relative to -// a root different from "/" -func TestBootstrapChroot(t *testing.T) { - tmpRoot, err := os.MkdirTemp(os.TempDir(), "domain-root") - assert.NoError(t, err) - defer os.RemoveAll(tmpRoot) - tmpDir, err := os.MkdirTemp(tmpRoot, "lcis-test-invalid-bootstrap") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - tmpFile, err := os.CreateTemp(tmpRoot, "lcis-test-bootstrap") - assert.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - // Setup cmd candidates - nonExistent := []string{"/foo/bar/baz"} - baseName := filepath.Base(tmpDir) - dir := []string{"/" + baseName, "--arg1", "foo"} - baseName = filepath.Base(tmpFile.Name()) - file := []string{"/" + baseName, "--arg1 s", "foo"} - cmdCandidates := [][]string{nonExistent, dir, file} - - // Setup working dir - cwd, err := os.MkdirTemp(tmpRoot, "cwd") - assert.NoError(t, err) - defer os.RemoveAll(cwd) - - // Setup environment - environment := env.NewEnvironment() - environment.StoreRuntimeAPIEnvironmentVariable("host:port") - environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") - - // Test - baseName = filepath.Base(cwd) - b := NewBootstrap(cmdCandidates, "/"+baseName, tmpRoot) - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, cwd, path.Join(tmpRoot, bCwd)) - assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) - - cmd, err := b.Cmd() - assert.NoError(t, err) - assert.Equal(t, file, cmd) -} - -func TestBootstrapEmptyCandidate(t *testing.T) { - // we expect newBootstrap to succeed and bootstrap.Cmd() to fail. - // We want to postpone the failure to be able to propagate error description to slicer and write it to customer log - invalidBootstrapCandidate := []string{} - bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/", "") - _, err := bs.Cmd() - assert.Error(t, err) -} - -func TestBootstrapChrootNonExistingRoot(t *testing.T) { - invalidBootstrapCandidate := []string{"/bin/bash", "-c"} - bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/", "/does_not_exist") - _, err := bs.Cmd() - assert.Error(t, err) -} - -func TestBootstrapSingleCmd(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - tmpFile, err := os.CreateTemp("", "lcis-test-bootstrap") - assert.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - // Setup single cmd candidate - file := []string{tmpFile.Name(), "--arg1 s", "foo"} - cmdCandidate := file - - // Setup working dir - cwd, err := os.Getwd() - assert.NoError(t, err) - - // Setup environment - environment := env.NewEnvironment() - environment.StoreRuntimeAPIEnvironmentVariable("host:port") - environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") - - // Test - b := NewBootstrapSingleCmd(cmdCandidate, cwd, "") - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, cwd, bCwd) - assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) - - cmd, err := b.Cmd() - assert.NoError(t, err) - assert.Equal(t, file, cmd) -} - -func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - // Setup inexistent single cmd candidate - file := []string{"/foo/bar", "--arg1 s", "foo"} - cmdCandidate := file - - // Setup working dir - cwd, err := os.Getwd() - assert.NoError(t, err) - - // Setup environment - environment := env.NewEnvironment() - environment.StoreRuntimeAPIEnvironmentVariable("host:port") - environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") - - // Test - b := NewBootstrapSingleCmd(cmdCandidate, cwd, "") - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, cwd, bCwd) - assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) - - // No validations run against single candidates - cmd, err := b.Cmd() - assert.NoError(t, err) - assert.Equal(t, file, cmd) -} - -// Test our ability to locate bootstrap files in the file system -func TestFindCustomRuntimeIfExists(t *testing.T) { - tmpFile, err := os.CreateTemp(os.TempDir(), "tmp-") - if err != nil { - t.Fatal("Cannot create temporary file", err) - } - defer os.Remove(tmpFile.Name()) - - tmpFile2, err := os.CreateTemp(os.TempDir(), "tmp-") - if err != nil { - t.Fatal("Cannot create temporary file", err) - } - defer os.Remove(tmpFile2.Name()) - - // one bootstrap argument was given and it exists - bootstrap := NewBootstrap([][]string{{tmpFile.Name()}}, "/", "") - cmd, err := bootstrap.Cmd() - assert.NoError(t, err) - assert.Equal(t, []string{tmpFile.Name()}, cmd) - assert.Nil(t, err) - - // two bootstrap arguments given, both exist but first one is returned - bootstrap = NewBootstrap([][]string{{tmpFile.Name()}, {tmpFile2.Name()}}, "/", "") - cmd, err = bootstrap.Cmd() - assert.NoError(t, err) - assert.Equal(t, []string{tmpFile.Name()}, cmd) - assert.Nil(t, err) - - // two bootstrap arguments given, first one does not exist, second exists and is returned - bootstrap = NewBootstrap([][]string{{"mk"}, {tmpFile2.Name()}}, "/", "") - cmd, err = bootstrap.Cmd() - assert.NoError(t, err) - assert.Equal(t, []string{tmpFile2.Name()}, cmd) - assert.Nil(t, err) - - // two bootstrap arguments given, none exists - bootstrap = NewBootstrap([][]string{{"mk"}, {"mk2"}}, "/", "") - cmd, err = bootstrap.Cmd() - assert.EqualError(t, err, "Couldn't find valid bootstrap(s): [mk mk2]") - assert.Equal(t, []string{}, cmd) -} - -func TestCwdIsAbsolute(t *testing.T) { - tmpFile, err := os.CreateTemp(os.TempDir(), "tmp-") - if err != nil { - t.Fatal("Cannot create temporary file", err) - } - defer os.Remove(tmpFile.Name()) - - cmdCandidates := [][]string{{tmpFile.Name()}} - - // no errors when currentWorkingDir is absolute - bootstrap := NewBootstrap(cmdCandidates, "/tmp", "") - cwd, err := bootstrap.Cwd() - assert.Nil(t, err) - assert.Equal(t, "/tmp", cwd) - - bootstrap = NewBootstrap(cmdCandidates, "tmp", "") - _, err = bootstrap.Cwd() - assert.EqualError(t, err, "the working directory 'tmp' is invalid, it needs to be an absolute path") - - bootstrap = NewBootstrap(cmdCandidates, "./", "") - _, err = bootstrap.Cwd() - assert.EqualError(t, err, "the working directory './' is invalid, it needs to be an absolute path") -} - -func TestBootstrapMissingWorkingDirectory(t *testing.T) { - tmpFile, err := os.CreateTemp(os.TempDir(), "cwd-test-bootstrap") - assert.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - tmpDir, err := os.MkdirTemp("", "cwd-test") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - // cwd argument exists - bootstrap := NewBootstrap([][]string{{tmpFile.Name()}}, tmpDir, "") - cwd, err := bootstrap.Cwd() - assert.Equal(t, cwd, tmpDir) - assert.NoError(t, err) - - // cwd argument doesn't exist - bootstrap = NewBootstrap([][]string{{tmpFile.Name()}}, "/foo", "") - _, err = bootstrap.Cwd() - assert.EqualError(t, err, "the working directory doesn't exist: /foo") -} - -func TestDefaultWorkeringDirectory(t *testing.T) { - bootstrap := NewBootstrap([][]string{{}}, "", "") - cwd, err := bootstrap.Cwd() - assert.NoError(t, err) - assert.Equal(t, "/", cwd) -} - -func TestBootstrapSingleCmdDefaultWorkingDir(t *testing.T) { - b := NewBootstrapSingleCmd([]string{}, "", "") - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, "/", bCwd) -} diff --git a/lambda/rapidcore/env/environment.go b/lambda/rapidcore/env/environment.go index be0584c..fbe0ef2 100644 --- a/lambda/rapidcore/env/environment.go +++ b/lambda/rapidcore/env/environment.go @@ -6,9 +6,7 @@ package env import ( "fmt" "os" - "strconv" "strings" - "syscall" log "github.com/sirupsen/logrus" ) @@ -16,37 +14,24 @@ import ( const runtimeAPIAddressKey = "AWS_LAMBDA_RUNTIME_API" const handlerEnvKey = "_HANDLER" const executionEnvKey = "AWS_EXECUTION_ENV" +const taskRootEnvKey = "LAMBDA_TASK_ROOT" +const runtimeDirEnvKey = "LAMBDA_RUNTIME_DIR" // Environment holds env vars for runtime, agents, and for // internal use, parsed during startup and from START msg type Environment struct { - RAPID map[string]string // env vars req'd internally by RAPID - Platform map[string]string // reserved platform env vars as per Lambda docs - Runtime map[string]string // reserved runtime env vars as per Lambda docs - PlatformUnreserved map[string]string // unreserved platform env vars that customers can override - Credentials map[string]string // reserved env vars for credentials, set on INIT - Customer map[string]string // customer & unreserved platform env vars, set on INIT + Customer map[string]string // customer & unreserved platform env vars, set on INIT + + rapid map[string]string // env vars req'd internally by RAPID + platform map[string]string // reserved platform env vars as per Lambda docs + runtime map[string]string // reserved runtime env vars as per Lambda docs + platformUnreserved map[string]string // unreserved platform env vars that customers can override + credentials map[string]string // reserved env vars for credentials, set on INIT runtimeAPISet bool initEnvVarsSet bool } -// RapidConfig holds config req'd for RAPID's internal -// operation, parsed from internal env vars. -type RapidConfig struct { - SbID string - LogFd int - ShmFd int - CtrlFd int - CnslFd int - DirectInvokeFd int - LambdaTaskRoot string - XrayDaemonAddress string - PreLoadTimeNs int64 - FunctionName string - TelemetryAPIPassphrase string -} - func lookupEnv(keys map[string]bool) map[string]string { res := map[string]string{} for key := range keys { @@ -61,13 +46,13 @@ func lookupEnv(keys map[string]bool) map[string]string { // NewEnvironment parses environment variables into an Environment object func NewEnvironment() *Environment { return &Environment{ - RAPID: lookupEnv(predefinedInternalEnvVarKeys()), - Platform: lookupEnv(predefinedPlatformEnvVarKeys()), - Runtime: lookupEnv(predefinedRuntimeEnvVarKeys()), - PlatformUnreserved: lookupEnv(predefinedPlatformUnreservedEnvVarKeys()), + rapid: lookupEnv(predefinedInternalEnvVarKeys()), + platform: lookupEnv(predefinedPlatformEnvVarKeys()), + runtime: lookupEnv(predefinedRuntimeEnvVarKeys()), + platformUnreserved: lookupEnv(predefinedPlatformUnreservedEnvVarKeys()), - Credentials: map[string]string{}, Customer: map[string]string{}, + credentials: map[string]string{}, runtimeAPISet: false, initEnvVarsSet: false, @@ -77,44 +62,49 @@ func NewEnvironment() *Environment { // StoreRuntimeAPIEnvironmentVariable stores value for AWS_LAMBDA_RUNTIME_API func (e *Environment) StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) { - e.Platform[runtimeAPIAddressKey] = runtimeAPIAddress + e.platform[runtimeAPIAddressKey] = runtimeAPIAddress e.runtimeAPISet = true } -// GetHandler turns the current setting for handler -func (e *Environment) GetHandler() string { - return e.Runtime[handlerEnvKey] -} - // SetHandler sets _HANDLER env variable value for Runtime func (e *Environment) SetHandler(handler string) { - e.Runtime[handlerEnvKey] = handler + e.runtime[handlerEnvKey] = handler } // GetExecutionEnv returns the current setting for AWS_EXECUTION_ENV func (e *Environment) GetExecutionEnv() string { - return e.Runtime[executionEnvKey] + return e.runtime[executionEnvKey] } // SetExecutionEnv sets AWS_EXECUTION_ENV variable value for Runtime func (e *Environment) SetExecutionEnv(executionEnv string) { - e.Runtime[executionEnvKey] = executionEnv + e.runtime[executionEnvKey] = executionEnv +} + +// SetTaskRoot sets the LAMBDA_TASK_ROOT environment variable for Runtime +func (e *Environment) SetTaskRoot(taskRoot string) { + e.runtime[taskRootEnvKey] = taskRoot +} + +// SetRuntimeDir sets the LAMBDA_RUNTIME_DIR environment variable for Runtime +func (e *Environment) SetRuntimeDir(runtimeDir string) { + e.runtime[runtimeDirEnvKey] = runtimeDir } // StoreEnvironmentVariablesFromInit sets the environment variables // for credentials & _HANDLER which are received in the START message func (e *Environment) StoreEnvironmentVariablesFromInit(customerEnv map[string]string, handler, awsKey, awsSecret, awsSession, funcName, funcVer string) { - e.Credentials["AWS_ACCESS_KEY_ID"] = awsKey - e.Credentials["AWS_SECRET_ACCESS_KEY"] = awsSecret - e.Credentials["AWS_SESSION_TOKEN"] = awsSession + e.credentials["AWS_ACCESS_KEY_ID"] = awsKey + e.credentials["AWS_SECRET_ACCESS_KEY"] = awsSecret + e.credentials["AWS_SESSION_TOKEN"] = awsSession e.storeNonCredentialEnvironmentVariablesFromInit(customerEnv, handler, funcName, funcVer) } func (e *Environment) StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) { - e.Credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"] = fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port) - e.Credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"] = token + e.credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"] = fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port) + e.credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"] = token e.storeNonCredentialEnvironmentVariablesFromInit(customerEnv, handler, funcName, funcVer) } @@ -125,11 +115,11 @@ func (e *Environment) storeNonCredentialEnvironmentVariablesFromInit(customerEnv } if funcName != "" { - e.Platform["AWS_LAMBDA_FUNCTION_NAME"] = funcName + e.platform["AWS_LAMBDA_FUNCTION_NAME"] = funcName } if funcVer != "" { - e.Platform["AWS_LAMBDA_FUNCTION_VERSION"] = funcVer + e.platform["AWS_LAMBDA_FUNCTION_VERSION"] = funcVer } e.mergeCustomerEnvironmentVariables(customerEnv) // overrides env vars from CLI options @@ -154,7 +144,7 @@ func (e *Environment) RuntimeExecEnv() map[string]string { log.Fatal("credentials, customer and runtime API address must be set") } - return mapUnion(e.Customer, e.PlatformUnreserved, e.Credentials, e.Runtime, e.Platform) + return mapUnion(e.Customer, e.platformUnreserved, e.credentials, e.runtime, e.platform) } // AgentExecEnv returns the key=value strings of all environment variables @@ -166,74 +156,7 @@ func (e *Environment) AgentExecEnv() map[string]string { excludedKeys := extensionExcludedKeys() excludeCondition := func(key string) bool { return excludedKeys[key] || strings.HasPrefix(key, "_") } - return mapExclude(mapUnion(e.Customer, e.Credentials, e.Platform), excludeCondition) -} - -// RAPIDInternalConfig returns the rapid config parsed from environment vars -func (e *Environment) RAPIDInternalConfig() RapidConfig { - return RapidConfig{ - SbID: e.getStrEnvVarOrDie(e.RAPID, "_LAMBDA_SB_ID"), - LogFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_LOG_FD"), - ShmFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_SHARED_MEM_FD"), - CtrlFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_CONTROL_SOCKET"), - CnslFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_CONSOLE_SOCKET"), - DirectInvokeFd: e.getOptionalSocketEnvVar(e.RAPID, "_LAMBDA_DIRECT_INVOKE_SOCKET"), - PreLoadTimeNs: e.getInt64EnvVarOrDie(e.RAPID, "_LAMBDA_RUNTIME_LOAD_TIME"), - LambdaTaskRoot: e.getStrEnvVarOrDie(e.Runtime, "LAMBDA_TASK_ROOT"), - XrayDaemonAddress: e.getStrEnvVarOrDie(e.PlatformUnreserved, "AWS_XRAY_DAEMON_ADDRESS"), - FunctionName: e.getStrEnvVarOrDie(e.Platform, "AWS_LAMBDA_FUNCTION_NAME"), - TelemetryAPIPassphrase: e.RAPID["_LAMBDA_TELEMETRY_API_PASSPHRASE"], // TODO: Die if not set - } -} - -func (e *Environment) getStrEnvVarOrDie(env map[string]string, name string) string { - val, ok := env[name] - if !ok { - log.WithField("name", name).Fatal("Environment variable is not set") - } - return val -} - -func (e *Environment) getInt64EnvVarOrDie(env map[string]string, name string) int64 { - strval := e.getStrEnvVarOrDie(env, name) - val, err := strconv.ParseInt(strval, 10, 64) - if err != nil { - log.WithError(err).WithField("name", name).Fatal("Unable to parse int env var.") - } - return val -} - -func (e *Environment) getIntEnvVarOrDie(env map[string]string, name string) int { - return int(e.getInt64EnvVarOrDie(env, name)) -} - -// getSocketEnvVarOrDie reads and returns an int value of the -// environment variable or dies, when unable to do so. -// It also makes CloseOnExec for this value. -func (e *Environment) getSocketEnvVarOrDie(env map[string]string, name string) int { - sock := e.getIntEnvVarOrDie(env, name) - syscall.CloseOnExec(sock) - return sock -} - -// returns -1 if env variable was not set. Exits if it holds unexpected (non-int) value -func (e *Environment) getOptionalSocketEnvVar(env map[string]string, name string) int { - val, found := env[name] - if !found { - return -1 - } - - sock, err := strconv.Atoi(val) - if err != nil { - log.WithError(err).WithField("name", name).Fatal("Unable to parse socket env var.") - } - - if sock < 0 { - log.WithError(err).WithField("name", name).Fatal("Negative socket descriptor value") - } - - syscall.CloseOnExec(sock) - return sock + return mapExclude(mapUnion(e.Customer, e.credentials, e.platform), excludeCondition) } func mapUnion(maps ...map[string]string) map[string]string { diff --git a/lambda/rapidcore/env/environment_test.go b/lambda/rapidcore/env/environment_test.go index ed3043c..04c0494 100644 --- a/lambda/rapidcore/env/environment_test.go +++ b/lambda/rapidcore/env/environment_test.go @@ -34,7 +34,7 @@ func TestRAPIDInternalConfig(t *testing.T) { os.Setenv("AWS_LAMBDA_FUNCTION_NAME", "a") os.Setenv("_LAMBDA_TELEMETRY_API_PASSPHRASE", "a") os.Setenv("_LAMBDA_DIRECT_INVOKE_SOCKET", "1") - NewEnvironment().RAPIDInternalConfig() + NewRapidConfig(NewEnvironment()) } func TestEnvironmentParsing(t *testing.T) { @@ -59,11 +59,11 @@ func TestEnvironmentParsing(t *testing.T) { env.StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress) env.StoreEnvironmentVariablesFromInit(customerEnv, runtimeEnvVal, credsEnvVal, credsEnvVal, credsEnvVal, platformEnvVal, platformEnvVal) - for _, val := range env.RAPID { + for _, val := range env.rapid { assert.Equal(t, internalEnvVal, val) } - for key, val := range env.Platform { + for key, val := range env.platform { if key == runtimeAPIAddressKey { assert.Equal(t, runtimeAPIAddress, val) } else { @@ -71,16 +71,16 @@ func TestEnvironmentParsing(t *testing.T) { } } - for _, val := range env.Runtime { + for _, val := range env.runtime { assert.Equal(t, runtimeEnvVal, val) } - for key, val := range env.Credentials { + for key, val := range env.credentials { assert.Equal(t, credsEnvVal, val) assert.NotContains(t, env.Customer, key) } - for _, val := range env.PlatformUnreserved { + for _, val := range env.platformUnreserved { assert.Equal(t, customerEnvVal, val) } @@ -94,10 +94,10 @@ func TestEnvironmentParsingUnsetPlatformAndInternalEnvVarKeysAreDeleted(t *testi os.Clearenv() env := NewEnvironment() - assert.Len(t, env.RAPID, 0) - assert.Len(t, env.Platform, 0) - assert.Len(t, env.PlatformUnreserved, 0) - assert.Len(t, env.Credentials, 0) // uninitialized + assert.Len(t, env.rapid, 0) + assert.Len(t, env.platform, 0) + assert.Len(t, env.platformUnreserved, 0) + assert.Len(t, env.credentials, 0) // uninitialized assert.Len(t, env.Customer, 0) // uninitialized } @@ -136,30 +136,30 @@ func TestRuntimeExecEnvironmentVariables(t *testing.T) { rapidEnvVarsSlice := envToSlice(rapidEnvVars) - for key := range env.RAPID { + for key := range env.rapid { assert.NotContains(t, rapidEnvKeys, key) } - for key, val := range env.Runtime { + for key, val := range env.runtime { assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } - for key, val := range env.Platform { + for key, val := range env.platform { assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } - for key, val := range env.PlatformUnreserved { + for key, val := range env.platformUnreserved { assert.Contains(t, rapidEnvVarsSlice, key+"="+val) assert.NotContains(t, env.Customer, key) } - for key, val := range env.Credentials { + for key, val := range env.credentials { assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.Customer { assert.Contains(t, rapidEnvVarsSlice, key+"="+val) - assert.NotContains(t, env.PlatformUnreserved, key) + assert.NotContains(t, env.platformUnreserved, key) } } @@ -195,11 +195,11 @@ func TestRuntimeExecEnvironmentVariablesPriority(t *testing.T) { env.StoreEnvironmentVariablesFromCLIOptions(cliOptionsEnv) env.StoreEnvironmentVariablesFromInit(customerEnv, runtimeEnvVal, credsEnvVal, credsEnvVal, credsEnvVal, platformEnvVal, platformEnvVal) - assert.Equal(t, len(predefinedPlatformEnvVarKeys()), len(env.Platform)) - assert.Equal(t, len(predefinedCredentialsEnvVarKeys()), len(env.Credentials)) - assert.Equal(t, len(predefinedPlatformUnreservedEnvVarKeys()), len(env.PlatformUnreserved)) - assert.Equal(t, len(predefinedInternalEnvVarKeys()), len(env.RAPID)) - assert.Equal(t, len(predefinedRuntimeEnvVarKeys()), len(env.Runtime)) + assert.Equal(t, len(predefinedPlatformEnvVarKeys()), len(env.platform)) + assert.Equal(t, len(predefinedCredentialsEnvVarKeys()), len(env.credentials)) + assert.Equal(t, len(predefinedPlatformUnreservedEnvVarKeys()), len(env.platformUnreserved)) + assert.Equal(t, len(predefinedInternalEnvVarKeys()), len(env.rapid)) + assert.Equal(t, len(predefinedRuntimeEnvVarKeys()), len(env.runtime)) rapidEnvVars := envToSlice(env.RuntimeExecEnv()) @@ -266,15 +266,15 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { agentEnvVarsSlice := envToSlice(agentEnvVars) - for key := range env.RAPID { + for key := range env.rapid { assert.NotContains(t, agentEnvKeys, key) } - for key, val := range env.Runtime { + for key, val := range env.runtime { assert.NotContains(t, agentEnvVarsSlice, key+"="+val) } - for key := range env.Platform { + for key := range env.platform { assert.Contains(t, agentEnvKeys, key) } @@ -282,11 +282,11 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { assert.Contains(t, agentEnvKeys, key) } - for key, val := range env.Credentials { + for key, val := range env.credentials { assert.Contains(t, agentEnvVarsSlice, key+"="+val) } - assert.Contains(t, agentEnvVarsSlice, runtimeAPIAddressKey+"="+env.Platform[runtimeAPIAddressKey]) + assert.Contains(t, agentEnvVarsSlice, runtimeAPIAddressKey+"="+env.platform[runtimeAPIAddressKey]) } func TestStoreEnvironmentVariablesFromInitCaching(t *testing.T) { @@ -301,11 +301,11 @@ func TestStoreEnvironmentVariablesFromInitCaching(t *testing.T) { env.StoreEnvironmentVariablesFromInitForInitCaching("samplehost", 1234, customerEnv, handler, funcName, funcVer, token) - assert.Equal(t, fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port), env.Credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"]) - assert.Equal(t, token, env.Credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"]) - assert.Equal(t, funcName, env.Platform["AWS_LAMBDA_FUNCTION_NAME"]) - assert.Equal(t, funcVer, env.Platform["AWS_LAMBDA_FUNCTION_VERSION"]) - assert.Equal(t, handler, env.Runtime["_HANDLER"]) + assert.Equal(t, fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port), env.credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"]) + assert.Equal(t, token, env.credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"]) + assert.Equal(t, funcName, env.platform["AWS_LAMBDA_FUNCTION_NAME"]) + assert.Equal(t, funcVer, env.platform["AWS_LAMBDA_FUNCTION_VERSION"]) + assert.Equal(t, handler, env.runtime["_HANDLER"]) } func setAll(keys map[string]bool, value string) { diff --git a/lambda/rapidcore/env/rapidenv.go b/lambda/rapidcore/env/rapidenv.go new file mode 100644 index 0000000..bc1a6ad --- /dev/null +++ b/lambda/rapidcore/env/rapidenv.go @@ -0,0 +1,96 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package env + +import ( + "strconv" + "syscall" + + log "github.com/sirupsen/logrus" +) + +// RapidConfig holds config req'd for RAPID's internal +// operation, parsed from internal env vars. +// It should be build using `NewRapidConfig` to make sure that all the +// internal invariants are respected. +type RapidConfig struct { + SbID string + LogFd int + ShmFd int + CtrlFd int + CnslFd int + DirectInvokeFd int + LambdaTaskRoot string + XrayDaemonAddress string + PreLoadTimeNs int64 + FunctionName string + TelemetryAPIPassphrase string +} + +// Build the `RapidConfig` struct checking all the internal invariants +func NewRapidConfig(e *Environment) RapidConfig { + return RapidConfig{ + SbID: getStrEnvVarOrDie(e.rapid, "_LAMBDA_SB_ID"), + LogFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_LOG_FD"), + ShmFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_SHARED_MEM_FD"), + CtrlFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_CONTROL_SOCKET"), + CnslFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_CONSOLE_SOCKET"), + DirectInvokeFd: getOptionalSocketEnvVar(e.rapid, "_LAMBDA_DIRECT_INVOKE_SOCKET"), + PreLoadTimeNs: getInt64EnvVarOrDie(e.rapid, "_LAMBDA_RUNTIME_LOAD_TIME"), + LambdaTaskRoot: getStrEnvVarOrDie(e.runtime, "LAMBDA_TASK_ROOT"), + XrayDaemonAddress: getStrEnvVarOrDie(e.platformUnreserved, "AWS_XRAY_DAEMON_ADDRESS"), + FunctionName: getStrEnvVarOrDie(e.platform, "AWS_LAMBDA_FUNCTION_NAME"), + TelemetryAPIPassphrase: e.rapid["_LAMBDA_TELEMETRY_API_PASSPHRASE"], // TODO: Die if not set + } +} + +func getStrEnvVarOrDie(env map[string]string, name string) string { + val, ok := env[name] + if !ok { + log.WithField("name", name).Fatal("Environment variable is not set") + } + return val +} + +func getInt64EnvVarOrDie(env map[string]string, name string) int64 { + strval := getStrEnvVarOrDie(env, name) + val, err := strconv.ParseInt(strval, 10, 64) + if err != nil { + log.WithError(err).WithField("name", name).Fatal("Unable to parse int env var.") + } + return val +} + +func getIntEnvVarOrDie(env map[string]string, name string) int { + return int(getInt64EnvVarOrDie(env, name)) +} + +// getSocketEnvVarOrDie reads and returns an int value of the +// environment variable or dies, when unable to do so. +// It also makes CloseOnExec for this value. +func getSocketEnvVarOrDie(env map[string]string, name string) int { + sock := getIntEnvVarOrDie(env, name) + syscall.CloseOnExec(sock) + return sock +} + +// returns -1 if env variable was not set. Exits if it holds unexpected (non-int) value +func getOptionalSocketEnvVar(env map[string]string, name string) int { + val, found := env[name] + if !found { + return -1 + } + + sock, err := strconv.Atoi(val) + if err != nil { + log.WithError(err).WithField("name", name).Fatal("Unable to parse socket env var.") + } + + if sock < 0 { + log.WithError(err).WithField("name", name).Fatal("Negative socket descriptor value") + } + + syscall.CloseOnExec(sock) + return sock +} diff --git a/lambda/rapidcore/runtime_release.go b/lambda/rapidcore/runtime_release.go new file mode 100644 index 0000000..3875209 --- /dev/null +++ b/lambda/rapidcore/runtime_release.go @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +type Logging string + +const ( + AmznStdout Logging = "amzn-stdout" + AmznStdoutTLV Logging = "amzn-stdout-tlv" +) + +// RuntimeRelease stores runtime identification data +type RuntimeRelease struct { + Name string + Version string + Logging Logging +} + +const RuntimeReleasePath = "/var/runtime/runtime-release" + +// GetRuntimeRelease reads Runtime identification data from config file and parses it into a struct +func GetRuntimeRelease(path string) (*RuntimeRelease, error) { + pairs, err := ParsePropertiesFile(path) + if err != nil { + return nil, fmt.Errorf("could not parse %s: %w", path, err) + } + + return &RuntimeRelease{pairs["NAME"], pairs["VERSION"], Logging(pairs["LOGGING"])}, nil +} + +// ParsePropertiesFile reads key-value pairs from file in newline-separated list of environment-like +// shell-compatible variable assignments. +// Format: https://www.freedesktop.org/software/systemd/man/os-release.html +// Value quotes are trimmed. Latest write wins for duplicated keys. +func ParsePropertiesFile(path string) (map[string]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("could not open %s: %w", path, err) + } + defer f.Close() + + pairs := make(map[string]string) + + s := bufio.NewScanner(f) + for s.Scan() { + if s.Text() == "" || strings.HasPrefix(s.Text(), "#") { + continue + } + k, v, found := strings.Cut(s.Text(), "=") + if !found { + return nil, fmt.Errorf("could not parse key-value pair from a line: %s", s.Text()) + } + pairs[k] = strings.Trim(v, "'\"") + } + if err := s.Err(); err != nil { + return nil, fmt.Errorf("failed to read properties file: %w", err) + } + + return pairs, nil +} diff --git a/lambda/rapidcore/runtime_release_test.go b/lambda/rapidcore/runtime_release_test.go new file mode 100644 index 0000000..7397140 --- /dev/null +++ b/lambda/rapidcore/runtime_release_test.go @@ -0,0 +1,97 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetRuntimeRelease(t *testing.T) { + tests := []struct { + name string + content string + want *RuntimeRelease + }{ + { + "simple", + "NAME=foo\nVERSION=bar\nLOGGING=baz\n", + &RuntimeRelease{"foo", "bar", "baz"}, + }, + { + "no trailing new line", + "NAME=foo\nVERSION=bar\nLOGGING=baz", + &RuntimeRelease{"foo", "bar", "baz"}, + }, + { + "nonexistent keys", + "LOGGING=baz\n", + &RuntimeRelease{"", "", "baz"}, + }, + { + "empty value", + "NAME=\nVERSION=\nLOGGING=\n", + &RuntimeRelease{"", "", ""}, + }, + { + "delimiter in value", + "NAME=Foo=Bar\nVERSION=bar\nLOGGING=baz\n", + &RuntimeRelease{"Foo=Bar", "bar", "baz"}, + }, + { + "empty file", + "", + &RuntimeRelease{"", "", ""}, + }, + { + "quotes", + "NAME=\"foo\"\nVERSION='bar'\n", + &RuntimeRelease{"foo", "bar", ""}, + }, + { + "double quotes", + "NAME='\"foo\"'\nVERSION=\"'bar'\"\n", + &RuntimeRelease{"foo", "bar", ""}, + }, + { + "empty lines", // production runtime-release files have empty line in the end of the file + "\nNAME=foo\n\nVERSION=bar\n\nLOGGING=baz\n\n", + &RuntimeRelease{"foo", "bar", "baz"}, + }, + { + "comments", + "# comment 1\nNAME=foo\n# comment 2\nVERSION=bar\n# comment 3\nLOGGING=baz\n# comment 4\n", + &RuntimeRelease{"foo", "bar", "baz"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, err := os.CreateTemp(os.TempDir(), "runtime-release") + require.NoError(t, err) + _, err = f.WriteString(tt.content) + require.NoError(t, err) + got, err := GetRuntimeRelease(f.Name()) + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestGetRuntimeRelease_NotFound(t *testing.T) { + _, err := GetRuntimeRelease("/sys/not-exists") + assert.Error(t, err) +} + +func TestGetRuntimeRelease_InvalidLine(t *testing.T) { + f, err := os.CreateTemp(os.TempDir(), "runtime-release") + require.NoError(t, err) + _, err = f.WriteString("NAME=foo\nVERSION=bar\nLOGGING=baz\nSOMETHING") + require.NoError(t, err) + _, err = GetRuntimeRelease(f.Name()) + assert.Error(t, err) +} diff --git a/lambda/rapidcore/sandbox_api.go b/lambda/rapidcore/sandbox_api.go index 0c7052e..2e8d713 100644 --- a/lambda/rapidcore/sandbox_api.go +++ b/lambda/rapidcore/sandbox_api.go @@ -4,6 +4,9 @@ package rapidcore import ( + "bytes" + + "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/interop" ) @@ -14,23 +17,26 @@ type SandboxContext struct { rapidCtx interop.RapidContext handler string runtimeAPIAddress string - - InvokeReceivedTime int64 - InvokeResponseMetrics *interop.InvokeResponseMetrics } +// initContext and its methods model the initialization lifecycle +// of the Sandbox, which persist across invocations type initContext struct { - initSuccessChan chan interop.InitSuccess - initFailureChan chan interop.InitFailure - rapidCtx interop.RapidContext - sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke + initSuccessChan chan interop.InitSuccess + initFailureChan chan interop.InitFailure + rapidCtx interop.RapidContext + sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke + invokeRequestBuffer *bytes.Buffer // byte buffer used to store the invoke request rendered to runtime (reused until reset) } +// invokeContext and its methods model the invocation lifecycle type invokeContext struct { - rapidCtx interop.RapidContext - invokeRequestChan chan *interop.Invoke - invokeSuccessChan chan interop.InvokeSuccess - invokeFailureChan chan interop.InvokeFailure + rapidCtx interop.RapidContext + invokeRequestChan chan *interop.Invoke + invokeSuccessChan chan interop.InvokeSuccess + invokeFailureChan chan interop.InvokeFailure + sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke + invokeRequestBuffer *bytes.Buffer // byte buffer used to store the invoke request rendered to runtime (reused until reset) } // Validate interface compliance @@ -38,8 +44,9 @@ var _ interop.SandboxContext = (*SandboxContext)(nil) var _ interop.InitContext = (*initContext)(nil) var _ interop.InvokeContext = (*invokeContext)(nil) -func (s SandboxContext) Init(init *interop.Init, timeoutMs int64) (interop.InitStarted, interop.InitContext) { - initStartedResponseChan := make(chan interop.InitStarted) +// Init starts the runtime domain initialization in a separate goroutine. +// Return value indicates that init request has been accepted and started. +func (s SandboxContext) Init(init *interop.Init, timeoutMs int64) interop.InitContext { initSuccessResponseChan := make(chan interop.InitSuccess) initFailureResponseChan := make(chan interop.InitFailure) @@ -48,49 +55,67 @@ func (s SandboxContext) Init(init *interop.Init, timeoutMs int64) (interop.InitS } init.EnvironmentVariables.StoreRuntimeAPIEnvironmentVariable(s.runtimeAPIAddress) + extensions.DisableViaMagicLayer() - go s.rapidCtx.HandleInit(init, initStartedResponseChan, initSuccessResponseChan, initFailureResponseChan) - initStarted := <-initStartedResponseChan + // We start initialization handling in a separate goroutine so that control can be returned back to + // caller, which can do work (e.g. notifying further upstream that initialization has started), and + // and call initCtx.Wait() to wait async for completion of initialization phase. + go s.rapidCtx.HandleInit(init, initSuccessResponseChan, initFailureResponseChan) sbMetadata := interop.SandboxInfoFromInit{ EnvironmentVariables: init.EnvironmentVariables, SandboxType: init.SandboxType, RuntimeBootstrap: init.Bootstrap, } - return initStarted, newInitContext(s.rapidCtx, sbMetadata, initSuccessResponseChan, initFailureResponseChan) + return newInitContext(s.rapidCtx, sbMetadata, initSuccessResponseChan, initFailureResponseChan) } +// Reset triggers a reset. In case of timeouts, the reset handler cancels all flows which triggers +// ongoing invoke handlers to return before proceeding with invoke +// TODO: move this method to the initialization context, since reset is conceptually on RT domain func (s SandboxContext) Reset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { defer s.rapidCtx.Clear() - return s.rapidCtx.HandleReset(reset, s.InvokeReceivedTime, s.InvokeResponseMetrics) + return s.rapidCtx.HandleReset(reset) } +// Reset triggers a shutdown. This is similar to a reset, except that this is a terminal state +// and no further invokes are allowed func (s SandboxContext) Shutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { return s.rapidCtx.HandleShutdown(shutdown) } -func (s SandboxContext) Restore(restore *interop.Restore) error { +func (s SandboxContext) Restore(restore *interop.Restore) (interop.RestoreResult, error) { return s.rapidCtx.HandleRestore(restore) } -func (s *SandboxContext) SetInvokeReceivedTime(invokeReceivedTime int64) { - s.InvokeReceivedTime = invokeReceivedTime +func (s *SandboxContext) SetRuntimeStartedTime(runtimeStartedTime int64) { + s.rapidCtx.SetRuntimeStartedTime(runtimeStartedTime) } func (s *SandboxContext) SetInvokeResponseMetrics(metrics *interop.InvokeResponseMetrics) { - s.InvokeResponseMetrics = metrics + s.rapidCtx.SetInvokeResponseMetrics(metrics) } func newInitContext(r interop.RapidContext, sbMetadata interop.SandboxInfoFromInit, initSuccessChan chan interop.InitSuccess, initFailureChan chan interop.InitFailure) initContext { + + // Invocation request buffer is initialized once per initialization + // to reduce memory usage & GC CPU time across invocations + var requestBuffer bytes.Buffer + return initContext{ - initSuccessChan: initSuccessChan, - initFailureChan: initFailureChan, - rapidCtx: r, - sbInfoFromInit: sbMetadata, + initSuccessChan: initSuccessChan, + initFailureChan: initFailureChan, + rapidCtx: r, + sbInfoFromInit: sbMetadata, + invokeRequestBuffer: &requestBuffer, } } +// Wait awaits until initialization phase is complete, i.e. one of: +// - until all runtime domain process call /next +// - any one of the runtime domain processes exit (init failure) +// Timeout handling is managed upstream entirely func (i initContext) Wait() (interop.InitSuccess, *interop.InitFailure) { select { case initSuccess, isOpen := <-i.initSuccessChan: @@ -108,35 +133,44 @@ func (i initContext) Wait() (interop.InitSuccess, *interop.InitFailure) { } } +// Reserve is used to initialize invoke-related state func (i initContext) Reserve() interop.InvokeContext { - invokeRequestChan := make(chan *interop.Invoke) invokeSuccessChan := make(chan interop.InvokeSuccess) invokeFailureChan := make(chan interop.InvokeFailure) + return invokeContext{ + rapidCtx: i.rapidCtx, + invokeRequestChan: invokeRequestChan, + invokeSuccessChan: invokeSuccessChan, + invokeFailureChan: invokeFailureChan, + sbInfoFromInit: i.sbInfoFromInit, + invokeRequestBuffer: i.invokeRequestBuffer, + } +} + +// SendRequest starts the invocation request handling in a separate goroutine, +// i.e. sending the request payload via /next response, +// and waiting for the synchronization points +func (invCtx invokeContext) SendRequest(invoke *interop.Invoke, responseSender interop.InvokeResponseSender) { + // Invoke handling needs to be in a separate goroutine so that control can + // be returned immediately to calling goroutine, which can do work and + // asynchronously call invCtx.Wait() to await completion of the invoke phase go func() { - invoke := <-invokeRequestChan // For suppressed inits, invoke needs the runtime and agent env vars - invokeSuccess, invokeFailure := i.rapidCtx.HandleInvoke(invoke, i.sbInfoFromInit) + invokeSuccess, invokeFailure := invCtx.rapidCtx.HandleInvoke(invoke, invCtx.sbInfoFromInit, invCtx.invokeRequestBuffer, responseSender) if invokeFailure != nil { - invokeFailureChan <- *invokeFailure + invCtx.invokeFailureChan <- *invokeFailure } else { - invokeSuccessChan <- invokeSuccess + invCtx.invokeSuccessChan <- invokeSuccess } }() - - return invokeContext{ - rapidCtx: i.rapidCtx, - invokeRequestChan: invokeRequestChan, - invokeSuccessChan: invokeSuccessChan, - invokeFailureChan: invokeFailureChan, - } -} - -func (invCtx invokeContext) SendRequest(i *interop.Invoke) { - invCtx.invokeRequestChan <- i } +// Wait awaits invoke completion, i.e. one of the following cases: +// - until all runtime domain process call /next +// - until a process exit (that notifies upstream to trigger a reset due to "failure") +// - until a timeout (triggered by a reset from upstream due to "timeout") func (invCtx invokeContext) Wait() (interop.InvokeSuccess, *interop.InvokeFailure) { select { case invokeSuccess := <-invCtx.invokeSuccessChan: diff --git a/lambda/rapidcore/sandbox_builder.go b/lambda/rapidcore/sandbox_builder.go index ce016a0..f51acda 100644 --- a/lambda/rapidcore/sandbox_builder.go +++ b/lambda/rapidcore/sandbox_builder.go @@ -33,7 +33,7 @@ type SandboxBuilder struct { lambdaInvokeAPI LambdaInvokeAPI defaultInteropServer *Server useCustomInteropServer bool - shutdownFuncs []context.CancelFunc + shutdownFuncs []func() handler string } @@ -45,42 +45,45 @@ const ( ) func NewSandboxBuilder() *SandboxBuilder { - defaultInteropServer := NewServer(context.Background()) - signalCtx, cancelSignalCtx := context.WithCancel(context.Background()) + defaultInteropServer := NewServer() + localSv := supervisor.NewLocalSupervisor() b := &SandboxBuilder{ sandbox: &rapid.Sandbox{ - PreLoadTimeNs: 0, // TODO StandaloneMode: true, LogsEgressAPI: &telemetry.NoOpLogsEgressAPI{}, EnableTelemetryAPI: false, Tracer: telemetry.NewNoOpTracer(), - SignalCtx: signalCtx, EventsAPI: &telemetry.NoOpEventsAPI{}, InitCachingEnabled: false, - Supervisor: supervisor.NewLocalSupervisor(), + Supervisor: localSv, + RuntimeFsRootPath: localSv.RootPath, RuntimeAPIHost: "127.0.0.1", RuntimeAPIPort: 9001, }, defaultInteropServer: defaultInteropServer, - shutdownFuncs: []context.CancelFunc{}, + shutdownFuncs: []func(){}, lambdaInvokeAPI: NewEmulatorAPI(defaultInteropServer), } - b.AddShutdownFunc(context.CancelFunc(func() { + b.AddShutdownFunc(func() { log.Info("Shutting down...") defaultInteropServer.Reset("SandboxTerminated", defaultSigtermResetTimeoutMs) - cancelSignalCtx() - })) + }) return b } -func (b *SandboxBuilder) SetSupervisor(supervisor supvmodel.Supervisor) *SandboxBuilder { +func (b *SandboxBuilder) SetSupervisor(supervisor supvmodel.ProcessSupervisor) *SandboxBuilder { b.sandbox.Supervisor = supervisor return b } +func (b *SandboxBuilder) SetRuntimeFsRootPath(rootPath string) *SandboxBuilder { + b.sandbox.RuntimeFsRootPath = rootPath + return b +} + func (b *SandboxBuilder) SetRuntimeAPIAddress(runtimeAPIAddress string) *SandboxBuilder { host, port, err := net.SplitHostPort(runtimeAPIAddress) if err != nil { @@ -105,7 +108,7 @@ func (b *SandboxBuilder) SetInteropServer(interopServer interop.Server) *Sandbox return b } -func (b *SandboxBuilder) SetEventsAPI(eventsAPI telemetry.EventsAPI) *SandboxBuilder { +func (b *SandboxBuilder) SetEventsAPI(eventsAPI interop.EventsAPI) *SandboxBuilder { b.sandbox.EventsAPI = eventsAPI return b } @@ -134,11 +137,6 @@ func (b *SandboxBuilder) SetInitCachingFlag(initCachingEnabled bool) *SandboxBui return b } -func (b *SandboxBuilder) SetPreLoadTimeNs(preLoadTimeNs int64) *SandboxBuilder { - b.sandbox.PreLoadTimeNs = preLoadTimeNs - return b -} - func (b *SandboxBuilder) SetTelemetrySubscription(logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI) *SandboxBuilder { b.sandbox.EnableTelemetryAPI = true b.sandbox.LogsSubscriptionAPI = logsSubscriptionAPI @@ -156,7 +154,7 @@ func (b *SandboxBuilder) SetHandler(handler string) *SandboxBuilder { return b } -func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc context.CancelFunc) *SandboxBuilder { +func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc func()) *SandboxBuilder { b.shutdownFuncs = append(b.shutdownFuncs, shutdownFunc) return b } @@ -166,16 +164,20 @@ func (b *SandboxBuilder) Create() (interop.SandboxContext, interop.InternalState b.sandbox.InteropServer = b.defaultInteropServer } - go signalHandler(b.shutdownFuncs) + ctx, cancel := context.WithCancel(context.Background()) + + // cancel is called when handling termination signals as a cancellation + // signal to the Runtime API sever to terminate gracefully + go signalHandler(cancel, b.shutdownFuncs) - rapidCtx, internalStateFn, runtimeAPIAddr := rapid.Start(b.sandbox) + // rapid.Start, among other things, starts the Runtime API server and + // terminates it gracefully if the cxt is canceled + rapidCtx, internalStateFn, runtimeAPIAddr := rapid.Start(ctx, b.sandbox) b.sandboxContext = &SandboxContext{ - rapidCtx: rapidCtx, - handler: b.handler, - runtimeAPIAddress: runtimeAPIAddr, - InvokeReceivedTime: int64(0), - InvokeResponseMetrics: nil, + rapidCtx: rapidCtx, + handler: b.handler, + runtimeAPIAddress: runtimeAPIAddr, } return b.sandboxContext, internalStateFn @@ -205,8 +207,10 @@ func SetInternalLogOutput(w io.Writer) { logging.SetOutput(w) } -// Trap SIGINT and SIGTERM signals and call shutdown function -func signalHandler(shutdownFuncs []context.CancelFunc) { +// Trap SIGINT and SIGTERM signals, call shutdown function, and cancel the +// ctx to terminate gracefully the Runtime API server +func signalHandler(cancel context.CancelFunc, shutdownFuncs []func()) { + defer cancel() sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) sigReceived := <-sig diff --git a/lambda/rapidcore/sandbox_emulator_api.go b/lambda/rapidcore/sandbox_emulator_api.go index 6737631..4cc2183 100644 --- a/lambda/rapidcore/sandbox_emulator_api.go +++ b/lambda/rapidcore/sandbox_emulator_api.go @@ -31,6 +31,7 @@ func NewEmulatorAPI(s *Server) *EmulatorAPI { // Init method is only used by the Runtime interface emulator func (l *EmulatorAPI) Init(i *interop.Init, timeoutMs int64) { l.server.Init(&interop.Init{ + AccountID: i.AccountID, Handler: i.Handler, AwsKey: i.AwsKey, AwsSecret: i.AwsSecret, diff --git a/lambda/rapidcore/server.go b/lambda/rapidcore/server.go index e652130..e903ebe 100644 --- a/lambda/rapidcore/server.go +++ b/lambda/rapidcore/server.go @@ -33,12 +33,6 @@ const ( resetDefaultTimeoutMs = 2000 ) -const ( - contentTypeHeader = "Content-Type" - errorTypeHeader = "Error-Type" - functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" -) - type rapidPhase int const ( @@ -84,7 +78,6 @@ type Server struct { initChanOut chan *interop.Init interruptedResponseChan chan *interop.Reset - sendRunningChan chan *interop.InitStarted sendResponseChan chan *interop.InvokeResponseMetrics doneChan chan *interop.Done @@ -107,7 +100,7 @@ type Server struct { initContext interop.InitContext invoker interop.InvokeContext initFailures chan interop.InitFailure - cachedInitErrorResponse *interop.ErrorResponse + cachedInitErrorResponse *interop.ErrorInvokeResponse } // Validate interface compliance @@ -266,7 +259,7 @@ func (s *Server) Release() error { s.reservationCancel() } - s.sandboxContext.SetInvokeReceivedTime(0) + s.sandboxContext.SetRuntimeStartedTime(-1) s.sandboxContext.SetInvokeResponseMetrics(nil) s.invokeCtx = nil return nil @@ -295,7 +288,7 @@ func (s *Server) SetInternalStateGetter(cb interop.InternalStateGetter) { s.InternalStateGetter = cb } -func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[string]string, status int, payload io.Reader, trailers http.Header, request *interop.CancellableRequest, runtimeCalledResponse bool) error { +func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[string]string, payload io.Reader, trailers http.Header, request *interop.CancellableRequest, runtimeCalledResponse bool) error { if s.invokeCtx == nil || invokeID != s.invokeCtx.Token.InvokeID { return interop.ErrInvalidInvokeID } @@ -310,7 +303,7 @@ func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[strin var reportedErr error if s.invokeCtx.Direct { - if err := directinvoke.SendDirectInvokeResponse(additionalHeaders, payload, trailers, s.invokeCtx.ReplyStream, s.interruptedResponseChan, s.sendResponseChan, request, runtimeCalledResponse); err != nil { + if err := directinvoke.SendDirectInvokeResponse(additionalHeaders, payload, trailers, s.invokeCtx.ReplyStream, s.interruptedResponseChan, s.sendResponseChan, request, runtimeCalledResponse, invokeID); err != nil { // TODO: Do we need to drain the reader in case of a large payload and connection reuse? log.Errorf("Failed to write response to %s: %s", invokeID, err) reportedErr = err @@ -328,7 +321,7 @@ func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[strin } startReadingResponseMonoTimeMs := metering.Monotime() - s.invokeCtx.ReplyStream.Header().Add(contentTypeHeader, additionalHeaders[contentTypeHeader]) + s.invokeCtx.ReplyStream.Header().Add(directinvoke.ContentTypeHeader, additionalHeaders[directinvoke.ContentTypeHeader]) written, err := s.invokeCtx.ReplyStream.Write(data) if err != nil { return fmt.Errorf("Failed to write response to %s: %s", invokeID, err) @@ -355,19 +348,19 @@ func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[strin return reportedErr } -func (s *Server) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { +func (s *Server) SendResponse(invokeID string, resp *interop.StreamableInvokeResponse) error { s.setRuntimeState(runtimeInvokeResponseSent) s.mutex.Lock() defer s.mutex.Unlock() runtimeCalledResponse := true - return s.sendResponseUnsafe(invokeID, headers, http.StatusOK, reader, trailers, request, runtimeCalledResponse) + return s.sendResponseUnsafe(invokeID, resp.Headers, resp.Payload, resp.Trailers, resp.Request, runtimeCalledResponse) } -func (s *Server) SendInitErrorResponse(invokeID string, resp *interop.ErrorResponse) error { - log.Debugf("Sending Init Error Response: %s", resp.ErrorType) +func (s *Server) SendInitErrorResponse(resp *interop.ErrorInvokeResponse) error { + log.Debugf("Sending Init Error Response: %s", resp.FunctionError.Type) if s.getRapidPhase() == phaseInvoking { // This branch occurs during suppressed init - return s.SendErrorResponse(invokeID, resp) + return s.SendErrorResponse(s.GetCurrentInvokeID(), resp) } // Handle an /init/error outside of the invoke phase @@ -376,17 +369,20 @@ func (s *Server) SendInitErrorResponse(invokeID string, resp *interop.ErrorRespo return nil } -func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorResponse) error { - log.Debugf("Sending Error Response: %s", resp.ErrorType) +func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorInvokeResponse) error { + log.Debugf("Sending Error Response: %s", resp.FunctionError.Type) s.setRuntimeState(runtimeInvokeError) s.mutex.Lock() defer s.mutex.Unlock() - additionalHeaders := map[string]string{contentTypeHeader: resp.ContentType, errorTypeHeader: resp.ErrorType} - if functionResponseMode := resp.FunctionResponseMode; functionResponseMode != "" { - additionalHeaders[functionResponseModeHeader] = functionResponseMode + additionalHeaders := map[string]string{ + directinvoke.ContentTypeHeader: resp.Headers.ContentType, + directinvoke.ErrorTypeHeader: string(resp.FunctionError.Type), + } + if functionResponseMode := resp.Headers.FunctionResponseMode; functionResponseMode != "" { + additionalHeaders[directinvoke.FunctionResponseModeHeader] = functionResponseMode } runtimeCalledResponse := false // we are sending an error here, so runtime called /error or crashed/timeout - return s.sendResponseUnsafe(invokeID, additionalHeaders, http.StatusInternalServerError, bytes.NewReader(resp.Payload), nil, nil, runtimeCalledResponse) + return s.sendResponseUnsafe(invokeID, additionalHeaders, bytes.NewReader(resp.Payload), nil, nil, runtimeCalledResponse) } func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) { @@ -409,10 +405,21 @@ func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescript s.setRuntimeState(runtimeNotStarted) var meta interop.DoneMetadata - if reset.InvokeResponseMetrics != nil { + if reset.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(reset.InvokeResponseMetrics) { meta.RuntimeTimeThrottledMs = reset.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) meta.RuntimeProducedBytes = reset.InvokeResponseMetrics.ProducedBytes meta.RuntimeOutboundThroughputBps = reset.InvokeResponseMetrics.OutboundThroughputBps + meta.MetricsDimensions = interop.DoneMetadataMetricsDimensions{ + InvokeResponseMode: reset.InvokeResponseMode, + } + + // These metrics aren't present in reset struct, therefore we need to get + // them from s.sandboxContext.Reset() response + if resetFailure != nil { + meta.RuntimeResponseLatencyMs = resetFailure.ResponseMetrics.RuntimeResponseLatencyMs + } else { + meta.RuntimeResponseLatencyMs = resetSuccess.ResponseMetrics.RuntimeResponseLatencyMs + } } if resetFailure != nil { @@ -431,15 +438,24 @@ func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescript return nil, errors.New(string(done.ErrorType)) } - return &statejson.ResetDescription{ExtensionsResetMs: done.Meta.ExtensionsResetMs}, nil + return &statejson.ResetDescription{ + ExtensionsResetMs: done.Meta.ExtensionsResetMs, + ResponseMetrics: statejson.ResponseMetrics{ + RuntimeResponseLatencyMs: done.Meta.RuntimeResponseLatencyMs, + Dimensions: statejson.ResponseMetricsDimensions{ + InvokeResponseMode: statejson.InvokeResponseMode( + done.Meta.MetricsDimensions.InvokeResponseMode, + ), + }, + }, + }, nil } -func NewServer(ctx context.Context) *Server { +func NewServer() *Server { s := &Server{ initChanOut: make(chan *interop.Init), interruptedResponseChan: make(chan *interop.Reset), - sendRunningChan: make(chan *interop.InitStarted), sendResponseChan: make(chan *interop.InvokeResponseMetrics), doneChan: make(chan *interop.Done), @@ -500,18 +516,15 @@ func (s *Server) Init(i *interop.Init, invokeTimeoutMs int64) error { s.SetInvokeTimeout(time.Duration(invokeTimeoutMs) * time.Millisecond) s.setRapidPhase(phaseInitializing) s.setInitFailuresChan() - initStarted, initCtx := s.sandboxContext.Init(i, invokeTimeoutMs) - initStarted.Ack <- struct{}{} + initCtx := s.sandboxContext.Init(i, invokeTimeoutMs) s.initContext = initCtx go s.awaitInitCompletion() - log.Debugf("Received RUNNING: %v", initStarted) return nil } func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error { - s.sandboxContext.SetInvokeReceivedTime(i.InvokeReceivedTime) invokeID, err := s.setReplyStream(w, direct) if err != nil { return err @@ -536,7 +549,7 @@ func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct boo s.setRuntimeState(runtimeInvokeComplete) return } - s.invoker.SendRequest(i) + s.invoker.SendRequest(i, s) invokeSuccess, invokeFailure := s.invoker.Wait() if invokeFailure != nil { if invokeFailure.ResetReceived { @@ -579,19 +592,19 @@ func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct boo return nil } -func (s *Server) setCachedInitErrorResponse(errResp *interop.ErrorResponse) { +func (s *Server) setCachedInitErrorResponse(errResp *interop.ErrorInvokeResponse) { s.mutex.Lock() defer s.mutex.Unlock() s.cachedInitErrorResponse = errResp } -func (s *Server) getCachedInitErrorResponse() *interop.ErrorResponse { +func (s *Server) getCachedInitErrorResponse() *interop.ErrorInvokeResponse { s.mutex.Lock() defer s.mutex.Unlock() return s.cachedInitErrorResponse } -func (s *Server) trySendDefaultErrorResponse(resp *interop.ErrorResponse) { +func (s *Server) trySendDefaultErrorResponse(resp *interop.ErrorInvokeResponse) { if err := s.SendErrorResponse(s.GetCurrentInvokeID(), resp); err != nil { if err != interop.ErrResponseSent { log.Panicf("Failed to send default error response: %s", err) @@ -658,9 +671,15 @@ func (s *Server) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invo // For init failures, cache the response so they can be checked later // We check if they have not already been set by a call to /init/error by runtime if s.getCachedInitErrorResponse() == nil { - errType, errMsg := string(initCompletionResp.InitErrorType), initCompletionResp.InitErrorMessage.Error() - s.setCachedInitErrorResponse(&interop.ErrorResponse{ErrorType: errType, ErrorMessage: errMsg}) + errType, errMsg := initCompletionResp.InitErrorType, initCompletionResp.InitErrorMessage.Error() + headers := interop.InvokeResponseHeaders{} + fnError := interop.FunctionError{Type: errType, Message: errMsg} + s.setCachedInitErrorResponse(&interop.ErrorInvokeResponse{Headers: headers, FunctionError: fnError, Payload: []byte{}}) } + + // Init failed, so we explicitly shutdown runtime (cleanup unused extensions). + // Because following fast invoke will start new (supressed) Init phase without reset call + s.Shutdown(&interop.Shutdown{DeadlineNs: metering.Monotime() + int64(resetDefaultTimeoutMs*1000*1000)}) } } @@ -759,7 +778,7 @@ func (s *Server) AwaitInitialized() error { return nil } -func (s *Server) AwaitRelease() (*statejson.InternalStateDescription, error) { +func (s *Server) AwaitRelease() (*statejson.ReleaseResponse, error) { defer func() { s.setRapidPhase(phaseIdle) s.setRuntimeState(runtimeInvokeComplete) @@ -776,8 +795,20 @@ func (s *Server) AwaitRelease() (*statejson.InternalStateDescription, error) { return nil, ErrInvokeDoneFailed } + releaseResponse := statejson.ReleaseResponse{ + InternalStateDescription: &doneWithState.State, + ResponseMetrics: statejson.ResponseMetrics{ + RuntimeResponseLatencyMs: doneWithState.Meta.RuntimeResponseLatencyMs, + Dimensions: statejson.ResponseMetricsDimensions{ + InvokeResponseMode: statejson.InvokeResponseMode( + doneWithState.Meta.MetricsDimensions.InvokeResponseMode, + ), + }, + }, + } + s.Release() - return &doneWithState.State, nil + return &releaseResponse, nil case <-s.reservationContext.Done(): return nil, ErrReleaseReservationDone @@ -806,7 +837,7 @@ func (s *Server) InternalState() (*statejson.InternalStateDescription, error) { return &state, nil } -func (s *Server) Restore(restore *interop.Restore) error { +func (s *Server) Restore(restore *interop.Restore) (interop.RestoreResult, error) { return s.sandboxContext.Restore(restore) } @@ -822,10 +853,14 @@ func doneFromInvokeSuccess(successMsg interop.InvokeSuccess) *interop.Done { InvokeCompletionTimeNs: successMsg.InvokeCompletionTimeNs, InvokeReceivedTime: successMsg.InvokeReceivedTime, + RuntimeResponseLatencyMs: successMsg.ResponseMetrics.RuntimeResponseLatencyMs, RuntimeTimeThrottledMs: successMsg.ResponseMetrics.RuntimeTimeThrottledMs, RuntimeProducedBytes: successMsg.ResponseMetrics.RuntimeProducedBytes, RuntimeOutboundThroughputBps: successMsg.ResponseMetrics.RuntimeOutboundThroughputBps, LogsAPIMetrics: successMsg.LogsAPIMetrics, + MetricsDimensions: interop.DoneMetadataMetricsDimensions{ + InvokeResponseMode: successMsg.InvokeResponseMode, + }, }, } } @@ -838,6 +873,7 @@ func doneFailFromInvokeFailure(failureMsg *interop.InvokeFailure) *interop.DoneF NumActiveExtensions: failureMsg.NumActiveExtensions, InvokeReceivedTime: failureMsg.InvokeReceivedTime, + RuntimeResponseLatencyMs: failureMsg.ResponseMetrics.RuntimeResponseLatencyMs, RuntimeTimeThrottledMs: failureMsg.ResponseMetrics.RuntimeTimeThrottledMs, RuntimeProducedBytes: failureMsg.ResponseMetrics.RuntimeProducedBytes, RuntimeOutboundThroughputBps: failureMsg.ResponseMetrics.RuntimeOutboundThroughputBps, @@ -848,6 +884,10 @@ func doneFailFromInvokeFailure(failureMsg *interop.InvokeFailure) *interop.DoneF ExtensionNames: failureMsg.ExtensionNames, LogsAPIMetrics: failureMsg.LogsAPIMetrics, + + MetricsDimensions: interop.DoneMetadataMetricsDimensions{ + InvokeResponseMode: failureMsg.InvokeResponseMode, + }, }, } } diff --git a/lambda/rapidcore/server_test.go b/lambda/rapidcore/server_test.go index 88eea3f..68ac30c 100644 --- a/lambda/rapidcore/server_test.go +++ b/lambda/rapidcore/server_test.go @@ -27,12 +27,6 @@ func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { } } -func sendInitStartedResponse(responseChannel chan<- interop.InitStarted, msg interop.InitStarted) { - msg.Ack = make(chan struct{}) - responseChannel <- msg - <-msg.Ack -} - func sendInitSuccessResponse(responseChannel chan<- interop.InitSuccess, msg interop.InitSuccess) { msg.Ack = make(chan struct{}) responseChannel <- msg @@ -46,20 +40,20 @@ func sendInitFailureResponse(responseChannel chan<- interop.InitFailure, msg int } type mockRapidCtx struct { - initHandler func(start chan<- interop.InitStarted, success chan<- interop.InitSuccess, fail chan<- interop.InitFailure) + initHandler func(success chan<- interop.InitSuccess, fail chan<- interop.InitFailure) invokeHandler func() (interop.InvokeSuccess, *interop.InvokeFailure) resetHandler func() (interop.ResetSuccess, *interop.ResetFailure) } -func (r *mockRapidCtx) HandleInit(init *interop.Init, startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - r.initHandler(startResp, successResp, failureResp) +func (r *mockRapidCtx) HandleInit(init *interop.Init, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + r.initHandler(successResp, failureResp) } -func (r *mockRapidCtx) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { +func (r *mockRapidCtx) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit, buf *bytes.Buffer, responseSender interop.InvokeResponseSender) (interop.InvokeSuccess, *interop.InvokeFailure) { return r.invokeHandler() } -func (r *mockRapidCtx) HandleReset(reset *interop.Reset, invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { +func (r *mockRapidCtx) HandleReset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { return r.resetHandler() } @@ -67,25 +61,33 @@ func (r *mockRapidCtx) HandleShutdown(shutdown *interop.Shutdown) interop.Shutdo return interop.ShutdownSuccess{} } -func (r *mockRapidCtx) HandleRestore(restore *interop.Restore) error { - return nil +func (r *mockRapidCtx) HandleRestore(restore *interop.Restore) (interop.RestoreResult, error) { + return interop.RestoreResult{}, nil } func (r *mockRapidCtx) Clear() {} +func (r *mockRapidCtx) SetRuntimeStartedTime(a int64) { +} + +func (r *mockRapidCtx) SetInvokeResponseMetrics(a *interop.InvokeResponseMetrics) { +} + +func (r *mockRapidCtx) SetEventsAPI(e interop.EventsAPI) { +} + func TestReserveDoesNotDeadlockWhenCalledMultipleTimes(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { sendInitSuccessResponse(successResp, interop.InitSuccess{}) } srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ initHandler, func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + }, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) @@ -112,18 +114,17 @@ func TestReserveDoesNotDeadlockWhenCalledMultipleTimes(t *testing.T) { } func TestInitSuccess(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { sendInitSuccessResponse(successResp, interop.InitSuccess{}) } srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ initHandler, func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + }, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -134,13 +135,12 @@ func TestInitSuccess(t *testing.T) { func TestInitErrorBeforeReserve(t *testing.T) { // Rapid thread sending init failure should not be blocked even if reserve hasn't arrived - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) initErrorResponseSent := make(chan error) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) sendInitFailureResponse(failureResp, interop.InitFailure{}) initErrorResponseSent <- errors.New("initErrorResponseSent") } @@ -148,7 +148,7 @@ func TestInitErrorBeforeReserve(t *testing.T) { initHandler, func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + }, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) @@ -169,19 +169,18 @@ func TestInitErrorBeforeReserve(t *testing.T) { } func TestInitErrorDuringReserve(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) sendInitFailureResponse(failureResp, interop.InitFailure{}) } srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ initHandler, func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + }, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) resp, err := srv.Reserve("", "", "") @@ -197,24 +196,24 @@ func TestInitErrorDuringReserve(t *testing.T) { } func TestInvokeSuccess(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) releaseRuntimeInit := make(chan struct{}) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { <-releaseRuntimeInit sendInitSuccessResponse(successResp, interop.InitSuccess{}) } invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader([]byte("response")), nil, nil)) + response := &interop.StreamableInvokeResponse{Headers: map[string]string{"Content-Type": "application/json"}, Payload: bytes.NewReader([]byte("response"))} + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), response)) require.NoError(t, srv.SendRuntimeReady()) return interop.InvokeSuccess{}, nil } resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -239,16 +238,16 @@ func TestInvokeSuccess(t *testing.T) { } func TestInvokeError(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { sendInitSuccessResponse(successResp, interop.InitSuccess{}) } invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }"), ContentType: "application/json"})) + headers := interop.InvokeResponseHeaders{ContentType: "application/json"} + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }"), Headers: headers})) require.NoError(t, srv.SendRuntimeReady()) return interop.InvokeSuccess{}, nil } @@ -257,7 +256,7 @@ func TestInvokeError(t *testing.T) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -291,19 +290,19 @@ func TestInvokeWithSuppressedInitSuccess(t *testing.T) { // Reserve() returns ErrInitAlreadyDone, since the server implementation // closes the InitDone channel after the first InitDone message. - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) initErrorCompleted := make(chan error) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) sendInitFailureResponse(failureResp, interop.InitFailure{}) initErrorCompleted <- errors.New("initErrorSequenceCompleted") } invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), nil, bytes.NewReader([]byte("response")), nil, nil)) + response := &interop.StreamableInvokeResponse{Payload: bytes.NewReader([]byte("response"))} + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), response)) return interop.InvokeSuccess{}, nil } @@ -311,7 +310,7 @@ func TestInvokeWithSuppressedInitSuccess(t *testing.T) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -356,27 +355,26 @@ func TestInvokeWithSuppressedInitSuccess(t *testing.T) { func TestInvokeWithSuppressedInitErrorDueToInitError(t *testing.T) { // Tests init/error followed by init/error during suppressed init - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) sendInitFailureResponse(failureResp, interop.InitFailure{}) } releaseChan := make(chan error) invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) releaseChan <- nil - return interop.InvokeSuccess{}, &interop.InvokeFailure{ErrorType: "A.B", RequestReset: true, DefaultErrorResponse: &interop.ErrorResponse{}} + return interop.InvokeSuccess{}, &interop.InvokeFailure{ErrorType: "A.B", RequestReset: true, DefaultErrorResponse: &interop.ErrorInvokeResponse{}} } resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) @@ -411,16 +409,15 @@ func TestInvokeWithSuppressedInitErrorDueToInitError(t *testing.T) { func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { // Tests init/error followed by init/error during suppressed init - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) sendInitFailureResponse(failureResp, interop.InitFailure{}) } invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) require.NoError(t, srv.SendRuntimeReady()) return interop.InvokeSuccess{}, nil } @@ -429,7 +426,7 @@ func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -461,16 +458,16 @@ func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { } func TestMultipleInvokeSuccess(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { sendInitSuccessResponse(successResp, interop.InitSuccess{}) } i := 0 invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), nil, bytes.NewReader([]byte("response-"+fmt.Sprint(i))), nil, nil)) + response := &interop.StreamableInvokeResponse{Payload: bytes.NewReader([]byte("response-" + fmt.Sprint(i)))} + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), response)) require.NoError(t, srv.SendRuntimeReady()) i++ return interop.InvokeSuccess{}, nil @@ -480,7 +477,7 @@ func TestMultipleInvokeSuccess(t *testing.T) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -505,11 +502,42 @@ func TestMultipleInvokeSuccess(t *testing.T) { } } +func TestAwaitReleaseOnSuccess(t *testing.T) { + srv := NewServer() + + // mocks + internalStateDescription := statejson.InternalStateDescription{} + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return internalStateDescription }) + doneWithState := DoneWithState{ + State: internalStateDescription, + Done: &interop.Done{ + Meta: interop.DoneMetadata{ + RuntimeResponseLatencyMs: 12345, + MetricsDimensions: interop.DoneMetadataMetricsDimensions{ + InvokeResponseMode: interop.InvokeResponseModeStreaming, + }, + }, + }, + } + srv.InvokeDoneChan <- doneWithState + srv.reservationContext, srv.reservationCancel = context.WithCancel(context.Background()) + + // under test + responseAwaitRelease, err := srv.AwaitRelease() + + // assertions + require.NoError(t, err) + require.Equal(t, doneWithState.Done.Meta.RuntimeResponseLatencyMs, responseAwaitRelease.ResponseMetrics.RuntimeResponseLatencyMs) + require.Equal(t, string(doneWithState.Done.Meta.MetricsDimensions.InvokeResponseMode), string(responseAwaitRelease.ResponseMetrics.Dimensions.InvokeResponseMode)) + require.Equal(t, &doneWithState.State, responseAwaitRelease.InternalStateDescription) +} + /* Unit tests remaining: - Shutdown behaviour - Reset behaviour during various phases - Runtime / extensions process exit sequences - Invoke() and Init() api tests +- How can we add handleRestore test here? See PlantUML state diagram for potential other uncovered paths through the state machine diff --git a/lambda/rapidcore/standalone/eventLogHandler.go b/lambda/rapidcore/standalone/eventLogHandler.go index 156db99..e5bf7ac 100644 --- a/lambda/rapidcore/standalone/eventLogHandler.go +++ b/lambda/rapidcore/standalone/eventLogHandler.go @@ -8,11 +8,11 @@ import ( "fmt" "net/http" - "go.amzn.com/lambda/rapidcore/telemetry" + "go.amzn.com/lambda/rapidcore/standalone/telemetry" ) -func EventLogHandler(w http.ResponseWriter, r *http.Request, eventLog *telemetry.EventLog) { - bytes, err := json.Marshal(eventLog) +func EventLogHandler(w http.ResponseWriter, r *http.Request, eventsAPI *telemetry.StandaloneEventsAPI) { + bytes, err := json.Marshal(eventsAPI.EventLog()) if err != nil { http.Error(w, fmt.Sprintf("marshalling error: %s", err), http.StatusInternalServerError) return diff --git a/lambda/rapidcore/standalone/executeHandler.go b/lambda/rapidcore/standalone/executeHandler.go index 9bac400..0c7162b 100644 --- a/lambda/rapidcore/standalone/executeHandler.go +++ b/lambda/rapidcore/standalone/executeHandler.go @@ -27,19 +27,21 @@ func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.LambdaInv switch err { // Reserve errors: case rapidcore.ErrAlreadyReserved: - log.Errorf("Failed to reserve: %s", err) + log.WithError(err).Error("Failed to reserve as it is already reserved.") w.WriteHeader(400) case rapidcore.ErrInternalServerError: + log.WithError(err).Error("Failed to reserve from an internal server error.") w.WriteHeader(http.StatusInternalServerError) // Invoke errors: case rapidcore.ErrNotReserved, rapidcore.ErrAlreadyReplied, rapidcore.ErrAlreadyInvocating: - log.Errorf("Failed to set reply stream: %s", err) + log.WithError(err).Error("Failed to invoke from setting the reply stream.") w.WriteHeader(400) case rapidcore.ErrInvokeResponseAlreadyWritten: return case rapidcore.ErrInvokeTimeout, rapidcore.ErrInitResetReceived: + log.WithError(err).Error("Failed to invoke from an invoke timeout.") w.WriteHeader(http.StatusGatewayTimeout) // DONE failures: @@ -50,6 +52,7 @@ func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.LambdaInv return // Reservation canceled errors case rapidcore.ErrReserveReservationDone, rapidcore.ErrInvokeReservationDone, rapidcore.ErrReleaseReservationDone, rapidcore.ErrInitNotStarted: + log.WithError(err).Error("Failed to cancel reservation.") w.WriteHeader(http.StatusGatewayTimeout) } diff --git a/lambda/rapidcore/standalone/invokeHandler.go b/lambda/rapidcore/standalone/invokeHandler.go index 3e9768c..48a3a03 100644 --- a/lambda/rapidcore/standalone/invokeHandler.go +++ b/lambda/rapidcore/standalone/invokeHandler.go @@ -6,6 +6,7 @@ package standalone import ( "fmt" "net/http" + "strconv" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" @@ -22,12 +23,30 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { return } + restoreDurationHeader := r.Header.Get("restore-duration") + restoreStartHeader := r.Header.Get("restore-start-time") + + var restoreDurationNs int64 = 0 + var restoreStartTimeMonotime int64 = 0 + if restoreDurationHeader != "" && restoreStartHeader != "" { + var err1, err2 error + restoreDurationNs, err1 = strconv.ParseInt(restoreDurationHeader, 10, 64) + restoreStartTimeMonotime, err2 = strconv.ParseInt(restoreStartHeader, 10, 64) + if err1 != nil || err2 != nil { + log.Errorf("Failed to parse 'restore-duration' from '%s' and/or 'restore-start-time' from '%s'", restoreDurationHeader, restoreStartHeader) + restoreDurationNs = 0 + restoreStartTimeMonotime = 0 + } + } + invokePayload := &interop.Invoke{ - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: r.Body, - DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), - InvokeReceivedTime: metering.Monotime(), + TraceID: r.Header.Get("X-Amzn-Trace-Id"), + LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), + Payload: r.Body, + DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), + InvokeReceivedTime: metering.Monotime(), + RestoreDurationNs: restoreDurationNs, + RestoreStartTimeMonotime: restoreStartTimeMonotime, } if err := s.AwaitInitialized(); err != nil { diff --git a/lambda/rapidcore/standalone/restoreHandler.go b/lambda/rapidcore/standalone/restoreHandler.go index 190b6d8..fdf7a5d 100644 --- a/lambda/rapidcore/standalone/restoreHandler.go +++ b/lambda/rapidcore/standalone/restoreHandler.go @@ -4,7 +4,9 @@ package standalone import ( + "encoding/json" "net/http" + "strconv" "time" log "github.com/sirupsen/logrus" @@ -12,10 +14,11 @@ import ( ) type RestoreBody struct { - AwsKey string `json:"awskey"` - AwsSecret string `json:"awssecret"` - AwsSession string `json:"awssession"` - CredentialsExpiry time.Time `json:"credentialsExpiry"` + AwsKey string `json:"awskey"` + AwsSecret string `json:"awssecret"` + AwsSession string `json:"awssession"` + CredentialsExpiry time.Time `json:"credentialsExpiry"` + RestoreHookTimeoutMs int64 `json:"restoreHookTimeoutMs"` } func RestoreHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { @@ -26,16 +29,30 @@ func RestoreHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { } restore := &interop.Restore{ - AwsKey: restoreRequest.AwsKey, - AwsSecret: restoreRequest.AwsSecret, - AwsSession: restoreRequest.AwsSession, - CredentialsExpiry: restoreRequest.CredentialsExpiry, + AwsKey: restoreRequest.AwsKey, + AwsSecret: restoreRequest.AwsSecret, + AwsSession: restoreRequest.AwsSession, + CredentialsExpiry: restoreRequest.CredentialsExpiry, + RestoreHookTimeoutMs: restoreRequest.RestoreHookTimeoutMs, } - err := s.Restore(restore) + restoreResult, err := s.Restore(restore) + + responseMap := make(map[string]string) + + responseMap["restoreMs"] = strconv.FormatInt(restoreResult.RestoreMs, 10) if err != nil { log.Errorf("Failed to restore: %s", err) + responseMap["restoreError"] = err.Error() w.WriteHeader(http.StatusBadGateway) } + + responseJSON, err := json.Marshal(responseMap) + + if err != nil { + log.Panicf("Cannot marshal the response map for RESTORE, %v", responseMap) + } + + w.Write(responseJSON) } diff --git a/lambda/rapidcore/standalone/router.go b/lambda/rapidcore/standalone/router.go index f1712ea..7957c32 100644 --- a/lambda/rapidcore/standalone/router.go +++ b/lambda/rapidcore/standalone/router.go @@ -10,7 +10,7 @@ import ( "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" - "go.amzn.com/lambda/rapidcore/telemetry" + "go.amzn.com/lambda/rapidcore/standalone/telemetry" "github.com/go-chi/chi" ) @@ -21,14 +21,14 @@ type InteropServer interface { FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) - AwaitRelease() (*statejson.InternalStateDescription, error) + AwaitRelease() (*statejson.ReleaseResponse, error) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription InternalState() (*statejson.InternalStateDescription, error) CurrentToken() *interop.Token - Restore(restore *interop.Restore) error + Restore(restore *interop.Restore) (interop.RestoreResult, error) } -func NewHTTPRouter(ipcSrv InteropServer, lambdaInvokeAPI rapidcore.LambdaInvokeAPI, eventLog *telemetry.EventLog, shutdownFunc context.CancelFunc, bs interop.Bootstrap) *chi.Mux { +func NewHTTPRouter(ipcSrv InteropServer, lambdaInvokeAPI rapidcore.LambdaInvokeAPI, eventsAPI *telemetry.StandaloneEventsAPI, shutdownFunc context.CancelFunc, bs interop.Bootstrap) *chi.Mux { r := chi.NewRouter() r.Use(standaloneAccessLogDecorator) @@ -43,7 +43,7 @@ func NewHTTPRouter(ipcSrv InteropServer, lambdaInvokeAPI rapidcore.LambdaInvokeA r.Post("/test/shutdown", func(w http.ResponseWriter, r *http.Request) { ShutdownHandler(w, r, ipcSrv, shutdownFunc) }) r.Post("/test/directInvoke/{reservationtoken}", func(w http.ResponseWriter, r *http.Request) { DirectInvokeHandler(w, r, ipcSrv) }) r.Get("/test/internalState", func(w http.ResponseWriter, r *http.Request) { InternalStateHandler(w, r, ipcSrv) }) - r.Get("/test/eventLog", func(w http.ResponseWriter, r *http.Request) { EventLogHandler(w, r, eventLog) }) + r.Get("/test/eventLog", func(w http.ResponseWriter, r *http.Request) { EventLogHandler(w, r, eventsAPI) }) r.Post("/test/restore", func(w http.ResponseWriter, r *http.Request) { RestoreHandler(w, r, ipcSrv) }) return r } diff --git a/lambda/rapidcore/standalone/telemetry/agent_writer.go b/lambda/rapidcore/standalone/telemetry/agent_writer.go new file mode 100644 index 0000000..6ff2581 --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/agent_writer.go @@ -0,0 +1,30 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "bufio" + "bytes" +) + +type SandboxAgentWriter struct { + eventType string // 'runtime' or 'extension' + eventsAPI *StandaloneEventsAPI +} + +func NewSandboxAgentWriter(api *StandaloneEventsAPI, source string) *SandboxAgentWriter { + return &SandboxAgentWriter{ + eventType: source, + eventsAPI: api, + } +} + +func (w *SandboxAgentWriter) Write(logline []byte) (int, error) { + scanner := bufio.NewScanner(bytes.NewReader(logline)) + scanner.Split(bufio.ScanLines) + for scanner.Scan() { + w.eventsAPI.sendLogEvent(w.eventType, scanner.Text()) + } + return len(logline), nil +} diff --git a/lambda/rapidcore/standalone/telemetry/eventLog.go b/lambda/rapidcore/standalone/telemetry/eventLog.go new file mode 100644 index 0000000..0ab7c44 --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/eventLog.go @@ -0,0 +1,13 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +type EventLog struct { + Events []SandboxEvent `json:"events,omitempty"` // populated by the StandaloneEventLog object + Traces []TracingEvent `json:"traces,omitempty"` +} + +func NewEventLog() *EventLog { + return &EventLog{} +} diff --git a/lambda/rapidcore/standalone/telemetry/events_api.go b/lambda/rapidcore/standalone/telemetry/events_api.go new file mode 100644 index 0000000..dcac7a3 --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/events_api.go @@ -0,0 +1,293 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "encoding/json" + "sort" + "sync" + "time" + + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" +) + +type EventType = string + +const ( + PlatformInitStart = EventType("platform.initStart") + PlatformInitRuntimeDone = EventType("platform.initRuntimeDone") + PlatformInitReport = EventType("platform.initReport") + PlatformRestoreRuntimeDone = EventType("platform.restoreRuntimeDone") + PlatformStart = EventType("platform.start") + PlatformRuntimeDone = EventType("platform.runtimeDone") + PlatformExtension = EventType("platform.extension") + PlatformEnd = EventType("platform.end") + PlatformReport = EventType("platform.report") + PlatformFault = EventType("platform.fault") +) + +/* +SandboxEvent represents a generic sandbox event. For example: + + { + "time": "2021-03-16T13:10:42.358Z", + "type": "platform.extension", + "platformEvent": { "name": "foo bar", "state": "Ready", "events": ["INVOKE", "SHUTDOWN"]} + } + +Or: + + { + "time": "2021-03-16T13:10:42.358Z", + "type": "extension", + "logMessage": "raw agent console output" + } + +FluxPump produces entries with a single field 'record', containing either an object or a string. +We make the distinction explicit by providing separate fields for the two cases, 'PlatformEvent' and 'LogMessage'. +Either one of the two would be populated, but not both. This makes code cleaner, but requires test client to merge +two fields back, producing a single 'record' entry again -- to match the FluxPump format that tests actually check. +*/ +type SandboxEvent struct { + Time string `json:"time"` + Type EventType `json:"type"` + PlatformEvent map[string]interface{} `json:"platformEvent,omitempty"` + LogMessage string `json:"logMessage,omitempty"` +} + +type tailLogs struct { + Events []SandboxEvent `json:"events,omitempty"` +} + +type StandaloneEventsAPI struct { + lock sync.Mutex + requestID interop.RequestID + eventLog EventLog +} + +func (s *StandaloneEventsAPI) LogTrace(entry TracingEvent) { + s.lock.Lock() + defer s.lock.Unlock() + s.eventLog.Traces = append(s.eventLog.Traces, entry) +} + +func (s *StandaloneEventsAPI) EventLog() *EventLog { + return &s.eventLog +} + +func (s *StandaloneEventsAPI) SetCurrentRequestID(requestID interop.RequestID) { + s.requestID = requestID +} + +func (s *StandaloneEventsAPI) SendInitStart(data interop.InitStartData) error { + record := map[string]interface{}{ + "initializationType": data.InitializationType, + "runtimeVersion": data.RuntimeVersion, + "runtimeArn": data.RuntimeVersionArn, + "runtimeVersionArn": data.RuntimeVersionArn, + "functionArn": data.FunctionArn, + "functionName": data.FunctionName, + "functionVersion": data.FunctionVersion, + "instanceId": data.InstanceID, + "instanceMaxMemory": data.InstanceMaxMemory, + "phase": data.Phase, + } + + s.addTracingToRecord(data.Tracing, record) + + return s.sendPlatformEvent(PlatformInitStart, record) +} + +func (s *StandaloneEventsAPI) SendInitRuntimeDone(data interop.InitRuntimeDoneData) error { + record := map[string]interface{}{ + "initializationType": data.InitializationType, + "status": data.Status, + "phase": data.Phase, + } + + s.addTracingToRecord(data.Tracing, record) + + if data.ErrorType != nil { + record["errorType"] = data.ErrorType + } + + return s.sendPlatformEvent(PlatformInitRuntimeDone, record) +} + +func (s *StandaloneEventsAPI) SendInitReport(data interop.InitReportData) error { + record := map[string]interface{}{ + "initializationType": data.InitializationType, + "metrics": data.Metrics, + "phase": data.Phase, + } + + s.addTracingToRecord(data.Tracing, record) + + return s.sendPlatformEvent(PlatformInitReport, record) +} + +func (s *StandaloneEventsAPI) SendRestoreRuntimeDone(data interop.RestoreRuntimeDoneData) error { + record := map[string]interface{}{"status": data.Status} + + s.addTracingToRecord(data.Tracing, record) + + if data.ErrorType != nil { + record["errorType"] = data.ErrorType + } + + return s.sendPlatformEvent(PlatformRestoreRuntimeDone, record) +} + +func (s *StandaloneEventsAPI) SendInvokeStart(data interop.InvokeStartData) error { + record := map[string]interface{}{ + "version": data.Version, + "requestId": data.RequestID, + } + + s.addTracingToRecord(data.Tracing, record) + + return s.sendPlatformEvent(PlatformStart, record) +} + +func (s *StandaloneEventsAPI) SendInvokeRuntimeDone(data interop.InvokeRuntimeDoneData) error { + record := map[string]interface{}{ + "requestId": s.requestID, + "status": data.Status, + "metrics": data.Metrics, + "internalMetrics": data.InternalMetrics, + "spans": data.Spans, + } + + if data.ErrorType != nil { + record["errorType"] = data.ErrorType + } + + s.addTracingToRecord(data.Tracing, record) + + return s.sendPlatformEvent(PlatformRuntimeDone, record) +} + +func (s *StandaloneEventsAPI) SendExtensionInit(data interop.ExtensionInitData) error { + sort.Strings(data.Subscriptions) + record := map[string]interface{}{ + "name": data.AgentName, + "state": data.State, + "events": data.Subscriptions, + } + if len(data.ErrorType) > 0 { + record["errorType"] = data.ErrorType + } + return s.sendPlatformEvent(PlatformExtension, record) +} + +func (s *StandaloneEventsAPI) SendImageErrorLog(interop.ImageErrorLogData) { + // Called on bootstrap exec errors for OCI error modes, e.g. InvalidEntrypoint etc. +} + +func (s *StandaloneEventsAPI) SendEnd(data interop.EndData) error { + record := map[string]interface{}{ + "requestId": data.RequestID, + } + + return s.sendPlatformEvent(PlatformEnd, record) +} + +func (s *StandaloneEventsAPI) SendReportSpan(interop.Span) error { + return nil +} + +func (s *StandaloneEventsAPI) SendReport(data interop.ReportData) error { + record := map[string]interface{}{ + "requestId": s.requestID, + "status": data.Status, + "metrics": data.Metrics, + "spans": data.Spans, + "tracing": data.Tracing, + } + if data.ErrorType != nil { + record["errorType"] = data.ErrorType + } + + return s.sendPlatformEvent(PlatformReport, record) +} + +func (s *StandaloneEventsAPI) SendFault(data interop.FaultData) error { + record := map[string]interface{}{ + "fault": data.String(), + } + + return s.sendPlatformEvent(PlatformFault, record) +} + +func (s *StandaloneEventsAPI) FetchTailLogs(string) (string, error) { + s.lock.Lock() + defer s.lock.Unlock() + + if len(s.eventLog.Events) == 0 { + return "", nil + } + + logs := tailLogs{Events: s.eventLog.Events} + logsBytes, err := json.Marshal(logs) + if err != nil { + return "", err + } + + s.eventLog.Events = nil + + return string(logsBytes), nil +} + +func (s *StandaloneEventsAPI) GetRuntimeDoneSpans( + runtimeStartedTime int64, + invokeResponseMetrics *interop.InvokeResponseMetrics, + runtimeOverheadStartedTime int64, + runtimeReadyTime int64, +) []interop.Span { + spans := telemetry.GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics) + return spans +} + +func (s *StandaloneEventsAPI) sendPlatformEvent(eventType string, record map[string]interface{}) error { + e := SandboxEvent{ + Time: time.Now().Format(time.RFC3339), + Type: eventType, + PlatformEvent: record, + } + s.appendEvent(e) + s.logEvent(e) + return nil +} + +func (s *StandaloneEventsAPI) sendLogEvent(eventType, logMessage string) error { + e := SandboxEvent{ + Time: time.Now().Format(time.RFC3339), + Type: eventType, + LogMessage: logMessage, + } + s.appendEvent(e) + s.logEvent(e) + return nil +} + +func (s *StandaloneEventsAPI) appendEvent(event SandboxEvent) { + s.lock.Lock() + defer s.lock.Unlock() + s.eventLog.Events = append(s.eventLog.Events, event) +} + +func (s *StandaloneEventsAPI) logEvent(e SandboxEvent) { + log.WithField("event", e).Info("sandbox event") +} + +func (s *StandaloneEventsAPI) addTracingToRecord(tracingData *interop.TracingCtx, record map[string]interface{}) { + if tracingData != nil { + record["tracing"] = map[string]string{ + "spanId": tracingData.SpanID, + "type": string(tracingData.Type), + "value": tracingData.Value, + } + } +} diff --git a/lambda/rapidcore/standalone/telemetry/logs_egress_api.go b/lambda/rapidcore/standalone/telemetry/logs_egress_api.go new file mode 100644 index 0000000..0f42dd1 --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/logs_egress_api.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import "io" + +type StandaloneLogsEgressAPI struct { + api *StandaloneEventsAPI +} + +func NewStandaloneLogsEgressAPI(api *StandaloneEventsAPI) *StandaloneLogsEgressAPI { + return &StandaloneLogsEgressAPI{ + api: api, + } +} + +func (s *StandaloneLogsEgressAPI) GetExtensionSockets() (io.Writer, io.Writer, error) { + w := NewSandboxAgentWriter(s.api, "extension") + return w, w, nil +} + +func (s *StandaloneLogsEgressAPI) GetRuntimeSockets() (io.Writer, io.Writer, error) { + w := NewSandboxAgentWriter(s.api, "function") + return w, w, nil +} diff --git a/lambda/rapidcore/standalone/telemetry/structured_logger.go b/lambda/rapidcore/standalone/telemetry/structured_logger.go new file mode 100644 index 0000000..8d9382b --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/structured_logger.go @@ -0,0 +1,21 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "github.com/sirupsen/logrus" + "os" +) + +var log = getLogger() + +func getLogger() *logrus.Logger { + formatter := logrus.JSONFormatter{} + formatter.DisableTimestamp = true + logger := new(logrus.Logger) + logger.Out = os.Stdout + logger.Formatter = &formatter + logger.Level = logrus.InfoLevel + return logger +} diff --git a/lambda/rapidcore/standalone/telemetry/tracer.go b/lambda/rapidcore/standalone/telemetry/tracer.go new file mode 100644 index 0000000..ba7f32d --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/tracer.go @@ -0,0 +1,216 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" + "go.amzn.com/lambda/rapi/model" + "go.amzn.com/lambda/telemetry" + + "github.com/sirupsen/logrus" +) + +// InitSubsegmentName provides name attribute for Init subsegment +const InitSubsegmentName = "Initialization" + +// RestoreSubsegmentName provides name attribute for Restore subsegment +const RestoreSubsegmentName = "Restore" + +// InvokeSubsegmentName provides name attribute for Invoke subsegment +const InvokeSubsegmentName = "Invocation" + +// OverheadSubsegmentName provides name attribute for Overhead subsegment +const OverheadSubsegmentName = "Overhead" + +type StandaloneTracer struct { + startFunction func(ctx context.Context, invoke *interop.Invoke, segmentName string, timestamp int64) + endFunction func(ctx context.Context, invoke *interop.Invoke, segmentName string, timestamp int64) + invoke *interop.Invoke + tracingHeader string + rootTraceID string + parent string + sampled string + lineage string + invocationSubsegmentID string + initStartTime int64 + initEndTime int64 + restoreStartTime int64 + restoreEndTime int64 + restorePresent bool +} + +type TracingEvent struct { + Message string `json:"message"` + TraceID string `json:"trace_id"` + SegmentName string `json:"segment_name"` + SegmentID string `json:"segment_id"` + Timestamp int64 `json:"timestamp"` +} + +func (t *StandaloneTracer) Configure(invoke *interop.Invoke) { + t.invoke = invoke + t.tracingHeader = invoke.TraceID + t.invocationSubsegmentID = "" + t.rootTraceID, t.parent, t.sampled, t.lineage = telemetry.ParseTracingHeader(invoke.TraceID) + if invoke.RestoreDurationNs == 0 { + t.restorePresent = false + } else { + t.restorePresent = true + t.restoreStartTime = metering.MonoToEpoch(invoke.RestoreStartTimeMonotime) + t.restoreEndTime = t.restoreStartTime + invoke.RestoreDurationNs + } +} + +func (t *StandaloneTracer) CaptureInvokeSegment(ctx context.Context, criticalFunction func(context.Context) error) error { + return t.withStartAndEnd(ctx, criticalFunction, "STANDALONE_FUNCTION_NAME") +} + +func (t *StandaloneTracer) CaptureInitSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { + return t.withStartAndEnd(ctx, criticalFunction, InitSubsegmentName) +} + +func (t *StandaloneTracer) CaptureInvokeSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { + t.invocationSubsegmentID = InvokeSubsegmentName + return t.withStartAndEnd(ctx, criticalFunction, InvokeSubsegmentName) +} + +func (t *StandaloneTracer) CaptureOverheadSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { + return t.withStartAndEnd(ctx, criticalFunction, OverheadSubsegmentName) +} + +func (t *StandaloneTracer) withStartAndEnd(ctx context.Context, criticalFunction func(context.Context) error, segmentName string) error { + ctx = telemetry.NewTraceContext(ctx, t.rootTraceID, segmentName) + t.startFunction(ctx, t.invoke, segmentName, time.Now().UnixNano()) + err := criticalFunction(ctx) + t.endFunction(ctx, t.invoke, segmentName, time.Now().UnixNano()) + return err +} + +func (t *StandaloneTracer) RecordInitStartTime() { + t.initStartTime = time.Now().UnixNano() +} + +func (t *StandaloneTracer) RecordInitEndTime() { + t.initEndTime = time.Now().UnixNano() + +} + +func (t *StandaloneTracer) sendPrepSubsegment(ctx context.Context, subsegmentName string, startTime int64, endTime int64) { + ctx = telemetry.NewTraceContext(ctx, t.rootTraceID, subsegmentName) + t.startFunction(ctx, t.invoke, subsegmentName, startTime) + t.endFunction(ctx, t.invoke, subsegmentName, endTime) +} + +func (t *StandaloneTracer) SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) { + t.sendPrepSubsegment(ctx, InitSubsegmentName, t.initStartTime, t.initEndTime) +} +func (t *StandaloneTracer) SendRestoreSubsegmentWithRecordedTimesOnce(ctx context.Context) { + if t.restorePresent { + t.sendPrepSubsegment(ctx, RestoreSubsegmentName, t.restoreStartTime, t.restoreEndTime) + } +} +func (t *StandaloneTracer) MarkError(ctx context.Context) {} +func (t *StandaloneTracer) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) {} + +func (t *StandaloneTracer) WithErrorCause(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { + return criticalFunction +} +func (t *StandaloneTracer) WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { + return criticalFunction +} + +func (t *StandaloneTracer) BuildTracingHeader() func(ctx context.Context) string { + // extract root trace ID and parent from context and build the tracing header + return func(ctx context.Context) string { + var parent string + var ok bool + + if parent, ok = ctx.Value(telemetry.DocumentIDKey).(string); !ok || parent == "" { + return t.invoke.TraceID + } + + if t.rootTraceID == "" || t.sampled == "" { + return "" + } + + var tracingHeader = "Root=%s;Parent=%s;Sampled=%s" + + if t.lineage == "" { + return fmt.Sprintf(tracingHeader, t.rootTraceID, parent, t.sampled) + } + + return fmt.Sprintf(tracingHeader+";Lineage=%s", t.rootTraceID, parent, t.sampled, t.lineage) + } +} + +func (t *StandaloneTracer) BuildTracingCtxForStart() *interop.TracingCtx { + if t.rootTraceID == "" || t.sampled != model.XRaySampled { + return nil + } + + return &interop.TracingCtx{ + SpanID: t.parent, + Type: model.XRayTracingType, + Value: telemetry.BuildFullTraceID(t.rootTraceID, t.invoke.LambdaSegmentID, t.sampled), + } +} +func (t *StandaloneTracer) BuildTracingCtxAfterInvokeComplete() *interop.TracingCtx { + if t.rootTraceID == "" || t.sampled != model.XRaySampled || t.invocationSubsegmentID == "" { + return nil + } + + return &interop.TracingCtx{ + SpanID: t.invocationSubsegmentID, + Type: model.XRayTracingType, + Value: t.tracingHeader, + } +} + +func isTracingEnabled(root, parent, sampled string) bool { + return len(root) != 0 && len(parent) != 0 && sampled == "1" +} + +func NewStandaloneTracer(api *StandaloneEventsAPI) *StandaloneTracer { + startCaptureFn := func(ctx context.Context, i *interop.Invoke, segmentName string, timestamp int64) { + root, parent, sampled, _ := telemetry.ParseTracingHeader(i.TraceID) + if isTracingEnabled(root, parent, sampled) { + e := TracingEvent{ + Message: "START", + TraceID: root, + SegmentName: segmentName, + SegmentID: parent, + Timestamp: timestamp / int64(time.Millisecond), + } + api.LogTrace(e) + log.WithFields(logrus.Fields{"trace": e}).Info("sandbox trace") + } + } + + endCaptureFn := func(ctx context.Context, i *interop.Invoke, segmentName string, timestamp int64) { + root, parent, sampled, _ := telemetry.ParseTracingHeader(i.TraceID) + if isTracingEnabled(root, parent, sampled) { + e := TracingEvent{ + Message: "END", + TraceID: root, + SegmentName: "", + SegmentID: parent, + Timestamp: timestamp / int64(time.Millisecond), + } + api.LogTrace(e) + log.WithFields(logrus.Fields{"trace": e}).Info("sandbox trace") + } + } + + return &StandaloneTracer{ + startFunction: startCaptureFn, + endFunction: endCaptureFn, + } +} diff --git a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go index 0a756dd..1caeb8c 100644 --- a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go +++ b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go @@ -10,7 +10,7 @@ import ( ) func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { - internalState, err := s.AwaitRelease() + releaseAwait, err := s.AwaitRelease() if err != nil { switch err { case rapidcore.ErrInvokeDoneFailed: @@ -22,10 +22,10 @@ func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s InteropSe return case rapidcore.ErrInitDoneFailed: w.WriteHeader(DoneFailedHTTPCode) - w.Write(internalState.AsJSON()) + w.Write(releaseAwait.AsJSON()) return } } - w.Write(internalState.AsJSON()) + w.Write(releaseAwait.AsJSON()) } diff --git a/lambda/rapidcore/telemetry/eventLog.go b/lambda/rapidcore/telemetry/eventLog.go deleted file mode 100644 index 2f809fa..0000000 --- a/lambda/rapidcore/telemetry/eventLog.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "strings" - "sync" - "time" -) - -// TODO: Refactor to represent event structs below as a form of Events API entity - -type XrayEvent struct { - Msg string `json:"msg"` - TraceID string `json:"traceID"` - SegmentName string `json:"segmentName"` - SegmentID string `json:"segmentID"` - Timestamp int64 `json:"timestamp"` -} - -// PlatformLogEvent represents a platform-generated customer log entry -type PlatformLogEvent struct { - Name string `json:"name"` - State string `json:"state"` - ErrorType string `json:"errorType"` - Subscriptions []string `json:"subscriptions"` -} - -// FunctionLogEvent represents a runtime-generated customer log entry -type FunctionLogEvent struct{} - -// ExtensionLogEvent represents an agent-generated customer log entry -type ExtensionLogEvent struct{} - -type EventLog struct { - Events []SandboxEvent `json:"events,omitempty"` // populated by the StandaloneEventLog object - Xray []XrayEvent `json:"xray,omitempty"` - PlatformLog []PlatformLogEvent `json:"platformLogs,omitempty"` - Logs []string `json:"rawLogs,omitempty"` - mutex sync.Mutex -} - -func parseLogString(s string) []string { - elems := strings.Split(s, "\t")[1:] - for i, e := range elems { - elems[i] = strings.Split(e, ": ")[1] - elems[i] = strings.TrimSuffix(elems[i], "\n") - elems[i] = strings.TrimPrefix(elems[i], "[") - elems[i] = strings.TrimSuffix(elems[i], "]") - } - return elems -} - -func (p *EventLog) dispatchLogEvent(logStr string) { - elems := parseLogString(logStr) - if strings.HasPrefix(logStr, "XRAY") { - // format: 'XRAY\tMessage: %s\tTraceID: %s\tSegmentName: %s\tSegmentID: %s' - msg, traceID, segmentName, segmentID := elems[0], elems[1], elems[2], elems[3] - p.Xray = append(p.Xray, XrayEvent{Msg: msg, TraceID: traceID, SegmentName: segmentName, SegmentID: segmentID, Timestamp: time.Now().UnixNano() / int64(time.Millisecond)}) - } -} - -func (p *EventLog) Write(logline []byte) (int, error) { - p.mutex.Lock() - defer p.mutex.Unlock() - - logStr := string(logline) - p.Logs = append(p.Logs, logStr) - - p.dispatchLogEvent(logStr) - - return len(logline), nil -} - -func NewEventLog() *EventLog { - return &EventLog{} -} diff --git a/lambda/rapidcore/telemetry/events_api.go b/lambda/rapidcore/telemetry/events_api.go deleted file mode 100644 index 7a882fd..0000000 --- a/lambda/rapidcore/telemetry/events_api.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "sort" - "time" - - "go.amzn.com/lambda/telemetry" -) - -// EventType indicates the type of SandboxEvent. See full list: -type EventType = string - -const ( - PlatformInitRuntimeDone = EventType("platform.initRuntimeDone") - PlatformRestoreRuntimeDone = EventType("platform.restoreRuntimeDone") - PlatformRuntimeDone = EventType("platform.runtimeDone") - PlatformExtension = EventType("platform.extension") -) - -/* - SandboxEvent represents a generic sandbox event. For example: - {'time': '2021-03-16T13:10:42.358Z', - 'type': 'platform.extension', - 'record': { "name": "foo bar", "state": "Ready", "events": ["INVOKE", "SHUTDOWN"]}} -*/ -type SandboxEvent struct { - Time string `json:"time"` - Type EventType `json:"type"` - Record map[string]interface{} `json:"record"` -} - -type StandaloneEventLog struct { - requestID string - eventLog *EventLog -} - -func (s *StandaloneEventLog) SetCurrentRequestID(requestID string) { - s.requestID = requestID -} - -func (s *StandaloneEventLog) SendInitRuntimeDone(data *telemetry.InitRuntimeDoneData) error { - record := map[string]interface{}{"initializationType": data.InitSource, "status": data.Status} - s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformInitRuntimeDone, record}) - return nil -} - -func (s *StandaloneEventLog) SendRestoreRuntimeDone(status string) error { - record := map[string]interface{}{"status": status} - s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformRestoreRuntimeDone, record}) - return nil -} - -func (s *StandaloneEventLog) SendRuntimeDone(data telemetry.InvokeRuntimeDoneData) error { - // e.g. 'record': {'requestId': '1506eb3053d148f3bb7ec0fabe6f8d91','status': 'success', 'metrics': {...}, 'tracing': {...}} - record := map[string]interface{}{ - "requestId": s.requestID, - "status": data.Status, - "metrics": data.Metrics, - "internalMetrics": data.InternalMetrics, - "spans": data.Spans, - } - - if data.Tracing != nil { - record["tracing"] = map[string]string{ - "spanId": data.Tracing.SpanID, - "type": string(data.Tracing.Type), - "value": data.Tracing.Value, - } - } - - s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformRuntimeDone, record}) - return nil -} - -func (s *StandaloneEventLog) SendExtensionInit(agentName, state, errorType string, subscriptions []string) error { - // e.g. 'record': { "name": "", "state": "", errorType: "", events: [""] } - sort.Strings(subscriptions) - record := map[string]interface{}{"name": agentName, "state": state, "events": subscriptions} - if len(errorType) > 0 { - record["errorType"] = errorType - } - s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformExtension, record}) - return nil -} - -func (s *StandaloneEventLog) SendImageErrorLog(logline string) { - // Called on bootstrap exec errors for OCI error modes, e.g. InvalidEntrypoint etc. -} - -func NewStandaloneEventLog(eventLog *EventLog) *StandaloneEventLog { - return &StandaloneEventLog{ - eventLog: eventLog, - } -} diff --git a/lambda/rapidcore/telemetry/xray.go b/lambda/rapidcore/telemetry/xray.go deleted file mode 100644 index d7a6842..0000000 --- a/lambda/rapidcore/telemetry/xray.go +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "context" - "encoding/json" - "fmt" - "io" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/telemetry" -) - -// InitSubsegmentName provides name attribute for Init subsegment -const InitSubsegmentName = "Initialization" - -// InvokeSubsegmentName provides name attribute for Invoke subsegment -const InvokeSubsegmentName = "Invocation" - -// OverheadSubsegmentName provides name attribute for Overhead subsegment -const OverheadSubsegmentName = "Overhead" - -type traceContextKey int - -const ( - traceIDKey traceContextKey = iota - documentIDKey -) - -type StandaloneTracer struct { - startFunction func(ctx context.Context, invoke *interop.Invoke, segmentName string) - endFunction func(ctx context.Context, invoke *interop.Invoke, segmentName string) - functionName string - invoke *interop.Invoke -} - -func (t *StandaloneTracer) Configure(invoke *interop.Invoke) { - - t.invoke = invoke -} - -func (t *StandaloneTracer) CaptureInvokeSegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, t.functionName) -} - -func (t *StandaloneTracer) CaptureInitSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, InitSubsegmentName) -} - -func (t *StandaloneTracer) CaptureInvokeSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, InvokeSubsegmentName) -} - -func (t *StandaloneTracer) CaptureOverheadSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, OverheadSubsegmentName) -} - -func (t *StandaloneTracer) withStartAndEnd(ctx context.Context, criticalFunction func(context.Context) error, segmentName string) error { - t.startFunction(ctx, t.invoke, segmentName) - err := criticalFunction(ctx) - t.endFunction(ctx, t.invoke, segmentName) - return err -} - -func (t *StandaloneTracer) RecordInitStartTime() {} -func (t *StandaloneTracer) RecordInitEndTime() {} -func (t *StandaloneTracer) SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) {} -func (t *StandaloneTracer) MarkError(ctx context.Context) {} -func (t *StandaloneTracer) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) {} - -func (t *StandaloneTracer) WithErrorCause(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { - return criticalFunction -} -func (t *StandaloneTracer) WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { - return criticalFunction -} -func (t *StandaloneTracer) TracingHeaderParser() func(context.Context, *interop.Invoke) string { - getCustomerTracingHeader := func(ctx context.Context, invoke *interop.Invoke) string { - var root, parent string - var ok bool - - if root, ok = ctx.Value(traceIDKey).(string); !ok { - return invoke.TraceID - } - - if parent, ok = ctx.Value(documentIDKey).(string); !ok { - return invoke.TraceID - } - - return fmt.Sprintf("Root=%s;Parent=%s;Sampled=1", root, parent) - } - - return getCustomerTracingHeader -} - -func isTracingEnabled(root, parent, sampled string) bool { - return len(root) != 0 && len(parent) != 0 && sampled == "1" -} - -func NewStandaloneTracer(eventLog io.Writer, functionName string) *StandaloneTracer { - traceFormat := "XRAY\tMessage: %s\tTraceID: %s\tSegmentName: %s\tSegmentID: %s" - startCaptureFn := func(ctx context.Context, i *interop.Invoke, segmentName string) { - root, parent, sampled := telemetry.ParseTraceID(i.TraceID) - if isTracingEnabled(root, parent, sampled) { - fmt.Fprintf(eventLog, traceFormat, "START", root, segmentName, parent) - } - } - - endCaptureFn := func(ctx context.Context, i *interop.Invoke, segmentName string) { - root, parent, sampled := telemetry.ParseTraceID(i.TraceID) - if isTracingEnabled(root, parent, sampled) { - fmt.Fprintf(eventLog, traceFormat, "END", root, "", parent) - } - } - - return &StandaloneTracer{ - startFunction: startCaptureFn, - endFunction: endCaptureFn, - functionName: functionName, - } -} diff --git a/lambda/supervisor/local_supervisor.go b/lambda/supervisor/local_supervisor.go index 1174089..4405686 100644 --- a/lambda/supervisor/local_supervisor.go +++ b/lambda/supervisor/local_supervisor.go @@ -4,9 +4,11 @@ package supervisor import ( + "context" "errors" "fmt" "os/exec" + "runtime" "sync" "syscall" "time" @@ -27,33 +29,31 @@ type process struct { } type LocalSupervisor struct { - events chan model.Event - processMapLock sync.Mutex - processMap map[string]process + events chan model.Event + processMapLock sync.Mutex + processMap map[string]process + freezeThawCycleStart time.Time + + RootPath string } -func NewLocalSupervisor() model.Supervisor { - return model.Supervisor{ - SupervisorClient: &LocalSupervisor{ - events: make(chan model.Event), - processMap: make(map[string]process), - }, - OperatorConfig: model.DomainConfig{ - RootPath: "/", - }, - RuntimeConfig: model.DomainConfig{ - RootPath: "/", - }, +func NewLocalSupervisor() *LocalSupervisor { + return &LocalSupervisor{ + events: make(chan model.Event), + processMap: make(map[string]process), + RootPath: "/", } } -func (*LocalSupervisor) Start(req *model.StartRequest) error { +func (*LocalSupervisor) Start(ctx context.Context, req *model.StartRequest) error { return nil } -func (*LocalSupervisor) Configure(req *model.ConfigureRequest) error { +func (*LocalSupervisor) Configure(ctx context.Context, req *model.ConfigureRequest) error { return nil } -func (s *LocalSupervisor) Exec(req *model.ExecRequest) error { +func (*LocalSupervisor) Exit(ctx context.Context) {} + +func (s *LocalSupervisor) Exec(ctx context.Context, req *model.ExecRequest) error { if req.Domain != "runtime" { log.Debug("Exec is a no op if domain != runtime") return nil @@ -97,6 +97,9 @@ func (s *LocalSupervisor) Exec(req *model.ExecRequest) error { } s.processMapLock.Unlock() + // The first freeze thaw cycle starts on Exec() at init time + s.freezeThawCycleStart = time.Now() + go func() { err = command.Wait() // close the termination channel to unblock whoever's blocked on @@ -141,11 +144,11 @@ func (s *LocalSupervisor) Exec(req *model.ExecRequest) error { return nil } -func kill(p process, name string, timeout *time.Duration) error { +func kill(p process, name string, deadline time.Time) error { // kill should report success if the process terminated by the time //supervisor receives the request. select { - // ifthis case is selected, the channel is closed, + // if this case is selected, the channel is closed, // which means the process is terminated case <-p.termination: log.Debugf("Process %s already terminated.", name) @@ -154,8 +157,8 @@ func kill(p process, name string, timeout *time.Duration) error { log.Infof("Sending SIGKILL to %s(%d).", name, p.pid) } - if timeout != nil && *timeout <= 0 { - return fmt.Errorf("Timed out while trying to SIGKILL %s", name) + if (time.Since(deadline)) > 0 { + return fmt.Errorf("invalid timeout while killing %s", name) } pgid, err := syscall.Getpgid(p.pid) @@ -167,23 +170,20 @@ func kill(p process, name string, timeout *time.Duration) error { syscall.Kill(p.pid, syscall.SIGKILL) } - // the nil channel blocks forever - var timer <-chan time.Time - if timeout != nil { - timer = time.After(*timeout) - } + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() // block until the (main) process exits // or the timeout fires select { case <-p.termination: return nil - case <-timer: - return fmt.Errorf("Timed out while trying to SIGKILL %s", name) + case <-ctx.Done(): + return fmt.Errorf("timed out while trying to SIGKILL %s", name) } } -func (s *LocalSupervisor) Kill(req *model.KillRequest) error { +func (s *LocalSupervisor) Kill(ctx context.Context, req *model.KillRequest) error { if req.Domain != "runtime" { log.Debug("Kill is a no op if domain != runtime") return nil @@ -198,12 +198,11 @@ func (s *LocalSupervisor) Kill(req *model.KillRequest) error { Message: &msg, } } - timeout := convertTimeout(req.Timeout) - return kill(process, req.Name, timeout) + return kill(process, req.Name, req.Deadline) } -func (s *LocalSupervisor) Terminate(req *model.TerminateRequest) error { +func (s *LocalSupervisor) Terminate(ctx context.Context, req *model.TerminateRequest) error { if req.Domain != "runtime" { log.Debug("Terminate is no op if domain != runtime") return nil @@ -235,12 +234,11 @@ func (s *LocalSupervisor) Terminate(req *model.TerminateRequest) error { return nil } -func (s *LocalSupervisor) Stop(req *model.StopRequest) error { +func (s *LocalSupervisor) Stop(ctx context.Context, req *model.StopRequest) (*model.StopResponse, error) { if req.Domain != "runtime" { log.Debug("Shutdown is no op if domain != runtime") - return nil + return &model.StopResponse{}, nil } - timeout := convertTimeout(req.Timeout) // shut down kills all the processes in the map s.processMapLock.Lock() @@ -253,7 +251,7 @@ func (s *LocalSupervisor) Stop(req *model.StopRequest) error { for name, proc := range s.processMap { go func(n string, p process) { log.Debugf("Killing %s", n) - err := kill(p, n, timeout) + err := kill(p, n, req.Deadline) if err != nil { errors <- err } else { @@ -269,34 +267,37 @@ func (s *LocalSupervisor) Stop(req *model.StopRequest) error { case <-successes: case e := <-errors: if err == nil { - err = fmt.Errorf("Shutdown failed: %s", e.Error()) + err = fmt.Errorf("shutdown failed: %s", e.Error()) } } } s.processMap = make(map[string]process) - return err + return nil, err } -func (*LocalSupervisor) Freeze(req *model.FreezeRequest) error { - return nil + +func (s *LocalSupervisor) Freeze(ctx context.Context, req *model.FreezeRequest) (*model.FreezeResponse, error) { + // We return mocked freeze/thaw cycle metrics to mimic usage metrics in standalone mode + var m runtime.MemStats + runtime.ReadMemStats(&m) + return &model.FreezeResponse{ + CycleDeltaMetrics: model.CycleDeltaMetrics{ + DomainCPURunNs: uint64(time.Since(s.freezeThawCycleStart).Nanoseconds()), + DomainRunNs: uint64(time.Since(s.freezeThawCycleStart).Nanoseconds()), + DomainMaxMemoryUsageBytes: m.Alloc, + MicrovmCPURunNs: uint64(time.Since(s.freezeThawCycleStart).Nanoseconds()), + }, + }, nil } -func (*LocalSupervisor) Thaw(req *model.ThawRequest) error { +func (s *LocalSupervisor) Thaw(ctx context.Context, req *model.ThawRequest) error { + s.freezeThawCycleStart = time.Now() return nil } -func (s *LocalSupervisor) Ping() error { +func (s *LocalSupervisor) Ping(ctx context.Context) error { return nil } -func (s *LocalSupervisor) Events() (<-chan model.Event, error) { +func (s *LocalSupervisor) Events(ctx context.Context, req *model.EventsRequest) (<-chan model.Event, error) { return s.events, nil } - -func convertTimeout(millis *uint64) *time.Duration { - var timeout *time.Duration - if millis != nil { - t := time.Duration(*millis) * time.Millisecond - timeout = &t - } - return timeout -} diff --git a/lambda/supervisor/local_supervisor_test.go b/lambda/supervisor/local_supervisor_test.go index 8b3336b..02a06f6 100644 --- a/lambda/supervisor/local_supervisor_test.go +++ b/lambda/supervisor/local_supervisor_test.go @@ -4,6 +4,7 @@ package supervisor import ( + "context" "errors" "fmt" "syscall" @@ -18,7 +19,7 @@ import ( func TestRuntimeDomainExec(t *testing.T) { supv := NewLocalSupervisor() - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", @@ -29,7 +30,7 @@ func TestRuntimeDomainExec(t *testing.T) { func TestInvalidRuntimeDomainExec(t *testing.T) { supv := NewLocalSupervisor() - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/none", @@ -40,10 +41,14 @@ func TestInvalidRuntimeDomainExec(t *testing.T) { func TestEvents(t *testing.T) { supv := NewLocalSupervisor() - client := supv.SupervisorClient.(*LocalSupervisor) sync := make(chan struct{}) go func() { - evt, ok := <-client.events + eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ + Domain: "runtime", + }) + require.NoError(t, err) + + evt, ok := <-eventCh require.True(t, ok) termination := evt.Event.ProcessTerminated() require.NotNil(t, termination) @@ -52,7 +57,7 @@ func TestEvents(t *testing.T) { sync <- struct{}{} }() - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", @@ -63,8 +68,7 @@ func TestEvents(t *testing.T) { func TestTerminate(t *testing.T) { supv := NewLocalSupervisor() - client := supv.SupervisorClient.(*LocalSupervisor) - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", @@ -72,13 +76,18 @@ func TestTerminate(t *testing.T) { }) require.NoError(t, err) time.Sleep(100 * time.Millisecond) - err = supv.Terminate(&model.TerminateRequest{ + err = supv.Terminate(context.Background(), &model.TerminateRequest{ Domain: "runtime", Name: "agent", }) require.NoError(t, err) // wait for process exit notification - ev := <-client.events + eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ + Domain: "runtime", + }) + require.NoError(t, err) + ev := <-eventCh + require.NotNil(t, ev.Event.ProcessTerminated()) term := *ev.Event.ProcessTerminated() require.Nil(t, term.Exited()) @@ -89,7 +98,7 @@ func TestTerminate(t *testing.T) { // Termiante should not fail if the message is not delivered func TestTerminateExited(t *testing.T) { supv := NewLocalSupervisor() - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", @@ -97,7 +106,7 @@ func TestTerminateExited(t *testing.T) { require.NoError(t, err) // wait a short bit for bash to exit time.Sleep(100 * time.Millisecond) - err = supv.Terminate(&model.TerminateRequest{ + err = supv.Terminate(context.Background(), &model.TerminateRequest{ Domain: "runtime", Name: "agent", }) @@ -106,22 +115,27 @@ func TestTerminateExited(t *testing.T) { func TestKill(t *testing.T) { supv := NewLocalSupervisor() - client := supv.SupervisorClient.(*LocalSupervisor) - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", Args: []string{"-c", "sleep 10s"}, }) require.NoError(t, err) - err = supv.Kill(&model.KillRequest{ - Domain: "runtime", - Name: "agent", + err = supv.Kill(context.Background(), &model.KillRequest{ + Domain: "runtime", + Name: "agent", + Deadline: time.Now().Add(time.Second), }) require.NoError(t, err) timer := time.NewTimer(50 * time.Millisecond) + eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ + Domain: "runtime", + }) + require.NoError(t, err) + select { - case _, ok := <-client.events: + case _, ok := <-eventCh: assert.True(t, ok) case <-timer.C: require.Fail(t, "Process should have exited by the time kill returns") @@ -130,27 +144,32 @@ func TestKill(t *testing.T) { func TestKillExited(t *testing.T) { supv := NewLocalSupervisor() - client := supv.SupervisorClient.(*LocalSupervisor) - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", }) require.NoError(t, err) //wait for natural exit event - <-client.events - err = supv.Kill(&model.KillRequest{ + eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ Domain: "runtime", - Name: "agent", + }) + require.NoError(t, err) + <-eventCh + err = supv.Kill(context.Background(), &model.KillRequest{ + Domain: "runtime", + Name: "agent", + Deadline: time.Now().Add(time.Second), }) require.NoError(t, err, "Kill should succeed for exited processes") } func TestKillUnknown(t *testing.T) { supv := NewLocalSupervisor() - err := supv.Kill(&model.KillRequest{ - Domain: "runtime", - Name: "unknown", + err := supv.Kill(context.Background(), &model.KillRequest{ + Domain: "runtime", + Name: "unknown", + Deadline: time.Now().Add(time.Second), }) require.Error(t, err) var supvError *model.SupervisorError @@ -160,10 +179,9 @@ func TestKillUnknown(t *testing.T) { func TestShutdown(t *testing.T) { supv := NewLocalSupervisor() - client := supv.SupervisorClient.(*LocalSupervisor) log.Debug("hello") // start a bunch of processes, some short running, some longer running - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent-0", Path: "/bin/bash", @@ -171,14 +189,14 @@ func TestShutdown(t *testing.T) { }) require.NoError(t, err) - err = supv.Exec(&model.ExecRequest{ + err = supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent-1", Path: "/bin/bash", }) require.NoError(t, err) - err = supv.Exec(&model.ExecRequest{ + err = supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent-2", Path: "/bin/bash", @@ -186,8 +204,9 @@ func TestShutdown(t *testing.T) { }) require.NoError(t, err) time.Sleep(100 * time.Millisecond) - err = supv.Stop(&model.StopRequest{ - Domain: "runtime", + _, err = supv.Stop(context.Background(), &model.StopRequest{ + Domain: "runtime", + Deadline: time.Now().Add(time.Second), }) require.NoError(t, err) // Shutdown is expected to block untill all processes have exited @@ -198,9 +217,13 @@ func TestShutdown(t *testing.T) { } done := false timer := time.NewTimer(200 * time.Millisecond) + eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ + Domain: "runtime", + }) + require.NoError(t, err) for !done { select { - case ev := <-client.events: + case ev := <-eventCh: data := ev.Event.ProcessTerminated() assert.NotNil(t, data) _, ok := expected[*data.Name] diff --git a/lambda/supervisor/model/model.go b/lambda/supervisor/model/model.go index 384726d..d89ec18 100644 --- a/lambda/supervisor/model/model.go +++ b/lambda/supervisor/model/model.go @@ -4,41 +4,73 @@ package model import ( + "context" "encoding/json" "fmt" "io" "os" "syscall" + "time" ) -type Supervisor struct { - SupervisorClient - OperatorConfig DomainConfig - RuntimeConfig DomainConfig +// Start, Stop and Configure methods are not used in Core anymore. +// Client interface splitted into Launcher and Executer parts for backward compatibility of dependent packages. +type ContainerSupervisor interface { + Start(context.Context, *StartRequest) error + Configure(context.Context, *ConfigureRequest) error + Stop(context.Context, *StopRequest) (*StopResponse, error) + Freeze(context.Context, *FreezeRequest) (*FreezeResponse, error) + Thaw(context.Context, *ThawRequest) error + Exit(context.Context) } -type DomainConfig struct { - // path to the root of the domain within the root mnt namespace - RootPath string +type ProcessSupervisor interface { + Exec(context.Context, *ExecRequest) error + Terminate(context.Context, *TerminateRequest) error + Kill(context.Context, *KillRequest) error + Events(context.Context, *EventsRequest) (<-chan Event, error) } type SupervisorClient interface { - Start(req *StartRequest) error - Configure(req *ConfigureRequest) error - Exec(req *ExecRequest) error - Terminate(req *TerminateRequest) error - Kill(req *KillRequest) error - Stop(req *StopRequest) error - Freeze(req *FreezeRequest) error - Thaw(req *ThawRequest) error - Ping() error - Events() (<-chan Event, error) + ContainerSupervisor + ProcessSupervisor + Ping(ctx context.Context) error } type StartRequest struct { Domain string `json:"domain"` - // name of the cgroup profile to start the domain in - CgroupProfile *string `json:"cgroup_profile,omitempty"` +} + +type Mount struct { + DriveMount DriveMount + BindMount BindMount + MountType MountType +} + +type MountType int + +const ( + _ MountType = iota + MountTypeDrive + MountTypeBind +) + +type CgroupProfileName string + +const ( + Throttled CgroupProfileName = "throttled" + Unthrottled CgroupProfileName = "unthrottled" +) + +func (m *Mount) MarshalJSON() ([]byte, error) { + switch m.MountType { + case MountTypeDrive: + return m.DriveMount.MarshalJSON() + case MountTypeBind: + return m.BindMount.MarshalJSON() + default: + return nil, fmt.Errorf("invalid mount type: %v", m.MountType) + } } // Mount in lockhard::mnt is a Rust enum, an algebraic type, where each case has different set of fields. @@ -66,6 +98,24 @@ func (m *DriveMount) MarshalJSON() ([]byte, error) { }) } +type BindMount struct { + Source string `json:"source,omitempty"` + Destination string `json:"destination,omitempty"` + Options []string `json:"options,omitempty"` +} + +func (m *BindMount) MarshalJSON() ([]byte, error) { + type bindMountAlias BindMount + + return json.Marshal(&struct { + Type string `json:"type,omitempty"` + *bindMountAlias + }{ + Type: "bind", + bindMountAlias: (*bindMountAlias)(m), + }) +} + type Capabilities struct { Ambient []string `json:"ambient,omitempty"` Bounding []string `json:"bounding,omitempty"` @@ -74,10 +124,14 @@ type Capabilities struct { Permitted []string `json:"permitted,omitempty"` } -type CgroupProfile struct { - Name string `json:"name"` - CPUPct *float64 `json:"cpu_pct,omitempty"` - MemMaxBytes *uint64 `json:"mem_max,omitempty"` +type CgroupProfiles struct { + Throttled CgroupProfileConfig `json:"throttled"` + Unthrottled CgroupProfileConfig `json:"unthrottled"` +} + +type CgroupProfileConfig struct { + CPULimit float64 `json:"cpu_limit"` + MemoryLimitBytes uint64 `json:"memory_limit_bytes"` } type ExecUser struct { @@ -88,12 +142,15 @@ type ExecUser struct { type ConfigureRequest struct { // domain to configure Domain string `json:"domain"` - Mounts []DriveMount `json:"mounts,omitempty"` + Mounts []Mount `json:"mounts,omitempty"` Capabilities *Capabilities `json:"capabilities,omitempty"` SeccompFilters []string `json:"seccomp_filters,omitempty"` // list of cgroup profiles available for the domain - // cgroup profiles are set on boot or thaw requests - CgroupProfiles []CgroupProfile `json:"cgroup_profiles,omitempty"` + // cgroup profiles are set on start and thaw request. Start profile + // if configured (as it can vary), thaw profile is always the same (throttled) + CgroupProfiles *CgroupProfiles `json:"cgroup_profiles,omitempty"` + // name of the cgroup profile to enforce at domain start + StartProfile CgroupProfileName `json:"start_profile,omitempty"` // uid and gid of the user the spawned process runs as (w.r.t. the domain user namespace). // If nil, Supervisor will use the ExecUser specified in the domain configuration file ExecUser *ExecUser `json:"exec_user,omitempty"` @@ -101,6 +158,10 @@ type ConfigureRequest struct { AdditionalStartHooks []Hook `json:"additional_start_hooks,omitempty"` } +type EventsRequest struct { + Domain string `json:"domain"` +} + type Event struct { Time uint64 `json:"timestamp_millis"` Event EventData `json:"event"` @@ -188,9 +249,6 @@ type Hook struct { Args []string `json:"args,omitempty"` // Map of ENV variables to set when running the hook Env *map[string]string `json:"envs,omitempty"` - // Maximum time for the hook to run. The hook will be considered failed - // if it takes more than this value (default 10_000) - TimeoutMillis *uint64 `json:"timeout_millis,omitempty"` } type ExecRequest struct { @@ -203,16 +261,38 @@ type ExecRequest struct { Path string `json:"path"` Args []string `json:"args,omitempty"` // If nil, root of the domain - Cwd *string `json:"cwd,omitempty"` - Env *map[string]string `json:"env,omitempty"` - // If not nil, points to the socket that Supervisor - // uses to get the processes stdout and stderr. - LogsSock *string `json:"logs_sock,omitempty"` - StdoutWriter io.Writer `json:"-"` - StderrWriter io.Writer `json:"-"` - ExtraFiles *[]*os.File `json:"-"` + Cwd *string `json:"cwd,omitempty"` + Env *map[string]string `json:"env,omitempty"` + Logging Logging `json:"log_config"` + StdoutWriter io.Writer `json:"-"` + StderrWriter io.Writer `json:"-"` + ExtraFiles *[]*os.File `json:"-"` +} + +// Logging specifies where Supervisor should send Command's logs to +type Logging struct { + Managed ManagedLogging `json:"managed"` } +type ManagedLogging struct { + Topic ManagedLoggingTopic `json:"topic"` + Formats []ManagedLoggingFormat `json:"formats"` +} + +type ManagedLoggingTopic string + +const ( + RuntimeManagedLoggingTopic ManagedLoggingTopic = "runtime" + RtExtensionManagedLoggingTopic ManagedLoggingTopic = "runtime_extension" +) + +type ManagedLoggingFormat string + +const ( + LineBasedManagedLogging ManagedLoggingFormat = "line" + MessageBasedManagedLogging ManagedLoggingFormat = "message" +) + type ErrorKind string const ( @@ -243,27 +323,54 @@ type TerminateRequest struct { // Force terminate a process (SIGKILL) // Block until process is exited or timeout -// If timeout is 0 or nil, block forever +// Deadline needs to be in the future type KillRequest struct { - Name string `json:"name"` - Domain string `json:"domain"` - Timeout *uint64 `json:",omitempty"` + Name string `json:"name"` + Domain string `json:"domain"` + Deadline time.Time `json:"deadline"` } -// Stop the domain. Supervisor will first try to -// cleanly terminate the domain's init process. If unsuccessful, -// within Timeout seconds, it will send SIGKILL. +// Stop the domain. type StopRequest struct { - Domain string `json:"domain"` - Timeout *uint64 `json:",omitempty"` + Domain string `json:"domain"` + Deadline time.Time `json:"deadline"` +} + +type StopResponse struct { + CycleDeltaMetrics CycleDeltaMetrics `json:"cycle_delta_metrics"` } type FreezeRequest struct { Domain string `json:"domain"` } +type FreezeResponse struct { + CycleDeltaMetrics CycleDeltaMetrics `json:"cycle_delta_metrics"` +} + +type MicrovmNetworkInterfaceMetrics struct { + ReceivedBytes uint64 `json:"received_bytes"` + TransmittedBytes uint64 `json:"transmitted_bytes"` +} + +type CycleDeltaMetrics struct { + // CPU time (in nanoseconds) obtained by domain cgroup from cpuacct.usage + // https://www.kernel.org/doc/Documentation/cgroup-v1/cpuacct.txt + DomainCPURunNs uint64 `json:"domain_cpu_run_ns"` + // time (in nanoseconds) for domain cycle + DomainRunNs uint64 `json:"domain_run_ns"` + // CPU delta time for service cgroup + ServiceCPURunNs uint64 `json:"service_cpu_run_ns"` + // Maximum memory used (in bytes) for domain + DomainMaxMemoryUsageBytes uint64 `json:"domain_max_memory_usage_bytes"` + // CPU delta time (in nanoseconds) obtained from /sys/fs/cgroup/cpu,cpuacct/cpuacct.usage + MicrovmCPURunNs uint64 `json:"microvm_cpu_run_ns"` + // Map with network interface name as key and network metrics as a value + MicrovmNetworksBytes map[string]MicrovmNetworkInterfaceMetrics `json:"microvm_network_interfaces"` + // time ( in nanoseconds ) for idle cpu time + InvokeIdleCPURunNs uint64 `json:"idle_cpu_run_ns"` +} + type ThawRequest struct { Domain string `json:"domain"` - // if not nil, changes the cgroup profile of the domain upon thawing. - CgroupProfile *string `json:"cgroup_profile,omitempty"` } diff --git a/lambda/supervisor/model/model_test.go b/lambda/supervisor/model/model_test.go new file mode 100644 index 0000000..ea39580 --- /dev/null +++ b/lambda/supervisor/model/model_test.go @@ -0,0 +1,31 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "testing" + "time" +) + +// LockHard accepts deadlines encoded as RFC3339 - we enforce this with a test +func Test_KillDeadlineIsMarshalledIntoRFC3339(t *testing.T) { + deadline, err := time.Parse(time.RFC3339, "2022-12-21T10:00:00Z") + if err != nil { + t.Error(err) + } + k := KillRequest{ + Name: "", + Domain: "", + Deadline: deadline, + } + bytes, err := json.Marshal(k) + if err != nil { + t.Error(err) + } + exepected := `{"name":"","domain":"","deadline":"2022-12-21T10:00:00Z"}` + if string(bytes) != exepected { + t.Errorf("error in marshaling `KillRequest` it does not match the expected string (Expected(%q) != Got(%q))", exepected, string(bytes)) + } +} diff --git a/lambda/rapidcore/telemetry/logsapi/constants.go b/lambda/telemetry/constants.go similarity index 94% rename from lambda/rapidcore/telemetry/logsapi/constants.go rename to lambda/telemetry/constants.go index f54e415..0198660 100644 --- a/lambda/rapidcore/telemetry/logsapi/constants.go +++ b/lambda/telemetry/constants.go @@ -1,17 +1,17 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package logsapi +package telemetry import "errors" -// ErrTelemetryServiceOff returned on attempt to subscribe after telemetry service has been turned off. -var ErrTelemetryServiceOff = errors.New("ErrTelemetryServiceOff") - -// Metrics const ( + // Metrics SubscribeSuccess = "logs_api_subscribe_success" SubscribeClientErr = "logs_api_subscribe_client_err" SubscribeServerErr = "logs_api_subscribe_server_err" NumSubscribers = "logs_api_num_subscribers" ) + +// ErrTelemetryServiceOff returned on attempt to subscribe after telemetry service has been turned off. +var ErrTelemetryServiceOff = errors.New("ErrTelemetryServiceOff") diff --git a/lambda/telemetry/events_api.go b/lambda/telemetry/events_api.go index e7c5c36..371f439 100644 --- a/lambda/telemetry/events_api.go +++ b/lambda/telemetry/events_api.go @@ -4,135 +4,151 @@ package telemetry import ( + "fmt" "time" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapi/model" ) -type RuntimeDoneInvokeMetrics struct { - ProducedBytes int64 - DurationMs float64 -} - -func GetRuntimeDoneInvokeMetrics(invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics, runtimeDoneTime int64) *RuntimeDoneInvokeMetrics { - if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && invokeReceivedTime != 0 { - return &RuntimeDoneInvokeMetrics{ +func GetRuntimeDoneInvokeMetrics(runtimeStartedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics, runtimeDoneTime int64) *interop.RuntimeDoneInvokeMetrics { + // time taken from sending the invoke to the sandbox until the runtime calls GET /next + duration := CalculateDuration(runtimeStartedTime, runtimeDoneTime) + if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && runtimeStartedTime != -1 { + return &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: invokeResponseMetrics.ProducedBytes, - // time taken from sending the invoke to the sandbox until the runtime calls GET /next - DurationMs: float64((runtimeDoneTime - invokeReceivedTime) / int64(time.Millisecond)), + DurationMs: duration, } } // when we get a reset before runtime called /response - if invokeReceivedTime != 0 { - return &RuntimeDoneInvokeMetrics{ + if runtimeStartedTime != -1 { + return &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(0), - DurationMs: float64((runtimeDoneTime - invokeReceivedTime) / int64(time.Millisecond)), + DurationMs: duration, } } // We didn't have time to register the invokeReceiveTime, which means we crash/reset very early, // too early for the runtime to actual run. In such case, the runtimeDone event shouldn't be sent // Not returning Nil even in this improbable case guarantees that we will always have some metrics to send to FluxPump - return &RuntimeDoneInvokeMetrics{ + return &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(0), DurationMs: float64(0), } } -type InitRuntimeDoneData struct { - InitSource string - Status string -} - -type InvokeRuntimeDoneData struct { - Status string - Metrics *RuntimeDoneInvokeMetrics - InternalMetrics *interop.InvokeResponseMetrics - Tracing *TracingCtx - Spans []Span -} +const ( + InitInsideInitPhase interop.InitPhase = "init" + InitInsideInvokePhase interop.InitPhase = "invoke" +) -type Span struct { - Name string - Start string - DurationMs float64 +func InitPhaseFromLifecyclePhase(phase interop.LifecyclePhase) (interop.InitPhase, error) { + switch phase { + case interop.LifecyclePhaseInit: + return InitInsideInitPhase, nil + case interop.LifecyclePhaseInvoke: + return InitInsideInvokePhase, nil + default: + return interop.InitPhase(""), fmt.Errorf("unexpected lifecycle phase: %v", phase) + } } -func GetRuntimeDoneSpans(invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) []Span { - if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && invokeReceivedTime != 0 { +func GetRuntimeDoneSpans(runtimeStartedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) []interop.Span { + if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && runtimeStartedTime != -1 { // time span from when the invoke is received in the sandbox to the moment the runtime calls PUT /response - responseLatencyMsSpan := Span{ + responseLatencyMsSpan := interop.Span{ Name: "responseLatency", - Start: getEpochTimeInISO8601FormatFromMonotime(invokeReceivedTime), - DurationMs: float64((invokeResponseMetrics.StartReadingResponseMonoTimeMs - invokeReceivedTime) / int64(time.Millisecond)), + Start: GetEpochTimeInISO8601FormatFromMonotime(runtimeStartedTime), + DurationMs: CalculateDuration(runtimeStartedTime, invokeResponseMetrics.StartReadingResponseMonoTimeMs), } // time span from when the runtime called PUT /response to the moment the body of the response is fully sent - responseDurationMsSpan := Span{ + responseDurationMsSpan := interop.Span{ Name: "responseDuration", - Start: getEpochTimeInISO8601FormatFromMonotime(invokeResponseMetrics.StartReadingResponseMonoTimeMs), - DurationMs: float64((invokeResponseMetrics.FinishReadingResponseMonoTimeMs - invokeResponseMetrics.StartReadingResponseMonoTimeMs) / int64(time.Millisecond)), + Start: GetEpochTimeInISO8601FormatFromMonotime(invokeResponseMetrics.StartReadingResponseMonoTimeMs), + DurationMs: CalculateDuration(invokeResponseMetrics.StartReadingResponseMonoTimeMs, invokeResponseMetrics.FinishReadingResponseMonoTimeMs), } - return []Span{responseLatencyMsSpan, responseDurationMsSpan} + return []interop.Span{responseLatencyMsSpan, responseDurationMsSpan} } - return []Span{} + return []interop.Span{} } -func getEpochTimeInISO8601FormatFromMonotime(monotime int64) string { - return time.Unix(0, metering.MonoToEpoch(monotime)).Format("2006-01-02T15:04:05.000Z") +// CalculateDuration calculates duration between two moments. +// The result is milliseconds with microsecond precision. +// Two assumptions here: +// 1. the passed values are nanoseconds +// 2. endNs > startNs +func CalculateDuration(startNs, endNs int64) float64 { + microseconds := int64(endNs-startNs) / int64(time.Microsecond) + return float64(microseconds) / 1000 } -type TracingCtx struct { - SpanID string - Type model.TracingType - Value string -} +const ( + InitTypeOnDemand interop.InitType = "on-demand" + InitTypeProvisionedConcurrency interop.InitType = "provisioned-concurrency" + InitTypeInitCaching interop.InitType = "snap-start" +) -func BuildTracingCtx(tracingType model.TracingType, traceID string, lambdaSegmentID string) *TracingCtx { - // it takes current tracing context and change its parent value with the provided lambda segment id - root, currentParent, sample := ParseTraceID(traceID) - if root == "" || sample != model.XRaySampled { - return nil - } +func InferInitType(initCachingEnabled bool, sandboxType interop.SandboxType) interop.InitType { + initSource := InitTypeOnDemand - return &TracingCtx{ - SpanID: currentParent, - Type: tracingType, - Value: BuildFullTraceID(root, lambdaSegmentID, sample), + // ToDo: Unify this selection of SandboxType by using the START message + // after having a roadmap on the combination of INIT modes + if initCachingEnabled { + initSource = InitTypeInitCaching + } else if sandboxType == interop.SandboxPreWarmed { + initSource = InitTypeProvisionedConcurrency } + + return initSource +} + +func GetEpochTimeInISO8601FormatFromMonotime(monotime int64) string { + return time.Unix(0, metering.MonoToEpoch(monotime)).Format("2006-01-02T15:04:05.000Z") } const ( RuntimeDoneSuccess = "success" - RuntimeDoneFailure = "failure" + RuntimeDoneError = "error" ) -type EventsAPI interface { - SetCurrentRequestID(requestID string) - SendInitRuntimeDone(data *InitRuntimeDoneData) error - SendRestoreRuntimeDone(status string) error - SendRuntimeDone(data InvokeRuntimeDoneData) error - SendExtensionInit(agentName, state, errorType string, subscriptions []string) error - SendImageErrorLog(logline string) -} - type NoOpEventsAPI struct{} -func (s *NoOpEventsAPI) SetCurrentRequestID(requestID string) {} -func (s *NoOpEventsAPI) SendInitRuntimeDone(data *InitRuntimeDoneData) error { - return nil -} -func (s *NoOpEventsAPI) SendRestoreRuntimeDone(status string) error { - return nil -} -func (s *NoOpEventsAPI) SendRuntimeDone(data InvokeRuntimeDoneData) error { - return nil -} -func (s *NoOpEventsAPI) SendExtensionInit(agentName, state, errorType string, subscriptions []string) error { - return nil +func (s *NoOpEventsAPI) SetCurrentRequestID(interop.RequestID) {} + +func (s *NoOpEventsAPI) SendInitStart(interop.InitStartData) error { return nil } + +func (s *NoOpEventsAPI) SendInitRuntimeDone(interop.InitRuntimeDoneData) error { return nil } + +func (s *NoOpEventsAPI) SendInitReport(interop.InitReportData) error { return nil } + +func (s *NoOpEventsAPI) SendRestoreRuntimeDone(interop.RestoreRuntimeDoneData) error { return nil } + +func (s *NoOpEventsAPI) SendInvokeStart(interop.InvokeStartData) error { return nil } + +func (s *NoOpEventsAPI) SendInvokeRuntimeDone(interop.InvokeRuntimeDoneData) error { return nil } + +func (s *NoOpEventsAPI) SendExtensionInit(interop.ExtensionInitData) error { return nil } + +func (s *NoOpEventsAPI) SendEnd(interop.EndData) error { return nil } + +func (s *NoOpEventsAPI) SendReportSpan(interop.Span) error { return nil } + +func (s *NoOpEventsAPI) SendReport(interop.ReportData) error { return nil } + +func (s *NoOpEventsAPI) SendFault(interop.FaultData) error { return nil } + +func (s *NoOpEventsAPI) SendImageErrorLog(interop.ImageErrorLogData) {} + +func (s *NoOpEventsAPI) FetchTailLogs(string) (string, error) { return "", nil } + +func (s *NoOpEventsAPI) GetRuntimeDoneSpans( + runtimeStartedTime int64, + invokeResponseMetrics *interop.InvokeResponseMetrics, + runtimeOverheadStartedTime int64, + runtimeReadyTime int64, +) []interop.Span { + return []interop.Span{} } -func (s *NoOpEventsAPI) SendImageErrorLog(logline string) {} diff --git a/lambda/telemetry/events_api_test.go b/lambda/telemetry/events_api_test.go index b943be9..f69e4ea 100644 --- a/lambda/telemetry/events_api_test.go +++ b/lambda/telemetry/events_api_test.go @@ -15,65 +15,66 @@ import ( func TestGetRuntimeDoneInvokeMetrics(t *testing.T) { now := metering.Monotime() - invokeReceivedTime := now + runtimeStartedTime := now invokeResponseMetrics := &interop.InvokeResponseMetrics{ ProducedBytes: int64(100), RuntimeCalledResponse: true, } runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) - expected := &RuntimeDoneInvokeMetrics{ + expected := &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(100), DurationMs: float64(10), } - assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, runtimeDoneTime)) + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(runtimeStartedTime, invokeResponseMetrics, runtimeDoneTime)) } func TestGetRuntimeDoneInvokeMetricsWhenRuntimeCalledError(t *testing.T) { now := metering.Monotime() - invokeReceivedTime := now + runtimeStartedTime := now invokeResponseMetrics := &interop.InvokeResponseMetrics{ ProducedBytes: int64(100), RuntimeCalledResponse: false, } - runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + // validating microsecond precision + runtimeDoneTime := now + int64(time.Duration(10)*time.Millisecond+time.Duration(50)*time.Microsecond) - expected := &RuntimeDoneInvokeMetrics{ + expected := &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(0), - DurationMs: float64(10), + DurationMs: float64(10.05), } - assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, runtimeDoneTime)) + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(runtimeStartedTime, invokeResponseMetrics, runtimeDoneTime)) } -func TestGetRuntimeDoneInvokeMetricsWhenInvokeReceivedTimeIsZero(t *testing.T) { - now := int64(0) // January 1st, 1970 at 00:00:00 UTC - invokeReceivedTime := now +func TestGetRuntimeDoneInvokeMetricsWhenRuntimeStartedTimeIsMinusOne(t *testing.T) { + now := int64(-1) + runtimeStartedTime := now runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) - expected := &RuntimeDoneInvokeMetrics{ + expected := &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(0), DurationMs: float64(0), } - actual := GetRuntimeDoneInvokeMetrics(invokeReceivedTime, nil, runtimeDoneTime) + actual := GetRuntimeDoneInvokeMetrics(runtimeStartedTime, nil, runtimeDoneTime) assert.Equal(t, expected, actual) } func TestGetRuntimeDoneInvokeMetricsWhenInvokeResponseMetricsIsNil(t *testing.T) { now := metering.Monotime() - invokeReceivedTime := now + runtimeStartedTime := now runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) - expected := &RuntimeDoneInvokeMetrics{ + expected := &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(0), DurationMs: float64(10), } - assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, nil, runtimeDoneTime)) + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(runtimeStartedTime, nil, runtimeDoneTime)) } func TestGetRuntimeDoneSpans(t *testing.T) { @@ -81,29 +82,29 @@ func TestGetRuntimeDoneSpans(t *testing.T) { startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) - invokeReceivedTime := now + runtimeStartedTime := now invokeResponseMetrics := &interop.InvokeResponseMetrics{ StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, RuntimeCalledResponse: true, } - expectedResponseLatencyMsStartTime := getEpochTimeInISO8601FormatFromMonotime(now) - expectedResponseDurationMsStartTime := getEpochTimeInISO8601FormatFromMonotime(startReadingResponseMonoTimeMs) - expected := []Span{ - Span{ + expectedResponseLatencyMsStartTime := GetEpochTimeInISO8601FormatFromMonotime(now) + expectedResponseDurationMsStartTime := GetEpochTimeInISO8601FormatFromMonotime(startReadingResponseMonoTimeMs) + expected := []interop.Span{ + { Name: "responseLatency", Start: expectedResponseLatencyMsStartTime, DurationMs: 5, }, - Span{ + { Name: "responseDuration", Start: expectedResponseDurationMsStartTime, DurationMs: 2, }, } - assert.Equal(t, expected, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) + assert.Equal(t, expected, GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics)) } func TestGetRuntimeDoneSpansWhenRuntimeCalledError(t *testing.T) { @@ -111,29 +112,101 @@ func TestGetRuntimeDoneSpansWhenRuntimeCalledError(t *testing.T) { startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) - invokeReceivedTime := now + runtimeStartedTime := now invokeResponseMetrics := &interop.InvokeResponseMetrics{ StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, RuntimeCalledResponse: false, } - assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) + assert.Equal(t, []interop.Span{}, GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics)) } func TestGetRuntimeDoneSpansWhenInvokeResponseMetricsNil(t *testing.T) { - invokeReceivedTime := metering.Monotime() + runtimeStartedTime := metering.Monotime() - assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, nil)) + assert.Equal(t, []interop.Span{}, GetRuntimeDoneSpans(runtimeStartedTime, nil)) } -func TestGetRuntimeDoneSpansWhenInvokeReceivedTimeIsZero(t *testing.T) { - now := int64(0) // January 1st, 1970 at 00:00:00 UTC - invokeReceivedTime := now +func TestGetRuntimeDoneSpansWhenRuntimeStartedTimeIsMinusOne(t *testing.T) { + now := int64(-1) + runtimeStartedTime := now invokeResponseMetrics := &interop.InvokeResponseMetrics{ StartReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(5)), FinishReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(7)), } - assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) + assert.Equal(t, []interop.Span{}, GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics)) +} + +func TestInferInitType(t *testing.T) { + testCases := map[string]struct { + initCachingEnabled bool + sandboxType interop.SandboxType + expected interop.InitType + }{ + "on demand": { + initCachingEnabled: false, + sandboxType: interop.SandboxClassic, + expected: InitTypeOnDemand, + }, + "pc": { + initCachingEnabled: false, + sandboxType: interop.SandboxPreWarmed, + expected: InitTypeProvisionedConcurrency, + }, + "snap-start for OD": { + initCachingEnabled: true, + sandboxType: interop.SandboxClassic, + expected: InitTypeInitCaching, + }, + "snap-start for PC": { + initCachingEnabled: true, + sandboxType: interop.SandboxPreWarmed, + expected: InitTypeInitCaching, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + initType := InferInitType(tc.initCachingEnabled, tc.sandboxType) + assert.Equal(t, tc.expected, initType) + }) + } +} + +func TestCalculateDuration(t *testing.T) { + testCases := map[string]struct { + start int64 + end int64 + expected float64 + }{ + "milliseconds only": { + start: int64(100 * time.Millisecond), + end: int64(120 * time.Millisecond), + expected: 20, + }, + "with microseconds": { + start: int64(100 * time.Millisecond), + end: int64(210*time.Millisecond + 65*time.Microsecond), + expected: 110.065, + }, + "nanoseconds must be dropped": { + start: int64(100 * time.Millisecond), + end: int64(140*time.Millisecond + 999*time.Nanosecond), + expected: 40, + }, + "microseconds presented, nanoseconds dropped": { + start: int64(100 * time.Millisecond), + end: int64(150*time.Millisecond + 2*time.Microsecond + 999*time.Nanosecond), + expected: 50.002, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + actual := CalculateDuration(tc.start, tc.end) + assert.Equal(t, tc.expected, actual) + }) + } } diff --git a/lambda/telemetry/logs_egress_api.go b/lambda/telemetry/logs_egress_api.go index 7e84fe2..f4da62d 100644 --- a/lambda/telemetry/logs_egress_api.go +++ b/lambda/telemetry/logs_egress_api.go @@ -29,3 +29,5 @@ func (s *NoOpLogsEgressAPI) GetRuntimeSockets() (io.Writer, io.Writer, error) { // os.Stderr can not be used for the stderrWriter because stderr is for internal logging (not customer visible). return os.Stdout, os.Stdout, nil } + +var _ StdLogsEgressAPI = (*NoOpLogsEgressAPI)(nil) diff --git a/lambda/telemetry/logs_subscription_api.go b/lambda/telemetry/logs_subscription_api.go index 6ee9490..2fa39f0 100644 --- a/lambda/telemetry/logs_subscription_api.go +++ b/lambda/telemetry/logs_subscription_api.go @@ -12,7 +12,7 @@ import ( // SubscriptionAPI represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible type SubscriptionAPI interface { - Subscribe(agentName string, body io.Reader, headers map[string][]string) (resp []byte, status int, respHeaders map[string][]string, err error) + Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) (resp []byte, status int, respHeaders map[string][]string, err error) RecordCounterMetric(metricName string, count int) FlushMetrics() interop.TelemetrySubscriptionMetrics Clear() @@ -25,7 +25,7 @@ type SubscriptionAPI interface { type NoOpSubscriptionAPI struct{} // Subscribe writes response to a shared memory -func (m *NoOpSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { +func (m *NoOpSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { return []byte(`{}`), http.StatusOK, map[string][]string{}, nil } diff --git a/lambda/telemetry/tracer.go b/lambda/telemetry/tracer.go index affca60..889682b 100644 --- a/lambda/telemetry/tracer.go +++ b/lambda/telemetry/tracer.go @@ -17,8 +17,8 @@ import ( type traceContextKey int const ( - traceIDKey traceContextKey = iota - documentIDKey + TraceIDKey traceContextKey = iota + DocumentIDKey ) type Tracer interface { @@ -30,11 +30,14 @@ type Tracer interface { RecordInitStartTime() RecordInitEndTime() SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) + SendRestoreSubsegmentWithRecordedTimesOnce(ctx context.Context) MarkError(ctx context.Context) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) WithErrorCause(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error - TracingHeaderParser() func(context.Context, *interop.Invoke) string + BuildTracingHeader() func(context.Context) string + BuildTracingCtxForStart() *interop.TracingCtx + BuildTracingCtxAfterInvokeComplete() *interop.TracingCtx } type NoOpTracer struct{} @@ -42,28 +45,25 @@ type NoOpTracer struct{} func (t *NoOpTracer) Configure(invoke *interop.Invoke) {} func (t *NoOpTracer) CaptureInvokeSegment(ctx context.Context, criticalFunction func(context.Context) error) error { - criticalFunction(ctx) - return nil + return criticalFunction(ctx) } func (t *NoOpTracer) CaptureInitSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - criticalFunction(ctx) - return nil + return criticalFunction(ctx) } func (t *NoOpTracer) CaptureInvokeSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - criticalFunction(ctx) - return nil + return criticalFunction(ctx) } func (t *NoOpTracer) CaptureOverheadSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - criticalFunction(ctx) - return nil + return criticalFunction(ctx) } func (t *NoOpTracer) RecordInitStartTime() {} func (t *NoOpTracer) RecordInitEndTime() {} func (t *NoOpTracer) SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) {} +func (t *NoOpTracer) SendRestoreSubsegmentWithRecordedTimesOnce(ctx context.Context) {} func (t *NoOpTracer) MarkError(ctx context.Context) {} func (t *NoOpTracer) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) {} @@ -73,8 +73,25 @@ func (t *NoOpTracer) WithErrorCause(ctx context.Context, appCtx appctx.Applicati func (t *NoOpTracer) WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { return criticalFunction } -func (t *NoOpTracer) TracingHeaderParser() func(context.Context, *interop.Invoke) string { - return GetCustomerTracingHeader +func (t *NoOpTracer) BuildTracingHeader() func(context.Context) string { + // extract root trace ID and parent from context and build the tracing header + return func(ctx context.Context) string { + root, _ := ctx.Value(TraceIDKey).(string) + parent, _ := ctx.Value(DocumentIDKey).(string) + + if root != "" && parent != "" { + return fmt.Sprintf("Root=%s;Parent=%s;Sampled=1", root, parent) + } + + return "" + } +} + +func (t *NoOpTracer) BuildTracingCtxForStart() *interop.TracingCtx { + return nil +} +func (t *NoOpTracer) BuildTracingCtxAfterInvokeComplete() *interop.TracingCtx { + return nil } func NewNoOpTracer() *NoOpTracer { @@ -83,49 +100,31 @@ func NewNoOpTracer() *NoOpTracer { // NewTraceContext returns new derived context with trace config set for testing func NewTraceContext(ctx context.Context, root string, parent string) context.Context { - ctxWithRoot := context.WithValue(ctx, traceIDKey, root) - return context.WithValue(ctxWithRoot, documentIDKey, parent) + ctxWithRoot := context.WithValue(ctx, TraceIDKey, root) + return context.WithValue(ctxWithRoot, DocumentIDKey, parent) } -// GetCustomerTracingHeader extracts the trace config from trace context and constructs header -func GetCustomerTracingHeader(ctx context.Context, invoke *interop.Invoke) string { - var root, parent string - var ok bool - - if root, ok = ctx.Value(traceIDKey).(string); !ok { - return invoke.TraceID - } - - if parent, ok = ctx.Value(documentIDKey).(string); !ok { - return invoke.TraceID - } - - return fmt.Sprintf("Root=%s;Parent=%s;Sampled=1", root, parent) - -} - -// ParseTraceID helps client to get TraceID, ParentID, Sampled information from a full trace -func ParseTraceID(fullTraceID string) (rootID, parentID, sample string) { - traceIDInfo := strings.Split(fullTraceID, ";") - for i := 0; i < len(traceIDInfo); i++ { - if len(traceIDInfo[i]) == 0 { - continue - } else { - var key string - var value string - keyValuePair := strings.Split(traceIDInfo[i], "=") - if len(keyValuePair) == 2 { - key = keyValuePair[0] - value = keyValuePair[1] - } - switch key { - case "Root": - rootID = value - case "Parent": - parentID = value - case "Sampled": - sample = value - } +// ParseTracingHeader extracts RootTraceID, ParentID, Sampled, and Lineage from a tracing header. +// Tracing header format is defined here: +// https://docs.aws.amazon.com/xray/latest/devguide/xray-concepts.html#xray-concepts-tracingheader +func ParseTracingHeader(tracingHeader string) (rootID, parentID, sampled, lineage string) { + keyValuePairs := strings.Split(tracingHeader, ";") + for _, pair := range keyValuePairs { + var key, value string + keyValue := strings.Split(pair, "=") + if len(keyValue) == 2 { + key = keyValue[0] + value = keyValue[1] + } + switch key { + case "Root": + rootID = value + case "Parent": + parentID = value + case "Sampled": + sampled = value + case "Lineage": + lineage = value } } return diff --git a/lambda/telemetry/tracer_test.go b/lambda/telemetry/tracer_test.go index c31653f..d67c389 100644 --- a/lambda/telemetry/tracer_test.go +++ b/lambda/telemetry/tracer_test.go @@ -4,35 +4,65 @@ package telemetry import ( + "context" + "fmt" + "strings" "testing" "go.amzn.com/lambda/rapi/model" ) +var BigString = strings.Repeat("a", 255) + var parserTests = []struct { - traceIDIn string - rootIDOut string - parentIDOut string - sampledOut string + tracingHeaderIn string + rootIDOut string + parentIDOut string + sampledOut string + lineageOut string }{ - {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", "1-5b3cc918-939afd635f8891ba6a9e1df6", "c88d77b0aef840e9", "1"}, - {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9", "1-5b3cc918-939afd635f8891ba6a9e1df6", "c88d77b0aef840e9", ""}, - {"1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", "", "c88d77b0aef840e9", "1"}, - {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6", "1-5b3cc918-939afd635f8891ba6a9e1df6", "", ""}, + {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", "1-5b3cc918-939afd635f8891ba6a9e1df6", "c88d77b0aef840e9", "1", ""}, + {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9", "1-5b3cc918-939afd635f8891ba6a9e1df6", "c88d77b0aef840e9", "", ""}, + {"1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", "", "c88d77b0aef840e9", "1", ""}, + {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6", "1-5b3cc918-939afd635f8891ba6a9e1df6", "", "", ""}, + {"", "", "", "", ""}, + {"abc;;", "", "", "", ""}, + {"abc", "", "", "", ""}, + {"abc;asd", "", "", "", ""}, + {"abc=as;asd=as", "", "", "", ""}, + {"Root=abc", "abc", "", "", ""}, + {"Root=abc;Parent=zxc;Sampled=1", "abc", "zxc", "1", ""}, + {"Root=root;Parent=par", "root", "par", "", ""}, + {"Root=root;Par", "root", "", "", ""}, + {"Root=", "", "", "", ""}, + {";Root=root;;", "root", "", "", ""}, + {"Root=root;Parent=parent;", "root", "parent", "", ""}, + {"Root=;Parent=parent;Sampled=1", "", "parent", "1", ""}, + {"Root=abc;Parent=zxc;Sampled=1;Lineage", "abc", "zxc", "1", ""}, + {"Root=abc;Parent=zxc;Sampled=1;Lineage=", "abc", "zxc", "1", ""}, + {"Root=abc;Parent=zxc;Sampled=1;Lineage=foo:1|bar:65535", "abc", "zxc", "1", "foo:1|bar:65535"}, + {"Root=abc;Parent=zxc;Lineage=foo:1|bar:65535;Sampled=1", "abc", "zxc", "1", "foo:1|bar:65535"}, + {fmt.Sprintf("Root=%s;Parent=%s;Sampled=1;Lineage=%s", BigString, BigString, BigString), BigString, BigString, "1", BigString}, } -func TestParseTraceID(t *testing.T) { +func TestParseTracingHeader(t *testing.T) { for _, tt := range parserTests { - t.Run(tt.traceIDIn, func(t *testing.T) { - rootID, parentID, sampled := ParseTraceID(tt.traceIDIn) + t.Run(tt.tracingHeaderIn, func(t *testing.T) { + rootID, parentID, sampled, lineage := ParseTracingHeader(tt.tracingHeaderIn) if rootID != tt.rootIDOut { - t.Errorf("got %q, wanted %q", rootID, tt.rootIDOut) + t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, rootID, tt.rootIDOut) } if parentID != tt.parentIDOut { - t.Errorf("got %q, wanted %q", rootID, tt.parentIDOut) + t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, parentID, tt.parentIDOut) } if sampled != tt.sampledOut { - t.Errorf("got %q, wanted %q", sampled, tt.sampledOut) + t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, sampled, tt.sampledOut) + } + if lineage != tt.lineageOut { + t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, lineage, tt.lineageOut) + } + if lineage != tt.lineageOut { + t.Errorf("got %q, wanted %q", lineage, tt.lineageOut) } }) } @@ -81,3 +111,45 @@ func TestBuildFullTraceID(t *testing.T) { }) } } + +func TestTracerDoesntSwallowErrorsFromCriticalFunctions(t *testing.T) { + ctx := context.Background() + + testCases := []struct { + name string + tracer Tracer + expectedError error + }{ + { + name: "NoOpTracer-success", + tracer: &NoOpTracer{}, + expectedError: nil, + }, + { + name: "NoOpTracer-fail", + tracer: &NoOpTracer{}, + expectedError: fmt.Errorf("invoke error"), + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + criticalFunction := func(ctx context.Context) error { + return test.expectedError + } + + if err := test.tracer.CaptureInvokeSegment(ctx, criticalFunction); err != test.expectedError { + t.Errorf("CaptureInvokeSegment failed; expected: '%v', but got: '%v'", test.expectedError, err) + } + if err := test.tracer.CaptureInitSubsegment(ctx, criticalFunction); err != test.expectedError { + t.Errorf("CaptureInitSubsegment failed; expected: '%v', but got: '%v'", test.expectedError, err) + } + if err := test.tracer.CaptureInvokeSubsegment(ctx, criticalFunction); err != test.expectedError { + t.Errorf("CaptureInvokeSubsegment failed; expected: '%v', but got: '%v'", test.expectedError, err) + } + if err := test.tracer.CaptureOverheadSubsegment(ctx, criticalFunction); err != test.expectedError { + t.Errorf("CaptureOverheadSubsegment failed; expected: '%v', but got: '%v'", test.expectedError, err) + } + }) + } +} diff --git a/lambda/testdata/flowtesting.go b/lambda/testdata/flowtesting.go index c028d7c..e2c4b49 100644 --- a/lambda/testdata/flowtesting.go +++ b/lambda/testdata/flowtesting.go @@ -4,10 +4,9 @@ package testdata import ( + "bytes" "context" - "io" "io/ioutil" - "net/http" "time" "go.amzn.com/lambda/appctx" @@ -25,15 +24,15 @@ const ( type MockInteropServer struct { Response []byte - ErrorResponse *interop.ErrorResponse + ErrorResponse *interop.ErrorInvokeResponse ResponseContentType string FunctionResponseMode string ActiveInvokeID string } // SendResponse writes response to a shared memory. -func (i *MockInteropServer) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { - bytes, err := ioutil.ReadAll(reader) +func (i *MockInteropServer) SendResponse(invokeID string, resp *interop.StreamableInvokeResponse) error { + bytes, err := ioutil.ReadAll(resp.Payload) if err != nil { return err } @@ -44,23 +43,23 @@ func (i *MockInteropServer) SendResponse(invokeID string, headers map[string]str } } i.Response = bytes - i.ResponseContentType = headers[contentTypeHeader] - i.FunctionResponseMode = headers[functionResponseModeHeader] + i.ResponseContentType = resp.Headers[contentTypeHeader] + i.FunctionResponseMode = resp.Headers[functionResponseModeHeader] return nil } // SendErrorResponse writes error response to a shared memory and sends GIRD FAULT. -func (i *MockInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error { +func (i *MockInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorInvokeResponse) error { i.ErrorResponse = response - i.ResponseContentType = response.ContentType - i.FunctionResponseMode = response.FunctionResponseMode + i.ResponseContentType = response.Headers.ContentType + i.FunctionResponseMode = response.Headers.FunctionResponseMode return nil } // SendInitErrorResponse writes error response during init to a shared memory and sends GIRD FAULT. -func (i *MockInteropServer) SendInitErrorResponse(invokeID string, response *interop.ErrorResponse) error { +func (i *MockInteropServer) SendInitErrorResponse(response *interop.ErrorInvokeResponse) error { i.ErrorResponse = response - i.ResponseContentType = response.ContentType + i.ResponseContentType = response.Headers.ContentType return nil } @@ -81,7 +80,7 @@ type FlowTest struct { InteropServer *MockInteropServer TelemetrySubscription *telemetry.NoOpSubscriptionAPI CredentialsService core.CredentialsService - EventsAPI telemetry.EventsAPI + EventsAPI interop.EventsAPI } // ConfigureForInit initialize synchronization gates and states for init. @@ -93,13 +92,25 @@ func (s *FlowTest) ConfigureForInit() { func (s *FlowTest) ConfigureForInvoke(ctx context.Context, invoke *interop.Invoke) { s.InteropServer.ActiveInvokeID = invoke.ID s.InvokeFlow.InitializeBarriers() - s.RenderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invoke, telemetry.GetCustomerTracingHeader)) + var buf bytes.Buffer // create default invoke renderer with new request buffer each time + s.ConfigureInvokeRenderer(ctx, invoke, &buf) +} + +// ConfigureInvokeRenderer overrides default invoke renderer to reuse request buffers (for benchmarks), etc. +func (s *FlowTest) ConfigureInvokeRenderer(ctx context.Context, invoke *interop.Invoke, buf *bytes.Buffer) { + s.RenderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invoke, buf, telemetry.NewNoOpTracer().BuildTracingHeader())) } func (s *FlowTest) ConfigureForRestore() { s.RenderingService.SetRenderer(rendering.NewRestoreRenderer()) } +func (s *FlowTest) ConfigureForRestoring() { + s.RegistrationService.PreregisterRuntime(s.Runtime) + s.Runtime.SetState(s.Runtime.RuntimeRestoringState) + s.RenderingService.SetRenderer(rendering.NewRestoreRenderer()) +} + func (s *FlowTest) ConfigureForInitCaching(token, awsKey, awsSecret, awsSession string) { credentialsExpiration := time.Now().Add(30 * time.Minute) s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession, credentialsExpiration) @@ -118,6 +129,8 @@ func NewFlowTest() *FlowTest { interopServer := &MockInteropServer{} eventsAPI := telemetry.NoOpEventsAPI{} appctx.StoreInteropServer(appCtx, interopServer) + appctx.StoreResponseSender(appCtx, interopServer) + return &FlowTest{ AppCtx: appCtx, InitFlow: initFlow, diff --git a/lambda/testdata/mocktracer/mocktracer.go b/lambda/testdata/mocktracer/mocktracer.go index f6ee9ab..3fb7054 100644 --- a/lambda/testdata/mocktracer/mocktracer.go +++ b/lambda/testdata/mocktracer/mocktracer.go @@ -5,14 +5,15 @@ package mocktracer import ( "context" - "go.amzn.com/lambda/xray" "time" + + "go.amzn.com/lambda/xray" ) // MockStartTime is start time set in Start method var MockStartTime = time.Now().UnixNano() -//MockEndTime is end time set in End method +// MockEndTime is end time set in End method var MockEndTime = time.Now().UnixNano() + 1 // MockTracer is used for unit tests