diff --git a/README.md b/README.md index 804223f..1fe10f6 100644 --- a/README.md +++ b/README.md @@ -99,8 +99,11 @@ You install the runtime interface emulator to your local machine. When you run t 2. Run your Lambda image function using the docker run command. - `docker run -d -v ~/.aws-lambda-rie:/aws-lambda -p 9000:8080 myfunction:latest - --entrypoint /aws-lambda/aws-lambda-rie <(optional) image command>` + ``` + docker run -d -v ~/.aws-lambda-rie:/aws-lambda -p 9000:8080 \ + --entrypoint /aws-lambda/aws-lambda-rie \ + myfunction:latest <(optional) image command> + ```` This runs the image as a container and starts up an endpoint locally at `localhost:9000/2015-03-31/functions/function/invocations`. diff --git a/cmd/aws-lambda-rie/handlers.go b/cmd/aws-lambda-rie/handlers.go index 75f5bda..39097fc 100644 --- a/cmd/aws-lambda-rie/handlers.go +++ b/cmd/aws-lambda-rie/handlers.go @@ -4,8 +4,8 @@ package main import ( + "bytes" "fmt" - "io" "io/ioutil" "math" "net/http" @@ -24,7 +24,7 @@ import ( type Sandbox interface { Init(i *interop.Init, invokeTimeoutMs int64) - Invoke(responseWriter io.Writer, invoke *interop.Invoke) error + Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error } var initDone bool @@ -98,7 +98,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { InvokedFunctionArn: fmt.Sprintf("arn:aws:lambda:us-east-1:012345678912:function:%s", GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function")), TraceID: r.Header.Get("X-Amzn-Trace-Id"), LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: bodyBytes, + Payload: bytes.NewReader(bodyBytes), CorrelationID: "invokeCorrelationID", } fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion) diff --git a/cmd/aws-lambda-rie/main.go b/cmd/aws-lambda-rie/main.go index edcefe9..4c28f51 100644 --- a/cmd/aws-lambda-rie/main.go +++ b/cmd/aws-lambda-rie/main.go @@ -15,7 +15,6 @@ import ( log "github.com/sirupsen/logrus" ) - const ( optBootstrap = "/opt/bootstrap" runtimeBootstrap = "/var/runtime/bootstrap" @@ -58,24 +57,22 @@ func getCLIArgs() (options, []string) { } func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) { - var bootstrapLookupCmdList [][]string + var bootstrapLookupCmd []string var handler string currentWorkingDir := "/var/task" // default value if len(args) <= 1 { - bootstrapLookupCmdList = [][]string{ - []string{fmt.Sprintf("%s/bootstrap", currentWorkingDir)}, - []string{optBootstrap}, - []string{runtimeBootstrap}, + bootstrapLookupCmd = []string{ + fmt.Sprintf("%s/bootstrap", currentWorkingDir), + optBootstrap, + runtimeBootstrap, } // handler is used later to set an env var for Lambda Image support handler = "" } else if len(args) > 1 { - bootstrapLookupCmdList = [][]string{ - args[1:], - } + bootstrapLookupCmd = args[1:] if cwd, err := os.Getwd(); err == nil { currentWorkingDir = cwd @@ -92,5 +89,5 @@ func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) { log.Panic("insufficient arguments: bootstrap not provided") } - return rapidcore.NewBootstrap(bootstrapLookupCmdList, currentWorkingDir), handler + return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir), handler } diff --git a/lambda/core/directinvoke/customerheaders.go b/lambda/core/directinvoke/customerheaders.go new file mode 100644 index 0000000..fd0e4ad --- /dev/null +++ b/lambda/core/directinvoke/customerheaders.go @@ -0,0 +1,41 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "bytes" + "encoding/base64" + "encoding/json" +) + +type CustomerHeaders struct { + CognitoIdentityID string `json:"Cognito-Identity-Id"` + CognitoIdentityPoolID string `json:"Cognito-Identity-Pool-Id"` + ClientContext string `json:"Client-Context"` +} + +func (s CustomerHeaders) Dump() string { + if (s == CustomerHeaders{}) { + return "" + } + + custHeadersJSON, err := json.Marshal(&s) + if err != nil { + panic(err) + } + + return base64.StdEncoding.EncodeToString(custHeadersJSON) +} + +func (s *CustomerHeaders) Load(in string) error { + *s = CustomerHeaders{} + + if in == "" { + return nil + } + + base64Decoder := base64.NewDecoder(base64.StdEncoding, bytes.NewReader([]byte(in))) + + return json.NewDecoder(base64Decoder).Decode(s) +} diff --git a/lambda/core/directinvoke/customerheaders_test.go b/lambda/core/directinvoke/customerheaders_test.go new file mode 100644 index 0000000..d81cbf4 --- /dev/null +++ b/lambda/core/directinvoke/customerheaders_test.go @@ -0,0 +1,25 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestCustomerHeadersEmpty(t *testing.T) { + in := CustomerHeaders{} + out := CustomerHeaders{} + + require.NoError(t, out.Load(in.Dump())) + require.Equal(t, in, out) +} + +func TestCustomerHeaders(t *testing.T) { + in := CustomerHeaders{CognitoIdentityID: "asd"} + out := CustomerHeaders{} + + require.NoError(t, out.Load(in.Dump())) + require.Equal(t, in, out) +} diff --git a/lambda/core/directinvoke/directinvoke.go b/lambda/core/directinvoke/directinvoke.go new file mode 100644 index 0000000..ab1075d --- /dev/null +++ b/lambda/core/directinvoke/directinvoke.go @@ -0,0 +1,116 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "fmt" + "io" + "net/http" + + "github.com/go-chi/chi" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" +) + +const ( + InvokeIDHeader = "Invoke-Id" + InvokedFunctionArnHeader = "Invoked-Function-Arn" + VersionIDHeader = "Invoked-Function-Version" + ReservationTokenHeader = "Reservation-Token" + CustomerHeadersHeader = "Customer-Headers" + ContentTypeHeader = "Content-Type" + + ErrorTypeHeader = "Error-Type" + + EndOfResponseTrailer = "End-Of-Response" + + SandboxErrorType = "Error.Sandbox" +) + +const ( + EndOfResponseComplete = "Complete" + EndOfResponseTruncated = "Truncated" + EndOfResponseOversized = "Oversized" +) + +var MaxDirectResponseSize int64 = interop.MaxPayloadSize // this is intentionally not a constant so we can configure it via CLI + +func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) { + w.Header().Set(ErrorTypeHeader, errorType) + w.WriteHeader(http.StatusBadRequest) + w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) +} + +// ReceiveDirectInvoke parses invoke and verifies it against Token message. Uses deadline provided by Token +// Renders BadRequest in case of error +func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.Token) (*interop.Invoke, error) { + w.Header().Set("Trailer", EndOfResponseTrailer) + + custHeaders := CustomerHeaders{} + if err := custHeaders.Load(r.Header.Get(CustomerHeadersHeader)); err != nil { + renderBadRequest(w, r, interop.ErrMalformedCustomerHeaders.Error()) + return nil, interop.ErrMalformedCustomerHeaders + } + + now := metering.Monotime() + inv := &interop.Invoke{ + ID: r.Header.Get(InvokeIDHeader), + ReservationToken: chi.URLParam(r, "reservationtoken"), + InvokedFunctionArn: r.Header.Get(InvokedFunctionArnHeader), + VersionID: r.Header.Get(VersionIDHeader), + ContentType: r.Header.Get(ContentTypeHeader), + CognitoIdentityID: custHeaders.CognitoIdentityID, + CognitoIdentityPoolID: custHeaders.CognitoIdentityPoolID, + TraceID: token.TraceID, + LambdaSegmentID: token.LambdaSegmentID, + ClientContext: custHeaders.ClientContext, + Payload: r.Body, + CorrelationID: "invokeCorrelationID", + DeadlineNs: fmt.Sprintf("%d", now+token.FunctionTimeout.Nanoseconds()), + NeedDebugLogs: token.NeedDebugLogs, + InvokeReceivedTime: now, + } + + if inv.ID != token.InvokeID { + renderBadRequest(w, r, interop.ErrInvalidInvokeID.Error()) + return nil, interop.ErrInvalidInvokeID + } + + if inv.ReservationToken != token.ReservationToken { + renderBadRequest(w, r, interop.ErrInvalidReservationToken.Error()) + return nil, interop.ErrInvalidReservationToken + } + + if inv.VersionID != token.VersionID { + renderBadRequest(w, r, interop.ErrInvalidFunctionVersion.Error()) + return nil, interop.ErrInvalidFunctionVersion + } + + if now > token.InvackDeadlineNs { + renderBadRequest(w, r, interop.ErrReservationExpired.Error()) + return nil, interop.ErrReservationExpired + } + + w.Header().Set(VersionIDHeader, token.VersionID) + w.Header().Set(ReservationTokenHeader, token.ReservationToken) + w.Header().Set(InvokeIDHeader, token.InvokeID) + + return inv, nil +} + +func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, w http.ResponseWriter) error { + for k, v := range additionalHeaders { + w.Header().Add(k, v) + } + + n, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // +1 because we do allow 10MB but not 10MB + 1 byte + if err != nil { + w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) + } else if n == MaxDirectResponseSize+1 { + w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) + } else { + w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) + } + return err +} diff --git a/lambda/core/registrations.go b/lambda/core/registrations.go index 89abc9a..dca9d90 100644 --- a/lambda/core/registrations.go +++ b/lambda/core/registrations.go @@ -158,6 +158,13 @@ func (s *registrationServiceImpl) getInternalStateDescription(appCtx appctx.Appl } func (s *registrationServiceImpl) CountAgents() int { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.countAgentsUnsafe() +} + +func (s *registrationServiceImpl) countAgentsUnsafe() int { res := 0 s.externalAgents.Visit(func(a *ExternalAgent) { res++ @@ -237,7 +244,7 @@ func (s *registrationServiceImpl) CreateInternalAgent(agentName string) (*Intern return nil, ErrRegistrationServiceOff } - if s.CountAgents() >= MaxAgentsAllowed { + if s.countAgentsUnsafe() >= MaxAgentsAllowed { return nil, ErrTooManyExtensions } diff --git a/lambda/core/statejson/description.go b/lambda/core/statejson/description.go index 8f0508d..eb46946 100644 --- a/lambda/core/statejson/description.go +++ b/lambda/core/statejson/description.go @@ -10,8 +10,9 @@ import ( // StateDescription ... type StateDescription struct { - Name string `json:"name"` - LastModified int64 `json:"lastModified"` + Name string `json:"name"` + LastModified int64 `json:"lastModified"` + ResponseTimeNs int64 `json:"responseTimeNs"` } // RuntimeDescription ... @@ -34,6 +35,11 @@ type InternalStateDescription struct { FirstFatalError string `json:"firstFatalError"` } +// ResetDescription describes fields of the response to an INVOKE API request +type ResetDescription struct { + ExtensionsResetMs int64 `json:"extensionsResetMs"` +} + func (s *InternalStateDescription) AsJSON() []byte { bytes, err := json.Marshal(s) if err != nil { @@ -41,3 +47,11 @@ func (s *InternalStateDescription) AsJSON() []byte { } return bytes } + +func (s *ResetDescription) AsJSON() []byte { + bytes, err := json.Marshal(s) + if err != nil { + log.Panicf("Failed to marshall reset description: %s", err) + } + return bytes +} diff --git a/lambda/core/states.go b/lambda/core/states.go index 32469bb..e6df068 100644 --- a/lambda/core/states.go +++ b/lambda/core/states.go @@ -85,6 +85,7 @@ type Runtime struct { currentState RuntimeState stateLastModified time.Time Pid int + responseTime time.Time RuntimeStartedState RuntimeState RuntimeInitErrorState RuntimeState @@ -150,19 +151,27 @@ func (s *Runtime) InitError() error { func (s *Runtime) ResponseSent() error { s.ManagedThread.Lock() defer s.ManagedThread.Unlock() - return s.currentState.ResponseSent() + err := s.currentState.ResponseSent() + if err == nil { + s.responseTime = time.Now() + } + return err } // GetRuntimeDescription returns runtime description object for debugging purposes func (s *Runtime) GetRuntimeDescription() statejson.RuntimeDescription { s.ManagedThread.Lock() defer s.ManagedThread.Unlock() - return statejson.RuntimeDescription{ + res := statejson.RuntimeDescription{ State: statejson.StateDescription{ Name: s.currentState.Name(), LastModified: s.stateLastModified.UnixNano() / int64(time.Millisecond), }, } + if !s.responseTime.IsZero() { + res.State.ResponseTimeNs = s.responseTime.UnixNano() + } + return res } // NewRuntime returns new Runtime instance. diff --git a/lambda/core/states_test.go b/lambda/core/states_test.go index 2dd25a2..1b6a62e 100644 --- a/lambda/core/states_test.go +++ b/lambda/core/states_test.go @@ -110,6 +110,7 @@ func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { runtime.SetState(runtime.RuntimeInvocationResponseState) assert.NoError(t, runtime.ResponseSent()) assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) + assert.NotEqual(t, 0, runtime.GetRuntimeDescription().State.ResponseTimeNs) // InvocationResponse-> InvocationResponse runtime.SetState(runtime.RuntimeInvocationResponseState) assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) diff --git a/lambda/fatalerror/fatalerror.go b/lambda/fatalerror/fatalerror.go index 1aef106..7292baf 100644 --- a/lambda/fatalerror/fatalerror.go +++ b/lambda/fatalerror/fatalerror.go @@ -16,6 +16,7 @@ const ( AgentLaunchError ErrorType = "Extension.LaunchError" // agent could not be launched RuntimeExit ErrorType = "Runtime.ExitError" InvalidEntrypoint ErrorType = "Runtime.InvalidEntrypoint" + InvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir" InvalidTaskConfig ErrorType = "Runtime.InvalidTaskConfig" Unknown ErrorType = "Unknown" ) diff --git a/lambda/interop/model.go b/lambda/interop/model.go index 6f2af15..6735a8b 100644 --- a/lambda/interop/model.go +++ b/lambda/interop/model.go @@ -7,13 +7,18 @@ import ( "encoding/json" "fmt" "io" + "net/http" + "time" "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/fatalerror" ) // MaxPayloadSize max event body size declared as LAMBDA_EVENT_BODY_SIZE const MaxPayloadSize = 6*1024*1024 + 100 // 6 MiB + 100 bytes +const functionResponseSizeTooLargeType = "Function.ResponseSizeTooLarge" + // Message is a generic interop message. type Message interface{} @@ -30,23 +35,29 @@ type Invoke struct { DeadlineNs string ClientContext string ContentType string - Payload []byte + Payload io.Reader NeedDebugLogs bool CorrelationID string // internal use only + ReservationToken string + VersionID string + InvokeReceivedTime int64 } -// Response is a response to an invoke that is sent to the slicer. -type Response struct { - Payload []byte +type Token struct { + ReservationToken string + InvokeID string + VersionID string + FunctionTimeout time.Duration + InvackDeadlineNs int64 + TraceID string + LambdaSegmentID string + InvokeMetadata string + NeedDebugLogs bool } -// ErrorResponse is an error response that is sent to the slicer. -// -// Note, this struct is implementation-specific to how Slicer -// processes errors. type ErrorResponse struct { // Payload sent via shared memory. - Payload []byte + Payload []byte `json:"Payload,omitempty"` // When error response body (Payload) is not provided, e.g. // not retrievable, error type and error message will be @@ -57,11 +68,11 @@ type ErrorResponse struct { // // when error type is provided, error response becomes: // '{"errorMessage":"Unknown application error occurred","errorType":"ErrorType"}' - ErrorType string - ErrorMessage string + ErrorType string `json:"errorType,omitempty"` + ErrorMessage string `json:"errorMessage,omitempty"` // Attached to invoke segment - ErrorCause json.RawMessage + ErrorCause json.RawMessage `json:"ErrorCause,omitempty"` } // SandboxType identifies sandbox type (PreWarmed vs Classic) @@ -113,41 +124,85 @@ type Shutdown struct { // Metrics for response status of LogsAPI `/subscribe` calls type LogsAPIMetrics map[string]int -// Done message is sent to the slicer, part of the protocol. -type Done struct { - WaitForExit bool +type DoneMetadata struct { NumActiveExtensions int + ExtensionsResetMs int64 RuntimeRelease string - ErrorType string // internal use only, still in use by standalone - CorrelationID string // internal use only // Metrics for response status of LogsAPI `/subscribe` calls - LogsAPIMetrics LogsAPIMetrics + LogsAPIMetrics LogsAPIMetrics + InvokeRequestReadTimeNs int64 + InvokeRequestSizeBytes int64 + InvokeCompletionTimeNs int64 + InvokeReceivedTime int64 +} + +type Done struct { + WaitForExit bool + ErrorType fatalerror.ErrorType + CorrelationID string // internal use only + Meta DoneMetadata } -// DoneFail message is sent to the slicer to report error and request reset. type DoneFail struct { - RuntimeRelease string - NumActiveExtensions int - ErrorType string - CorrelationID string // internal use only - // Metrics for response status of LogsAPI `/subscribe` calls - LogsAPIMetrics LogsAPIMetrics + ErrorType fatalerror.ErrorType + CorrelationID string // internal use only + Meta DoneMetadata } -// ErrInvalidInvokeID is returned when provided invokeID doesn't match current invokeID +// ErrInvalidInvokeID is returned when invokeID provided in Invoke2 does not match one provided in Token var ErrInvalidInvokeID = fmt.Errorf("ErrInvalidInvokeID") +// ErrInvalidReservationToken is returned when reservationToken provided in Invoke2 does not match one provided in Token +var ErrInvalidReservationToken = fmt.Errorf("ErrInvalidReservationToken") + +// ErrInvalidFunctionVersion is returned when functionVersion provided in Invoke2 does not match one provided in Token +var ErrInvalidFunctionVersion = fmt.Errorf("ErrInvalidFunctionVersion") + +// ErrMalformedCustomerHeaders is returned when customer headers format is invalid +var ErrMalformedCustomerHeaders = fmt.Errorf("ErrMalformedCustomerHeaders") + // ErrResponseSent is returned when response with given invokeID was already sent. var ErrResponseSent = fmt.Errorf("ErrResponseSent") +// ErrReservationExpired is returned when invoke arrived after InvackDeadline +var ErrReservationExpired = fmt.Errorf("ErrReservationExpired") + +// ErrorResponseTooLarge is returned when response Payload exceeds shared memory buffer size +type ErrorResponseTooLarge struct { + MaxResponseSize int + ResponseSize int +} + +// ErrorResponseTooLarge is returned when response provided by Runtime does not fit into shared memory buffer +func (s *ErrorResponseTooLarge) Error() string { + return fmt.Sprintf("Response payload size (%d bytes) exceeded maximum allowed payload size (%d bytes).", s.ResponseSize, s.MaxResponseSize) +} + +// AsErrorResponse generates ErrorResponse from ErrorResponseTooLarge +func (s *ErrorResponseTooLarge) AsInteropError() *ErrorResponse { + resp := ErrorResponse{ + ErrorType: functionResponseSizeTooLargeType, + ErrorMessage: s.Error(), + } + respJSON, err := json.Marshal(resp) + if err != nil { + panic("Failed to marshal interop.ErrorResponse") + } + resp.Payload = respJSON + return &resp +} + // Server implements Slicer communication protocol. type Server interface { + // StartAcceptingDirectInvokes starts accepting on direct invoke socket (if one is available) + StartAcceptingDirectInvokes() error + // SendErrorResponse sends response. // Errors returned: // ErrInvalidInvokeID - validation error indicating that provided invokeID doesn't match current invokeID // ErrResponseSent - validation error indicating that response with given invokeID was already sent // Non-nil error - non-nil error indicating transport failure - SendResponse(invokeID string, response *Response) error + SendResponse(invokeID string, response io.Reader) error // SendErrorResponse sends error response. // Errors returned: @@ -207,7 +262,7 @@ type Server interface { Init(i *Start, invokeTimeoutMs int64) - Invoke(responseWriter io.Writer, invoke *Invoke) error + Invoke(responseWriter http.ResponseWriter, invoke *Invoke) error Shutdown(shutdown *Shutdown) *statejson.InternalStateDescription } diff --git a/lambda/logging/taillog.go b/lambda/logging/taillog.go index af663d3..9fe5352 100644 --- a/lambda/logging/taillog.go +++ b/lambda/logging/taillog.go @@ -5,26 +5,37 @@ package logging import ( "io" + "sync" ) // TailLogWriter writes tail/debug log to provided io.Writer type TailLogWriter struct { out io.Writer enabled bool + mutex sync.Mutex } // Enable enables log writer. func (lw *TailLogWriter) Enable() { + lw.mutex.Lock() + defer lw.mutex.Unlock() + lw.enabled = true } // Disable disables log writer. func (lw *TailLogWriter) Disable() { + lw.mutex.Lock() + defer lw.mutex.Unlock() + lw.enabled = false } // Writer wraps the basic io.Write method func (lw *TailLogWriter) Write(p []byte) (n int, err error) { + lw.mutex.Lock() + defer lw.mutex.Unlock() + if lw.enabled { return lw.out.Write(p) } diff --git a/lambda/metering/time.go b/lambda/metering/time.go index c3ccefe..1f5f047 100644 --- a/lambda/metering/time.go +++ b/lambda/metering/time.go @@ -5,7 +5,8 @@ package metering import ( _ "runtime" //for nanotime() and walltime() - _ "unsafe" //for go:linkname + "time" + _ "unsafe" //for go:linkname ) //go:linkname Monotime runtime.nanotime @@ -24,3 +25,35 @@ func MonoToEpoch(t int64) int64 { clockOffset := wallNsec - monoNsec return t + clockOffset } + +type ExtensionsResetDurationProfiler struct { + NumAgentsRegisteredForShutdown int + AvailableNs int64 + extensionsResetStartTimeNs int64 + extensionsResetEndTimeNs int64 +} + +func (p *ExtensionsResetDurationProfiler) Start() { + p.extensionsResetStartTimeNs = Monotime() +} + +func (p *ExtensionsResetDurationProfiler) Stop() { + p.extensionsResetEndTimeNs = Monotime() +} + +func (p *ExtensionsResetDurationProfiler) CalculateExtensionsResetMs() (int64, bool) { + var extensionsResetDurationNs = p.extensionsResetEndTimeNs - p.extensionsResetStartTimeNs + var extensionsResetMs int64 + timedOut := false + + if p.NumAgentsRegisteredForShutdown == 0 || p.AvailableNs < 0 || extensionsResetDurationNs < 0 { + extensionsResetMs = 0 + } else if extensionsResetDurationNs > p.AvailableNs { + extensionsResetMs = p.AvailableNs / time.Millisecond.Nanoseconds() + timedOut = true + } else { + extensionsResetMs = extensionsResetDurationNs / time.Millisecond.Nanoseconds() + } + + return extensionsResetMs, timedOut +} diff --git a/lambda/metering/time_test.go b/lambda/metering/time_test.go index 1025543..0088f9f 100644 --- a/lambda/metering/time_test.go +++ b/lambda/metering/time_test.go @@ -18,3 +18,60 @@ func TestMonoToEpochPrecision(t *testing.T) { // Conversion error is less than a millisecond. assert.True(t, math.Abs(float64(a-b)) < float64(time.Millisecond)) } + +func TestExtensionsResetDurationProfilerForExtensionsResetWithNoExtensions(t *testing.T) { + mono := Monotime() + profiler := ExtensionsResetDurationProfiler{} + + profiler.extensionsResetStartTimeNs = mono + profiler.extensionsResetEndTimeNs = mono + time.Second.Nanoseconds() + profiler.AvailableNs = 3 * time.Second.Nanoseconds() + profiler.NumAgentsRegisteredForShutdown = 0 + extensionsResetMs, resetTimeout := profiler.CalculateExtensionsResetMs() + + assert.Equal(t, int64(0), extensionsResetMs) + assert.Equal(t, false, resetTimeout) +} + +func TestExtensionsResetDurationProfilerForExtensionsResetWithinDeadline(t *testing.T) { + mono := Monotime() + profiler := ExtensionsResetDurationProfiler{} + + profiler.extensionsResetStartTimeNs = mono + profiler.extensionsResetEndTimeNs = mono + time.Second.Nanoseconds() + profiler.AvailableNs = 3 * time.Second.Nanoseconds() + profiler.NumAgentsRegisteredForShutdown = 1 + extensionsResetMs, resetTimeout := profiler.CalculateExtensionsResetMs() + + assert.Equal(t, time.Second.Milliseconds(), extensionsResetMs) + assert.Equal(t, false, resetTimeout) +} + +func TestExtensionsResetDurationProfilerForExtensionsResetTimeout(t *testing.T) { + mono := Monotime() + profiler := ExtensionsResetDurationProfiler{} + + profiler.extensionsResetStartTimeNs = mono + profiler.extensionsResetEndTimeNs = mono + 3*time.Second.Nanoseconds() + profiler.AvailableNs = time.Second.Nanoseconds() + profiler.NumAgentsRegisteredForShutdown = 1 + extensionsResetMs, resetTimeout := profiler.CalculateExtensionsResetMs() + + assert.Equal(t, time.Second.Milliseconds(), extensionsResetMs) + assert.Equal(t, true, resetTimeout) +} + +func TestExtensionsResetDurationProfilerEndToEnd(t *testing.T) { + profiler := ExtensionsResetDurationProfiler{} + + profiler.Start() + time.Sleep(time.Second) + profiler.Stop() + + profiler.AvailableNs = 2 * time.Second.Nanoseconds() + profiler.NumAgentsRegisteredForShutdown = 1 + extensionsResetMs, _ := profiler.CalculateExtensionsResetMs() + + assert.GreaterOrEqual(t, 2*time.Second.Milliseconds(), extensionsResetMs) + assert.LessOrEqual(t, time.Second.Milliseconds(), extensionsResetMs) +} diff --git a/lambda/rapi/handler/agentnext_test.go b/lambda/rapi/handler/agentnext_test.go index 8ea489a..ef14e49 100644 --- a/lambda/rapi/handler/agentnext_test.go +++ b/lambda/rapi/handler/agentnext_test.go @@ -10,6 +10,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -103,7 +104,7 @@ func TestRenderAgentInvokeNextHappy(t *testing.T) { ClientContext: "ClientContext", DeadlineNs: fmt.Sprintf("%d", deadlineNs), ContentType: "image/png", - Payload: []byte("Payload"), + Payload: strings.NewReader("Payload"), } renderingService := rendering.NewRenderingService() @@ -152,7 +153,7 @@ func TestRenderAgentInternalInvokeNextHappy(t *testing.T) { ClientContext: "ClientContext", DeadlineNs: fmt.Sprintf("%d", deadlineNs), ContentType: "image/png", - Payload: []byte("Payload"), + Payload: strings.NewReader("Payload"), } renderingService := rendering.NewRenderingService() @@ -261,3 +262,43 @@ func TestRenderAgentExternalShutdownEvent(t *testing.T) { assert.Equal(t, int64(deadlineMs), response.AgentEvent.DeadlineMs) assert.Equal(t, shutdownReason, response.ShutdownReason) } + +func TestRenderAgentInvokeNextHappyEmptyTraceID(t *testing.T) { + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization(), + core.NewInvokeFlowSynchronization(), + ) + agent, err := registrationService.CreateExternalAgent("dummyName") + assert.NoError(t, err) + + agent.SetState(agent.RegisteredState) + agent.Release() // sets operator condition to true so that the thread doesn't suspend waiting for invoke request + + deadlineNs := metering.Monotime() + int64(100*time.Millisecond) + requestID, functionArn := "ID", "InvokedFunctionArn" + traceID := "" + invoke := &interop.Invoke{ + TraceID: traceID, + ID: requestID, + InvokedFunctionArn: functionArn, + DeadlineNs: fmt.Sprintf("%d", deadlineNs), + ContentType: "image/png", + Payload: strings.NewReader("Payload"), + } + + renderingService := rendering.NewRenderingService() + renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, telemetry.GetCustomerTracingHeader)) + + handler := NewAgentNextHandler(registrationService, renderingService) + request := httptest.NewRequest("GET", "/", nil) + request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + assert.Equal(t, http.StatusOK, responseRecorder.Code) + var response model.AgentInvokeEvent + respBody, _ := ioutil.ReadAll(responseRecorder.Body) + json.Unmarshal(respBody, &response) + + assert.Nil(t, response.Tracing) +} diff --git a/lambda/rapi/handler/invocationerror_test.go b/lambda/rapi/handler/invocationerror_test.go index e44b030..b1e96c8 100644 --- a/lambda/rapi/handler/invocationerror_test.go +++ b/lambda/rapi/handler/invocationerror_test.go @@ -52,7 +52,7 @@ func runTestInvocationErrorHandler(t *testing.T) { DeadlineNs: "deadlinens1", ClientContext: "clientcontext1", ContentType: "image/png", - Payload: []byte("Payload1"), + Payload: strings.NewReader("Payload1"), } flowTest.ConfigureForInvoke(context.Background(), invoke) diff --git a/lambda/rapi/handler/invocationnext_test.go b/lambda/rapi/handler/invocationnext_test.go index 4c06281..beebd97 100644 --- a/lambda/rapi/handler/invocationnext_test.go +++ b/lambda/rapi/handler/invocationnext_test.go @@ -14,6 +14,7 @@ import ( "os" "runtime" "strconv" + "strings" "testing" "time" @@ -52,6 +53,7 @@ func TestRenderInvoke(t *testing.T) { appCtx := flowTest.AppCtx deadlineNs := 12345 + invokePayload := "Payload" invoke := &interop.Invoke{ TraceID: "Root=RootID;Parent=LambdaFrontend;Sampled=1", ID: "ID", @@ -61,7 +63,7 @@ func TestRenderInvoke(t *testing.T) { ClientContext: "ClientContext", DeadlineNs: strconv.Itoa(deadlineNs), ContentType: "image/png", - Payload: []byte("Payload"), + Payload: strings.NewReader(invokePayload), } ctx := telemetry.NewTraceContext(context.Background(), "RootID", "InvocationSubegmentID") @@ -87,7 +89,7 @@ func TestRenderInvoke(t *testing.T) { assert.Equal(t, "image/png", headers.Get("Content-Type")) assert.Len(t, headers, 7) - assert.Equal(t, invoke.Payload, responseRecorder.Body.Bytes()) + assert.Equal(t, invokePayload, responseRecorder.Body.String()) } //Cgo calls removed due to crashes while spawning threads under memory pressure. @@ -115,7 +117,7 @@ func BenchmarkRenderInvoke(b *testing.B) { ClientContext: "ClientContext", DeadlineNs: strconv.Itoa(deadlineNs), ContentType: "image/png", - Payload: []byte("Payload"), + Payload: strings.NewReader("Payload"), } ctx := telemetry.NewTraceContext(context.Background(), "RootID", "InvocationSubegmentID") diff --git a/lambda/rapi/handler/invocationresponse.go b/lambda/rapi/handler/invocationresponse.go index 2efb7ad..50c575c 100644 --- a/lambda/rapi/handler/invocationresponse.go +++ b/lambda/rapi/handler/invocationresponse.go @@ -4,40 +4,21 @@ package handler import ( - "fmt" - "io" - "io/ioutil" "net/http" "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/core" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/rendering" "github.com/go-chi/chi" log "github.com/sirupsen/logrus" ) -const ( - functionResponseSizeTooLargeType = "Function.ResponseSizeTooLarge" -) - type invocationResponseHandler struct { registrationService core.RegistrationService } -func readBody(request *http.Request) ([]byte, error) { - size := request.ContentLength - if size < 1 { - return ioutil.ReadAll(request.Body) - } - buffer := make([]byte, size) - _, err := io.ReadFull(request.Body, buffer) - return buffer, err -} - func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { appCtx := appctx.FromRequest(request) @@ -54,43 +35,28 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques return } - data, err := readBody(request) - if err != nil { - log.Error(err) - rendering.RenderInternalServerError(writer, request) - return - } + invokeID := chi.URLParam(request, "awsrequestid") - if len(data) > interop.MaxPayloadSize { - log.Warn("Request entity too large") + if err := server.SendResponse(invokeID, request.Body); err != nil { + switch err := err.(type) { + case *interop.ErrorResponseTooLarge: + if server.SendErrorResponse(invokeID, err.AsInteropError()) != nil { + rendering.RenderInteropError(writer, request, err) + return + } - resp := model.ErrorResponse{ - ErrorType: functionResponseSizeTooLargeType, - ErrorMessage: fmt.Sprintf("Response payload size (%d bytes) exceeded maximum allowed payload size (%d bytes).", len(data), interop.MaxPayloadSize), - } + appctx.StoreErrorResponse(appCtx, err.AsInteropError()) + + if err := runtime.ResponseSent(); err != nil { + log.Panic(err) + } - if server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), resp.AsInteropError()) != nil { + rendering.RenderRequestEntityTooLarge(writer, request) + return + default: rendering.RenderInteropError(writer, request, err) return } - - appctx.StoreErrorResponse(appCtx, resp.AsInteropError()) - - if err := runtime.ResponseSent(); err != nil { - log.Panic(err) - } - - rendering.RenderRequestEntityTooLarge(writer, request) - return - } - - response := &interop.Response{ - Payload: data, - } - - if err := server.SendResponse(chi.URLParam(request, "awsrequestid"), response); err != nil { - rendering.RenderInteropError(writer, request, err) - return } if err := runtime.ResponseSent(); err != nil { diff --git a/lambda/rapi/handler/invocationresponse_test.go b/lambda/rapi/handler/invocationresponse_test.go index 032abf5..e3ede59 100644 --- a/lambda/rapi/handler/invocationresponse_test.go +++ b/lambda/rapi/handler/invocationresponse_test.go @@ -8,8 +8,10 @@ import ( "context" "encoding/json" "fmt" + "io/ioutil" "net/http" "net/http/httptest" + "strings" "testing" "github.com/aws/aws-lambda-go/events/test" @@ -36,7 +38,7 @@ func TestResponseTooLarge(t *testing.T) { DeadlineNs: "deadlinens1", ClientContext: "clientcontext1", ContentType: "application/json", - Payload: []byte(`{"message": "hello"}`), + Payload: strings.NewReader(`{"message": "hello"}`), } flowTest.ConfigureForInvoke(context.Background(), invoke) @@ -53,7 +55,9 @@ func TestResponseTooLarge(t *testing.T) { responseRecorder.Code, http.StatusRequestEntityTooLarge) expectedAPIResponse := fmt.Sprintf("{\"errorMessage\":\"Exceeded maximum allowed payload size (%d bytes).\",\"errorType\":\"RequestEntityTooLarge\"}\n", interop.MaxPayloadSize) - test.AssertJsonsEqual(t, []byte(expectedAPIResponse), responseRecorder.Body.Bytes()) + body, err := ioutil.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + test.AssertJsonsEqual(t, []byte(expectedAPIResponse), body) errorResponse := flowTest.InteropServer.ErrorResponse assert.NotNil(t, errorResponse) @@ -62,7 +66,7 @@ func TestResponseTooLarge(t *testing.T) { assert.Equal(t, "Response payload size (6291557 bytes) exceeded maximum allowed payload size (6291556 bytes).", errorResponse.ErrorMessage) var errorPayload map[string]interface{} - json.Unmarshal(errorResponse.Payload, &errorPayload) + assert.NoError(t, json.Unmarshal(errorResponse.Payload, &errorPayload)) assert.Equal(t, errorResponse.ErrorType, errorPayload["errorType"]) assert.Equal(t, errorResponse.ErrorMessage, errorPayload["errorMessage"]) } @@ -84,7 +88,7 @@ func TestResponseAccepted(t *testing.T) { DeadlineNs: "deadlinens1", ClientContext: "clientcontext1", ContentType: "application/json", - Payload: []byte(`{"message": "hello"}`), + Payload: strings.NewReader(`{"message": "hello"}`), } flowTest.ConfigureForInvoke(context.Background(), invoke) @@ -101,11 +105,13 @@ func TestResponseAccepted(t *testing.T) { responseRecorder.Code, http.StatusAccepted) expectedAPIResponse := "{\"status\":\"OK\"}\n" - test.AssertJsonsEqual(t, []byte(expectedAPIResponse), responseRecorder.Body.Bytes()) + body, err := ioutil.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + test.AssertJsonsEqual(t, []byte(expectedAPIResponse), body) response := flowTest.InteropServer.Response assert.NotNil(t, response) assert.Nil(t, flowTest.InteropServer.ErrorResponse) - assert.Equal(t, responseBody, response.Payload, + assert.Equal(t, responseBody, response, "Persisted response data in app context must match the submitted.") } diff --git a/lambda/rapi/model/agentevent.go b/lambda/rapi/model/agentevent.go index d888af8..5c0cc73 100644 --- a/lambda/rapi/model/agentevent.go +++ b/lambda/rapi/model/agentevent.go @@ -12,9 +12,9 @@ type AgentEvent struct { // AgentInvokeEvent is the response to agent's get next request type AgentInvokeEvent struct { *AgentEvent - RequestID string `json:"requestId"` - InvokedFunctionArn string `json:"invokedFunctionArn"` - Tracing Tracing `json:"tracing"` + RequestID string `json:"requestId"` + InvokedFunctionArn string `json:"invokedFunctionArn"` + Tracing *Tracing `json:"tracing,omitempty"` } // AgentShutdownEvent is the response to agent's get next request diff --git a/lambda/rapi/model/tracing.go b/lambda/rapi/model/tracing.go index a0b3613..af90e8f 100644 --- a/lambda/rapi/model/tracing.go +++ b/lambda/rapi/model/tracing.go @@ -20,8 +20,12 @@ type XRayTracing struct { } // NewXRayTracing returns a new XRayTracing object with specified value -func NewXRayTracing(value string) Tracing { - return Tracing{ +func NewXRayTracing(value string) *Tracing { + if len(value) == 0 { + return nil + } + + return &Tracing{ XRayTracingType, XRayTracing{value}, } diff --git a/lambda/rapi/rendering/rendering.go b/lambda/rapi/rendering/rendering.go index 8cd974c..c75d010 100644 --- a/lambda/rapi/rendering/rendering.go +++ b/lambda/rapi/rendering/rendering.go @@ -8,6 +8,8 @@ import ( "encoding/json" "errors" "fmt" + "io" + "io/ioutil" "net/http" "strconv" "sync" @@ -82,11 +84,20 @@ func NewRenderingService() *EventRenderingService { } } +// InvokeRendererMetrics contains metrics of invoke request +type InvokeRendererMetrics struct { + ReadTime time.Duration + SizeBytes int +} + // InvokeRenderer knows how to render invoke event. type InvokeRenderer struct { ctx context.Context invoke *interop.Invoke tracingHeaderParser func(context.Context, *interop.Invoke) string + requestBuffer []byte + requestMutex sync.Mutex + metrics InvokeRendererMetrics } // NewAgentInvokeEvent forms a new AgentInvokeEvent from INVOKE request @@ -127,6 +138,22 @@ func (s *InvokeRenderer) RenderAgentEvent(writer http.ResponseWriter, request *h return nil } +func (s *InvokeRenderer) bufferInvokeRequest() error { + s.requestMutex.Lock() + defer s.requestMutex.Unlock() + var err error = nil + if nil == s.requestBuffer { + reader := io.LimitReader(s.invoke.Payload, interop.MaxPayloadSize) + start := time.Now() + s.requestBuffer, err = ioutil.ReadAll(reader) + s.metrics = InvokeRendererMetrics{ + ReadTime: time.Since(start), + SizeBytes: len(s.requestBuffer), + } + } + return err +} + // RenderRuntimeEvent renders invoke payload for runtime. func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { invoke := s.invoke @@ -155,20 +182,34 @@ func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request renderInvokeHeaders(writer, invoke.ID, customerTraceID, invoke.ClientContext, cognitoIdentityJSON, invoke.InvokedFunctionArn, deadlineHeader, invoke.ContentType) - _, err := writer.Write(invoke.Payload) + if invoke.Payload != nil { + if err := s.bufferInvokeRequest(); err != nil { + return err + } + _, err := writer.Write(s.requestBuffer) + return err + } - return err + return nil } // NewInvokeRenderer returns new invoke event renderer -func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, traceParser func(context.Context, *interop.Invoke) string) RendererState { +func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, traceParser func(context.Context, *interop.Invoke) string) *InvokeRenderer { return &InvokeRenderer{ invoke: invoke, ctx: ctx, tracingHeaderParser: traceParser, + requestBuffer: nil, + requestMutex: sync.Mutex{}, } } +func (s *InvokeRenderer) GetMetrics() InvokeRendererMetrics { + s.requestMutex.Lock() + defer s.requestMutex.Unlock() + return s.metrics +} + // ShutdownRenderer renderer for shutdown event. type ShutdownRenderer struct { AgentEvent model.AgentShutdownEvent diff --git a/lambda/rapi/router_test.go b/lambda/rapi/router_test.go index 1876563..f1cbde8 100644 --- a/lambda/rapi/router_test.go +++ b/lambda/rapi/router_test.go @@ -10,6 +10,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "github.com/google/uuid" @@ -25,7 +26,7 @@ func createInvoke(id string) *interop.Invoke { return &interop.Invoke{ ID: id, InvokedFunctionArn: "arn::dummy:Function", - Payload: []byte("{\"invoke\":\"" + id + "\"}"), + Payload: strings.NewReader("{\"invoke\":\"" + id + "\"}"), DeadlineNs: "123456", } } diff --git a/lambda/rapid/bootstrap.go b/lambda/rapid/bootstrap.go index 5629fe1..e82ec6c 100644 --- a/lambda/rapid/bootstrap.go +++ b/lambda/rapid/bootstrap.go @@ -12,7 +12,7 @@ import ( type Bootstrap interface { Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable Env(e EnvironmentVariables) []string // returns the environment variables to be passed to the bootstrapped process - Cwd() string // returns the working directory of the bootstrap process + Cwd() (string, error) // returns the working directory of the bootstrap process ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime CachedFatalError(err error) (fatalerror.ErrorType, string, bool) } diff --git a/lambda/rapid/exit.go b/lambda/rapid/exit.go index e8a2c79..af5cb72 100644 --- a/lambda/rapid/exit.go +++ b/lambda/rapid/exit.go @@ -7,7 +7,6 @@ import ( "fmt" "os" - "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" @@ -39,13 +38,9 @@ func trySendDefaultErrorResponse(interopServer interop.Server, invokeID string, } } -func reportErrorAndExit(appCtx appctx.ApplicationContext, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error, correlationID string) { +func reportErrorAndExit(doneFailMsg *interop.DoneFail, invokeID string, interopServer interop.Server, err error) { // This function maintains compatibility of exit sequence behaviour // with Sandbox Factory in the absence of extensions - errorType, found := appctx.LoadFirstFatalError(appCtx) - if !found { - errorType = fatalerror.Unknown - } // NOTE this check will prevent us from sending FAULT message in case // response (positive or negative) has already been sent. This is done @@ -53,20 +48,18 @@ func reportErrorAndExit(appCtx appctx.ApplicationContext, execCtx *rapidContext, // ALSO NOTE, this works in case of positive response because this will // be followed by RAPID exit. if !interopServer.IsResponseSent() { - trySendDefaultErrorResponse(interopServer, invokeID, errorType, err) + trySendDefaultErrorResponse(interopServer, invokeID, doneFailMsg.ErrorType, err) } if err := interopServer.CommitResponse(); err != nil { checkInteropError("Failed to commit error response: %s", err) } + // old behavior: no DoneFails doneMsg := &interop.Done{ - WaitForExit: true, - RuntimeRelease: appctx.GetRuntimeRelease(appCtx), - CorrelationID: correlationID, // required for standalone mode - } - if execCtx.telemetryAPIEnabled { - doneMsg.LogsAPIMetrics = execCtx.telemetryService.FlushMetrics() + WaitForExit: true, + CorrelationID: doneFailMsg.CorrelationID, // required for standalone mode + Meta: doneFailMsg.Meta, } if err := interopServer.SendDone(doneMsg); err != nil { @@ -76,67 +69,48 @@ func reportErrorAndExit(appCtx appctx.ApplicationContext, execCtx *rapidContext, os.Exit(1) } -func reportErrorAndRequestReset(appCtx appctx.ApplicationContext, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error, correlationID string) { - errorType, found := appctx.LoadFirstFatalError(appCtx) - if !found { - errorType = fatalerror.Unknown +func reportErrorAndRequestReset(doneFailMsg *interop.DoneFail, invokeID string, interopServer interop.Server, err error) { + if err == errResetReceived { + // errResetReceived is returned when execution flow was interrupted by the Reset message, + // hence this error deserves special handling and we yield to main receive loop to handle it + return } - trySendDefaultErrorResponse(interopServer, invokeID, errorType, err) + trySendDefaultErrorResponse(interopServer, invokeID, doneFailMsg.ErrorType, err) if err := interopServer.CommitResponse(); err != nil { checkInteropError("Failed to commit error response: %s", err) } - doneFailMsg := &interop.DoneFail{ - ErrorType: string(errorType), - RuntimeRelease: appctx.GetRuntimeRelease(appCtx), - CorrelationID: correlationID, // required for standalone mode - NumActiveExtensions: execCtx.registrationService.CountAgents(), - } - if execCtx.telemetryAPIEnabled { - doneFailMsg.LogsAPIMetrics = execCtx.telemetryService.FlushMetrics() - } - if err := interopServer.SendDoneFail(doneFailMsg); err != nil { checkInteropError("Failed to send DONEFAIL: %s", err) } } -func handleError(appCtx appctx.ApplicationContext, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error, correlationID string) { - if err == errResetReceived { - // errResetReceived is returned when execution flow was interrupted by the Reset message, - // hence this error deserves special handling and we yield to main receive loop to handle it - return - } - - reportErrorAndRequestReset(appCtx, execCtx, invokeID, interopServer, err, correlationID) -} - -func handleInitError(appCtx appctx.ApplicationContext, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error, correlationID string) { +func handleInitError(doneFailMsg *interop.DoneFail, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error) { if execCtx.standaloneMode { - handleError(appCtx, execCtx, invokeID, interopServer, err, correlationID) + reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) return } if !execCtx.HasActiveExtensions() { // we don't expect Slicer to send RESET during INIT, that's why we Exit here - reportErrorAndExit(appCtx, execCtx, invokeID, interopServer, err, correlationID) + reportErrorAndExit(doneFailMsg, invokeID, interopServer, err) } - handleError(appCtx, execCtx, invokeID, interopServer, err, correlationID) + reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) } -func handleInvokeError(appCtx appctx.ApplicationContext, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error, correlationID string) { +func handleInvokeError(doneFailMsg *interop.DoneFail, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error) { if execCtx.standaloneMode { - handleError(appCtx, execCtx, invokeID, interopServer, err, correlationID) + reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) return } // Invoke with extensions disabled maintains behaviour parity with pre-extensions rapid if !extensions.AreEnabled() { - reportErrorAndExit(appCtx, execCtx, invokeID, interopServer, err, correlationID) + reportErrorAndExit(doneFailMsg, invokeID, interopServer, err) } - handleError(appCtx, execCtx, invokeID, interopServer, err, correlationID) + reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) } diff --git a/lambda/rapid/graceful_shutdown.go b/lambda/rapid/graceful_shutdown.go index dab9978..5ad1326 100644 --- a/lambda/rapid/graceful_shutdown.go +++ b/lambda/rapid/graceful_shutdown.go @@ -43,7 +43,7 @@ func awaitSigkilledProcessesToExit(exitPidChan chan int, processesExited, sigkil } } -func gracefulShutdown(execCtx *rapidContext, watchdog *core.Watchdog, deadlineNs int64, killAgents bool, reason string) { +func gracefulShutdown(execCtx *rapidContext, watchdog *core.Watchdog, profiler *metering.ExtensionsResetDurationProfiler, deadlineNs int64, killAgents bool, reason string) { watchdog.Mute() defer watchdog.Unmute() @@ -68,18 +68,24 @@ func gracefulShutdown(execCtx *rapidContext, watchdog *core.Watchdog, deadlineNs availableNs = 0 } + profiler.AvailableNs = availableNs + start := time.Now() + profiler.Start() runtimeDeadline := start.Add(time.Duration(float64(availableNs) * runtimeDeadlineShare)) agentsDeadline := start.Add(time.Duration(availableNs)) sigkilledPids := make(map[int]bool) // Track process ids that were sent sigkill processesExited := make(map[int]bool) // Don't send sigkill to processes that exit out of order + processesExited, sigkilledPids = shutdownRuntime(execCtx, start, runtimeDeadline, processesExited, sigkilledPids) - processesExited, sigkilledPids = shutdownAgents(execCtx, start, agentsDeadline, killAgents, reason, processesExited, sigkilledPids) + processesExited, sigkilledPids = shutdownAgents(execCtx, start, profiler, agentsDeadline, killAgents, reason, processesExited, sigkilledPids) if execCtx.standaloneMode { awaitSigkilledProcessesToExit(execCtx.exitPidChan, processesExited, sigkilledPids) } + + profiler.Stop() } func shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time, processesExited, sigkilledPids map[int]bool) (map[int]bool, map[int]bool) { @@ -110,14 +116,14 @@ func shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time, log.Warnf("Process %d exited unexpectedly", pid) case <-runtimeTimer.C: - log.Warnf("Runtime didn't exit after SIGTERM in %v; dispatching SIGKILL to runtime process group", runtimeTimeout) + log.Warnf("Timeout: no SIGCHLD from Runtime after %d ms; dispatching SIGKILL to runtime process group", int64(runtimeTimeout/time.Millisecond)) sigkilledPids = sigkillProcessGroup(runtime.Pid, sigkilledPids) return processesExited, sigkilledPids } } } -func shutdownAgents(execCtx *rapidContext, start time.Time, deadline time.Time, killAgents bool, reason string, processesExited, sigkilledPids map[int]bool) (map[int]bool, map[int]bool) { +func shutdownAgents(execCtx *rapidContext, start time.Time, profiler *metering.ExtensionsResetDurationProfiler, deadline time.Time, killAgents bool, reason string, processesExited, sigkilledPids map[int]bool) (map[int]bool, map[int]bool) { // For each external agent, if agent is launched: // 1. Send Shutdown event if subscribed for it, else send SIGKILL to process group // 2. Wait for all Shutdown-subscribed agents to exit with timeout @@ -150,6 +156,7 @@ func shutdownAgents(execCtx *rapidContext, start time.Time, deadline time.Time, } } } + profiler.NumAgentsRegisteredForShutdown = len(pidsToShutdown) var timerChan <-chan time.Time // default timerChan if killAgents { diff --git a/lambda/rapid/start.go b/lambda/rapid/start.go index 8053fb1..711f122 100644 --- a/lambda/rapid/start.go +++ b/lambda/rapid/start.go @@ -9,6 +9,7 @@ import ( "errors" "io" "os" + "time" "go.amzn.com/lambda/agents" "go.amzn.com/lambda/appctx" @@ -67,7 +68,7 @@ type rapidContext struct { } func (c *rapidContext) HasActiveExtensions() bool { - return extensions.AreEnabled() + return extensions.AreEnabled() && c.registrationService.CountAgents() > 0 } func logAgentsInitStatus(execCtx *rapidContext) { @@ -163,7 +164,17 @@ func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) } bootstrapEnv := bootstrap.Env(execCtx.environment) - bootstrapCwd := bootstrap.Cwd() + bootstrapCwd, err := bootstrap.Cwd() + if err != nil { + if fatalError, formattedLog, hasError := bootstrap.CachedFatalError(err); hasError { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) + execCtx.platformLogger.Printf("%s", formattedLog) + } else { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidWorkingDir) + } + return err + } + bootstrapExtraFiles := bootstrap.ExtraFiles() runtimeCmd := runtimecmd.NewCustomRuntimeCmd(ctx, bootstrapCmd, bootstrapCwd, bootstrapEnv, execCtx.runtimeLogWriter, bootstrapExtraFiles) @@ -207,7 +218,7 @@ func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) return nil } -func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke) error { +func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke, mx *rendering.InvokeRendererMetrics) error { appCtx := execCtx.appCtx appctx.StoreErrorResponse(appCtx, nil) @@ -255,7 +266,12 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo // Invoke if err := xray.CaptureInvokeSubsegment(ctx, xray.WithError(ctx, appCtx, func(ctx context.Context) error { log.Debug("Set renderer for invoke") - execCtx.renderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invokeRequest, xray.TracingHeaderParser())) + renderer := rendering.NewInvokeRenderer(ctx, invokeRequest, xray.TracingHeaderParser()) + defer func() { + *mx = renderer.GetMetrics() + }() + + execCtx.renderingService.SetRenderer(renderer) if extensions.AreEnabled() { log.Debug("Release agents conditions") for _, agent := range extAgents { @@ -265,6 +281,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo agent.Release() } } + log.Debug("Release runtime condition") runtime.Release() log.Debug("Await runtime response") @@ -342,30 +359,68 @@ func handleStart(ctx context.Context, execCtx *rapidContext, watchdog *core.Watc if !startRequest.SuppressInit { if err := doInit(ctx, execCtx, watchdog); err != nil { log.WithError(err).WithField("InvokeID", startRequest.InvokeID).Error("Init failed") - handleInitError(appCtx, execCtx, startRequest.InvokeID, interopServer, err, startRequest.CorrelationID) + doneFailMsg := generateDoneFail(execCtx, startRequest.CorrelationID, nil, 0) + handleInitError(doneFailMsg, execCtx, startRequest.InvokeID, interopServer, err) return } } doneMsg := &interop.Done{ - RuntimeRelease: appctx.GetRuntimeRelease(appCtx), - CorrelationID: startRequest.CorrelationID, - NumActiveExtensions: execCtx.registrationService.CountAgents(), + CorrelationID: startRequest.CorrelationID, + Meta: interop.DoneMetadata{ + RuntimeRelease: appctx.GetRuntimeRelease(appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + }, } if execCtx.telemetryAPIEnabled { - doneMsg.LogsAPIMetrics = execCtx.telemetryService.FlushMetrics() + doneMsg.Meta.LogsAPIMetrics = execCtx.telemetryService.FlushMetrics() } if err := interopServer.SendDone(doneMsg); err != nil { log.Panic(err) } + + if err := interopServer.StartAcceptingDirectInvokes(); err != nil { + log.Panic(err) + } +} + +func generateDoneFail(execCtx *rapidContext, correlationID string, invokeMx *rendering.InvokeRendererMetrics, invokeReceivedTime int64) *interop.DoneFail { + errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + errorType = fatalerror.Unknown + } + + doneFailMsg := &interop.DoneFail{ + ErrorType: errorType, + CorrelationID: correlationID, // required for standalone mode + Meta: interop.DoneMetadata{ + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + InvokeReceivedTime: invokeReceivedTime, + }, + } + + if invokeMx != nil { + doneFailMsg.Meta.InvokeRequestReadTimeNs = invokeMx.ReadTime.Nanoseconds() + doneFailMsg.Meta.InvokeRequestSizeBytes = int64(invokeMx.SizeBytes) + } + + if execCtx.telemetryAPIEnabled { + doneFailMsg.Meta.LogsAPIMetrics = execCtx.telemetryService.FlushMetrics() + } + + return doneFailMsg } func handleInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke) { interopServer, appCtx := execCtx.interopServer, execCtx.appCtx - if err := doInvoke(ctx, execCtx, watchdog, invokeRequest); err != nil { + invokeMx := rendering.InvokeRendererMetrics{} + + if err := doInvoke(ctx, execCtx, watchdog, invokeRequest, &invokeMx); err != nil { log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Invoke failed") - handleInvokeError(appCtx, execCtx, invokeRequest.ID, interopServer, err, invokeRequest.CorrelationID) + doneFailMsg := generateDoneFail(execCtx, invokeRequest.CorrelationID, &invokeMx, invokeRequest.InvokeReceivedTime) + handleInvokeError(doneFailMsg, execCtx, invokeRequest.ID, interopServer, err) return } @@ -373,13 +428,24 @@ func handleInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Wat log.Panic(err) } + var invokeCompletionTimeNs int64 + if responseTimeNs := execCtx.registrationService.GetRuntime().GetRuntimeDescription().State.ResponseTimeNs; responseTimeNs != 0 { + invokeCompletionTimeNs = time.Now().UnixNano() - responseTimeNs + } + doneMsg := &interop.Done{ - RuntimeRelease: appctx.GetRuntimeRelease(appCtx), - CorrelationID: invokeRequest.CorrelationID, - NumActiveExtensions: execCtx.registrationService.CountAgents(), + CorrelationID: invokeRequest.CorrelationID, + Meta: interop.DoneMetadata{ + RuntimeRelease: appctx.GetRuntimeRelease(appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + InvokeRequestReadTimeNs: invokeMx.ReadTime.Nanoseconds(), + InvokeRequestSizeBytes: int64(invokeMx.SizeBytes), + InvokeCompletionTimeNs: invokeCompletionTimeNs, + InvokeReceivedTime: invokeRequest.InvokeReceivedTime, + }, } if execCtx.telemetryAPIEnabled { - doneMsg.LogsAPIMetrics = execCtx.telemetryService.FlushMetrics() + doneMsg.Meta.LogsAPIMetrics = execCtx.telemetryService.FlushMetrics() } if err := interopServer.SendDone(doneMsg); err != nil { @@ -411,13 +477,30 @@ func blockForever() { func handleReset(execCtx *rapidContext, watchdog *core.Watchdog, reset *interop.Reset) { log.Warnf("Reset initiated: %s", reset.Reason) - gracefulShutdown(execCtx, watchdog, reset.DeadlineNs, execCtx.standaloneMode, reset.Reason) + profiler := metering.ExtensionsResetDurationProfiler{} + gracefulShutdown(execCtx, watchdog, &profiler, reset.DeadlineNs, execCtx.standaloneMode, reset.Reason) + + extensionsResetMs, resetTimeout := profiler.CalculateExtensionsResetMs() + + meta := interop.DoneMetadata{ + ExtensionsResetMs: extensionsResetMs, + } if !execCtx.standaloneMode { - // GIRP interopServer implementation sends GIRP RSTDONE - if err := execCtx.interopServer.SendDone(&interop.Done{CorrelationID: reset.CorrelationID}); err != nil { - log.Panicf("Failed to SendDone: %s", err) + // GIRP interopServer implementation sends GIRP RSTFAIL/RSTDONE + if resetTimeout { + // TODO: DoneFail must contain a reset timeout ErrorType for rapid local to distinguish errors + doneFail := &interop.DoneFail{CorrelationID: reset.CorrelationID, Meta: meta} + if err := execCtx.interopServer.SendDoneFail(doneFail); err != nil { + log.Panicf("Failed to SendDoneFail: %s", err) + } + } else { + done := &interop.Done{CorrelationID: reset.CorrelationID, Meta: meta} + if err := execCtx.interopServer.SendDone(done); err != nil { + log.Panicf("Failed to SendDone: %s", err) + } } + os.Exit(0) } @@ -425,8 +508,16 @@ func handleReset(execCtx *rapidContext, watchdog *core.Watchdog, reset *interop. fatalErrorType, _ := appctx.LoadFirstFatalError(execCtx.appCtx) - if err := execCtx.interopServer.SendDone(&interop.Done{CorrelationID: reset.CorrelationID, ErrorType: string(fatalErrorType)}); err != nil { - log.Panicf("Failed to SendDone: %s", err) + if resetTimeout { + doneFail := &interop.DoneFail{CorrelationID: reset.CorrelationID, ErrorType: fatalErrorType, Meta: meta} + if err := execCtx.interopServer.SendDoneFail(doneFail); err != nil { + log.Panicf("Failed to SendDoneFail: %s", err) + } + } else { + done := &interop.Done{CorrelationID: reset.CorrelationID, ErrorType: fatalErrorType, Meta: meta} + if err := execCtx.interopServer.SendDone(done); err != nil { + log.Panicf("Failed to SendDone: %s", err) + } } } @@ -434,11 +525,11 @@ func handleReset(execCtx *rapidContext, watchdog *core.Watchdog, reset *interop. func handleShutdown(execCtx *rapidContext, watchdog *core.Watchdog, shutdown *interop.Shutdown, reason string) { log.Warnf("Shutdown initiated") - gracefulShutdown(execCtx, watchdog, shutdown.DeadlineNs, true, reason) + gracefulShutdown(execCtx, watchdog, &metering.ExtensionsResetDurationProfiler{}, shutdown.DeadlineNs, true, reason) fatalErrorType, _ := appctx.LoadFirstFatalError(execCtx.appCtx) - if err := execCtx.interopServer.SendDone(&interop.Done{CorrelationID: shutdown.CorrelationID, ErrorType: string(fatalErrorType)}); err != nil { + if err := execCtx.interopServer.SendDone(&interop.Done{CorrelationID: shutdown.CorrelationID, ErrorType: fatalErrorType}); err != nil { log.Panicf("Failed to SendDone: %s", err) } diff --git a/lambda/rapid/start_test.go b/lambda/rapid/start_test.go index c2bc8ef..3210b68 100644 --- a/lambda/rapid/start_test.go +++ b/lambda/rapid/start_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "net/http" + "strings" "testing" "time" @@ -79,7 +80,7 @@ func BenchmarkChannelsSelect2(b *testing.B) { func TestListen(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - flowTest.ConfigureForInvoke(context.Background(), &interop.Invoke{ID: "ID", DeadlineNs: "1", Payload: []byte("MyTest")}) + flowTest.ConfigureForInvoke(context.Background(), &interop.Invoke{ID: "ID", DeadlineNs: "1", Payload: strings.NewReader("MyTest")}) ctx := context.Background() telemetryAPIEnabled := true diff --git a/lambda/rapidcore/bootstrap.go b/lambda/rapidcore/bootstrap.go index b6edcc9..e39f7e9 100644 --- a/lambda/rapidcore/bootstrap.go +++ b/lambda/rapidcore/bootstrap.go @@ -6,6 +6,7 @@ package rapidcore import ( "fmt" "os" + "path/filepath" "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/logging" @@ -37,6 +38,12 @@ func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string) *Bootstrap orderedLookupBootstrapPaths = append(orderedLookupBootstrapPaths, args[0]) } } + + if currentWorkingDir == "" { + // use the root directory as the default working directory + currentWorkingDir = "/" + } + return &Bootstrap{ orderedLookupPaths: orderedLookupBootstrapPaths, workingDir: currentWorkingDir, @@ -44,6 +51,14 @@ func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string) *Bootstrap } } +func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string) *Bootstrap { + // a single candidate command makes it automatically valid + return &Bootstrap{ + validCmd: cmd, + workingDir: currentWorkingDir, + } +} + // locateBootstrap sets the first occurrence of an // actual bootstrap, given a list of possible files func (b *Bootstrap) locateBootstrap() error { @@ -60,6 +75,10 @@ func (b *Bootstrap) locateBootstrap() error { // Cmd returns the args of bootstrap, where args[0] // is the path to executable func (b *Bootstrap) Cmd() ([]string, error) { + if len(b.validCmd) > 0 { + return b.validCmd, nil + } + if err := b.locateBootstrap(); err != nil { return []string{}, err } @@ -75,8 +94,14 @@ func (b *Bootstrap) Env(e rapid.EnvironmentVariables) []string { } // Cwd returns the working directory of the bootstrap process -func (b *Bootstrap) Cwd() string { - return b.workingDir +func (b *Bootstrap) Cwd() (string, error) { + if !filepath.IsAbs(b.workingDir) { + return "", fmt.Errorf("the working directory '%s' is invalid, it needs to be an absolute path", b.workingDir) + } else if _, err := os.Stat(b.workingDir); os.IsNotExist(err) { + return "", fmt.Errorf("the working directory doesn't exist: %s", b.workingDir) + } + + return b.workingDir, nil } // SetExtraFiles sets the extra file descriptors apart from 1 & 2 to be passed to runtime @@ -107,16 +132,22 @@ func (b *Bootstrap) SetCachedFatalError(bootstrapErrFn BootstrapError) { b.bootstrapError = bootstrapErrFn } -// BootstrapErrInvalidOCITaskConfig represents an error while parsing OCI task config -func BootstrapErrInvalidOCITaskConfig(err error) BootstrapError { +// BootstrapErrInvalidLCISTaskConfig represents an error while parsing LCIS task config +func BootstrapErrInvalidLCISTaskConfig(err error) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { return fatalerror.InvalidTaskConfig, logging.SupernovaInvalidTaskConfigRepr(err) } } -// BootstrapErrInvalidOCIEntrypoint represents an invalid OCI entrypoint error -func BootstrapErrInvalidOCIEntrypoint(entrypoint []string, cmd []string, workingdir string) BootstrapError { +// BootstrapErrInvalidLCISEntrypoint represents an invalid LCIS entrypoint error +func BootstrapErrInvalidLCISEntrypoint(entrypoint []string, cmd []string, workingdir string) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { return fatalerror.InvalidEntrypoint, logging.SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) } } + +func BootstrapErrInvalidLCISWorkingDir(entrypoint []string, cmd []string, workingdir string) BootstrapError { + return func() (fatalerror.ErrorType, LogFormatter) { + return fatalerror.InvalidWorkingDir, logging.SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) + } +} diff --git a/lambda/rapidcore/bootstrap_test.go b/lambda/rapidcore/bootstrap_test.go index 4b40eea..a0e466e 100644 --- a/lambda/rapidcore/bootstrap_test.go +++ b/lambda/rapidcore/bootstrap_test.go @@ -39,7 +39,9 @@ func TestBootstrap(t *testing.T) { // Test b := NewBootstrap(cmdCandidates, cwd) - assert.Equal(t, cwd, b.Cwd()) + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, cwd, bCwd) assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) cmd, err := b.Cmd() @@ -56,6 +58,71 @@ func TestBootstrapEmptyCandidate(t *testing.T) { assert.Error(t, err) } +func TestBootstrapSingleCmd(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + assert.NoError(t, err) + defer os.RemoveAll(tmpDir) + + tmpFile, err := ioutil.TempFile("", "lcis-test-bootstrap") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // Setup single cmd candidate + file := []string{tmpFile.Name(), "--arg1 s", "foo"} + cmdCandidate := file + + // Setup working dir + cwd, err := os.Getwd() + assert.NoError(t, err) + + // Setup environment + environment := env.NewEnvironment() + environment.StoreRuntimeAPIEnvironmentVariable("host:port") + environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") + + // Test + b := NewBootstrapSingleCmd(cmdCandidate, cwd) + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, cwd, bCwd) + assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + + cmd, err := b.Cmd() + assert.NoError(t, err) + assert.Equal(t, file, cmd) +} + +func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + assert.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Setup inexistent single cmd candidate + file := []string{"/foo/bar", "--arg1 s", "foo"} + cmdCandidate := file + + // Setup working dir + cwd, err := os.Getwd() + assert.NoError(t, err) + + // Setup environment + environment := env.NewEnvironment() + environment.StoreRuntimeAPIEnvironmentVariable("host:port") + environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") + + // Test + b := NewBootstrapSingleCmd(cmdCandidate, cwd) + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, cwd, bCwd) + assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + + // No validations run against single candidates + cmd, err := b.Cmd() + assert.NoError(t, err) + assert.Equal(t, file, cmd) +} + // Test our ability to locate bootstrap files in the file system func TestFindCustomRuntimeIfExists(t *testing.T) { tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp-") @@ -97,3 +164,55 @@ func TestFindCustomRuntimeIfExists(t *testing.T) { assert.EqualError(t, err, "Couldn't find valid bootstrap(s): [mk mk2]") assert.Equal(t, []string{}, cmd) } + +func TestCwdIsAbsolute(t *testing.T) { + tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp-") + if err != nil { + t.Fatal("Cannot create temporary file", err) + } + defer os.Remove(tmpFile.Name()) + + cmdCandidates := [][]string{[]string{tmpFile.Name()}} + + // no errors when currentWorkingDir is absolute + bootstrap := NewBootstrap(cmdCandidates, "/tmp") + cwd, err := bootstrap.Cwd() + assert.Nil(t, err) + assert.Equal(t, "/tmp", cwd) + + bootstrap = NewBootstrap(cmdCandidates, "tmp") + _, err = bootstrap.Cwd() + assert.EqualError(t, err, "the working directory 'tmp' is invalid, it needs to be an absolute path") + + bootstrap = NewBootstrap(cmdCandidates, "./") + _, err = bootstrap.Cwd() + assert.EqualError(t, err, "the working directory './' is invalid, it needs to be an absolute path") +} + +func TestBootstrapMissingWorkingDirectory(t *testing.T) { + tmpFile, err := ioutil.TempFile(os.TempDir(), "cwd-test-bootstrap") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + tmpDir, err := ioutil.TempDir("", "cwd-test") + assert.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // cwd argument exists + bootstrap := NewBootstrap([][]string{[]string{tmpFile.Name()}}, tmpDir) + cwd, err := bootstrap.Cwd() + assert.Equal(t, cwd, tmpDir) + assert.NoError(t, err) + + // cwd argument doesn't exist + bootstrap = NewBootstrap([][]string{[]string{tmpFile.Name()}}, "/foo") + _, err = bootstrap.Cwd() + assert.EqualError(t, err, "the working directory doesn't exist: /foo") +} + +func TestDefaultWorkeringDirectory(t *testing.T) { + bootstrap := NewBootstrap([][]string{[]string{}}, "") + cwd, err := bootstrap.Cwd() + assert.NoError(t, err) + assert.Equal(t, "/", cwd) +} diff --git a/lambda/rapidcore/env/constants.go b/lambda/rapidcore/env/constants.go index b230352..50e9fcb 100644 --- a/lambda/rapidcore/env/constants.go +++ b/lambda/rapidcore/env/constants.go @@ -5,12 +5,13 @@ package env func predefinedInternalEnvVarKeys() map[string]bool { return map[string]bool{ - "_LAMBDA_SB_ID": true, - "_LAMBDA_LOG_FD": true, - "_LAMBDA_SHARED_MEM_FD": true, - "_LAMBDA_CONTROL_SOCKET": true, - "_LAMBDA_RUNTIME_LOAD_TIME": true, - "_LAMBDA_CONSOLE_SOCKET": true, + "_LAMBDA_SB_ID": true, + "_LAMBDA_LOG_FD": true, + "_LAMBDA_SHARED_MEM_FD": true, + "_LAMBDA_CONTROL_SOCKET": true, + "_LAMBDA_DIRECT_INVOKE_SOCKET": true, + "_LAMBDA_RUNTIME_LOAD_TIME": true, + "_LAMBDA_CONSOLE_SOCKET": true, // _X_AMZN_TRACE_ID is set by stock runtimes. Provided // runtimes should set and mutate it on each invoke. "_X_AMZN_TRACE_ID": true, diff --git a/lambda/rapidcore/env/customer.go b/lambda/rapidcore/env/customer.go index 17826ea..f784570 100644 --- a/lambda/rapidcore/env/customer.go +++ b/lambda/rapidcore/env/customer.go @@ -10,24 +10,32 @@ import ( log "github.com/sirupsen/logrus" ) -func logUnfilteredInternalEnvVars(envKey string) { - // We would like to filter out all internal environment variables, but we - // log this for now to get data to ensure customers aren't depending on it. - if strings.HasPrefix(envKey, "_") { - log.Warn("Internal environment variable not filtered") +func isInternalEnvVar(envKey string) bool { + // the rule is no '_' prefixed env. variables will be propagated to the runtime but the ones explicitly exempted + allowedKeys := map[string]bool{ + "_HANDLER": true, + "_AWS_XRAY_DAEMON_ADDRESS": true, + "_AWS_XRAY_DAEMON_PORT": true, + "_LAMBDA_TELEMETRY_LOG_FD": true, } + return strings.HasPrefix(envKey, "_") && !allowedKeys[envKey] } // CustomerEnvironmentVariables parses all environment variables that are // not internal/credential/platform, and must be called before agent bootstrap. func CustomerEnvironmentVariables() map[string]string { - isInternal := predefinedInternalEnvVarKeys() - isPlatform := predefinedPlatformEnvVarKeys() - isRuntime := predefinedRuntimeEnvVarKeys() - isCredential := predefinedCredentialsEnvVarKeys() - isPlatformUnreserved := predefinedPlatformUnreservedEnvVarKeys() + internalKeys := predefinedInternalEnvVarKeys() + platformKeys := predefinedPlatformEnvVarKeys() + runtimeKeys := predefinedRuntimeEnvVarKeys() + credentialKeys := predefinedCredentialsEnvVarKeys() + platformUnreservedKeys := predefinedPlatformUnreservedEnvVarKeys() isCustomer := func(key string) bool { - return !isInternal[key] && !isRuntime[key] && !isPlatform[key] && !isCredential[key] && !isPlatformUnreserved[key] + return !internalKeys[key] && + !runtimeKeys[key] && + !platformKeys[key] && + !credentialKeys[key] && + !platformUnreservedKeys[key] && + !isInternalEnvVar(key) } customerEnv := map[string]string{} @@ -39,7 +47,6 @@ func CustomerEnvironmentVariables() map[string]string { } if isCustomer(key) { - logUnfilteredInternalEnvVars(key) customerEnv[key] = val } } diff --git a/lambda/rapidcore/env/environment.go b/lambda/rapidcore/env/environment.go index 16c19af..5c80229 100644 --- a/lambda/rapidcore/env/environment.go +++ b/lambda/rapidcore/env/environment.go @@ -14,6 +14,7 @@ import ( const runtimeAPIAddressKey = "AWS_LAMBDA_RUNTIME_API" const handlerEnvKey = "_HANDLER" +const executionEnvKey = "AWS_EXECUTION_ENV" // Environment holds env vars for runtime, agents, and for // internal use, parsed during startup and from START msg @@ -37,6 +38,7 @@ type RapidConfig struct { ShmFd int CtrlFd int CnslFd int + DirectInvokeFd int LambdaTaskRoot string XrayDaemonAddress string PreLoadTimeNs int64 @@ -88,6 +90,16 @@ func (e *Environment) SetHandler(handler string) { e.Runtime[handlerEnvKey] = handler } +// GetExecutionEnv returns the current setting for AWS_EXECUTION_ENV +func (e *Environment) GetExecutionEnv() string { + return e.Runtime[executionEnvKey] +} + +// SetExecutionEnv sets AWS_EXECUTION_ENV variable value for Runtime +func (e *Environment) SetExecutionEnv(executionEnv string) { + e.Runtime[executionEnvKey] = executionEnv +} + // StoreEnvironmentVariablesFromInit sets the environment variables // for credentials & _HANDLER which are received in the START message func (e *Environment) StoreEnvironmentVariablesFromInit(customerEnv map[string]string, handler, awsKey, awsSecret, awsSession, funcName, funcVer string) { @@ -112,7 +124,7 @@ func (e *Environment) StoreEnvironmentVariablesFromInit(customerEnv map[string]s } // StoreEnvironmentVariablesFromCLIOptions sets the environment -// variables received via a CLI flag, for example OCI config +// variables received via a CLI flag, for example LCIS config func (e *Environment) StoreEnvironmentVariablesFromCLIOptions(envVars map[string]string) { e.mergeCustomerEnvironmentVariables(envVars) } @@ -154,6 +166,7 @@ func (e *Environment) RAPIDInternalConfig() RapidConfig { ShmFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_SHARED_MEM_FD"), CtrlFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_CONTROL_SOCKET"), CnslFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_CONSOLE_SOCKET"), + DirectInvokeFd: e.getOptionalSocketEnvVar(e.RAPID, "_LAMBDA_DIRECT_INVOKE_SOCKET"), PreLoadTimeNs: e.getInt64EnvVarOrDie(e.RAPID, "_LAMBDA_RUNTIME_LOAD_TIME"), LambdaTaskRoot: e.getStrEnvVarOrDie(e.Runtime, "LAMBDA_TASK_ROOT"), XrayDaemonAddress: e.getStrEnvVarOrDie(e.PlatformUnreserved, "AWS_XRAY_DAEMON_ADDRESS"), @@ -192,6 +205,26 @@ func (e *Environment) getSocketEnvVarOrDie(env map[string]string, name string) i return sock } +// returns -1 if env variable was not set. Exits if it holds unexpected (non-int) value +func (e *Environment) getOptionalSocketEnvVar(env map[string]string, name string) int { + val, found := env[name] + if !found { + return -1 + } + + sock, err := strconv.Atoi(val) + if err != nil { + log.WithError(err).WithField("name", name).Fatal("Unable to parse socket env var.") + } + + if sock < 0 { + log.WithError(err).WithField("name", name).Fatal("Negative socket descriptor value") + } + + syscall.CloseOnExec(sock) + return sock +} + func mapUnion(maps ...map[string]string) map[string]string { // last maps in argument overwrite values of ones before union := map[string]string{} diff --git a/lambda/rapidcore/env/environment_test.go b/lambda/rapidcore/env/environment_test.go index 09812a9..53a2f1e 100644 --- a/lambda/rapidcore/env/environment_test.go +++ b/lambda/rapidcore/env/environment_test.go @@ -23,6 +23,7 @@ func TestRAPIDInternalConfig(t *testing.T) { os.Setenv("AWS_XRAY_DAEMON_ADDRESS", "a") os.Setenv("AWS_LAMBDA_FUNCTION_NAME", "a") os.Setenv("_LAMBDA_TELEMETRY_API_PASSPHRASE", "a") + os.Setenv("_LAMBDA_DIRECT_INVOKE_SOCKET", "1") NewEnvironment().RAPIDInternalConfig() } @@ -40,6 +41,7 @@ func TestEnvironmentParsing(t *testing.T) { setAll(predefinedCredentialsEnvVarKeys(), credsEnvVal) os.Setenv("MY_FOOBAR_ENV_1", customerEnvVal) os.Setenv("MY_EMPTY_ENV", "") + os.Setenv("_UNKNOWN_INTERNAL_ENV", platformEnvVal) env := NewEnvironment() // parse environment variables customerEnv := CustomerEnvironmentVariables() @@ -72,8 +74,9 @@ func TestEnvironmentParsing(t *testing.T) { assert.Equal(t, customerEnvVal, val) } - assert.Equal(t, env.Customer["MY_FOOBAR_ENV_1"], customerEnvVal) - assert.Equal(t, env.Customer["MY_EMPTY_ENV"], "") + assert.Equal(t, customerEnvVal, env.Customer["MY_FOOBAR_ENV_1"]) + assert.Equal(t, "", env.Customer["MY_EMPTY_ENV"]) + assert.Equal(t, "", env.Customer["_UNKNOWN_INTERNAL_ENV"]) } func TestEnvironmentParsingUnsetPlatformAndInternalEnvVarKeysAreDeleted(t *testing.T) { @@ -172,7 +175,7 @@ func TestRuntimeExecEnvironmentVariablesPriority(t *testing.T) { } cliOptionsEnv := map[string]string{ - "LCIS_ARG1": lcisCLIArgEnvVal, + "LCIS_ARG1": lcisCLIArgEnvVal, conflictPlatformKeyFromCLI: lcisCLIArgEnvVal, } @@ -208,7 +211,7 @@ func TestCustomerEnvironmentVariablesFromInitCanOverrideEnvironmentVariablesFrom } cliOptionsEnv := map[string]string{ - "LCIS_ARG1": lcisCLIArgEnvVal, + "LCIS_ARG1": lcisCLIArgEnvVal, "MY_FOOBAR_ENV_1": lcisCLIArgEnvVal, } diff --git a/lambda/rapidcore/errors.go b/lambda/rapidcore/errors.go index ea142be..06a4830 100644 --- a/lambda/rapidcore/errors.go +++ b/lambda/rapidcore/errors.go @@ -7,6 +7,7 @@ import "errors" var ErrInitAlreadyDone = errors.New("InitAlreadyDone") var ErrInitDoneFailed = errors.New("InitDoneFailed") +var ErrInitError = errors.New("InitError") var ErrNotReserved = errors.New("NotReserved") var ErrAlreadyReserved = errors.New("AlreadyReserved") @@ -21,7 +22,6 @@ var ErrInvokeReservationDone = errors.New("InvokeReservationDone") var ErrReleaseReservationDone = errors.New("ReleaseReservationDone") var ErrInternalServerError = errors.New("InternalServerError") - var ErrInvokeTimeout = errors.New("InvokeTimeout") var ErrTerminated = errors.New("SandboxTerminated") // sent to signal a process exit diff --git a/lambda/rapidcore/sandbox.go b/lambda/rapidcore/sandbox.go index 739de68..fee0e73 100644 --- a/lambda/rapidcore/sandbox.go +++ b/lambda/rapidcore/sandbox.go @@ -7,6 +7,7 @@ import ( "context" "io" "io/ioutil" + "net/http" "os" "os/signal" "syscall" @@ -28,17 +29,23 @@ const ( type Sandbox interface { Init(i *interop.Init, invokeTimeoutMs int64) - Invoke(responseWriter io.Writer, invoke *interop.Invoke) error + Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error InteropServer() InteropServer } +type ReserveResponse struct { + Token interop.Token + InternalState *statejson.InternalStateDescription +} + type InteropServer interface { - FastInvoke(w io.Writer, i *interop.Invoke) error - Reserve(id string) (string, *statejson.InternalStateDescription, error) - Reset(reason string, timeoutMs int64) error + FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error + Reserve(id string, traceID, lambdaSegmentID string) (*ReserveResponse, error) + Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) AwaitRelease() (*statejson.InternalStateDescription, error) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription InternalState() (*statejson.InternalStateDescription, error) + CurrentToken() *interop.Token } type SandboxBuilder struct { @@ -197,7 +204,7 @@ func (b *SandboxBuilder) Init(i *interop.Init, timeoutMs int64) { }, timeoutMs) } -func (b *SandboxBuilder) Invoke(w io.Writer, i *interop.Invoke) error { +func (b *SandboxBuilder) Invoke(w http.ResponseWriter, i *interop.Invoke) error { return b.sandbox.InteropServer.Invoke(w, i) } diff --git a/lambda/rapidcore/server.go b/lambda/rapidcore/server.go index 462dc1d..ee5b243 100644 --- a/lambda/rapidcore/server.go +++ b/lambda/rapidcore/server.go @@ -4,13 +4,18 @@ package rapidcore import ( + "bytes" "context" "errors" "fmt" "io" + "io/ioutil" + "math" + "net/http" "sync" "time" + "go.amzn.com/lambda/core/directinvoke" "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" @@ -23,10 +28,37 @@ const ( autoresetReasonTimeout = "Timeout" autoresetReasonReserveFail = "ReserveFail" autoresetReasonReleaseFail = "ReleaseFail" + standaloneVersionID = "1" resetDefaultTimeoutMs = 2000 ) +// rapidPhase tracks the state machine in the go.amzn.com/lambda/rapid receive loop. See +// a state diagram of how the events and states of rapid package and this interop server +type rapidPhase int + +const ( + phaseIdle rapidPhase = iota + phaseInitializing + phaseInvoking +) + +type runtimeState int + +const ( + runtimeNotStarted = iota + + runtimeInitStarted + runtimeInitError + runtimeInitComplete + runtimeInitFailed + + runtimeInvokeResponseSent + runtimeInvokeError + runtimeReady + runtimeInvokeComplete +) + type DoneWithState struct { *interop.Done State statejson.InternalStateDescription @@ -37,9 +69,10 @@ func (s *DoneWithState) String() string { } type InvokeContext struct { - ID string + Token interop.Token ReplySent bool - ReplyStream io.Writer + ReplyStream http.ResponseWriter + Direct bool } type Server struct { @@ -66,6 +99,41 @@ type Server struct { reservationContext context.Context reservationCancel func() + + rapidPhase rapidPhase + runtimeState runtimeState +} + +func (s *Server) StartAcceptingDirectInvokes() error { + return nil +} + +func (s *Server) setRapidPhase(phase rapidPhase) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.rapidPhase = phase +} + +func (s *Server) getRapidPhase() rapidPhase { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.rapidPhase +} + +func (s *Server) setRuntimeState(state runtimeState) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.runtimeState = state +} + +func (s *Server) getRuntimeState() runtimeState { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.runtimeState } func (s *Server) SetInvokeTimeout(timeout time.Duration) { @@ -82,28 +150,56 @@ func (s *Server) GetInvokeTimeout() time.Duration { return s.invokeTimeout } -// Reserve allocates invoke context, returnes new invokeID -func (s *Server) Reserve(id string) (string, *statejson.InternalStateDescription, error) { +func (s *Server) GetInvokeContext() *InvokeContext { s.mutex.Lock() + defer s.mutex.Unlock() + + ctx := *s.invokeCtx + return &ctx +} + +func (s *Server) setNewInvokeContext(invokeID string, traceID, lambdaSegmentID string) (*ReserveResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() if s.invokeCtx != nil { - return "", nil, ErrAlreadyReserved - } - invokeID := uuid.New().String() - if len(id) > 0 { - invokeID = id + return nil, ErrAlreadyReserved } s.invokeCtx = &InvokeContext{ - ID: invokeID, + Token: interop.Token{ + ReservationToken: uuid.New().String(), + InvokeID: invokeID, + VersionID: standaloneVersionID, + FunctionTimeout: s.invokeTimeout, + TraceID: traceID, + LambdaSegmentID: lambdaSegmentID, + InvackDeadlineNs: math.MaxInt64, // no INVACK in standalone + }, + } + + resp := &ReserveResponse{ + Token: s.invokeCtx.Token, } s.reservationContext, s.reservationCancel = context.WithCancel(context.Background()) - s.mutex.Unlock() + return resp, nil +} + +// Reserve allocates invoke context +func (s *Server) Reserve(id string, traceID, lambdaSegmentID string) (*ReserveResponse, error) { + invokeID := uuid.New().String() + if len(id) > 0 { + invokeID = id + } + resp, err := s.setNewInvokeContext(invokeID, traceID, lambdaSegmentID) + if err != nil { + return nil, err + } - internalState, err := s.waitInit() - return invokeID, internalState, err + resp.InternalState, err = s.waitInit() + return resp, err } func (s *Server) waitInit() (*statejson.InternalStateDescription, error) { @@ -112,11 +208,15 @@ func (s *Server) waitInit() (*statejson.InternalStateDescription, error) { case doneWithState, chanOpen := <-s.InitDoneChan: if !chanOpen { + // init only happens once return nil, ErrInitAlreadyDone } - // this was first call to reserve - close(s.InitDoneChan) + close(s.InitDoneChan) // this was first call to reserve + + if s.getRuntimeState() == runtimeInitFailed { + return &doneWithState.State, ErrInitError + } if len(doneWithState.ErrorType) > 0 { log.Errorf("INIT DONE failed: %s", doneWithState.ErrorType) @@ -131,7 +231,7 @@ func (s *Server) waitInit() (*statejson.InternalStateDescription, error) { } } -func (s *Server) setReplyStream(w io.Writer) (string, error) { +func (s *Server) setReplyStream(w http.ResponseWriter, direct bool) (string, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -148,7 +248,8 @@ func (s *Server) setReplyStream(w io.Writer) (string, error) { } s.invokeCtx.ReplyStream = w - return s.invokeCtx.ID, nil + s.invokeCtx.Direct = direct + return s.invokeCtx.Token.InvokeID, nil } // Release closes the invocation, making server ready for reserve again @@ -177,7 +278,7 @@ func (s *Server) GetCurrentInvokeID() string { return "" } - return s.invokeCtx.ID + return s.invokeCtx.Token.InvokeID } // SetInternalStateGetter is used to set callback which returnes internal state for /test/internalState request @@ -210,8 +311,8 @@ func (s *Server) TransportErrorChan() <-chan error { return s.errorChanOut } -func (s *Server) sendResponseUnsafe(invokeID string, status int, payload []byte) error { - if s.invokeCtx == nil || invokeID != s.invokeCtx.ID { +func (s *Server) sendResponseUnsafe(invokeID string, status int, payload io.Reader) error { + if s.invokeCtx == nil || invokeID != s.invokeCtx.Token.InvokeID { return interop.ErrInvalidInvokeID } @@ -223,34 +324,70 @@ func (s *Server) sendResponseUnsafe(invokeID string, status int, payload []byte) return fmt.Errorf("ReplyStream not available") } + // TODO: earlier, status was set to 500 if runtime called /invocation/error. However, the integration + // tests do not differentiate between /invocation/error and /invocation/response, but they check the error type: + // To identify user-errors, we should also allowlist custom errortypes and propagate them via headers. + // s.invokeCtx.ReplyStream.WriteHeader(status) - if _, err := s.invokeCtx.ReplyStream.Write(payload); err != nil { - return fmt.Errorf("Failed to write response to %s: %s", invokeID, err) + + if s.invokeCtx.Direct { + if err := directinvoke.SendDirectInvokeResponse(nil, payload, s.invokeCtx.ReplyStream); err != nil { + // we intentionally do not return an error here: + // even if error happened, the response has already been initiated (and might be partially written into the socket) + // so there is no other option except to consider response to be sent. + log.Errorf("Failed to write response to %s: %s", invokeID, err) + } + } else { + data, err := ioutil.ReadAll(payload) + if err != nil { + return fmt.Errorf("Failed to read response on %s: %s", invokeID, err) + } + if len(data) > interop.MaxPayloadSize { + return &interop.ErrorResponseTooLarge{ + ResponseSize: len(data), + MaxResponseSize: interop.MaxPayloadSize, + } + } + if _, err := s.invokeCtx.ReplyStream.Write(data); err != nil { + return fmt.Errorf("Failed to write response to %s: %s", invokeID, err) + } } s.sendResponseChan <- struct{}{} - s.invokeCtx.ReplySent = true + s.invokeCtx.Direct = false return nil } -func (s *Server) SendResponse(invokeID string, resp *interop.Response) error { +func (s *Server) SendResponse(invokeID string, reader io.Reader) error { + s.setRuntimeState(runtimeInvokeResponseSent) s.mutex.Lock() defer s.mutex.Unlock() - return s.sendResponseUnsafe(invokeID, 200, resp.Payload) + return s.sendResponseUnsafe(invokeID, http.StatusOK, reader) } func (s *Server) CommitResponse() error { return nil } func (s *Server) SendRunning(run *interop.Running) error { + s.setRuntimeState(runtimeInitStarted) s.sendRunningChan <- run return nil } func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorResponse) error { - s.mutex.Lock() - defer s.mutex.Unlock() - return s.sendResponseUnsafe(invokeID, 500, resp.Payload) + switch s.getRapidPhase() { + case phaseInitializing: + s.setRuntimeState(runtimeInitError) + return nil + case phaseInvoking: + // This branch can also occur during a suppressed init error, which is reported as invoke error + s.setRuntimeState(runtimeInvokeError) + s.mutex.Lock() + defer s.mutex.Unlock() + return s.sendResponseUnsafe(invokeID, http.StatusInternalServerError, bytes.NewReader(resp.Payload)) + default: + panic("received unexpected error response outside invoke or init phases") + } } func (s *Server) SendDone(done *interop.Done) error { @@ -260,15 +397,14 @@ func (s *Server) SendDone(done *interop.Done) error { func (s *Server) SendDoneFail(doneFail *interop.DoneFail) error { s.doneChan <- &interop.Done{ - RuntimeRelease: doneFail.RuntimeRelease, - NumActiveExtensions: doneFail.NumActiveExtensions, - ErrorType: doneFail.ErrorType, - CorrelationID: doneFail.CorrelationID, // filipovi: correlationID is required to dispatch message into correct channel + ErrorType: doneFail.ErrorType, + CorrelationID: doneFail.CorrelationID, // filipovi: correlationID is required to dispatch message into correct channel + Meta: doneFail.Meta, } return nil } -func (s *Server) Reset(reason string, timeoutMs int64) error { +func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) { // pass reset to rapid s.resetChanOut <- &interop.Reset{ Reason: reason, @@ -282,10 +418,10 @@ func (s *Server) Reset(reason string, timeoutMs int64) error { s.Release() if done.ErrorType != "" { - return errors.New(done.ErrorType) + return nil, errors.New(string(done.ErrorType)) } - return nil + return &statejson.ResetDescription{ExtensionsResetMs: done.Meta.ExtensionsResetMs}, nil } func NewServer(ctx context.Context) *Server { @@ -313,6 +449,14 @@ func NewServer(ctx context.Context) *Server { return s } +func (s *Server) setInitDoneRuntimeState(done *interop.Done) { + if len(done.ErrorType) > 0 { + s.setRuntimeState(runtimeInitFailed) // donefail + } else { + s.setRuntimeState(runtimeInitComplete) // done + } +} + // Note, the dispatch loop below has potential to block, when // channel is not drained. E.g. if test assumes sandbox init // completion before dispatching reset, then reset will block @@ -321,14 +465,19 @@ func (s *Server) dispatchDone() { for { done := <-s.doneChan log.Debug("Dispatching DONE:", done.CorrelationID) - + internalState := s.InternalStateGetter() + s.setRapidPhase(phaseIdle) if done.CorrelationID == "initCorrelationID" { - s.InitDoneChan <- DoneWithState{Done: done, State: s.InternalStateGetter()} + s.setInitDoneRuntimeState(done) + s.InitDoneChan <- DoneWithState{Done: done, State: internalState} } else if done.CorrelationID == "invokeCorrelationID" { - s.InvokeDoneChan <- DoneWithState{Done: done, State: s.InternalStateGetter()} + s.setRuntimeState(runtimeInvokeComplete) + s.InvokeDoneChan <- DoneWithState{Done: done, State: internalState} } else if done.CorrelationID == "resetCorrelationID" { + s.setRuntimeState(runtimeNotStarted) s.ResetDoneChan <- done } else if done.CorrelationID == "shutdownCorrelationID" { + s.setRuntimeState(runtimeNotStarted) s.ShutdownDoneChan <- done } else { panic("Received DONE without correlation ID") @@ -359,7 +508,11 @@ func (s *Server) IsResponseSent() bool { panic("unexpected call to unimplemented method in rapidcore: IsResponseSent()") } -func (s *Server) SendRuntimeReady() error { return nil } +func (s *Server) SendRuntimeReady() error { + // only called when extensions are enabled + s.setRuntimeState(runtimeReady) + return nil +} func deadlineNsFromTimeoutMs(timeoutMs int64) int64 { mono := metering.Monotime() @@ -370,17 +523,19 @@ func (s *Server) Init(i *interop.Start, invokeTimeoutMs int64) { s.SetInvokeTimeout(time.Duration(invokeTimeoutMs) * time.Millisecond) s.startChanOut <- i + s.setRapidPhase(phaseInitializing) <-s.sendRunningChan log.Debug("Received RUNNING") } -func (s *Server) FastInvoke(w io.Writer, i *interop.Invoke) error { - i.DeadlineNs = fmt.Sprintf("%v", time.Now().Add(s.invokeTimeout).UnixNano()) - invokeID, err := s.setReplyStream(w) +func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error { + invokeID, err := s.setReplyStream(w, direct) if err != nil { return err } + s.setRapidPhase(phaseInvoking) + i.ID = invokeID select { @@ -402,7 +557,19 @@ func (s *Server) FastInvoke(w io.Writer, i *interop.Invoke) error { return nil } -func (s *Server) Invoke(responseWriter io.Writer, invoke *interop.Invoke) error { +func (s *Server) CurrentToken() *interop.Token { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.invokeCtx == nil { + return nil + } + tok := s.invokeCtx.Token + return &tok +} + +// Invoke is used by the Runtime Interface Emulator (Rapid Local) +// https://github.com/aws/aws-lambda-runtime-interface-emulator +func (s *Server) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error { resetCtx, resetCancel := context.WithCancel(context.Background()) defer resetCancel() @@ -416,14 +583,26 @@ func (s *Server) Invoke(responseWriter io.Writer, invoke *interop.Invoke) error log.Debugf("execute finished, autoreset cancelled") } }() - if _, _, err := s.Reserve(invoke.ID); err != nil { + + reserveResp, err := s.Reserve(invoke.ID, "", "") + if err != nil { switch err { + case ErrInitError: + // Simulate 'Suppressed Init' scenario + s.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) + reserveResp, err = s.Reserve("", "", "") + if err == ErrInitAlreadyDone { + break + } + return err case ErrInitDoneFailed, ErrTerminated: s.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) return err + case ErrInitAlreadyDone: - // init already happened, just return internal state - // this was retained to prevent execute test regressions + // This is a valid response (e.g. for 2nd and 3rd invokes) + // TODO: switch on ReserveResponse status instead of err, + // since these are valid values if s.InternalStateGetter == nil { responseWriter.Write([]byte("error: internal state callback not set")) return ErrInternalServerError @@ -433,11 +612,12 @@ func (s *Server) Invoke(responseWriter io.Writer, invoke *interop.Invoke) error return err } } - invoke.DeadlineNs = fmt.Sprintf("%v", time.Now().Add(s.invokeTimeout).UnixNano()) + + invoke.DeadlineNs = fmt.Sprintf("%d", metering.Monotime()+reserveResp.Token.FunctionTimeout.Nanoseconds()) invokeChan := make(chan error) go func() { - if err := s.FastInvoke(responseWriter, invoke); err != nil { + if err := s.FastInvoke(responseWriter, invoke, false); err != nil { invokeChan <- err } }() @@ -448,8 +628,6 @@ func (s *Server) Invoke(responseWriter io.Writer, invoke *interop.Invoke) error releaseChan <- err }() - var err error - // TODO: verify the order of channel receives. When timeouts happen, Reset() // is called first, which also does Release() => this may signal a type // Err<*>ReservationDone error to the non-timeout channels. This is currently diff --git a/lambda/rapidcore/server_test.go b/lambda/rapidcore/server_test.go new file mode 100644 index 0000000..85530f1 --- /dev/null +++ b/lambda/rapidcore/server_test.go @@ -0,0 +1,403 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/interop" +) + +func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { + select { + case err := <-channel: + return err + case <-time.After(timeout): + return nil + } +} + +func TestReserveDoesNotDeadlockWhenCalledMultipleTimes(t *testing.T) { + srv := NewServer(context.Background()) + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + + go func() { <-srv.StartChan() }() + go srv.SendRunning(&interop.Running{}) + srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + + go srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"}) + _, err := srv.Reserve("", "", "") // reserve successfully + require.NoError(t, err) + + resp, err := srv.Reserve("", "", "") // attempt double reservation + require.Nil(t, resp) + require.Equal(t, ErrAlreadyReserved, err) + + successChan := make(chan error) + go func() { + resp, err := srv.Reserve("", "", "") + require.Nil(t, resp) + require.Equal(t, ErrAlreadyReserved, err) + successChan <- nil + }() + + select { + case <-time.After(1 * time.Second): + require.Fail(t, "Timed out while waiting for Reserve() response") + case <-successChan: + } +} + +func TestInitSuccess(t *testing.T) { + srv := NewServer(context.Background()) + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + + go func() { + <-srv.StartChan() + require.NoError(t, srv.SendRunning(&interop.Running{})) + require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) + }() + + srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + _, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) +} + +func TestInitErrorBeforeReserve(t *testing.T) { + srv := NewServer(context.Background()) + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + + initErrorResponseSent := make(chan error) + go func() { + <-srv.StartChan() + require.NoError(t, srv.SendRunning(&interop.Running{})) + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) + initErrorResponseSent <- errors.New("initErrorResponseSent") + }() + + srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + + if msg := waitForChanWithTimeout(initErrorResponseSent, 1*time.Second); msg == nil { + require.Fail(t, "Timed out waiting for init error response to be sent") + } + + resp, err := srv.Reserve("", "", "") + require.EqualError(t, err, ErrInitError.Error()) + require.True(t, len(resp.Token.InvokeID) > 0) + require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) +} + +func TestInitErrorDuringReserve(t *testing.T) { + srv := NewServer(context.Background()) + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + + go func() { + <-srv.StartChan() + require.NoError(t, srv.SendRunning(&interop.Running{})) + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) + }() + + srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + resp, err := srv.Reserve("", "", "") + require.EqualError(t, err, ErrInitError.Error()) + require.True(t, len(resp.Token.InvokeID) > 0) + require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) +} + +func TestInvokeSuccess(t *testing.T) { + srv := NewServer(context.Background()) + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + + go func() { + <-srv.StartChan() + require.NoError(t, srv.SendRunning(&interop.Running{})) + require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) + + <-srv.InvokeChan() + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), bytes.NewReader([]byte("response")))) + require.NoError(t, srv.SendRuntimeReady()) + require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) + }() + + srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + _, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) + + responseRecorder := httptest.NewRecorder() + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + require.NoError(t, invokeErr) + require.Equal(t, "response", responseRecorder.Body.String()) + + _, err = srv.AwaitRelease() + require.NoError(t, err) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) +} + +func TestInvokeError(t *testing.T) { + srv := NewServer(context.Background()) + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + + go func() { + <-srv.StartChan() + require.NoError(t, srv.SendRunning(&interop.Running{})) + require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) + + <-srv.InvokeChan() + + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + require.NoError(t, srv.SendRuntimeReady()) + require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) + }() + + srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + _, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) + + responseRecorder := httptest.NewRecorder() + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + require.NoError(t, invokeErr) + require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) + + _, err = srv.AwaitRelease() + require.NoError(t, err) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) +} + +func TestInvokeWithSuppressedInitSuccess(t *testing.T) { + // Tests an init/error followed by suppressed init: + // Runtime may have called init/error before Reserve, in which case we + // expect a suppressed init, i.e. init during the invoke. + // The first Reserve() after init/error returns ErrInitError because + // SendDoneFail was called on init/error. + // We expect the caller to then call Reset() to prepare for suppressed init, + // followed by Reserve() so that a valid reservation context is available. + // Reserve() returns ErrInitAlreadyDone, since the server implementation + // closes the InitDone channel after the first InitDone message. + + srv := NewServer(context.Background()) + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + + initErrorCompleted := make(chan error) + go func() { + <-srv.StartChan() + require.NoError(t, srv.SendRunning(&interop.Running{})) + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) + initErrorCompleted <- errors.New("initErrorSequenceCompleted") + + <-srv.ResetChan() + require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"})) + + <-srv.InvokeChan() // run only after FastInvoke is called + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), bytes.NewReader([]byte("response")))) + require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) + }() + + srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + if msg := waitForChanWithTimeout(initErrorCompleted, 1*time.Second); msg == nil { + require.Fail(t, "Timed out waiting for init error sequence to be called") + } + + _, err := srv.Reserve("", "", "") + require.EqualError(t, err, ErrInitError.Error()) + require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + + _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for suppressed init + require.NoError(t, err) + + _, err = srv.Reserve("", "", "") + require.EqualError(t, err, ErrInitAlreadyDone.Error()) + + responseRecorder := httptest.NewRecorder() + successChan := make(chan error) + go func() { + directInvoke := false + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, directInvoke) + require.NoError(t, invokeErr) + successChan <- errors.New("invokeResponseWritten") + }() + + invokeErr := waitForChanWithTimeout(successChan, 1*time.Second) + if invokeErr == nil { + require.Fail(t, "Timed out while waiting for invoke response") + } + + require.Equal(t, "response", responseRecorder.Body.String()) + + _, err = srv.AwaitRelease() + require.NoError(t, err) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) +} + +func TestInvokeWithSuppressedInitErrorDueToInitError(t *testing.T) { + // Tests init/error followed by init/error during suppressed init + srv := NewServer(context.Background()) + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + + releaseChan := make(chan error) + go func() { + <-srv.StartChan() + require.NoError(t, srv.SendRunning(&interop.Running{})) + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "A.B"})) + + <-srv.ResetChan() + srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"}) + + <-srv.InvokeChan() + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + releaseChan <- nil + require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "invokeCorrelationID", ErrorType: "A.B"})) + }() + + srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + _, err := srv.Reserve("", "", "") + require.EqualError(t, err, ErrInitError.Error()) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + + _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for invoke with suppressed init + require.NoError(t, err) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + + _, err = srv.Reserve("", "", "") + require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + + responseRecorder := httptest.NewRecorder() + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + require.NoError(t, invokeErr) + require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) + require.Equal(t, phaseInvoking, srv.getRapidPhase()) + + <-releaseChan // Unblock gorotune to send donefail + _, err = srv.AwaitRelease() + require.EqualError(t, err, ErrInvokeDoneFailed.Error()) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) +} + +func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { + // Tests init/error followed by init/error during suppressed init + srv := NewServer(context.Background()) + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + + go func() { + <-srv.StartChan() + require.NoError(t, srv.SendRunning(&interop.Running{})) + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "A.B"})) + + <-srv.ResetChan() + srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"}) + + <-srv.InvokeChan() + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) + require.NoError(t, srv.SendRuntimeReady()) + require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) + }() + + srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + _, err := srv.Reserve("", "", "") + require.EqualError(t, err, ErrInitError.Error()) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + + _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for invoke with suppressed init + require.NoError(t, err) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + + _, err = srv.Reserve("", "", "") + require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + + responseRecorder := httptest.NewRecorder() + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + require.NoError(t, invokeErr) + require.Equal(t, "{ 'errorType': 'B.C' }", responseRecorder.Body.String()) + + _, err = srv.AwaitRelease() + require.NoError(t, err) // /invocation/error -> /invocation/next returns no error / donefail + require.Equal(t, phaseIdle, srv.getRapidPhase()) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) +} + +func TestMultipleInvokeSuccess(t *testing.T) { + srv := NewServer(context.Background()) + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + + go func() { + <-srv.StartChan() + require.NoError(t, srv.SendRunning(&interop.Running{})) + require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) + }() + + srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + invokeFunc := func(i int) { + <-srv.InvokeChan() + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), bytes.NewReader([]byte("response-"+fmt.Sprint(i))))) + require.NoError(t, srv.SendRuntimeReady()) + require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) + } + go func() { + for i := 0; i < 3; i++ { + invokeFunc(i) + } + }() + + for i := 0; i < 3; i++ { + _, err := srv.Reserve("", "", "") + require.Contains(t, []error{nil, ErrInitAlreadyDone}, err) + require.Equal(t, phaseIdle, srv.getRapidPhase()) + + responseRecorder := httptest.NewRecorder() + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + require.NoError(t, invokeErr) + require.Equal(t, "response-"+fmt.Sprint(i), responseRecorder.Body.String()) + + _, err = srv.AwaitRelease() + require.NoError(t, err) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) + } +} + +/* Unit tests remaining: +- Shutdown behaviour +- Reset behaviour during various phases +- Runtime / extensions process exit sequences +- Invoke() and Init() api tests + +See PlantUML state diagram for potential other uncovered paths +through the state machine +*/ diff --git a/lambda/rapidcore/standalone/directInvokeHandler.go b/lambda/rapidcore/standalone/directInvokeHandler.go new file mode 100644 index 0000000..a485deb --- /dev/null +++ b/lambda/rapidcore/standalone/directInvokeHandler.go @@ -0,0 +1,39 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package standalone + +import ( + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/core/directinvoke" + "go.amzn.com/lambda/rapidcore" + "net/http" +) + +func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { + tok := s.CurrentToken() + if tok == nil { + log.Errorf("Attempt to call directInvoke without Reserve") + w.WriteHeader(http.StatusBadRequest) + return + } + + invoke, err := directinvoke.ReceiveDirectInvoke(w, r, *tok) + if err != nil { + log.Errorf("direct invoke error: %s", err) + return + } + + if err := s.FastInvoke(w, invoke, true); err != nil { + switch err { + case rapidcore.ErrNotReserved: + case rapidcore.ErrAlreadyReplied: + case rapidcore.ErrAlreadyInvocating: + log.Errorf("Failed to set reply stream: %s", err) + w.WriteHeader(http.StatusBadRequest) + return + case rapidcore.ErrInvokeReservationDone: + w.WriteHeader(http.StatusBadGateway) + } + } +} diff --git a/lambda/rapidcore/standalone/executeHandler.go b/lambda/rapidcore/standalone/executeHandler.go index ff425dd..0b89322 100644 --- a/lambda/rapidcore/standalone/executeHandler.go +++ b/lambda/rapidcore/standalone/executeHandler.go @@ -4,7 +4,6 @@ package standalone import ( - "io/ioutil" "net/http" log "github.com/sirupsen/logrus" @@ -13,17 +12,11 @@ import ( ) func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) { - bodyBytes, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Errorf("Failed to read invoke body: %s", err) - w.WriteHeader(500) - return - } invokePayload := &interop.Invoke{ TraceID: r.Header.Get("X-Amzn-Trace-Id"), LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: bodyBytes, + Payload: r.Body, CorrelationID: "invokeCorrelationID", } diff --git a/lambda/rapidcore/standalone/invokeHandler.go b/lambda/rapidcore/standalone/invokeHandler.go index 45662f2..25819e3 100644 --- a/lambda/rapidcore/standalone/invokeHandler.go +++ b/lambda/rapidcore/standalone/invokeHandler.go @@ -4,31 +4,33 @@ package standalone import ( - "io/ioutil" + "fmt" "net/http" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapidcore" log "github.com/sirupsen/logrus" ) func InvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { - bodyBytes, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Errorf("Failed to read invoke body: %s", err) - w.WriteHeader(500) + tok := s.CurrentToken() + if tok == nil { + log.Errorf("Attempt to call directInvoke without Reserve") + w.WriteHeader(http.StatusBadRequest) return } invokePayload := &interop.Invoke{ TraceID: r.Header.Get("X-Amzn-Trace-Id"), LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: bodyBytes, + Payload: r.Body, CorrelationID: "invokeCorrelationID", + DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), } - if err := s.FastInvoke(w, invokePayload); err != nil { + if err := s.FastInvoke(w, invokePayload, false); err != nil { switch err { case rapidcore.ErrNotReserved: case rapidcore.ErrAlreadyReplied: diff --git a/lambda/rapidcore/standalone/reserveHandler.go b/lambda/rapidcore/standalone/reserveHandler.go index 05c102f..d3e0b9f 100644 --- a/lambda/rapidcore/standalone/reserveHandler.go +++ b/lambda/rapidcore/standalone/reserveHandler.go @@ -7,33 +7,48 @@ import ( "net/http" log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/core/directinvoke" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" ) +const ( + ReservationTokenHeader = "Reservation-Token" + InvokeIDHeader = "Invoke-ID" + VersionIDHeader = "Version-ID" +) + +func tokenToHeaders(w http.ResponseWriter, token interop.Token) { + w.Header().Set(ReservationTokenHeader, token.ReservationToken) + w.Header().Set(directinvoke.InvokeIDHeader, token.InvokeID) + w.Header().Set(directinvoke.VersionIDHeader, token.VersionID) +} + func ReserveHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { - _, internalState, err := s.Reserve("") + reserveResp, err := s.Reserve("", r.Header.Get("X-Amzn-Trace-Id"), r.Header.Get("X-Amzn-Segment-Id")) + if err != nil { switch err { - case rapidcore.ErrAlreadyReserved: - log.Errorf("Failed to reserve: %s", err) - w.WriteHeader(400) - return case rapidcore.ErrInitAlreadyDone: // init already happened before, just provide internal state and return + tokenToHeaders(w, reserveResp.Token) InternalStateHandler(w, r, s) - return case rapidcore.ErrReserveReservationDone: // TODO use http.StatusBadGateway w.WriteHeader(http.StatusGatewayTimeout) - return - case rapidcore.ErrInitDoneFailed: + case rapidcore.ErrInitDoneFailed, rapidcore.ErrInitError: w.WriteHeader(DoneFailedHTTPCode) + w.Write(reserveResp.InternalState.AsJSON()) case rapidcore.ErrTerminated: w.WriteHeader(DoneFailedHTTPCode) - w.Write(internalState.AsJSON()) - return + w.Write(reserveResp.InternalState.AsJSON()) + default: + log.Errorf("Failed to reserve: %s", err) + w.WriteHeader(400) } + return } - w.Write(internalState.AsJSON()) + tokenToHeaders(w, reserveResp.Token) + w.Write(reserveResp.InternalState.AsJSON()) } diff --git a/lambda/rapidcore/standalone/resetHandler.go b/lambda/rapidcore/standalone/resetHandler.go index d6d7b8e..1a719ff 100644 --- a/lambda/rapidcore/standalone/resetHandler.go +++ b/lambda/rapidcore/standalone/resetHandler.go @@ -21,10 +21,11 @@ func ResetHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropSer return } - if err := s.Reset(reset.Reason, reset.TimeoutMs); err != nil { + resetDescription, err := s.Reset(reset.Reason, reset.TimeoutMs) + if err != nil { (&FailureReply{}).Send(w, r) return } - (&SuccessReply{}).Send(w, r) + w.Write(resetDescription.AsJSON()) } diff --git a/lambda/rapidcore/standalone/router.go b/lambda/rapidcore/standalone/router.go index dfa97bb..5a4ae7c 100644 --- a/lambda/rapidcore/standalone/router.go +++ b/lambda/rapidcore/standalone/router.go @@ -24,6 +24,7 @@ func NewHTTPRouter(sandbox rapidcore.Sandbox, eventLog *telemetry.EventLog, shut r.Post("/test/waitUntilRelease", func(w http.ResponseWriter, r *http.Request) { WaitUntilReleaseHandler(w, r, ipcSrv) }) r.Post("/test/reset", func(w http.ResponseWriter, r *http.Request) { ResetHandler(w, r, ipcSrv) }) r.Post("/test/shutdown", func(w http.ResponseWriter, r *http.Request) { ShutdownHandler(w, r, ipcSrv, shutdownFunc) }) + r.Post("/test/directInvoke/{reservationtoken}", func(w http.ResponseWriter, r *http.Request) { DirectInvokeHandler(w, r, ipcSrv) }) r.Get("/test/internalState", func(w http.ResponseWriter, r *http.Request) { InternalStateHandler(w, r, ipcSrv) }) r.Get("/test/eventLog", func(w http.ResponseWriter, r *http.Request) { EventLogHandler(w, r, eventLog) }) diff --git a/lambda/testdata/flowtesting.go b/lambda/testdata/flowtesting.go index 14140b6..f729632 100644 --- a/lambda/testdata/flowtesting.go +++ b/lambda/testdata/flowtesting.go @@ -6,6 +6,7 @@ package testdata import ( "context" "io" + "io/ioutil" "net/http" "go.amzn.com/lambda/appctx" @@ -18,14 +19,27 @@ import ( ) type MockInteropServer struct { - Response *interop.Response + Response []byte ErrorResponse *interop.ErrorResponse ActiveInvokeID string } +// StartAcceptingDirectInvokes +func (i *MockInteropServer) StartAcceptingDirectInvokes() error { return nil } + // SendResponse writes response to a shared memory. -func (i *MockInteropServer) SendResponse(invokeID string, response *interop.Response) error { - i.Response = response +func (i *MockInteropServer) SendResponse(invokeID string, reader io.Reader) error { + bytes, err := ioutil.ReadAll(reader) + if err != nil { + return err + } + if len(bytes) > interop.MaxPayloadSize { + return &interop.ErrorResponseTooLarge{ + ResponseSize: len(bytes), + MaxResponseSize: interop.MaxPayloadSize, + } + } + i.Response = bytes return nil } @@ -77,11 +91,10 @@ func (i *MockInteropServer) SetInternalStateGetter(isd interop.InternalStateGett func (m *MockInteropServer) Init(i *interop.Start, invokeTimeoutMs int64) {} -func (m *MockInteropServer) Invoke(w io.Writer, i *interop.Invoke) error { return nil } +func (m *MockInteropServer) Invoke(w http.ResponseWriter, i *interop.Invoke) error { return nil } func (m *MockInteropServer) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { return nil } - // FlowTest provides configuration for tests that involve synchronization flows. type FlowTest struct { AppCtx appctx.ApplicationContext diff --git a/test/integration/local_lambda/end-to-end-test.py b/test/integration/local_lambda/end-to-end-test.py index a99a84e..27d0e07 100644 --- a/test/integration/local_lambda/end-to-end-test.py +++ b/test/integration/local_lambda/end-to-end-test.py @@ -106,6 +106,48 @@ def test_exception_returned(self): r = requests.post("http://localhost:9002/2015-03-31/functions/function/invocations", json={}) self.assertEqual(b'{"errorMessage": "Raising an exception", "errorType": "Exception", "stackTrace": [" File \\"/var/task/main.py\\", line 13, in exception_handler\\n raise Exception(\\"Raising an exception\\")\\n"]}', r.content) + def test_context_get_remaining_time_in_three_seconds(self): + cmd = f"docker run --name remainingtimethree -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=3 -v {self.path_to_binary}:/local-lambda-runtime-server -p 9004:8080 --entrypoint /local-lambda-runtime-server/aws-lambda-rie {self.image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" + + Popen(cmd.split(' ')).communicate() + + r = requests.post("http://localhost:9004/2015-03-31/functions/function/invocations", json={}) + + # sleep 1s to give enough time for the endpoint to be up to curl + time.sleep(SLEEP_TIME) + # Executation time is not decided, 1.0s ~ 3.0s is a good estimation + self.assertLess(int(r.content), 3000) + self.assertGreater(int(r.content), 1000) + + + def test_context_get_remaining_time_in_ten_seconds(self): + cmd = f"docker run --name remainingtimeten -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=10 -v {self.path_to_binary}:/local-lambda-runtime-server -p 9005:8080 --entrypoint /local-lambda-runtime-server/aws-lambda-rie {self.image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" + + Popen(cmd.split(' ')).communicate() + + r = requests.post("http://localhost:9005/2015-03-31/functions/function/invocations", json={}) + + # sleep 1s to give enough time for the endpoint to be up to curl + time.sleep(SLEEP_TIME) + # Executation time is not decided, 8.0s ~ 10.0s is a good estimation + self.assertLess(int(r.content), 10000) + self.assertGreater(int(r.content), 8000) + + + def test_context_get_remaining_time_in_default_deadline(self): + cmd = f"docker run --name remainingtimedefault -d -v {self.path_to_binary}:/local-lambda-runtime-server -p 9006:8080 --entrypoint /local-lambda-runtime-server/aws-lambda-rie {self.image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" + + Popen(cmd.split(' ')).communicate() + + r = requests.post("http://localhost:9006/2015-03-31/functions/function/invocations", json={}) + + # sleep 1s to give enough time for the endpoint to be up to curl + time.sleep(SLEEP_TIME) + # Executation time is not decided, 298.0s ~ 300.0s is a good estimation + self.assertLess(int(r.content), 300000) + self.assertGreater(int(r.content), 298000) + + class TestPython36Runtime(TestCase): @classmethod @@ -153,4 +195,4 @@ def test_function_name_is_overriden(self): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/test/integration/testdata/main.py b/test/integration/testdata/main.py index b22df8f..b6b527d 100644 --- a/test/integration/testdata/main.py +++ b/test/integration/testdata/main.py @@ -35,3 +35,9 @@ def assert_lambda_arn_in_context(event, context): return "My lambda ran succesfully" else: raise("Function Arn was not there") + + +def check_remaining_time_handler(event, context): + # Wait 1s to see if the remaining time changes + time.sleep(1) + return context.get_remaining_time_in_millis()