diff --git a/README.md b/README.md index 0eec533..3599cc0 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ ## AWS Lambda Runtime Interface Emulator +![GitHub release (latest by date)](https://img.shields.io/github/v/release/aws/aws-lambda-runtime-interface-emulator) +![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/aws/aws-lambda-runtime-interface-emulator) +![GitHub](https://img.shields.io/github/license/aws/aws-lambda-runtime-interface-emulator) -![Apache-2.0](https://img.shields.io/npm/l/aws-sam-local.svg) The Lambda Runtime Interface Emulator is a proxy for Lambda’s Runtime and Extensions APIs, which allows customers to locally test their Lambda function packaged as a container image. It is a lightweight web-server that converts @@ -12,6 +14,21 @@ requests instead of the JSON events required for deployment to Lambda. This comp Lambda’s orchestrator, or security and authentication configurations. You can get started by downloading and installing it on your local machine. When the Lambda Runtime API emulator is executed, a `/2015-03-31/functions/function/invocations` endpoint will be stood up within the container that you post data to it in order to invoke your function for testing. +## Content +* [Installing](#installing) +* [Getting started](#getting-started) + * [Test an image with RIE included in the image](#test-an-image-with-rie-included-in-the-image) + * [To test your Lambda function with the emulator](#to-test-your-lambda-function-with-the-emulator) + * [Build RIE into your base image](#build-rie-into-your-base-image) + * [To build the emulator into your image](#to-build-the-emulator-into-your-image) + * [Test an image without adding RIE to the image](#test-an-image-without-adding-rie-to-the-image) + * [To test an image without adding RIE to the image](#to-test-an-image-without-adding-rie-to-the-image) +* [How to configure](#how-to-configure) +* [Level of support](#level-of-support) +* [Security](#security) +* [License](#license) + + ## Installing Instructions for installing AWS Lambda Runtime Interface Emulator for your platform @@ -26,26 +43,26 @@ Instructions for installing AWS Lambda Runtime Interface Emulator for your platf ## Getting started -There are a few ways you use the Runtime Interface Emulator (RIE) to locally test your function depending on the base image used. +There are a few ways you use the Runtime Interface Emulator (RIE) to locally test your function depending on the base image used. ### Test an image with RIE included in the image -The AWS base images for Lambda include the runtime interface emulator. You can also follow these steps if you built the RIE into your alternative base image. +The AWS base images for Lambda include the runtime interface emulator. You can also follow these steps if you built the RIE into your alternative base image. #### To test your Lambda function with the emulator -1. Build your image locally using the docker build command. +1. Build your image locally using the docker build command. `docker build -t myfunction:latest .` -2. Run your container image locally using the docker run command. +2. Run your container image locally using the docker run command. `docker run -p 9000:8080 myfunction:latest` - This command runs the image as a container and starts up an endpoint locally at `localhost:9000/2015-03-31/functions/function/invocations`. + This command runs the image as a container and starts up an endpoint locally at `localhost:9000/2015-03-31/functions/function/invocations`. -3. Post an event to the following endpoint using a curl command: +3. Post an event to the following endpoint using a curl command: `curl -XPOST "http://localhost:9000/2015-03-31/functions/function/invocations" -d '{}'` @@ -59,10 +76,11 @@ You can build RIE into a base image. Download the RIE from GitHub to your local 1. Create a script and save it in your project directory. Set execution permissions for the script file. -The script checks for the presence of the `AWS_LAMBDA_RUNTIME_API` environment variable, which indicates the presence of the runtime API. If the runtime API is present, the script runs [the runtime interface client](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-images.html#runtimes-api-client). Otherwise, the script runs the runtime interface emulator. + The script checks for the presence of the `AWS_LAMBDA_RUNTIME_API` environment variable, which indicates the presence of the runtime API. If the runtime API is present, the script runs [the runtime interface client](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-images.html#runtimes-api-client). Otherwise, the script runs the runtime interface emulator. -The following example shows a typical script for a Node.js function. - ``` + The following example shows a typical script for a Node.js function. + + ```sh #!/bin/sh if [ -z "${AWS_LAMBDA_RUNTIME_API}" ]; then exec /usr/local/bin/aws-lambda-rie /usr/bin/npx aws-lambda-ric @@ -71,74 +89,83 @@ The following example shows a typical script for a Node.js function. fi ``` -2. Download the [runtime interface emulator](https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest) for your target architecture (`aws-lambda-rie` for x86\_64 or `aws-lambda-rie-arm64` for arm64) from GitHub into your project directory. +2. Download the [runtime interface emulator](https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest) for your target architecture (`aws-lambda-rie` for x86\_64 or `aws-lambda-rie-arm64` for arm64) from GitHub into your project directory. 3. Install the emulator package and change `ENTRYPOINT` to run the new script by adding the following lines to your Dockerfile: -To use the default x86\_64 architecture - ``` + To use the default x86\_64 architecture + + ```dockerfile ADD aws-lambda-rie /usr/local/bin/aws-lambda-rie ENTRYPOINT [ "/entry_script.sh" ] ``` -To use the arm64 architecture: - ``` + To use the arm64 architecture: + + ```dockerfile ADD aws-lambda-rie-arm64 /usr/local/bin/aws-lambda-rie ENTRYPOINT [ "/entry_script.sh" ] ``` 4. Build your image locally using the docker build command. - ``` + + ```sh docker build -t myfunction:latest . ``` 5. Run your image locally using the docker run command. - ``` + + ```sh docker run -p 9000:8080 myfunction:latest ``` ### Test an image without adding RIE to the image You install the runtime interface emulator to your local machine. When you run the container image, you set the entry point to be the emulator. -*To test an image without adding RIE to the image * + +#### To test an image without adding RIE to the image 1. From your project directory, run the following command to download the RIE (x86-64 architecture) from GitHub and install it on your local machine. - ``` + ```sh mkdir -p ~/.aws-lambda-rie && curl -Lo ~/.aws-lambda-rie/aws-lambda-rie \ https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie \ && chmod +x ~/.aws-lambda-rie/aws-lambda-rie - ``` + ``` + + To download the RIE for arm64 architecture, use the previous command with a different GitHub download url. -To download the RIE for arm64 architecture, use the previous command with a different GitHub download url. ``` https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie-arm64 \ ``` -2. Run your Lambda image function using the docker run command. - ``` - docker run -d -v ~/.aws-lambda-rie:/aws-lambda -p 9000:8080 myfunction:latest - --entrypoint /aws-lambda/aws-lambda-rie <(optional) image command>` +2. Run your Lambda image function using the docker run command. + + ```sh + docker run -d -v ~/.aws-lambda-rie:/aws-lambda -p 9000:8080 myfunction:latest \ + --entrypoint /aws-lambda/aws-lambda-rie <(optional) image command> ``` - This runs the image as a container and starts up an endpoint locally at `localhost:9000/2015-03-31/functions/function/invocations`. + This runs the image as a container and starts up an endpoint locally at `localhost:9000/2015-03-31/functions/function/invocations`. -3. Post an event to the following endpoint using a curl command: +3. Post an event to the following endpoint using a curl command: - `curl -XPOST "http://localhost:9000/2015-03-31/functions/function/invocations" -d '{}'` + ```sh + curl -XPOST "http://localhost:9000/2015-03-31/functions/function/invocations" -d '{}' + ``` This command invokes the function running in the container image and returns a response. -## How to configure +## How to configure -`aws-lambda-rie` can be configured through Environment Variables within the local running Image. +`aws-lambda-rie` can be configured through Environment Variables within the local running Image. You can configure your credentials by setting: * `AWS_ACCESS_KEY_ID` * `AWS_SECRET_ACCESS_KEY` * `AWS_SESSION_TOKEN` * `AWS_REGION` -You can configure timeout by setting AWS_LAMBDA_FUNCTION_TIMEOUT to the number of seconds you want your function to timeout in. +You can configure timeout by setting `AWS_LAMBDA_FUNCTION_TIMEOUT` to the number of seconds you want your function to timeout in. The rest of these Environment Variables can be set to match AWS Lambda's environment but are not required. * `AWS_LAMBDA_FUNCTION_VERSION` @@ -147,17 +174,16 @@ The rest of these Environment Variables can be set to match AWS Lambda's environ ## Level of support -You can use the emulator to test if your function code is compatible with the Lambda environment, executes successfully -and provides the expected output. For example, you can mock test events from different event sources. You can also use -it to test extensions and agents built into the container image against the Lambda Extensions API. This component -does *not *emulate* *the orchestration behavior of AWS Lambda. For example, Lambda has a network and security -configurations that will not be emulated by this component. - +You can use the emulator to test if your function code is compatible with the Lambda environment, executes successfully +and provides the expected output. For example, you can mock test events from different event sources. You can also use +it to test extensions and agents built into the container image against the Lambda Extensions API. This component +does _not_ emulate the orchestration behavior of AWS Lambda. For example, Lambda has a network and security +configurations that will not be emulated by this component. * You can use the emulator to test if your function code is compatible with the Lambda environment, runs successfully and provides the expected output. * You can also use it to test extensions and agents built into the container image against the Lambda Extensions API. -* This component does _not_ emulate Lambda’s orchestration, or security and authentication configurations. -* The component does _not_ support X-ray and other Lambda integrations locally. +* This component does _not_ emulate Lambda’s orchestration, or security and authentication configurations. +* The component does _not_ support X-ray and other Lambda integrations locally. * The component supports only Linux, for x86-64 and arm64 architectures. ## Security diff --git a/cmd/aws-lambda-rie/handlers.go b/cmd/aws-lambda-rie/handlers.go index 39097fc..42032cf 100644 --- a/cmd/aws-lambda-rie/handlers.go +++ b/cmd/aws-lambda-rie/handlers.go @@ -14,8 +14,10 @@ import ( "strings" "time" + "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" + "go.amzn.com/lambda/rapidcore/env" "github.com/google/uuid" @@ -27,6 +29,19 @@ type Sandbox interface { Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error } +type InteropServer interface { + Init(i *interop.Init, invokeTimeoutMs int64) error + AwaitInitialized() error + FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error + Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) + Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) + AwaitRelease() (*statejson.InternalStateDescription, error) + Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription + InternalState() (*statejson.InternalStateDescription, error) + CurrentToken() *interop.Token + Restore(restore *interop.Restore) error +} + var initDone bool func GetenvWithDefault(key string, defaultValue string) string { @@ -57,7 +72,7 @@ func printEndReports(invokeId string, initDuration string, memorySize string, in invokeId, invokeDuration, math.Ceil(invokeDuration), memorySize, memorySize) } -func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { +func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs interop.Bootstrap) { log.Debugf("invoke: -> %s %s %v", r.Method, r.URL, r.Header) bodyBytes, err := ioutil.ReadAll(r.Body) if err != nil { @@ -80,7 +95,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { if !initDone { - initStart, initEnd := InitHandler(sandbox, functionVersion, timeout) + initStart, initEnd := InitHandler(sandbox, functionVersion, timeout, bs) // Calculate InitDuration initTimeMS := math.Min(float64(initEnd.Sub(initStart).Nanoseconds()), @@ -99,7 +114,6 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { TraceID: r.Header.Get("X-Amzn-Trace-Id"), LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), Payload: bytes.NewReader(bodyBytes), - CorrelationID: "invokeCorrelationID", } fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion) @@ -166,7 +180,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { w.Write(invokeResp.Body) } -func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.Time, time.Time) { +func InitHandler(sandbox Sandbox, functionVersion string, timeout int64, bs interop.Bootstrap) (time.Time, time.Time) { additionalFunctionEnvironmentVariables := map[string]string{} // Add default Env Vars if they were not defined. This is a required otherwise 1p Python2.7, Python3.6, and @@ -189,15 +203,20 @@ func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.T // pass to rapid sandbox.Init(&interop.Init{ Handler: GetenvWithDefault("AWS_LAMBDA_FUNCTION_HANDLER", os.Getenv("_HANDLER")), - CorrelationID: "initCorrelationID", AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), AwsSession: os.Getenv("AWS_SESSION_TOKEN"), XRayDaemonAddress: "0.0.0.0:0", // TODO FunctionName: GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function"), FunctionVersion: functionVersion, - + RuntimeInfo: interop.RuntimeInfo{ + ImageJSON: "{}", + Arn: "", + Version: ""}, CustomerEnvironmentVariables: additionalFunctionEnvironmentVariables, + SandboxType: interop.SandboxClassic, + Bootstrap: bs, + EnvironmentVariables: env.NewEnvironment(), }, timeout*1000) initEnd := time.Now() return initStart, initEnd diff --git a/cmd/aws-lambda-rie/http.go b/cmd/aws-lambda-rie/http.go index be4002d..88bd39b 100644 --- a/cmd/aws-lambda-rie/http.go +++ b/cmd/aws-lambda-rie/http.go @@ -7,16 +7,18 @@ import ( "net/http" log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore" ) -func startHTTPServer(ipport string, sandbox Sandbox) { +func startHTTPServer(ipport string, sandbox *rapidcore.SandboxBuilder, bs interop.Bootstrap) { srv := &http.Server{ Addr: ipport, } // Pass a channel http.HandleFunc("/2015-03-31/functions/function/invocations", func(w http.ResponseWriter, r *http.Request) { - InvokeHandler(w, r, sandbox) + InvokeHandler(w, r, sandbox.LambdaInvokeAPI(), bs) }) // go routine (main thread waits) diff --git a/cmd/aws-lambda-rie/main.go b/cmd/aws-lambda-rie/main.go index 3a87e46..65879c0 100644 --- a/cmd/aws-lambda-rie/main.go +++ b/cmd/aws-lambda-rie/main.go @@ -6,6 +6,7 @@ package main import ( "context" "fmt" + "net" "os" "runtime/debug" @@ -21,8 +22,11 @@ const ( ) type options struct { - LogLevel string `long:"log-level" default:"info" description:"log level"` + LogLevel string `long:"log-level" description:"The level of AWS Lambda Runtime Interface Emulator logs to display. Can also be set by the environment variable 'LOG_LEVEL'. Defaults to the value 'info'."` InitCachingEnabled bool `long:"enable-init-caching" description:"Enable support for Init Caching"` + // Do not have a default value so we do not need to keep it in sync with the default value in lambda/rapidcore/sandbox_builder.go + RuntimeAPIAddress string `long:"runtime-api-address" description:"The address of the AWS Lambda Runtime API to communicate with the Lambda execution environment."` + RuntimeInterfaceEmulatorAddress string `long:"runtime-interface-emulator-address" default:"0.0.0.0:8080" description:"The address for the AWS Lambda Runtime Interface Emulator to accept HTTP request upon."` } func main() { @@ -30,11 +34,37 @@ func main() { debug.SetGCPercent(33) opts, args := getCLIArgs() - rapidcore.SetLogLevel(opts.LogLevel) + + logLevel := "info" + + // If you specify an option by using a parameter on the CLI command line, it overrides any value from either the corresponding environment variable. + // + // https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html + if opts.LogLevel != "" { + logLevel = opts.LogLevel + } else if envLogLevel, envLogLevelSet := os.LookupEnv("LOG_LEVEL"); envLogLevelSet { + logLevel = envLogLevel + } + + rapidcore.SetLogLevel(logLevel) + + if opts.RuntimeAPIAddress != "" { + _, _, err := net.SplitHostPort(opts.RuntimeAPIAddress) + + if err != nil { + log.WithError(err).Fatalf("The command line value for \"--runtime-api-address\" is not a valid network address %q.", opts.RuntimeAPIAddress) + } + } + + _, _, err := net.SplitHostPort(opts.RuntimeInterfaceEmulatorAddress) + + if err != nil { + log.WithError(err).Fatalf("The command line value for \"--runtime-interface-emulator-address\" is not a valid network address %q.", opts.RuntimeInterfaceEmulatorAddress) + } bootstrap, handler := getBootstrap(args, opts) sandbox := rapidcore. - NewSandboxBuilder(bootstrap). + NewSandboxBuilder(). AddShutdownFunc(context.CancelFunc(func() { os.Exit(0) })). SetExtensionsFlag(true). SetInitCachingFlag(opts.InitCachingEnabled) @@ -43,10 +73,17 @@ func main() { sandbox.SetHandler(handler) } - go sandbox.Create() + if opts.RuntimeAPIAddress != "" { + sandbox.SetRuntimeAPIAddress(opts.RuntimeAPIAddress) + } + + sandboxContext, internalStateFn := sandbox.Create() + // Since we have not specified a custom interop server for standalone, we can + // directly reference the default interop server, which is a concrete type + sandbox.DefaultInteropServer().SetSandboxContext(sandboxContext) + sandbox.DefaultInteropServer().SetInternalStateGetter(internalStateFn) - testAPIipport := "0.0.0.0:8080" - startHTTPServer(testAPIipport, sandbox) + startHTTPServer(opts.RuntimeInterfaceEmulatorAddress, sandbox, bootstrap) } func getCLIArgs() (options, []string) { @@ -112,5 +149,5 @@ func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) { log.Panic("insufficient arguments: bootstrap not provided") } - return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir), handler + return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir, ""), handler } diff --git a/go.mod b/go.mod index 871b812..278c63a 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/konsorten/go-windows-terminal-sequences v1.0.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.1.0 // indirect - golang.org/x/net v0.1.0 // indirect - golang.org/x/sys v0.1.0 // indirect + golang.org/x/net v0.7.0 // indirect + golang.org/x/sys v0.5.0 // indirect gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect ) diff --git a/go.sum b/go.sum index daa8fe3..905e315 100644 --- a/go.sum +++ b/go.sum @@ -28,14 +28,14 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/urfave/cli/v2 v2.2.0/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= -golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/lambda/agents/agent.go b/lambda/agents/agent.go index 16625c2..b1f8563 100644 --- a/lambda/agents/agent.go +++ b/lambda/agents/agent.go @@ -4,77 +4,38 @@ package agents import ( - "fmt" - "io" - "io/ioutil" - "os/exec" + "os" "path" - "syscall" + "path/filepath" log "github.com/sirupsen/logrus" ) -// AgentProcess is the common interface exposed by both internal and external agent processes -type AgentProcess interface { - Name() string -} - -// ExternalAgentProcess represents an external agent process -type ExternalAgentProcess struct { - cmd *exec.Cmd -} - -// NewExternalAgentProcess returns a new external agent process -func NewExternalAgentProcess(path string, env []string, stdoutWriter io.Writer, stderrWriter io.Writer) ExternalAgentProcess { - command := exec.Command(path) - command.Env = env - - command.Stdout = NewNewlineSplitWriter(stdoutWriter) - command.Stderr = NewNewlineSplitWriter(stderrWriter) - command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - - return ExternalAgentProcess{ - cmd: command, - } -} - -// Name returns the name of the agent -// For external agents is the executable name -func (a *ExternalAgentProcess) Name() string { - return path.Base(a.cmd.Path) -} - -func (a *ExternalAgentProcess) Pid() int { - return a.cmd.Process.Pid -} - -// Start starts an external agent process -func (a *ExternalAgentProcess) Start() error { - return a.cmd.Start() -} - -// Wait waits for the external agent process to exit -func (a *ExternalAgentProcess) Wait() error { - return a.cmd.Wait() -} - -// String is used to print values passed as an operand to any format that accepts a string or to an unformatted printer such as Print. -func (a *ExternalAgentProcess) String() string { - return fmt.Sprintf("%s (%s)", a.Name(), a.cmd.Path) -} - // ListExternalAgentPaths return a list of external agents found in a given directory -func ListExternalAgentPaths(root string) []string { +func ListExternalAgentPaths(dir string, root string) []string { var agentPaths []string - files, err := ioutil.ReadDir(root) + if !isCanonical(dir) || !isCanonical(root) { + log.Warningf("Agents base paths are not absolute and in canonical form: %s, %s", dir, root) + return agentPaths + } + fullDir := path.Join(root, dir) + files, err := os.ReadDir(fullDir) if err != nil { log.WithError(err).Warning("Cannot list external agents") return agentPaths } for _, file := range files { if !file.IsDir() { - agentPaths = append(agentPaths, path.Join(root, file.Name())) + // The returned path is absolute wrt to `root`. This allows + // to exec the agents in their own mount namespace + p := path.Join("/", dir, file.Name()) + agentPaths = append(agentPaths, p) } } return agentPaths } + +func isCanonical(path string) bool { + absPath, err := filepath.Abs(path) + return err == nil && absPath == path +} diff --git a/lambda/agents/agent_test.go b/lambda/agents/agent_test.go index d314a76..e6732ff 100644 --- a/lambda/agents/agent_test.go +++ b/lambda/agents/agent_test.go @@ -4,13 +4,12 @@ package agents import ( - "bytes" - "io/ioutil" "os" "path" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // - Test utilities @@ -50,14 +49,8 @@ func mkLink(name, target string) fileInfo { } } -// populate a temporary directory with a list of files and directories -// returns the name of the temporary root directory -func createFileTree(fs []fileInfo) (string, error) { - - root, err := ioutil.TempDir(os.TempDir(), "tmp-") - if err != nil { - return "", err - } +// populate a directory with a list of files and directories +func createFileTree(root string, fs []fileInfo) error { for _, info := range fs { filename := info.name @@ -65,67 +58,40 @@ func createFileTree(fs []fileInfo) (string, error) { name := path.Base(filename) err := os.MkdirAll(dir, 0775) if err != nil && !os.IsExist(err) { - return "", err + return err } if os.ModeDir == info.mode&os.ModeDir { err := os.Mkdir(path.Join(dir, name), info.mode&os.ModePerm) if err != nil { - return "", err + return err } } else if os.ModeSymlink == info.mode&os.ModeSymlink { target := path.Join(root, info.target) _, err = os.Stat(target) if err != nil { - return "", err + return err } err := os.Symlink(target, path.Join(dir, name)) if err != nil { - return "", err + return err } } else { file, err := os.OpenFile(path.Join(dir, name), os.O_RDWR|os.O_CREATE, info.mode&os.ModePerm) if err != nil { - return "", err + return err } file.Truncate(info.size) file.Close() } } - return root, nil -} - -// executes a given closure inside a temporary directory populated with the given FS tree -func within(fs []fileInfo, closure func()) error { - - var root string - var cwd string - var err error - - if root, err = createFileTree(fs); err != nil { - return err - } - - defer os.RemoveAll(root) - - if cwd, err = os.Getwd(); err != nil { - return err - } - - if err = os.Chdir(root); err != nil { - return err - } - - defer os.Chdir(cwd) - - closure() return nil } // - Actual tests // If the agents folder is empty it is not an error -func TestRootEmpty(t *testing.T) { +func TestBaseEmpty(t *testing.T) { assert := assert.New(t) @@ -133,34 +99,51 @@ func TestRootEmpty(t *testing.T) { mkDir("/opt/extensions", 0777), } - assert.NoError(within(fs, func() { - agents := ListExternalAgentPaths("opt/extensions") - assert.Equal(0, len(agents)) - })) + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) + + createFileTree(tmpDir, fs) + defer os.RemoveAll(tmpDir) + + agents := ListExternalAgentPaths(path.Join(tmpDir, "/opt/extensions"), "/") + assert.Equal(0, len(agents)) } // Test that non-existant /opt/extensions is treated as if no agents were found -func TestRootNotExist(t *testing.T) { +func TestBaseNotExist(t *testing.T) { assert := assert.New(t) - agents := ListExternalAgentPaths("/path/which/does/not/exist") + agents := ListExternalAgentPaths("/path/which/does/not/exist", "/") + assert.Equal(0, len(agents)) +} + +// Test that non-existant root dir is teaded as if no agents were found +func TestChrootNotExist(t *testing.T) { + + assert := assert.New(t) + + agents := ListExternalAgentPaths("/bin", "/does/not/exist") assert.Equal(0, len(agents)) } // Test that non-directory /opt/extensions is treated as if no agents were found -func TestRootNotDir(t *testing.T) { +func TestBaseNotDir(t *testing.T) { assert := assert.New(t) fs := []fileInfo{ mkFile("/opt/extensions", 1, 0777), } + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) + + createFileTree(tmpDir, fs) + defer os.RemoveAll(tmpDir) - assert.NoError(within(fs, func() { - agents := ListExternalAgentPaths("opt/extensions") - assert.Equal(0, len(agents)) - })) + path := path.Join(tmpDir, "/opt/extensions") + agents := ListExternalAgentPaths(path, "/") + assert.Equal(0, len(agents)) } // Test our ability to find agent bootstraps in the FS and return them sorted. @@ -188,99 +171,63 @@ func TestFindAgentMixed(t *testing.T) { fs := append([]fileInfo{}, listed...) fs = append(fs, unlisted...) - assert.NoError(within(fs, func() { - agentPaths := ListExternalAgentPaths("opt/extensions") - assert.Equal(len(listed), len(agentPaths)) - last := "" - for index := range listed { - if len(last) > 0 { - assert.GreaterOrEqual(agentPaths[index], last) - } - last = agentPaths[index] - } - })) -} - -// Test our ability to start agents -func TestAgentStart(t *testing.T) { - assert := assert.New(t) - agent := NewExternalAgentProcess("../testdata/agents/bash_true.sh", []string{}, &mockWriter{}, &mockWriter{}) - assert.Nil(agent.Start()) - assert.Nil(agent.Wait()) -} + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) -// Test that execution of invalid agents is correctly reported -func TestInvalidAgentStart(t *testing.T) { - assert := assert.New(t) - agent := NewExternalAgentProcess("/bin/none", []string{}, &mockWriter{}, &mockWriter{}) - assert.True(os.IsNotExist(agent.Start())) -} + createFileTree(tmpDir, fs) + defer os.RemoveAll(tmpDir) -func TestAgentStdoutWriter(t *testing.T) { - // Given - assert := assert.New(t) - - stdout := &mockWriter{} - stderr := &mockWriter{} - expectedStdout := "stdout line 1\nstdout line 2\nstdout line 3\n" - expectedStderr := "" - - agent := NewExternalAgentProcess("../testdata/agents/bash_stdout.sh", []string{}, stdout, stderr) - - // When - assert.NoError(agent.Start()) - assert.NoError(agent.Wait()) - - // Then - assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) - assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) + path := path.Join(tmpDir, "/opt/extensions") + agentPaths := ListExternalAgentPaths(path, "/") + assert.Equal(len(listed), len(agentPaths)) + last := "" + for index := range listed { + if len(last) > 0 { + assert.GreaterOrEqual(agentPaths[index], last) + } + last = agentPaths[index] + } } -func TestAgentStderrWriter(t *testing.T) { - // Given - assert := assert.New(t) - - stdout := &mockWriter{} - stderr := &mockWriter{} - expectedStdout := "" - expectedStderr := "stderr line 1\nstderr line 2\nstderr line 3\n" - - agent := NewExternalAgentProcess("../testdata/agents/bash_stderr.sh", []string{}, stdout, stderr) - - // When - assert.NoError(agent.Start()) - assert.NoError(agent.Wait()) - - // Then - assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) - assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) -} +// Test our ability to find agent bootstraps in the FS and return them sorted, +// when using a different mount namespace root for the extensiosn domain. +// Even if not all files are valid as executable agents, +// ListExternalAgentPaths() is expected to return all of them. +func TestFindAgentMixedInChroot(t *testing.T) { -func TestAgentStdoutAndStderrSeperateWriters(t *testing.T) { - // Given assert := assert.New(t) - stdout := &mockWriter{} - stderr := &mockWriter{} - expectedStdout := "stdout line 1\nstdout line 2\nstdout line 3\n" - expectedStderr := "stderr line 1\nstderr line 2\nstderr line 3\n" + listed := []fileInfo{ + mkFile("/opt/extensions/ok2", 1, 0777), // this is ok + mkFile("/opt/extensions/ok1", 1, 0777), // this is ok as well + mkFile("/opt/extensions/not_exec", 1, 0666), // this is not executable + mkFile("/opt/extensions/not_read", 1, 0333), // this is not readable + mkFile("/opt/extensions/empty_file", 0, 0777), // this is empty + mkLink("/opt/extensions/link", "/opt/extensions/ok1"), // symlink must be ignored + } - agent := NewExternalAgentProcess("../testdata/agents/bash_stdout_and_stderr.sh", []string{}, stdout, stderr) + unlisted := []fileInfo{ + mkDir("/opt/extensions/empty_dir", 0777), // this is an empty directory + mkDir("/opt/extensions/nonempty_dir", 0777), // subdirs should not be listed + mkFile("/opt/extensions/nonempty_dir/notok", 1, 0777), // files in subdirs should not be listed + } - // When - assert.NoError(agent.Start()) - assert.NoError(agent.Wait()) + fs := append([]fileInfo{}, listed...) + fs = append(fs, unlisted...) - // Then - assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) - assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) -} + rootDir, err := os.MkdirTemp("", "rootfs") + require.NoError(t, err) -type mockWriter struct { - bytesReceived [][]byte -} + createFileTree(rootDir, fs) + defer os.RemoveAll(rootDir) -func (m *mockWriter) Write(bytes []byte) (int, error) { - m.bytesReceived = append(m.bytesReceived, bytes) - return len(bytes), nil + agentPaths := ListExternalAgentPaths("/opt/extensions", rootDir) + assert.Equal(len(listed), len(agentPaths)) + last := "" + for index := range listed { + if len(last) > 0 { + assert.GreaterOrEqual(agentPaths[index], last) + } + last = agentPaths[index] + } } diff --git a/lambda/agents/log_line_splitter.go b/lambda/agents/log_line_splitter.go deleted file mode 100644 index ac2c134..0000000 --- a/lambda/agents/log_line_splitter.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package agents - -import ( - "bytes" - "io" -) - -// NewlineSplitWriter wraps an io.Writer and calls the underlying writer for each newline separated line -type NewlineSplitWriter struct { - writer io.Writer -} - -// NewNewlineSplitWriter returns an instance of NewlineSplitWriter -func NewNewlineSplitWriter(w io.Writer) *NewlineSplitWriter { - return &NewlineSplitWriter{ - writer: w, - } -} - -// Write splits the byte buffer by newline and calls the underlying writer for each line -func (nsw *NewlineSplitWriter) Write(buf []byte) (int, error) { - newBuf := make([]byte, len(buf)) - copy(newBuf, buf) - lines := bytes.SplitAfter(newBuf, []byte("\n")) - var bytesWritten int - for _, line := range lines { - if len(line) > 0 { - n, err := nsw.writer.Write(line) - bytesWritten += n - if err != nil { - return bytesWritten, err - } - } - } - - return bytesWritten, nil -} diff --git a/lambda/appctx/appctx.go b/lambda/appctx/appctx.go index 44776ab..6c81653 100644 --- a/lambda/appctx/appctx.go +++ b/lambda/appctx/appctx.go @@ -10,6 +10,8 @@ import ( // A Key type is used as a key for storing values in the application context. type Key int +type InitType int + const ( // AppCtxInvokeErrorResponseKey is used for storing deferred invoke error response. // Only used by xray. TODO refactor xray interface so it doesn't use appctx @@ -23,6 +25,18 @@ const ( // AppCtxFirstFatalErrorKey is used to store first unrecoverable error message encountered to propagate it to slicer with DONE(errortype) or DONEFAIL(errortype) AppCtxFirstFatalErrorKey + + // AppCtxInitType is used to store the init type (init caching or plain INIT) + AppCtxInitType + + // AppCtxSandbox type is used to store the sandbox type (SandboxClassic or SandboxPreWarmed) + AppCtxSandboxType +) + +// Possible values for AppCtxInitType key +const ( + Init InitType = iota + InitCaching ) // ApplicationContext is an application scope context. diff --git a/lambda/appctx/appctxutil.go b/lambda/appctx/appctxutil.go index a3e652f..a30677f 100644 --- a/lambda/appctx/appctxutil.go +++ b/lambda/appctx/appctxutil.go @@ -5,11 +5,12 @@ package appctx import ( "context" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" "net/http" "strings" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + log "github.com/sirupsen/logrus" ) @@ -164,3 +165,24 @@ func LoadFirstFatalError(appCtx ApplicationContext) (errorType fatalerror.ErrorT } return v.(fatalerror.ErrorType), true } + +func StoreInitType(appCtx ApplicationContext, initCachingEnabled bool) { + if initCachingEnabled { + appCtx.Store(AppCtxInitType, InitCaching) + } else { + appCtx.Store(AppCtxInitType, Init) + } +} + +// Default Init Type is Init unless it's explicitly stored in ApplicationContext +func LoadInitType(appCtx ApplicationContext) InitType { + return appCtx.GetOrDefault(AppCtxInitType, Init).(InitType) +} + +func StoreSandboxType(appCtx ApplicationContext, sandboxType interop.SandboxType) { + appCtx.Store(AppCtxSandboxType, sandboxType) +} + +func LoadSandboxType(appCtx ApplicationContext) interop.SandboxType { + return appCtx.GetOrDefault(AppCtxSandboxType, interop.SandboxClassic).(interop.SandboxType) +} diff --git a/lambda/appctx/appctxutil_test.go b/lambda/appctx/appctxutil_test.go index a8a4761..b6df9aa 100644 --- a/lambda/appctx/appctxutil_test.go +++ b/lambda/appctx/appctxutil_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.amzn.com/lambda/fatalerror" + + "go.amzn.com/lambda/interop" ) func runTestRequestWithUserAgent(t *testing.T, userAgent string, expectedRuntimeRelease string) { @@ -200,3 +202,26 @@ func TestFirstFatalError(t *testing.T) { require.True(t, found) require.Equal(t, fatalerror.AgentCrash, v) } + +func TestStoreLoadInitType(t *testing.T) { + appCtx := NewApplicationContext() + + initType := LoadInitType(appCtx) + assert.Equal(t, Init, initType) + + StoreInitType(appCtx, true) + initType = LoadInitType(appCtx) + assert.Equal(t, InitCaching, initType) +} + +func TestStoreLoadSandboxType(t *testing.T) { + appCtx := NewApplicationContext() + + sandboxType := LoadSandboxType(appCtx) + assert.Equal(t, interop.SandboxClassic, sandboxType) + + StoreSandboxType(appCtx, interop.SandboxPreWarmed) + + sandboxType = LoadSandboxType(appCtx) + assert.Equal(t, interop.SandboxPreWarmed, sandboxType) +} diff --git a/lambda/core/bandwidthlimiter/bandwidthlimiter.go b/lambda/core/bandwidthlimiter/bandwidthlimiter.go new file mode 100644 index 0000000..05c600a --- /dev/null +++ b/lambda/core/bandwidthlimiter/bandwidthlimiter.go @@ -0,0 +1,61 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "io" + + "go.amzn.com/lambda/interop" +) + +func BandwidthLimitingCopy(dst *BandwidthLimitingWriter, src io.Reader) (written int64, err error) { + written, err = io.Copy(dst, src) + _ = dst.Close() + return +} + +func NewBandwidthLimitingWriter(w io.Writer, bucket *Bucket) (*BandwidthLimitingWriter, error) { + throttler, err := NewThrottler(bucket) + if err != nil { + return nil, err + } + return &BandwidthLimitingWriter{w: w, th: throttler}, nil +} + +type BandwidthLimitingWriter struct { + w io.Writer + th *Throttler +} + +func (w *BandwidthLimitingWriter) ChunkedWrite(p []byte) (n int, err error) { + i := NewChunkIterator(p, int(w.th.b.capacity)) + for { + buf := i.Next() + if buf == nil { + return + } + written, writeErr := w.th.bandwidthLimitingWrite(w.w, buf) + n += written + if writeErr != nil { + return n, writeErr + } + } +} + +func (w *BandwidthLimitingWriter) Write(p []byte) (n int, err error) { + w.th.start() + if int64(len(p)) > w.th.b.capacity { + return w.ChunkedWrite(p) + } + return w.th.bandwidthLimitingWrite(w.w, p) +} + +func (w *BandwidthLimitingWriter) Close() (err error) { + w.th.stop() + return +} + +func (w *BandwidthLimitingWriter) GetMetrics() (metrics *interop.InvokeResponseMetrics) { + return w.th.metrics +} diff --git a/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go b/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go new file mode 100644 index 0000000..7ede24b --- /dev/null +++ b/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go @@ -0,0 +1,106 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "bytes" + "io" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBandwidthLimitingCopy(t *testing.T) { + var size10mb int64 = 10 * 1024 * 1024 + + inputBuffer := []byte(strings.Repeat("a", int(size10mb))) + reader := bytes.NewReader(inputBuffer) + + bucket, err := NewBucket(size10mb/2, size10mb/4, size10mb/2, time.Millisecond/2) + assert.NoError(t, err) + + internalWriter := bytes.NewBuffer(make([]byte, 0, size10mb)) + writer, err := NewBandwidthLimitingWriter(internalWriter, bucket) + assert.NoError(t, err) + + n, err := BandwidthLimitingCopy(writer, reader) + assert.Equal(t, size10mb, n) + assert.Equal(t, nil, err) + assert.Equal(t, inputBuffer, internalWriter.Bytes()) +} + +type ErrorBufferWriter struct { + w ByteBufferWriter + failAfter int +} + +func (w *ErrorBufferWriter) Write(p []byte) (n int, err error) { + if w.failAfter >= 1 { + w.failAfter-- + } + n, err = w.w.Write(p) + if w.failAfter == 0 { + return n, io.ErrUnexpectedEOF + } + return n, err +} + +func (w *ErrorBufferWriter) Bytes() []byte { + return w.w.Bytes() +} + +func TestNewBandwidthLimitingWriter(t *testing.T) { + type testCase struct { + refillNumber int64 + internalWriter ByteBufferWriter + inputBuffer []byte + expectedN int + expectedError error + } + testCases := []testCase{ + { + refillNumber: 2, + internalWriter: bytes.NewBuffer(make([]byte, 0, 36)), // buffer size greater than bucket size + inputBuffer: []byte(strings.Repeat("a", 36)), + expectedN: 36, + expectedError: nil, + }, + { + refillNumber: 2, + internalWriter: bytes.NewBuffer(make([]byte, 0, 12)), // buffer size lesser than bucket size + inputBuffer: []byte(strings.Repeat("a", 12)), + expectedN: 12, + expectedError: nil, + }, + { + // buffer size greater than bucket size and error after two Write() invocations + refillNumber: 2, + internalWriter: &ErrorBufferWriter{w: bytes.NewBuffer(make([]byte, 0, 36)), failAfter: 2}, + inputBuffer: []byte(strings.Repeat("a", 36)), + expectedN: 32, + expectedError: io.ErrUnexpectedEOF, + }, + } + + for _, test := range testCases { + bucket, err := NewBucket(16, 8, test.refillNumber, 100*time.Millisecond) + assert.NoError(t, err) + + writer, err := NewBandwidthLimitingWriter(test.internalWriter, bucket) + assert.NoError(t, err) + assert.False(t, writer.th.running) + + n, err := writer.Write(test.inputBuffer) + assert.True(t, writer.th.running) + assert.Equal(t, test.expectedN, n) + assert.Equal(t, test.expectedError, err) + assert.Equal(t, test.inputBuffer[:n], test.internalWriter.Bytes()) + + err = writer.Close() + assert.Nil(t, err) + assert.False(t, writer.th.running) + } +} diff --git a/lambda/core/bandwidthlimiter/throttler.go b/lambda/core/bandwidthlimiter/throttler.go new file mode 100644 index 0000000..b3b57dd --- /dev/null +++ b/lambda/core/bandwidthlimiter/throttler.go @@ -0,0 +1,154 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "errors" + "fmt" + "io" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" +) + +var ErrBufferSizeTooLarge = errors.New("buffer size cannot be greater than bucket size") + +func NewBucket(capacity int64, initialTokenCount int64, refillNumber int64, refillInterval time.Duration) (*Bucket, error) { + if capacity <= 0 || initialTokenCount < 0 || refillNumber <= 0 || refillInterval <= 0 || + capacity < initialTokenCount { + errorMsg := fmt.Sprintf("invalid bucket parameters (capacity: %d, initialTokenCount: %d, refillNumber: %d,"+ + "refillInterval: %d)", capacity, initialTokenCount, refillInterval, refillInterval) + log.Error(errorMsg) + return nil, errors.New(errorMsg) + } + return &Bucket{ + capacity: capacity, + tokenCount: initialTokenCount, + refillNumber: refillNumber, + refillInterval: refillInterval, + mutex: sync.Mutex{}, + }, nil +} + +type Bucket struct { + capacity int64 + tokenCount int64 + refillNumber int64 + refillInterval time.Duration + mutex sync.Mutex +} + +func (b *Bucket) produceTokens() { + b.mutex.Lock() + defer b.mutex.Unlock() + if b.tokenCount < b.capacity { + b.tokenCount = min64(b.tokenCount+b.refillNumber, b.capacity) + } +} + +func (b *Bucket) consumeTokens(n int64) bool { + b.mutex.Lock() + defer b.mutex.Unlock() + if n <= b.tokenCount { + b.tokenCount -= n + return true + } + return false +} + +func (b *Bucket) getTokenCount() int64 { + b.mutex.Lock() + defer b.mutex.Unlock() + return b.tokenCount +} + +func NewThrottler(bucket *Bucket) (*Throttler, error) { + if bucket == nil { + errorMsg := "cannot create a throttler with nil bucket" + log.Error(errorMsg) + return nil, errors.New(errorMsg) + } + return &Throttler{ + b: bucket, + running: false, + produced: make(chan int64), + done: make(chan struct{}), + // FIXME: + // The runtime tells whether the function response mode is streaming or not. + // Ideally, we would want to use that value here. Since I'm just rebasing, I will leave + // as-is, but we should use that instead of relying on our memory to set this here + // because we "know" it's a streaming code path. + metrics: &interop.InvokeResponseMetrics{FunctionResponseMode: interop.FunctionResponseModeStreaming}, + }, nil +} + +type Throttler struct { + b *Bucket + running bool + produced chan int64 + done chan struct{} + metrics *interop.InvokeResponseMetrics +} + +func (th *Throttler) start() { + if th.running { + return + } + th.running = true + th.metrics.StartReadingResponseMonoTimeMs = metering.Monotime() + go func() { + ticker := time.NewTicker(th.b.refillInterval) + for { + select { + case <-ticker.C: + th.b.produceTokens() + select { + case th.produced <- metering.Monotime(): + default: + } + case <-th.done: + ticker.Stop() + return + } + } + }() +} + +func (th *Throttler) stop() { + if !th.running { + return + } + th.running = false + th.metrics.FinishReadingResponseMonoTimeMs = metering.Monotime() + durationMs := (th.metrics.FinishReadingResponseMonoTimeMs - th.metrics.StartReadingResponseMonoTimeMs) / int64(time.Millisecond) + if durationMs > 0 { + th.metrics.OutboundThroughputBps = (th.metrics.ProducedBytes / durationMs) * int64(time.Second/time.Millisecond) + } else { + th.metrics.OutboundThroughputBps = -1 + } + th.done <- struct{}{} +} + +func (th *Throttler) bandwidthLimitingWrite(w io.Writer, p []byte) (written int, err error) { + n := int64(len(p)) + if n > th.b.capacity { + return 0, ErrBufferSizeTooLarge + } + for { + if th.b.consumeTokens(n) { + written, err = w.Write(p) + th.metrics.ProducedBytes += int64(written) + return + } + waitStart := metering.Monotime() + elapsed := <-th.produced - waitStart + if elapsed > 0 { + th.metrics.TimeShapedNs += elapsed + } + } +} diff --git a/lambda/core/bandwidthlimiter/throttler_test.go b/lambda/core/bandwidthlimiter/throttler_test.go new file mode 100644 index 0000000..a88a14d --- /dev/null +++ b/lambda/core/bandwidthlimiter/throttler_test.go @@ -0,0 +1,215 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewBucket(t *testing.T) { + type testCase struct { + capacity int64 + initialTokenCount int64 + refillNumber int64 + refillInterval time.Duration + bucketCreated bool + } + testCases := []testCase{ + {capacity: 8, initialTokenCount: 6, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: true}, + {capacity: 8, initialTokenCount: 6, refillNumber: 2, refillInterval: -100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: 6, refillNumber: -5, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: -2, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: -2, initialTokenCount: 6, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: 10, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + } + + for _, test := range testCases { + bucket, err := NewBucket(test.capacity, test.initialTokenCount, test.refillNumber, test.refillInterval) + if test.bucketCreated { + assert.NoError(t, err) + assert.NotNil(t, bucket) + } else { + assert.Error(t, err) + assert.Nil(t, bucket) + } + } +} + +func TestBucket_produceTokens_consumeTokens(t *testing.T) { + var consumed bool + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + assert.Equal(t, int64(8), bucket.getTokenCount()) + + consumed = bucket.consumeTokens(5) + assert.Equal(t, int64(3), bucket.getTokenCount()) + assert.True(t, consumed) + + bucket.produceTokens() + assert.Equal(t, int64(9), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(15), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(16), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(16), bucket.getTokenCount()) + + consumed = bucket.consumeTokens(18) + assert.Equal(t, int64(16), bucket.getTokenCount()) + assert.False(t, consumed) + + consumed = bucket.consumeTokens(16) + assert.Equal(t, int64(0), bucket.getTokenCount()) + assert.True(t, consumed) +} + +func TestNewThrottler(t *testing.T) { + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + assert.NotNil(t, throttler) + + throttler, err = NewThrottler(nil) + assert.Error(t, err) + assert.Nil(t, throttler) +} + +func TestNewThrottler_start_stop(t *testing.T) { + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + + assert.False(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + + <-time.Tick(2 * throttler.b.refillInterval) + assert.LessOrEqual(t, int64(14), throttler.b.getTokenCount()) + assert.True(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + <-time.Tick(2 * throttler.b.refillInterval) + assert.Equal(t, int64(16), throttler.b.getTokenCount()) + assert.True(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) +} + +type ByteBufferWriter interface { + Write(p []byte) (n int, err error) + Bytes() []byte +} + +type FixedSizeBufferWriter struct { + buf []byte +} + +func (w *FixedSizeBufferWriter) Write(p []byte) (n int, err error) { + n = copy(w.buf, p) + return +} + +func (w *FixedSizeBufferWriter) Bytes() []byte { + return w.buf +} + +func TestNewThrottler_bandwidthLimitingWrite(t *testing.T) { + var size10mb int64 = 10 * 1024 * 1024 + + type testCase struct { + capacity int64 + initialTokenCount int64 + writer ByteBufferWriter + inputBuffer []byte + expectedN int + expectedError error + } + testCases := []testCase{ + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 14)), + inputBuffer: []byte(strings.Repeat("a", 12)), + expectedN: 12, + expectedError: nil, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 12)), + inputBuffer: []byte(strings.Repeat("a", 14)), + expectedN: 14, + expectedError: nil, + }, + { + capacity: size10mb, + initialTokenCount: size10mb, + writer: bytes.NewBuffer(make([]byte, 0, size10mb)), + inputBuffer: []byte(strings.Repeat("a", int(size10mb))), + expectedN: int(size10mb), + expectedError: nil, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 18)), + inputBuffer: []byte(strings.Repeat("a", 18)), + expectedN: 0, + expectedError: ErrBufferSizeTooLarge, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: &FixedSizeBufferWriter{buf: make([]byte, 12)}, + inputBuffer: []byte(strings.Repeat("a", 14)), + expectedN: 12, + expectedError: nil, + }, + } + + for _, test := range testCases { + bucket, err := NewBucket(test.capacity, test.initialTokenCount, 2, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + + writer := test.writer + throttler.start() + n, err := throttler.bandwidthLimitingWrite(writer, test.inputBuffer) + assert.Equal(t, test.expectedN, n) + assert.Equal(t, test.expectedError, err) + + if test.expectedError == nil { + assert.Equal(t, test.inputBuffer[:n], test.writer.Bytes()) + } else { + assert.Equal(t, []byte{}, test.writer.Bytes()) + } + throttler.stop() + } +} diff --git a/lambda/core/bandwidthlimiter/util.go b/lambda/core/bandwidthlimiter/util.go new file mode 100644 index 0000000..7078d5d --- /dev/null +++ b/lambda/core/bandwidthlimiter/util.go @@ -0,0 +1,46 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func min64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +func NewChunkIterator(buf []byte, chunkSize int) *ChunkIterator { + if buf == nil { + return nil + } + return &ChunkIterator{ + buf: buf, + chunkSize: chunkSize, + offset: 0, + } +} + +type ChunkIterator struct { + buf []byte + chunkSize int + offset int +} + +func (i *ChunkIterator) Next() []byte { + begin := i.offset + end := min(i.offset+i.chunkSize, len(i.buf)) + i.offset = end + + if begin == end { + return nil + } + return i.buf[begin:end] +} diff --git a/lambda/core/bandwidthlimiter/util_test.go b/lambda/core/bandwidthlimiter/util_test.go new file mode 100644 index 0000000..ed93c77 --- /dev/null +++ b/lambda/core/bandwidthlimiter/util_test.go @@ -0,0 +1,45 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewChunkIterator(t *testing.T) { + buf := []byte("abcdefghijk") + + type testCase struct { + buf []byte + chunkSize int + expectedResult [][]byte + } + testCases := []testCase{ + {buf: nil, chunkSize: 0, expectedResult: [][]byte{}}, + {buf: nil, chunkSize: 1, expectedResult: [][]byte{}}, + {buf: buf, chunkSize: 0, expectedResult: [][]byte{}}, + {buf: buf, chunkSize: 1, expectedResult: [][]byte{ + []byte("a"), []byte("b"), []byte("c"), []byte("d"), []byte("e"), []byte("f"), []byte("g"), []byte("h"), + []byte("i"), []byte("j"), []byte("k"), + }}, + {buf: buf, chunkSize: 4, expectedResult: [][]byte{[]byte("abcd"), []byte("efgh"), []byte("ijk")}}, + {buf: buf, chunkSize: 5, expectedResult: [][]byte{[]byte("abcde"), []byte("fghij"), []byte("k")}}, + {buf: buf, chunkSize: 11, expectedResult: [][]byte{[]byte("abcdefghijk")}}, + {buf: buf, chunkSize: 12, expectedResult: [][]byte{[]byte("abcdefghijk")}}, + } + + for _, test := range testCases { + iterator := NewChunkIterator(test.buf, test.chunkSize) + if test.buf == nil { + assert.Nil(t, iterator) + } else { + for _, expectedChunk := range test.expectedResult { + assert.Equal(t, expectedChunk, iterator.Next()) + } + assert.Nil(t, iterator.Next()) + } + } +} diff --git a/lambda/core/credentials.go b/lambda/core/credentials.go index 7b1bf14..ad152d0 100644 --- a/lambda/core/credentials.go +++ b/lambda/core/credentials.go @@ -7,8 +7,6 @@ import ( "fmt" "sync" "time" - - log "github.com/sirupsen/logrus" ) const ( @@ -26,11 +24,9 @@ type Credentials struct { } type CredentialsService interface { - SetCredentials(token, awsKey, awsSecret, awsSession string) + SetCredentials(token, awsKey, awsSecret, awsSession string, expiration time.Time) GetCredentials(token string) (*Credentials, error) - UpdateCredentials(awsKey, awsSecret, awsSession string) error - BlockService() - UnblockService() + UpdateCredentials(awsKey, awsSecret, awsSession string, expiration time.Time) error } type credentialsServiceImpl struct { @@ -51,7 +47,7 @@ func NewCredentialsService() CredentialsService { return credentialsService } -func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSession string) { +func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSession string, expiration time.Time) { c.contentMutex.Lock() defer c.contentMutex.Unlock() @@ -59,7 +55,7 @@ func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSes AwsKey: awsKey, AwsSecret: awsSecret, AwsSession: awsSession, - Expiration: time.Now().Add(16 * time.Minute), + Expiration: expiration, } } @@ -77,33 +73,7 @@ func (c *credentialsServiceImpl) GetCredentials(token string) (*Credentials, err return nil, ErrCredentialsNotFound } -func (c *credentialsServiceImpl) BlockService() { - if c.currentState == BLOCKED { - return - } - log.Info("blocking the credentials service") - c.serviceMutex.Lock() - - c.contentMutex.Lock() - defer c.contentMutex.Unlock() - - c.currentState = BLOCKED -} - -func (c *credentialsServiceImpl) UnblockService() { - if c.currentState == UNBLOCKED { - return - } - log.Info("unblocking the credentials service") - - c.contentMutex.Lock() - defer c.contentMutex.Unlock() - - c.currentState = UNBLOCKED - c.serviceMutex.Unlock() -} - -func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession string) error { +func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession string, expiration time.Time) error { mapSize := len(c.credentials) if mapSize != 1 { return fmt.Errorf("there are %d set of credentials", mapSize) @@ -114,6 +84,6 @@ func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession token = key } - c.SetCredentials(token, awsKey, awsSecret, awsSession) + c.SetCredentials(token, awsKey, awsSecret, awsSession, expiration) return nil } diff --git a/lambda/core/credentials_test.go b/lambda/core/credentials_test.go index ab0b247..625ab8e 100644 --- a/lambda/core/credentials_test.go +++ b/lambda/core/credentials_test.go @@ -19,7 +19,8 @@ const ( func TestGetSetCredentialsHappy(t *testing.T) { credentialsService := NewCredentialsService() - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) + credentialsExpiration := time.Now().Add(15 * time.Minute) + credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession, credentialsExpiration) credentials, err := credentialsService.GetCredentials(Token) @@ -40,8 +41,12 @@ func TestGetCredentialsFail(t *testing.T) { func TestUpdateCredentialsHappy(t *testing.T) { credentialsService := NewCredentialsService() - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) - err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1") + credentialsExpiration := time.Now().Add(15 * time.Minute) + credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession, credentialsExpiration) + + restoreCredentialsExpiration := time.Now().Add(10 * time.Hour) + + err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1", restoreCredentialsExpiration) assert.NoError(t, err) credentials, err := credentialsService.GetCredentials(Token) @@ -50,49 +55,16 @@ func TestUpdateCredentialsHappy(t *testing.T) { assert.Equal(t, "sampleKey1", credentials.AwsKey) assert.Equal(t, "sampleSecret1", credentials.AwsSecret) assert.Equal(t, "sampleSession1", credentials.AwsSession) -} - -func TestUpdateCredentialsFail(t *testing.T) { - credentialsService := NewCredentialsService() - err := credentialsService.UpdateCredentials("unknownKey", "unknownSecret", "unknownSession") - - assert.Error(t, err) -} + nineHoursLater := time.Now().Add(9 * time.Hour) -func TestUpdateCredentialsOfBlockedService(t *testing.T) { - credentialsService := NewCredentialsService() - credentialsService.BlockService() - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) - err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1") - assert.NoError(t, err) + assert.True(t, nineHoursLater.Before(credentials.Expiration)) } -func TestConsecutiveBlockService(t *testing.T) { +func TestUpdateCredentialsFail(t *testing.T) { credentialsService := NewCredentialsService() - timeout := time.After(1 * time.Second) - done := make(chan bool) - - go func() { - for i := 0; i < 10; i++ { - credentialsService.BlockService() - } - done <- true - }() - - select { - case <-timeout: - t.Fatal("BlockService should not block the calling thread.") - case <-done: - } -} - -// unlocking a mutex twice causes panic -// the assertion here is basically not having panic -func TestConsecutiveUnblockService(t *testing.T) { - credentialsService := NewCredentialsService() + err := credentialsService.UpdateCredentials("unknownKey", "unknownSecret", "unknownSession", time.Now()) - credentialsService.UnblockService() - credentialsService.UnblockService() + assert.Error(t, err) } diff --git a/lambda/core/directinvoke/directinvoke.go b/lambda/core/directinvoke/directinvoke.go index 1699121..8ef59ae 100644 --- a/lambda/core/directinvoke/directinvoke.go +++ b/lambda/core/directinvoke/directinvoke.go @@ -4,28 +4,38 @@ package directinvoke import ( + "context" "fmt" "io" "net/http" + "strconv" "github.com/go-chi/chi" + "go.amzn.com/lambda/core/bandwidthlimiter" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" + + log "github.com/sirupsen/logrus" ) const ( - InvokeIDHeader = "Invoke-Id" - InvokedFunctionArnHeader = "Invoked-Function-Arn" - VersionIDHeader = "Invoked-Function-Version" - ReservationTokenHeader = "Reservation-Token" - CustomerHeadersHeader = "Customer-Headers" - ContentTypeHeader = "Content-Type" + InvokeIDHeader = "Invoke-Id" + InvokedFunctionArnHeader = "Invoked-Function-Arn" + VersionIDHeader = "Invoked-Function-Version" + ReservationTokenHeader = "Reservation-Token" + CustomerHeadersHeader = "Customer-Headers" + ContentTypeHeader = "Content-Type" + MaxPayloadSizeHeader = "MaxPayloadSize" + ResponseBandwidthRateHeader = "ResponseBandwidthRate" + ResponseBandwidthBurstSizeHeader = "ResponseBandwidthBurstSize" + FunctionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" ErrorTypeHeader = "Error-Type" - EndOfResponseTrailer = "End-Of-Response" - - SandboxErrorType = "Error.Sandbox" + EndOfResponseTrailer = "End-Of-Response" + FunctionErrorTypeTrailer = "Lambda-Runtime-Function-Error-Type" + FunctionErrorBodyTrailer = "Lambda-Runtime-Function-Error-Body" ) const ( @@ -34,7 +44,14 @@ const ( EndOfResponseOversized = "Oversized" ) +var ResetReasonMap = map[string]fatalerror.ErrorType{ + "failure": fatalerror.SandboxFailure, + "timeout": fatalerror.SandboxTimeout, +} + var MaxDirectResponseSize int64 = interop.MaxPayloadSize // this is intentionally not a constant so we can configure it via CLI +var ResponseBandwidthRate int64 = interop.ResponseBandwidthRate +var ResponseBandwidthBurstSize int64 = interop.ResponseBandwidthBurstSize func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) { w.Header().Set(ErrorTypeHeader, errorType) @@ -42,6 +59,12 @@ func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } +func renderInternalServerError(w http.ResponseWriter, errorType string) { + w.Header().Set(ErrorTypeHeader, errorType) + w.WriteHeader(http.StatusInternalServerError) + w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) +} + // ReceiveDirectInvoke parses invoke and verifies it against Token message. Uses deadline provided by Token // Renders BadRequest in case of error func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.Token) (*interop.Invoke, error) { @@ -54,6 +77,47 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T } now := metering.Monotime() + + MaxDirectResponseSize = interop.MaxPayloadSize + if maxPayloadSize := r.Header.Get(MaxPayloadSizeHeader); maxPayloadSize != "" { + if n, err := strconv.ParseInt(maxPayloadSize, 10, 64); err == nil && n >= -1 { + MaxDirectResponseSize = n + } else { + log.Error("MaxPayloadSize header is not a valid number") + renderBadRequest(w, r, interop.ErrInvalidMaxPayloadSize.Error()) + return nil, interop.ErrInvalidMaxPayloadSize + } + } + + if MaxDirectResponseSize == -1 { + w.Header().Add("Trailer", FunctionErrorTypeTrailer) + w.Header().Add("Trailer", FunctionErrorBodyTrailer) + + ResponseBandwidthRate = interop.ResponseBandwidthRate + if responseBandwidthRate := r.Header.Get(ResponseBandwidthRateHeader); responseBandwidthRate != "" { + if n, err := strconv.ParseInt(responseBandwidthRate, 10, 64); err == nil && + interop.MinResponseBandwidthRate <= n && n <= interop.MaxResponseBandwidthRate { + ResponseBandwidthRate = n + } else { + log.Error("ResponseBandwidthRate header is not a valid number or is out of the allowed range") + renderBadRequest(w, r, interop.ErrInvalidResponseBandwidthRate.Error()) + return nil, interop.ErrInvalidResponseBandwidthRate + } + } + + ResponseBandwidthBurstSize = interop.ResponseBandwidthBurstSize + if responseBandwidthBurstSize := r.Header.Get(ResponseBandwidthBurstSizeHeader); responseBandwidthBurstSize != "" { + if n, err := strconv.ParseInt(responseBandwidthBurstSize, 10, 64); err == nil && + interop.MinResponseBandwidthBurstSize <= n && n <= interop.MaxResponseBandwidthBurstSize { + ResponseBandwidthBurstSize = n + } else { + log.Error("ResponseBandwidthBurstSize header is not a valid number or is out of the allowed range") + renderBadRequest(w, r, interop.ErrInvalidResponseBandwidthBurstSize.Error()) + return nil, interop.ErrInvalidResponseBandwidthBurstSize + } + } + } + inv := &interop.Invoke{ ID: r.Header.Get(InvokeIDHeader), ReservationToken: chi.URLParam(r, "reservationtoken"), @@ -66,7 +130,6 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T LambdaSegmentID: token.LambdaSegmentID, ClientContext: custHeaders.ClientContext, Payload: r.Body, - CorrelationID: "invokeCorrelationID", DeadlineNs: fmt.Sprintf("%d", now+token.FunctionTimeout.Nanoseconds()), NeedDebugLogs: token.NeedDebugLogs, InvokeReceivedTime: now, @@ -99,24 +162,215 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T return inv, nil } -func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, w http.ResponseWriter) error { - for k, v := range additionalHeaders { - w.Header().Add(k, v) +type CopyDoneResult struct { + Metrics *interop.InvokeResponseMetrics + Error error +} + +func getErrorTypeFromResetReason(resetReason string) fatalerror.ErrorType { + errorTypeTrailer, ok := ResetReasonMap[resetReason] + if !ok { + errorTypeTrailer = fatalerror.Unknown + } + return errorTypeTrailer +} + +func isErrorResponse(additionalHeaders map[string]string) (isErrorResponse bool) { + _, isErrorResponse = additionalHeaders[ErrorTypeHeader] + return +} + +func isStreamingInvoke() bool { + return MaxDirectResponseSize == -1 +} + +func asyncPayloadCopy(w http.ResponseWriter, payload io.Reader) (copyDone chan CopyDoneResult, cancel context.CancelFunc, err error) { + copyDone = make(chan CopyDoneResult) + streamedResponseWriter, cancel, err := NewStreamedResponseWriter(w) + if err != nil { + return nil, nil, &interop.ErrInternalPlatformError{} + } + go func() { // copy payload in a separate go routine + _, copyError := bandwidthlimiter.BandwidthLimitingCopy(streamedResponseWriter, payload) + if copyError != nil { + w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) + } else { + w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) + } + copyDoneResult := CopyDoneResult{ + Metrics: streamedResponseWriter.GetMetrics(), + Error: copyError, + } + copyDone <- copyDoneResult + cancel() // free resources + }() + return +} + +func sendStreamingInvokeResponse(payload io.Reader, trailers http.Header, w http.ResponseWriter, + interruptedResponseChan chan *interop.Reset, sendResponseChan chan *interop.InvokeResponseMetrics, + request *interop.CancellableRequest, runtimeCalledResponse bool) (err error) { + /* In case of /response, we copy the payload and, once copied, we attach: + * 1) 'Lambda-Runtime-Function-Error-Type' + * 2) 'Lambda-Runtime-Function-Error-Body' + * trailers. */ + copyDone, cancel, err := asyncPayloadCopy(w, payload) + if err != nil { + renderInternalServerError(w, err.Error()) + return err + } + + var errorTypeTrailer string + var errorBodyTrailer string + var copyDoneResult CopyDoneResult + select { + case copyDoneResult = <-copyDone: // copy finished + errorTypeTrailer = trailers.Get(FunctionErrorTypeTrailer) + errorBodyTrailer = trailers.Get(FunctionErrorBodyTrailer) + if copyDoneResult.Error != nil && errorTypeTrailer == "" { // truncated payload, error type not known + errorTypeTrailer = string(fatalerror.TruncatedResponse) + } + case reset := <-interruptedResponseChan: // reset initiated + cancel() + if request != nil { + // In case of reset: + // * to interrupt copying when runtime called /response (a potential stuck on Body.Read() operation), + // we close the underlying connection using .Close() method on the request object + // * for /error case, the whole body is already read in /error handler, so we don't need additional handling + // when sending streaming invoke error response + connErr := request.Cancel() + if connErr != nil { + log.Warnf("Failed to close underlying connection: %s", connErr) + } + } else { + log.Warn("Cannot close underlying connection. Request object is nil") + } + copyDoneResult = <-copyDone + reset.InvokeResponseMetrics = copyDoneResult.Metrics + interruptedResponseChan <- nil + errorTypeTrailer = string(getErrorTypeFromResetReason(reset.Reason)) + } + w.Header().Set(FunctionErrorTypeTrailer, errorTypeTrailer) + w.Header().Set(FunctionErrorBodyTrailer, errorBodyTrailer) + + copyDoneResult.Metrics.RuntimeCalledResponse = runtimeCalledResponse + sendResponseChan <- copyDoneResult.Metrics + + if copyDoneResult.Error != nil { + log.Errorf("Error while streaming response payload: %s", copyDoneResult.Error) + err = &interop.ErrTruncatedResponse{} + } + return +} + +func sendStreamingInvokeErrorResponse(payload io.Reader, w http.ResponseWriter, + interruptedResponseChan chan *interop.Reset, sendResponseChan chan *interop.InvokeResponseMetrics, runtimeCalledResponse bool) (err error) { + + copyDone, cancel, err := asyncPayloadCopy(w, payload) + if err != nil { + renderInternalServerError(w, err.Error()) + return err + } + + var copyDoneResult CopyDoneResult + select { + case copyDoneResult = <-copyDone: // copy finished + case reset := <-interruptedResponseChan: // reset initiated + cancel() + copyDoneResult = <-copyDone + reset.InvokeResponseMetrics = copyDoneResult.Metrics + interruptedResponseChan <- nil + } + + copyDoneResult.Metrics.RuntimeCalledResponse = runtimeCalledResponse + sendResponseChan <- copyDoneResult.Metrics + + if copyDoneResult.Error != nil { + log.Errorf("Error while streaming error response payload: %s", copyDoneResult.Error) + err = &interop.ErrTruncatedResponse{} + } + return +} + +// parseFunctionResponseMode fetches the mode from the header +// If the header is absent, it returns buffered mode. +func parseFunctionResponseMode(w http.ResponseWriter) (interop.FunctionResponseMode, error) { + headerValue := w.Header().Get(FunctionResponseModeHeader) + // the header is optional, so it's ok to be absent + if headerValue == "" { + return interop.FunctionResponseModeBuffered, nil + } + + return interop.ConvertToFunctionResponseMode(headerValue) +} + +func sendPayloadLimitedResponse(payload io.Reader, trailers http.Header, w http.ResponseWriter, sendResponseChan chan *interop.InvokeResponseMetrics, runtimeCalledResponse bool) (err error) { + functionResponseMode, err := parseFunctionResponseMode(w) + if err != nil { + return err + } + + // non-streaming invoke request but runtime is streaming: predefine Trailer headers + if functionResponseMode == interop.FunctionResponseModeStreaming { + w.Header().Add("Trailer", FunctionErrorTypeTrailer) + w.Header().Add("Trailer", FunctionErrorBodyTrailer) + } + + startReadingResponseMonoTimeMs := metering.Monotime() + written, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // +1 because we do allow 10MB but not 10MB + 1 byte + + // non-streaming invoke request but runtime is streaming: set response trailers + if functionResponseMode == interop.FunctionResponseModeStreaming { + w.Header().Set(FunctionErrorTypeTrailer, trailers.Get(FunctionErrorTypeTrailer)) + w.Header().Set(FunctionErrorBodyTrailer, trailers.Get(FunctionErrorBodyTrailer)) } - n, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // +1 because we do allow 10MB but not 10MB + 1 byte if err != nil { w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) - } else if n == MaxDirectResponseSize+1 { + err = &interop.ErrTruncatedResponse{} + } else if MaxDirectResponseSize != -1 && written == MaxDirectResponseSize+1 { w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) err = &interop.ErrorResponseTooLargeDI{ ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ - ResponseSize: int(n), + ResponseSize: int(written), MaxResponseSize: int(MaxDirectResponseSize), }, } } else { w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } - return err + + sendResponseChan <- &interop.InvokeResponseMetrics{ + ProducedBytes: int64(written), + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: metering.Monotime(), + TimeShapedNs: int64(-1), + OutboundThroughputBps: int64(-1), + // FIXME: + // We should use InvokeResponseMode here, because only when it's streaming we're interested + // on it. If the invoke is buffered, we don't generate streaming metrics, even if the + // function response mode is streaming. + FunctionResponseMode: interop.FunctionResponseModeBuffered, + RuntimeCalledResponse: runtimeCalledResponse, + } + return +} + +func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, trailers http.Header, + w http.ResponseWriter, interruptedResponseChan chan *interop.Reset, + sendResponseChan chan *interop.InvokeResponseMetrics, request *interop.CancellableRequest, runtimeCalledResponse bool) error { + + for k, v := range additionalHeaders { + w.Header().Add(k, v) + } + + if isStreamingInvoke() { // unlimited payload; response streaming mode + if isErrorResponse(additionalHeaders) { // send streamed error response when runtime called /error + return sendStreamingInvokeErrorResponse(payload, w, interruptedResponseChan, sendResponseChan, runtimeCalledResponse) + } + // send streamed response when runtime called /response + return sendStreamingInvokeResponse(payload, trailers, w, interruptedResponseChan, sendResponseChan, request, runtimeCalledResponse) + } + + return sendPayloadLimitedResponse(payload, trailers, w, sendResponseChan, runtimeCalledResponse) } diff --git a/lambda/core/directinvoke/directinvoke_test.go b/lambda/core/directinvoke/directinvoke_test.go new file mode 100644 index 0000000..4e26161 --- /dev/null +++ b/lambda/core/directinvoke/directinvoke_test.go @@ -0,0 +1,358 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "bytes" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/interop" +) + +func NewResponseWriterWithoutFlushMethod() *ResponseWriterWithoutFlushMethod { + return &ResponseWriterWithoutFlushMethod{} +} + +type ResponseWriterWithoutFlushMethod struct{} + +func (*ResponseWriterWithoutFlushMethod) Header() http.Header { return http.Header{} } +func (*ResponseWriterWithoutFlushMethod) Write([]byte) (n int, err error) { return } +func (*ResponseWriterWithoutFlushMethod) WriteHeader(_ int) {} + +func NewSimpleResponseWriter() *SimpleResponseWriter { + return &SimpleResponseWriter{ + buffer: bytes.NewBuffer(nil), + trailers: make(http.Header), + } +} + +type SimpleResponseWriter struct { + buffer *bytes.Buffer + trailers http.Header +} + +func (w *SimpleResponseWriter) Header() http.Header { return w.trailers } +func (w *SimpleResponseWriter) Write(p []byte) (n int, err error) { return w.buffer.Write(p) } +func (*SimpleResponseWriter) WriteHeader(_ int) {} +func (*SimpleResponseWriter) Flush() {} + +func NewInterruptableResponseWriter(interruptAfter int) (*InterruptableResponseWriter, chan struct{}) { + interruptedTestWriterChan := make(chan struct{}) + return &InterruptableResponseWriter{ + buffer: bytes.NewBuffer(nil), + trailers: make(http.Header), + interruptAfter: interruptAfter, + interruptedTestWriterChan: interruptedTestWriterChan, + }, interruptedTestWriterChan +} + +type InterruptableResponseWriter struct { + buffer *bytes.Buffer + trailers http.Header + interruptAfter int // expect Writer to be interrupted after 'interruptAfter' number of writes + interruptedTestWriterChan chan struct{} +} + +func (w *InterruptableResponseWriter) Header() http.Header { return w.trailers } +func (w *InterruptableResponseWriter) Write(p []byte) (n int, err error) { + if w.interruptAfter >= 1 { + w.interruptAfter-- + } else if w.interruptAfter == 0 { + w.interruptedTestWriterChan <- struct{}{} // ready to be interrupted + <-w.interruptedTestWriterChan // wait until interrupted + } + n, err = w.buffer.Write(p) + return +} +func (*InterruptableResponseWriter) WriteHeader(_ int) {} +func (*InterruptableResponseWriter) Flush() {} + +// This is a simple reader implementing io.Reader interface. It's based on strings.Reader, but it doesn't have extra +// methods that allow faster copying such as .WriteTo() method. +func NewReader(s string) *Reader { return &Reader{s, 0, -1} } + +type Reader struct { + s string + i int64 // current reading index + prevRune int // index of previous rune; or < 0 +} + +func (r *Reader) Read(b []byte) (n int, err error) { + if r.i >= int64(len(r.s)) { + return 0, io.EOF + } + r.prevRune = -1 + n = copy(b, r.s[r.i:]) + r.i += int64(n) + return +} + +func TestSendDirectInvokeWithIncompatibleResponseWriter(t *testing.T) { + MaxDirectResponseSize = -1 + err := SendDirectInvokeResponse(nil, nil, nil, NewResponseWriterWithoutFlushMethod(), nil, nil, nil, false) + require.Error(t, err) + require.Equal(t, "ErrInternalPlatformError", err.Error()) +} + +func TestAsyncPayloadCopySuccess(t *testing.T) { + payloadString := strings.Repeat("a", 10*1024*1024) + writer := NewSimpleResponseWriter() + + expectedPayloadString := payloadString + + copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + <-copyDone + require.Equal(t, expectedPayloadString, writer.buffer.String()) +} + +// We use an interruptable response writer which informs on a channel that it's ready to be interrupted after +// 'interruptAfter' number of writes, then it waits for interruption completion to resume the current write operation. +// For this test, after initiating copying, we wait for one chunk of 32 KiB to be copied. Then, we use cancel() to +// interrupt copying. At this point, only ongoing .Write() operations can be performed. We inform the writer about +// interruption completion, and the writer resumes the current .Write() operation, which gives us another 32 KiB chunk +// that is copied. After that, copying returns, and we receive a signal on <-copyDone channel. +func TestAsyncPayloadCopySuccessAfterCancel(t *testing.T) { + payloadString := strings.Repeat("a", 10*1024*1024) // 10 MiB + writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) + + expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB (2 chunks) + + copyDone, cancel, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks + cancel() // interrupt copying + interruptedTestWriterChan <- struct{}{} // inform test writer about interruption + + <-copyDone + require.Equal(t, expectedPayloadString, writer.buffer.String()) +} + +func TestAsyncPayloadCopyWithIncompatibleResponseWriter(t *testing.T) { + copyDone, cancel, err := asyncPayloadCopy(&ResponseWriterWithoutFlushMethod{}, nil) + require.Nil(t, copyDone) + require.Nil(t, cancel) + require.Error(t, err) + require.Equal(t, "ErrInternalPlatformError", err.Error()) +} + +func TestSendStreamingInvokeResponseSuccess(t *testing.T) { + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + trailers := http.Header{} + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := payloadString + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendPayloadLimitedResponseWithinThresholdWithStreamingFunction(t *testing.T) { + payloadSize := 1 + payloadString := strings.Repeat("a", payloadSize) + payload := NewReader(payloadString) + trailers := http.Header{} + writer := NewSimpleResponseWriter() + writer.Header().Set("Lambda-Runtime-Function-Response-Mode", "streaming") + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + MaxDirectResponseSize = int64(payloadSize + 1) + + go func() { + err := sendPayloadLimitedResponse(payload, trailers, writer, sendResponseChan, true) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + metrics := <-sendResponseChan + require.Equal(t, interop.FunctionResponseModeBuffered, metrics.FunctionResponseMode) + require.Equal(t, len(payloadString), len(writer.buffer.String())) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished + + // Reset to its default value, just in case other tests use them + MaxDirectResponseSize = interop.MaxPayloadSize +} + +func TestSendPayloadLimitedResponseAboveThresholdWithStreamingFunction(t *testing.T) { + payloadSize := 2 + payloadString := strings.Repeat("a", payloadSize) + payload := NewReader(payloadString) + trailers := http.Header{} + writer := NewSimpleResponseWriter() + writer.Header().Set("Lambda-Runtime-Function-Response-Mode", "streaming") + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + MaxDirectResponseSize = int64(payloadSize - 1) + expectedError := &interop.ErrorResponseTooLargeDI{ + ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ + MaxResponseSize: int(MaxDirectResponseSize), + ResponseSize: payloadSize, + }, + } + + go func() { + err := sendPayloadLimitedResponse(payload, trailers, writer, sendResponseChan, true) + require.Equal(t, expectedError, err) + testFinished <- struct{}{} + }() + + metrics := <-sendResponseChan + require.Equal(t, interop.FunctionResponseModeBuffered, metrics.FunctionResponseMode) + require.Equal(t, len(payloadString), len(writer.buffer.String())) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Oversized", writer.Header().Get("End-Of-Response")) + <-testFinished + + // Reset to its default value, just in case other tests use them + MaxDirectResponseSize = interop.MaxPayloadSize +} + +func TestSendStreamingInvokeResponseSuccessWithTrailers(t *testing.T) { + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + trailers := http.Header{ + "Lambda-Runtime-Function-Error-Type": []string{"ErrorType"}, + "Lambda-Runtime-Function-Error-Body": []string{"ErrorBody"}, + } + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := payloadString + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "ErrorType", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "ErrorBody", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendStreamingInvokeResponseReset(t *testing.T) { // Reset initiated after writing two chunks of 32 KiB + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + trailers := http.Header{} + writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, true) + require.Error(t, err) + require.Equal(t, "ErrTruncatedResponse", err.Error()) + testFinished <- struct{}{} + }() + + reset := &interop.Reset{Reason: "timeout"} + require.Nil(t, reset.InvokeResponseMetrics) + + <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks + interruptedResponseChan <- reset // send reset + time.Sleep(10 * time.Millisecond) // wait for cancel() being called (first instruction after getting reset) + interruptedTestWriterChan <- struct{}{} // inform test writer about interruption + <-interruptedResponseChan // wait for copy done after interruption + require.NotNil(t, reset.InvokeResponseMetrics) + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "Sandbox.Timeout", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Truncated", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendStreamingInvokeErrorResponseSuccess(t *testing.T) { + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := payloadString + + go func() { + err := sendStreamingInvokeErrorResponse(payload, writer, interruptedResponseChan, sendResponseChan, false) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendStreamingInvokeErrorResponseReset(t *testing.T) { // Reset initiated after writing two chunks of 32 KiB + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB + + go func() { + err := sendStreamingInvokeErrorResponse(payload, writer, interruptedResponseChan, sendResponseChan, true) + require.Error(t, err) + require.Equal(t, "ErrTruncatedResponse", err.Error()) + testFinished <- struct{}{} + }() + + reset := &interop.Reset{Reason: "timeout"} + require.Nil(t, reset.InvokeResponseMetrics) + + <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks + interruptedResponseChan <- reset // send reset + time.Sleep(10 * time.Millisecond) // wait for cancel() being called (first instruction after getting reset) + interruptedTestWriterChan <- struct{}{} // inform test writer about interruption + <-interruptedResponseChan // wait for copy done after interruption + require.NotNil(t, reset.InvokeResponseMetrics) + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Truncated", writer.Header().Get("End-Of-Response")) + <-testFinished +} diff --git a/lambda/core/directinvoke/util.go b/lambda/core/directinvoke/util.go new file mode 100644 index 0000000..511d656 --- /dev/null +++ b/lambda/core/directinvoke/util.go @@ -0,0 +1,84 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "context" + "errors" + "go.amzn.com/lambda/core/bandwidthlimiter" + "io" + "net/http" + "time" + + log "github.com/sirupsen/logrus" +) + +const DefaultRefillIntervalMs = 125 // default refill interval in milliseconds + +func NewStreamedResponseWriter(w http.ResponseWriter) (*bandwidthlimiter.BandwidthLimitingWriter, context.CancelFunc, error) { + flushingWriter, err := NewFlushingWriter(w) // after writing a chunk we have to flush it to avoid additional buffering by ResponseWriter + if err != nil { + return nil, nil, err + } + cancellableWriter, cancel := NewCancellableWriter(flushingWriter) // cancelling prevents next calls to Write() from happening + + refillNumber := ResponseBandwidthRate * DefaultRefillIntervalMs / 1000 // refillNumber is calculated based on 'ResponseBandwidthRate' and bucket refill interval + refillInterval := DefaultRefillIntervalMs * time.Millisecond + + // Initial bucket for token bucket algorithm allows for a burst of up to 6 MiB, and an average transmission rate of 2 MiB/s + bucket, err := bandwidthlimiter.NewBucket(ResponseBandwidthBurstSize, ResponseBandwidthBurstSize, refillNumber, refillInterval) + if err != nil { + cancel() // free resources + return nil, nil, err + } + + bandwidthLimitingWriter, err := bandwidthlimiter.NewBandwidthLimitingWriter(cancellableWriter, bucket) + if err != nil { + cancel() // free resources + return nil, nil, err + } + + return bandwidthLimitingWriter, cancel, nil +} + +func NewFlushingWriter(w io.Writer) (*FlushingWriter, error) { + flusher, ok := w.(http.Flusher) + if !ok { + errorMsg := "expected http.ResponseWriter to be an http.Flusher" + log.Error(errorMsg) + return nil, errors.New(errorMsg) + } + return &FlushingWriter{ + w: w, + flusher: flusher, + }, nil +} + +type FlushingWriter struct { + w io.Writer + flusher http.Flusher +} + +func (w *FlushingWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + w.flusher.Flush() + return +} + +func NewCancellableWriter(w io.Writer) (*CancellableWriter, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return &CancellableWriter{w: w, ctx: ctx}, cancel +} + +type CancellableWriter struct { + w io.Writer + ctx context.Context +} + +func (w *CancellableWriter) Write(p []byte) (int, error) { + if err := w.ctx.Err(); err != nil { + return 0, err + } + return w.w.Write(p) +} diff --git a/lambda/core/doc.go b/lambda/core/doc.go index 23c1539..4a7157f 100644 --- a/lambda/core/doc.go +++ b/lambda/core/doc.go @@ -2,26 +2,23 @@ // SPDX-License-Identifier: Apache-2.0 /* - Package core provides state objects and synchronization primitives for managing data flow in the system. - -States +# States Runtime and Agent implement state object design pattern. Runtime state interface: -type RuntimeState interface { - InitError() error - Ready() error - InvocationResponse() error - InvocationErrorResponse() error -} + type RuntimeState interface { + InitError() error + Ready() error + InvocationResponse() error + InvocationErrorResponse() error + } - -Gates +# Gates Gates provide synchornization primitives for managing data flow in the system. @@ -31,8 +28,9 @@ set of operations being performed in other threads completes. To better understand gates, consider two examples below: Example 1: main thread is awaiting registered threads to walk through the gate, - and after the last registered thread walked through the gate, gate - condition will be satisfied and main thread will proceed: + + and after the last registered thread walked through the gate, gate + condition will be satisfied and main thread will proceed: [main] // register threads with the gate and start threads ... [main] g.AwaitGateCondition() @@ -42,27 +40,25 @@ Example 1: main thread is awaiting registered threads to walk through the gate, [thread] // not blocked Example 2: main thread is awaiting registered threads to arrive at the gate, - and after the last registered thread arrives at the gate, gate - condition will be satisfied and main thread, along with registered - threads will proceed: + + and after the last registered thread arrives at the gate, gate + condition will be satisfied and main thread, along with registered + threads will proceed: [main] // register threads with the gate and start threads ... [main] g.AwaitGateCondition() [main] // blocked until gate condition is satisfied - -Flow +# Flow Flow wraps a set of specific gates required to implement specific data flow in the system. Example flows would be INIT, INVOKE and RESET. - -Registrations +# Registrations Registration service manages registrations, it maintains the mapping between registered parties are events they are registered. Parties not registered in the system will not be issued events. - */ package core diff --git a/lambda/core/externalagent.go b/lambda/core/externalagent.go index 792f356..cd367d2 100644 --- a/lambda/core/externalagent.go +++ b/lambda/core/externalagent.go @@ -22,7 +22,6 @@ type ExternalAgent struct { currentState ExternalAgentState stateLastModified time.Time - Pid int StartedState ExternalAgentState RegisteredState ExternalAgentState diff --git a/lambda/core/flow.go b/lambda/core/flow.go index 3c22b84..b2cb538 100644 --- a/lambda/core/flow.go +++ b/lambda/core/flow.go @@ -19,6 +19,9 @@ type InitFlowSynchronization interface { CancelWithError(error) + RuntimeRestoreReady() error + AwaitRuntimeRestoreReady() error + Clear() } @@ -26,6 +29,7 @@ type initFlowSynchronizationImpl struct { externalAgentsRegisteredGate Gate runtimeReadyGate Gate agentReadyGate Gate + runtimeRestoreReadyGate Gate } // SetExternalAgentsRegisterCount notifies init flow that N /extension/register calls should be done in future by external agents @@ -43,6 +47,11 @@ func (s *initFlowSynchronizationImpl) AwaitRuntimeReady() error { return s.runtimeReadyGate.AwaitGateCondition() } +// AwaitRuntimeRestoreReady awaits runtime restore ready state (/restore/next is called by runtime) +func (s *initFlowSynchronizationImpl) AwaitRuntimeRestoreReady() error { + return s.runtimeRestoreReadyGate.AwaitGateCondition() +} + // AwaitExternalAgentsRegistered awaits for all subscribed agents to report registered func (s *initFlowSynchronizationImpl) AwaitExternalAgentsRegistered() error { return s.externalAgentsRegisteredGate.AwaitGateCondition() @@ -58,6 +67,11 @@ func (s *initFlowSynchronizationImpl) RuntimeReady() error { return s.runtimeReadyGate.WalkThrough() } +// Ready called by runtime when restore is completed (i.e. /next is called after /restore/next) +func (s *initFlowSynchronizationImpl) RuntimeRestoreReady() error { + return s.runtimeRestoreReadyGate.WalkThrough() +} + // Ready called by agent when initialized func (s *initFlowSynchronizationImpl) AgentReady() error { return s.agentReadyGate.WalkThrough() @@ -73,6 +87,7 @@ func (s *initFlowSynchronizationImpl) CancelWithError(err error) { s.externalAgentsRegisteredGate.CancelWithError(err) s.runtimeReadyGate.CancelWithError(err) s.agentReadyGate.CancelWithError(err) + s.runtimeRestoreReadyGate.CancelWithError(err) } // Clear gates state @@ -80,6 +95,7 @@ func (s *initFlowSynchronizationImpl) Clear() { s.externalAgentsRegisteredGate.Clear() s.runtimeReadyGate.Clear() s.agentReadyGate.Clear() + s.runtimeRestoreReadyGate.Clear() } // NewInitFlowSynchronization returns new InitFlowSynchronization instance. @@ -88,6 +104,7 @@ func NewInitFlowSynchronization() InitFlowSynchronization { runtimeReadyGate: NewGate(1), externalAgentsRegisteredGate: NewGate(0), agentReadyGate: NewGate(maxAgentsLimit), + runtimeRestoreReadyGate: NewGate(1), } return initFlow } diff --git a/lambda/core/registrations.go b/lambda/core/registrations.go index dca9d90..f68612c 100644 --- a/lambda/core/registrations.go +++ b/lambda/core/registrations.go @@ -10,8 +10,11 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/interop" "github.com/google/uuid" + + log "github.com/sirupsen/logrus" ) type registrationServiceState int @@ -70,6 +73,7 @@ type FunctionMetadata struct { FunctionName string FunctionVersion string Handler string + RuntimeInfo interop.RuntimeInfo } // RegistrationService keeps track of registered parties, including external agents, threads, and runtime. @@ -94,6 +98,7 @@ type RegistrationService interface { CountAgents() int Clear() AgentsInfo() []AgentInfo + CancelFlows(err error) } type registrationServiceImpl struct { @@ -105,6 +110,7 @@ type registrationServiceImpl struct { initFlow InitFlowSynchronization invokeFlow InvokeFlowSynchronization functionMetadata FunctionMetadata + cancelOnce sync.Once } func (s *registrationServiceImpl) Clear() { @@ -115,6 +121,7 @@ func (s *registrationServiceImpl) Clear() { s.internalAgents.Clear() s.externalAgents.Clear() s.state = registrationServiceOn + s.cancelOnce = sync.Once{} } func (s *registrationServiceImpl) InitFlow() InitFlowSynchronization { @@ -373,6 +380,19 @@ func (s *registrationServiceImpl) TurnOff() { s.state = registrationServiceOff } +// CancelFlows cancels init and invoke flows with error. +func (s *registrationServiceImpl) CancelFlows(err error) { + s.mutex.Lock() + defer s.mutex.Unlock() + // The following block protects us from overwriting the error + // which was first used to cancel flows. + s.cancelOnce.Do(func() { + log.Debugf("Canceling flows: %s", err) + s.initFlow.CancelWithError(err) + s.invokeFlow.CancelWithError(err) + }) +} + // NewRegistrationService returns new RegistrationService instance. func NewRegistrationService(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchronization) RegistrationService { return ®istrationServiceImpl{ @@ -382,5 +402,6 @@ func NewRegistrationService(initFlow InitFlowSynchronization, invokeFlow InvokeF externalAgents: NewExternalAgentsMap(), initFlow: initFlow, invokeFlow: invokeFlow, + cancelOnce: sync.Once{}, } } diff --git a/lambda/core/registrations_test.go b/lambda/core/registrations_test.go index d8857a6..5956ac3 100644 --- a/lambda/core/registrations_test.go +++ b/lambda/core/registrations_test.go @@ -63,7 +63,7 @@ func TestRegistrationServiceHappyPathDuringInit(t *testing.T) { assert.NoError(t, runtime.Ready()) }() - assert.NoError(t, initFlow.AwaitRuntimeReady()) + assert.NoError(t, initFlow.AwaitRuntimeRestoreReady()) registrationService.TurnOff() // Agents Ready diff --git a/lambda/core/runtime_state_names.go b/lambda/core/runtime_state_names.go index b20b9f8..b04ba5d 100644 --- a/lambda/core/runtime_state_names.go +++ b/lambda/core/runtime_state_names.go @@ -5,10 +5,14 @@ package core // String values of possibles runtime states const ( - RuntimeStartedStateName = "Started" - RuntimeInitErrorStateName = "InitError" - RuntimeReadyStateName = "Ready" - RuntimeRunningStateName = "Running" + RuntimeStartedStateName = "Started" + RuntimeInitErrorStateName = "InitError" + RuntimeReadyStateName = "Ready" + RuntimeRunningStateName = "Running" + // RuntimeStartedState -> RuntimeRestoreReadyState + RuntimeRestoreReadyStateName = "RestoreReady" + // RuntimeRestoreReadyState -> RuntimeRestoringState + RuntimeRestoringStateName = "Restoring" RuntimeInvocationResponseStateName = "InvocationResponse" RuntimeInvocationErrorResponseStateName = "InvocationErrorResponse" RuntimeResponseSentStateName = "RuntimeResponseSentState" diff --git a/lambda/core/states.go b/lambda/core/states.go index bc7359d..a5e2010 100644 --- a/lambda/core/states.go +++ b/lambda/core/states.go @@ -72,6 +72,7 @@ var ErrConcurrentStateModification = errors.New("Concurrent state modification") type RuntimeState interface { InitError() error Ready() error + RestoreReady() error InvocationResponse() error InvocationErrorResponse() error ResponseSent() error @@ -82,6 +83,7 @@ type disallowEveryTransitionByDefault struct{} func (s *disallowEveryTransitionByDefault) InitError() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) Ready() error { return ErrNotAllowed } +func (s *disallowEveryTransitionByDefault) RestoreReady() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) InvocationResponse() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) InvocationErrorResponse() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) ResponseSent() error { return ErrNotAllowed } @@ -92,13 +94,14 @@ type Runtime struct { currentState RuntimeState stateLastModified time.Time - Pid int responseTime time.Time RuntimeStartedState RuntimeState RuntimeInitErrorState RuntimeState RuntimeReadyState RuntimeState RuntimeRunningState RuntimeState + RuntimeRestoreReadyState RuntimeState + RuntimeRestoringState RuntimeState RuntimeInvocationResponseState RuntimeState RuntimeInvocationErrorResponseState RuntimeState RuntimeResponseSentState RuntimeState @@ -135,6 +138,12 @@ func (s *Runtime) Ready() error { return s.currentState.Ready() } +func (s *Runtime) RestoreReady() error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.RestoreReady() +} + // InvocationResponse delegates to state implementation. func (s *Runtime) InvocationResponse() error { s.ManagedThread.Lock() @@ -196,6 +205,8 @@ func NewRuntime(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchroni runtime.RuntimeInvocationResponseState = &RuntimeInvocationResponseState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeInvocationErrorResponseState = &RuntimeInvocationErrorResponseState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeResponseSentState = &RuntimeResponseSentState{runtime: runtime, invokeFlow: invokeFlow} + runtime.RuntimeRestoreReadyState = &RuntimeRestoreReadyState{} + runtime.RuntimeRestoringState = &RuntimeRestoringState{runtime: runtime, initFlow: initFlow} runtime.setStateUnsafe(runtime.RuntimeStartedState) return runtime @@ -211,7 +222,14 @@ type RuntimeStartedState struct { // Ready call when runtime init done. func (s *RuntimeStartedState) Ready() error { s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) - err := s.initFlow.RuntimeReady() + // runtime called /next without calling /restore/next + // that means it's not interested in restore phase + err := s.initFlow.RuntimeRestoreReady() + if err != nil { + return err + } + + err = s.initFlow.RuntimeReady() if err != nil { return err } @@ -225,6 +243,22 @@ func (s *RuntimeStartedState) Ready() error { return nil } +func (s *RuntimeStartedState) RestoreReady() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeRestoreReadyState) + err := s.initFlow.RuntimeRestoreReady() + if err != nil { + return err + } + + s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeRestoreReadyState && s.runtime.currentState != s.runtime.RuntimeRestoringState { + return ErrConcurrentStateModification + } + + s.runtime.setStateUnsafe(s.runtime.RuntimeRestoringState) + return nil +} + // InitError move runtime to init error state. func (s *RuntimeStartedState) InitError() error { s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) @@ -236,6 +270,38 @@ func (s *RuntimeStartedState) Name() string { return RuntimeStartedStateName } +type RuntimeRestoringState struct { + disallowEveryTransitionByDefault + runtime *Runtime + initFlow InitFlowSynchronization +} + +// Runtime is healthy after restore and called /next +func (s *RuntimeRestoringState) Ready() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) + err := s.initFlow.RuntimeReady() + if err != nil { + return err + } + s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { + return ErrConcurrentStateModification + } + + s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) + return nil +} + +// Runtime has thrown an exception when executing restore hooks and called /init/error +func (s *RuntimeRestoringState) InitError() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) + return nil +} + +func (s *RuntimeRestoringState) Name() string { + return RuntimeRestoringStateName +} + // RuntimeInitErrorState runtime started state. type RuntimeInitErrorState struct { disallowEveryTransitionByDefault @@ -297,6 +363,14 @@ func (s *RuntimeRunningState) Name() string { return RuntimeRunningStateName } +type RuntimeRestoreReadyState struct { + disallowEveryTransitionByDefault +} + +func (s *RuntimeRestoreReadyState) Name() string { + return RuntimeRestoreReadyStateName +} + // RuntimeInvocationResponseState runtime response is available. // Start state for runtime response submission. type RuntimeInvocationResponseState struct { diff --git a/lambda/core/states_test.go b/lambda/core/states_test.go index 4b01838..37f38e2 100644 --- a/lambda/core/states_test.go +++ b/lambda/core/states_test.go @@ -39,10 +39,7 @@ func TestRuntimeInitErrorAfterReady(t *testing.T) { } func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // Started assert.Equal(t, runtime.RuntimeStartedState, runtime.GetState()) // Started -> InitError @@ -53,6 +50,10 @@ func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { runtime.SetState(runtime.RuntimeStartedState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Started -> RestoreReady + runtime.SetState(runtime.RuntimeStartedState) + assert.NoError(t, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) // Started -> ResponseSent runtime.SetState(runtime.RuntimeStartedState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -68,10 +69,7 @@ func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { } func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // InitError -> InitError runtime.SetState(runtime.RuntimeInitErrorState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -80,6 +78,10 @@ func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { runtime.SetState(runtime.RuntimeInitErrorState) assert.Equal(t, ErrNotAllowed, runtime.Ready()) assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) + // InitError -> RestoreReady + runtime.SetState(runtime.RuntimeInitErrorState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) // InitError -> ResponseSent runtime.SetState(runtime.RuntimeInitErrorState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -95,10 +97,7 @@ func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { } func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // Ready -> InitError runtime.SetState(runtime.RuntimeReadyState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -107,6 +106,10 @@ func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { runtime.SetState(runtime.RuntimeReadyState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Ready -> RestoreReady + runtime.SetState(runtime.RuntimeReadyState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) // Ready -> ResponseSent runtime.SetState(runtime.RuntimeReadyState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -122,10 +125,7 @@ func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { } func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // Running -> InitError runtime.SetState(runtime.RuntimeRunningState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -134,6 +134,10 @@ func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { runtime.SetState(runtime.RuntimeRunningState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Running -> RestoreReady + runtime.SetState(runtime.RuntimeRunningState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) // Running -> ResponseSent runtime.SetState(runtime.RuntimeRunningState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -149,10 +153,7 @@ func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { } func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // InvocationResponse -> InitError runtime.SetState(runtime.RuntimeInvocationResponseState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -161,6 +162,10 @@ func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { runtime.SetState(runtime.RuntimeInvocationResponseState) assert.Equal(t, ErrNotAllowed, runtime.Ready()) assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) + // InvocationResponse -> RestoreReady + runtime.SetState(runtime.RuntimeInvocationResponseState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) // InvocationResponse -> ResponseSent runtime.SetState(runtime.RuntimeInvocationResponseState) assert.NoError(t, runtime.ResponseSent()) @@ -177,10 +182,7 @@ func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { } func TestRuntimeStateTransitionsFromInvocationErrorResponseState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // InvocationErrorResponse -> InitError runtime.SetState(runtime.RuntimeInvocationErrorResponseState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -189,6 +191,10 @@ func TestRuntimeStateTransitionsFromInvocationErrorResponseState(t *testing.T) { runtime.SetState(runtime.RuntimeInvocationErrorResponseState) assert.Equal(t, ErrNotAllowed, runtime.Ready()) assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) + // InvocationErrorResponse -> RestoreReady + runtime.SetState(runtime.RuntimeInvocationErrorResponseState) + assert.Equal(t, ErrNotAllowed, runtime.Ready()) + assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) // InvocationErrorResponse -> ResponseSent runtime.SetState(runtime.RuntimeInvocationErrorResponseState) assert.NoError(t, runtime.ResponseSent()) @@ -204,10 +210,7 @@ func TestRuntimeStateTransitionsFromInvocationErrorResponseState(t *testing.T) { } func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // ResponseSent -> InitError runtime.SetState(runtime.RuntimeResponseSentState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -216,6 +219,10 @@ func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { runtime.SetState(runtime.RuntimeResponseSentState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // ResponseSent -> RestoreReady + runtime.SetState(runtime.RuntimeResponseSentState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) // ResponseSent -> ResponseSent runtime.SetState(runtime.RuntimeResponseSentState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -230,6 +237,71 @@ func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) } +func TestRuntimeStateTransitionsFromRestoreReadyState(t *testing.T) { + runtime := newRuntime() + // RestoreReady -> InitError + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> Ready + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> RestoreReady() + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> ResponseSent + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> InvocationResponse + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> InvocationErrorResponse + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) +} + +func TestRuntimeStateTransitionsFromRestoringState(t *testing.T) { + runtime := newRuntime() + // RestoreRunning -> InitError + runtime.SetState(runtime.RuntimeRestoringState) + assert.NoError(t, runtime.InitError()) + assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) + // RestoreRunning -> Ready + runtime.SetState(runtime.RuntimeRestoringState) + assert.NoError(t, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // RestoreRunning -> RestoreReady + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> ResponseSent + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> InvocationResponse + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> InvocationErrorResponse + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) +} + +func newRuntime() *Runtime { + initFlow := &mockInitFlowSynchronization{} + invokeFlow := &mockInvokeFlowSynchronization{} + runtime := NewRuntime(initFlow, invokeFlow) + runtime.ManagedThread = &mockthread.MockManagedThread{} + + return runtime +} + type mockInitFlowSynchronization struct { mock.Mock ReadyCond *sync.Cond @@ -272,6 +344,12 @@ func (s *mockInitFlowSynchronization) CancelWithError(err error) { s.Called(err) } func (s *mockInitFlowSynchronization) Clear() {} +func (s *mockInitFlowSynchronization) RuntimeRestoreReady() error { + return nil +} +func (s *mockInitFlowSynchronization) AwaitRuntimeRestoreReady() error { + return nil +} type mockInvokeFlowSynchronization struct{ mock.Mock } diff --git a/lambda/core/watchdog.go b/lambda/core/watchdog.go deleted file mode 100644 index bf57d01..0000000 --- a/lambda/core/watchdog.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "fmt" - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/fatalerror" - "sync" -) - -type WaitableProcess interface { - // Wait blocks until process exits and returns error in case of non-zero exit code - Wait() error - // Pid returnes process ID - Pid() int - // Name returnes process executable name (for logging) - Name() string -} - -// Watchdog watches started goroutines. -type Watchdog struct { - cancelOnce sync.Once - initFlow InitFlowSynchronization - invokeFlow InvokeFlowSynchronization - exitPidChan chan<- int - appCtx appctx.ApplicationContext - mutedMutex sync.Mutex - muted bool -} - -func (w *Watchdog) Mute() { - w.mutedMutex.Lock() - defer w.mutedMutex.Unlock() - w.muted = true -} - -func (w *Watchdog) Unmute() { - w.mutedMutex.Lock() - defer w.mutedMutex.Unlock() - w.muted = false -} - -func (w *Watchdog) Muted() bool { - w.mutedMutex.Lock() - defer w.mutedMutex.Unlock() - return w.muted -} - -// GoWait waits for process to complete in separate goroutine and handles the process termination -// Returns PID of the process -func (w *Watchdog) GoWait(p WaitableProcess, errorType fatalerror.ErrorType) int { - pid := p.Pid() - name := p.Name() - appCtx := w.appCtx - go func() { - err := p.Wait() - - if !w.Muted() { - appctx.StoreFirstFatalError(appCtx, errorType) - - if err == nil { - err = fmt.Errorf("exit code 0") - } - log.Warnf("Process %d(%s) exited: %s", pid, name, err) - } - - w.CancelFlows(err) - w.exitPidChan <- pid - }() - - return pid -} - -// CancelFlows cancels init and invoke flows with error. -func (w *Watchdog) CancelFlows(err error) { - // The following block protects us from overwriting the error - // which was first used to cancel flows. - w.cancelOnce.Do(func() { - log.Debugf("Canceling flows: %s", err) - w.initFlow.CancelWithError(err) - w.invokeFlow.CancelWithError(err) - }) -} - -// Clear watchdog state -func (w *Watchdog) Clear() { - w.cancelOnce = sync.Once{} -} - -// NewWatchdog returns new instance of a Watchdog. -func NewWatchdog(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchronization, exitPidChan chan<- int, appCtx appctx.ApplicationContext) *Watchdog { - return &Watchdog{ - initFlow: initFlow, - invokeFlow: invokeFlow, - exitPidChan: exitPidChan, - appCtx: appCtx, - mutedMutex: sync.Mutex{}, - } -} diff --git a/lambda/core/watchdog_test.go b/lambda/core/watchdog_test.go deleted file mode 100644 index 84f8342..0000000 --- a/lambda/core/watchdog_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "errors" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/fatalerror" - "testing" -) - -var errTest = errors.New("ErrTest") - -type MockProcess struct { -} - -func (s *MockProcess) Wait() error { return errTest } -func (s *MockProcess) Pid() int { return 0 } -func (s *MockProcess) Name() string { return "" } - -func TestWatchdogCallback(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - initFlow.On("CancelWithError", mock.Anything) - invokeFlow.On("CancelWithError", mock.Anything) - - pidChan := make(chan int) - appCtx := appctx.NewApplicationContext() - w := NewWatchdog(initFlow, invokeFlow, pidChan, appCtx) - - w.GoWait(&MockProcess{}, fatalerror.AgentExitError) - w.GoWait(&MockProcess{}, fatalerror.AgentExitError) - - <-pidChan - initFlow.AssertCalled(t, "CancelWithError", errTest) - initFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - invokeFlow.AssertCalled(t, "CancelWithError", errTest) - invokeFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - - <-pidChan - initFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - invokeFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - - err, found := appctx.LoadFirstFatalError(appCtx) - require.True(t, found) - require.Equal(t, err, fatalerror.AgentExitError) -} diff --git a/lambda/fatalerror/fatalerror.go b/lambda/fatalerror/fatalerror.go index 7292baf..bb8a86a 100644 --- a/lambda/fatalerror/fatalerror.go +++ b/lambda/fatalerror/fatalerror.go @@ -3,7 +3,7 @@ package fatalerror -// This package defines constant error types returned to slicer with DONE(failure) +// This package defines constant error types returned to slicer with DONE(failure), and also sandbox errors // Separate package for namespacing // ErrorType is returned to slicer inside DONE @@ -18,5 +18,8 @@ const ( InvalidEntrypoint ErrorType = "Runtime.InvalidEntrypoint" InvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir" InvalidTaskConfig ErrorType = "Runtime.InvalidTaskConfig" + TruncatedResponse ErrorType = "Runtime.TruncatedResponse" + SandboxFailure ErrorType = "Sandbox.Failure" + SandboxTimeout ErrorType = "Sandbox.Timeout" Unknown ErrorType = "Unknown" ) diff --git a/lambda/interop/bootstrap.go b/lambda/interop/bootstrap.go new file mode 100644 index 0000000..4a9b6af --- /dev/null +++ b/lambda/interop/bootstrap.go @@ -0,0 +1,18 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "os" + + "go.amzn.com/lambda/fatalerror" +) + +type Bootstrap interface { + Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable + Env(e EnvironmentVariables) map[string]string // returns the environment variables to be passed to the bootstrapped process + Cwd() (string, error) // returns the working directory of the bootstrap process + ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime + CachedFatalError(err error) (fatalerror.ErrorType, string, bool) +} diff --git a/lambda/interop/cancellable_request.go b/lambda/interop/cancellable_request.go new file mode 100644 index 0000000..7e8fca5 --- /dev/null +++ b/lambda/interop/cancellable_request.go @@ -0,0 +1,27 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "net" + "net/http" +) + +type key int + +const ( + HTTPConnKey key = iota +) + +func GetConn(r *http.Request) net.Conn { + return r.Context().Value(HTTPConnKey).(net.Conn) +} + +type CancellableRequest struct { + Request *http.Request +} + +func (c *CancellableRequest) Cancel() error { + return GetConn(c.Request).Close() +} diff --git a/lambda/interop/environment_variables.go b/lambda/interop/environment_variables.go new file mode 100644 index 0000000..46bdf8b --- /dev/null +++ b/lambda/interop/environment_variables.go @@ -0,0 +1,14 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +type EnvironmentVariables interface { + AgentExecEnv() map[string]string + RuntimeExecEnv() map[string]string + SetHandler(handler string) + StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) + StoreEnvironmentVariablesFromInit(customerEnv map[string]string, + handler, awsKey, awsSecret, awsSession, funcName, funcVer string) + StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) +} diff --git a/lambda/interop/model.go b/lambda/interop/model.go index 5cdf63f..cc9c7d0 100644 --- a/lambda/interop/model.go +++ b/lambda/interop/model.go @@ -8,17 +8,99 @@ import ( "fmt" "io" "net/http" + "strings" "time" "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/supervisor/model" + + log "github.com/sirupsen/logrus" ) // MaxPayloadSize max event body size declared as LAMBDA_EVENT_BODY_SIZE -const MaxPayloadSize = 6*1024*1024 + 100 // 6 MiB + 100 bytes +const ( + MaxPayloadSize = 6*1024*1024 + 100 // 6 MiB + 100 bytes + + ResponseBandwidthRate = 2 * 1024 * 1024 // default average rate of 2 MiB/s + ResponseBandwidthBurstSize = 6 * 1024 * 1024 // default burst size of 6 MiB + + MinResponseBandwidthRate = 32 * 1024 // 32 KiB/s + MaxResponseBandwidthRate = 64 * 1024 * 1024 // 64 MiB/s + + MinResponseBandwidthBurstSize = 32 * 1024 // 32 KiB + MaxResponseBandwidthBurstSize = 64 * 1024 * 1024 // 64 MiB +) const functionResponseSizeTooLargeType = "Function.ResponseSizeTooLarge" +// ResponseMode are top-level constants used in combination with the various types of +// modes we have for responses, such as invoke's response mode and function's response mode. +// In the future we might have invoke's request mode or similar, so these help set the ground +// for consistency. +type ResponseMode string + +const ResponseModeBuffered = "Buffered" +const ResponseModeStreaming = "Streaming" + +type InvokeResponseMode string + +const InvokeResponseModeBuffered InvokeResponseMode = ResponseModeBuffered +const InvokeResponseModeStreaming InvokeResponseMode = ResponseModeStreaming + +var AllInvokeResponseModes = []string{ + string(InvokeResponseModeBuffered), string(InvokeResponseModeStreaming), +} + +// ConvertToInvokeResponseMode converts the given string to a InvokeResponseMode +// It is case insensitive and if there is no match, an error is thrown. +func ConvertToInvokeResponseMode(value string) (InvokeResponseMode, error) { + // buffered + if strings.EqualFold(value, string(InvokeResponseModeBuffered)) { + return InvokeResponseModeBuffered, nil + } + + // streaming + if strings.EqualFold(value, string(InvokeResponseModeStreaming)) { + return InvokeResponseModeStreaming, nil + } + + // unknown + allowedValues := strings.Join(AllInvokeResponseModes, ", ") + log.Errorf("Unlable to map %s to %s.", value, allowedValues) + return "", ErrInvalidInvokeResponseMode +} + +// FunctionResponseMode is passed by Runtime to tell whether the response should be +// streamed or not. +type FunctionResponseMode string + +const FunctionResponseModeBuffered FunctionResponseMode = ResponseModeBuffered +const FunctionResponseModeStreaming FunctionResponseMode = ResponseModeStreaming + +var AllFunctionResponseModes = []string{ + string(FunctionResponseModeBuffered), string(FunctionResponseModeStreaming), +} + +// ConvertToFunctionResponseMode converts the given string to a FunctionResponseMode +// It is case insensitive and if there is no match, an error is thrown. +func ConvertToFunctionResponseMode(value string) (FunctionResponseMode, error) { + // buffered + if strings.EqualFold(value, string(FunctionResponseModeBuffered)) { + return FunctionResponseModeBuffered, nil + } + + // streaming + if strings.EqualFold(value, string(FunctionResponseModeStreaming)) { + return FunctionResponseModeStreaming, nil + } + + // unknown + allowedValues := strings.Join(AllFunctionResponseModes, ", ") + log.Errorf("Unlable to map %s to %s.", value, allowedValues) + return "", ErrInvalidFunctionResponseMode +} + // Message is a generic interop message. type Message interface{} @@ -37,11 +119,10 @@ type Invoke struct { ContentType string Payload io.Reader NeedDebugLogs bool - CorrelationID string // internal use only ReservationToken string VersionID string InvokeReceivedTime int64 - ResyncState Resync + InvokeResponseMetrics *InvokeResponseMetrics } type Token struct { @@ -54,21 +135,13 @@ type Token struct { LambdaSegmentID string InvokeMetadata string NeedDebugLogs bool - ResyncState Resync -} - -type Resync struct { - IsResyncReceived bool - AwsKey string - AwsSecret string - AwsSession string - ReceivedTime time.Time } type ErrorResponse struct { // Payload sent via shared memory. - Payload []byte `json:"Payload,omitempty"` - ContentType string `json:"-"` + Payload []byte `json:"Payload,omitempty"` + ContentType string `json:"-"` + FunctionResponseMode string `json:"-"` // When error response body (Payload) is not provided, e.g. // not retrievable, error type and error message will be @@ -92,48 +165,80 @@ type SandboxType string const SandboxPreWarmed SandboxType = "PreWarmed" const SandboxClassic SandboxType = "Classic" -// Start message received from the slicer, part of the protocol. -type Start struct { - InvokeID string - Handler string - AwsKey string - AwsSecret string - AwsSession string - SuppressInit bool - XRayDaemonAddress string // only in standalone - FunctionName string // only in standalone - FunctionVersion string // only in standalone - CorrelationID string // internal use only - // TODO: define new Init type that has the Start fields as well as env vars below. - // In standalone mode, these env vars come from test/init but from environment otherwise. - CustomerEnvironmentVariables map[string]string - SandboxType SandboxType +// RuntimeInfo contains metadata about the runtime used by the Sandbox +type RuntimeInfo struct { + ImageJSON string // image config, e.g {\"layers\":[]} + Arn string // runtime ARN, e.g. arn:awstest:lambda:us-west-2::runtime:python3.8::alpha + Version string // human-readable runtime arn equivalent, e.g. python3.8.v999 } -// Running message is sent to the slicer, part of the protocol. -type Running struct { - WaitStartTimeNs int64 - WaitEndTimeNs int64 - PreLoadTimeNs int64 - PostLoadTimeNs int64 - ExtensionsEnabled bool +// Captures configuration of the operator and runtime domain +// that are only known after INIT is received +type DynamicDomainConfig struct { + // extra hooks to execute at domain start. Currently used for filesystem and network hooks. + // It can be empty. + AdditionalStartHooks []model.Hook + Mounts []model.DriveMount + //TODO: other dynamic configurations for the domain go here } // Reset message is sent to rapid to initiate reset sequence type Reset struct { - Reason string - DeadlineNs int64 - CorrelationID string // internal use only + Reason string + DeadlineNs int64 + InvokeResponseMetrics *InvokeResponseMetrics + TraceID string + LambdaSegmentID string +} + +// Restore message is sent to rapid to restore runtime to make it ready for consecutive invokes +type Restore struct { + AwsKey string + AwsSecret string + AwsSession string + CredentialsExpiry time.Time +} + +type Resync struct { } // Shutdown message is sent to rapid to initiate graceful shutdown type Shutdown struct { - DeadlineNs int64 - CorrelationID string // internal use only + DeadlineNs int64 +} + +// Metrics for response status of LogsAPI/TelemetryAPI `/subscribe` calls +type TelemetrySubscriptionMetrics map[string]int + +func MergeSubscriptionMetrics(logsAPIMetrics TelemetrySubscriptionMetrics, telemetryAPIMetrics TelemetrySubscriptionMetrics) TelemetrySubscriptionMetrics { + metrics := make(map[string]int) + for metric, value := range logsAPIMetrics { + metrics[metric] = value + } + + for metric, value := range telemetryAPIMetrics { + metrics[metric] += value + } + return metrics +} + +// InvokeResponseMetrics are produced while sending streaming invoke response to WP +type InvokeResponseMetrics struct { + StartReadingResponseMonoTimeMs int64 + FinishReadingResponseMonoTimeMs int64 + TimeShapedNs int64 + ProducedBytes int64 + OutboundThroughputBps int64 // in bytes per second + FunctionResponseMode FunctionResponseMode + RuntimeCalledResponse bool } -// Metrics for response status of LogsAPI `/subscribe` calls -type LogsAPIMetrics map[string]int +func IsResponseStreamingMetrics(metrics *InvokeResponseMetrics) bool { + if metrics == nil { + return false + } + return metrics.FunctionResponseMode == FunctionResponseModeStreaming +} type DoneMetadata struct { NumActiveExtensions int @@ -141,25 +246,26 @@ type DoneMetadata struct { ExtensionNames string RuntimeRelease string // Metrics for response status of LogsAPI `/subscribe` calls - LogsAPIMetrics LogsAPIMetrics - InvokeRequestReadTimeNs int64 - InvokeRequestSizeBytes int64 - InvokeCompletionTimeNs int64 - InvokeReceivedTime int64 - RuntimeReadyTime int64 + LogsAPIMetrics TelemetrySubscriptionMetrics + InvokeRequestReadTimeNs int64 + InvokeRequestSizeBytes int64 + InvokeCompletionTimeNs int64 + InvokeReceivedTime int64 + RuntimeReadyTime int64 + RuntimeTimeThrottledMs int64 + RuntimeProducedBytes int64 + RuntimeOutboundThroughputBps int64 } type Done struct { - WaitForExit bool - ErrorType fatalerror.ErrorType - CorrelationID string // internal use only - Meta DoneMetadata + WaitForExit bool + ErrorType fatalerror.ErrorType + Meta DoneMetadata } type DoneFail struct { - ErrorType fatalerror.ErrorType - CorrelationID string // internal use only - Meta DoneMetadata + ErrorType fatalerror.ErrorType + Meta DoneMetadata } // ErrInvalidInvokeID is returned when invokeID provided in Invoke2 does not match one provided in Token @@ -171,6 +277,22 @@ var ErrInvalidReservationToken = fmt.Errorf("ErrInvalidReservationToken") // ErrInvalidFunctionVersion is returned when functionVersion provided in Invoke2 does not match one provided in Token var ErrInvalidFunctionVersion = fmt.Errorf("ErrInvalidFunctionVersion") +// ErrInvalidFunctionResponseMode is returned when the value sent by runtime during Invoke2 +// is not a constant of type interop.FunctionResponseMode +var ErrInvalidFunctionResponseMode = fmt.Errorf("ErrInvalidFunctionResponseMode") + +// ErrInvalidInvokeResponseMode is returned when optional InvokeResponseMode header provided in Invoke2 is not a constant of type interop.InvokeResponseMode +var ErrInvalidInvokeResponseMode = fmt.Errorf("ErrInvalidInvokeResponseMode") + +// ErrInvalidMaxPayloadSize is returned when optional MaxPayloadSize header provided in Invoke2 is invalid +var ErrInvalidMaxPayloadSize = fmt.Errorf("ErrInvalidMaxPayloadSize") + +// ErrInvalidResponseBandwidthRate is returned when optional ResponseBandwidthRate header provided in Invoke2 is invalid +var ErrInvalidResponseBandwidthRate = fmt.Errorf("ErrInvalidResponseBandwidthRate") + +// ErrInvalidResponseBandwidthBurstSize is returned when optional ResponseBandwidthBurstSize header provided in Invoke2 is invalid +var ErrInvalidResponseBandwidthBurstSize = fmt.Errorf("ErrInvalidResponseBandwidthBurstSize") + // ErrMalformedCustomerHeaders is returned when customer headers format is invalid var ErrMalformedCustomerHeaders = fmt.Errorf("ErrMalformedCustomerHeaders") @@ -180,6 +302,20 @@ var ErrResponseSent = fmt.Errorf("ErrResponseSent") // ErrReservationExpired is returned when invoke arrived after InvackDeadline var ErrReservationExpired = fmt.Errorf("ErrReservationExpired") +// ErrInternalPlatformError is returned when internal platform error occurred +type ErrInternalPlatformError struct{} + +func (s *ErrInternalPlatformError) Error() string { + return "ErrInternalPlatformError" +} + +// ErrTruncatedResponse is returned when response is truncated +type ErrTruncatedResponse struct{} + +func (s *ErrTruncatedResponse) Error() string { + return "ErrTruncatedResponse" +} + // ErrorResponseTooLarge is returned when response Payload exceeds shared memory buffer size type ErrorResponseTooLarge struct { MaxResponseSize int @@ -211,17 +347,21 @@ func (s *ErrorResponseTooLarge) AsInteropError() *ErrorResponse { return &resp } -// Server implements Slicer communication protocol. +// Server used for sending messages and sharing data between the Runtime API handlers and the +// internal platform facing servers. For example, +// +// responseCtx.SendResponse(...) +// +// will send the response payload and metadata provided by the runtime to the platform, through the internal +// protocol used by the specific implementation +// TODO: rename this to InvokeResponseContext, used to send responses from handlers to platform-facing server type Server interface { - // StartAcceptingDirectInvokes starts accepting on direct invoke socket (if one is available) - StartAcceptingDirectInvokes() error - - // SendErrorResponse sends response. + // SendResponse sends response. // Errors returned: // ErrInvalidInvokeID - validation error indicating that provided invokeID doesn't match current invokeID // ErrResponseSent - validation error indicating that response with given invokeID was already sent // Non-nil error - non-nil error indicating transport failure - SendResponse(invokeID string, contentType string, response io.Reader) error + SendResponse(invokeID string, headers map[string]string, response io.Reader, trailers http.Header, request *CancellableRequest) error // SendErrorResponse sends error response. // Errors returned: @@ -229,61 +369,36 @@ type Server interface { // ErrResponseSent - validation error indicating that response with given invokeID was already sent // Non-nil error - non-nil error indicating transport failure SendErrorResponse(invokeID string, response *ErrorResponse) error + SendInitErrorResponse(invokeID string, response *ErrorResponse) error // GetCurrentInvokeID returns current invokeID. // NOTE, in case of INIT, when invokeID is not known in advance (e.g. provisioned concurrency), // returned invokeID will contain empty value. GetCurrentInvokeID() string - // CommitMessage confirms that the message written through SendResponse and SendErrorResponse is complete. - CommitResponse() error - - // SendRunning sends GIRD RUNNING. - // Returns error on transport failure. - SendRunning(*Running) error - - // SendRuntimeReady sends GIRD RTREADY + // SendRuntimeReady sends a message indicating the runtime has called /invocation/next. + // The checkpoint allows us to compute the overhead due to Extensions by substracting it + // from the time when all extensions have called /next. + // TODO: this method is a lifecycle event used only for metrics, and doesn't belong here SendRuntimeReady() error +} - // SendDone sends GIRD DONE. - // Returns error on transport failure. - SendDone(*Done) error - - // SendDone sends GIRD DONEFAIL. - // Returns error on transport failure. - SendDoneFail(*DoneFail) error - - // StartChan returns Start emitter - StartChan() <-chan *Start - - // InvokeChan returns Invoke emitter - InvokeChan() <-chan *Invoke - - // ResetChan returns Reset emitter - ResetChan() <-chan *Reset - - // ShutdownChan returns Shutdown emitter - ShutdownChan() <-chan *Shutdown - - // TransportErrorChan emits errors if there was parsing/connection issue - TransportErrorChan() <-chan error - - // Clear is called on rapid reset. It should leave server prepared for new invocations - Clear() - - // IsResponseSent exposes is response sent flag - IsResponseSent() bool - - // The following are used by standalone rapid only - // TODO refactor to decouple the interfaces +type InternalStateGetter func() statejson.InternalStateDescription - SetInternalStateGetter(cb InternalStateGetter) +const OnDemandInitTelemetrySource string = "on-demand" +const ProvisionedConcurrencyInitTelemetrySource string = "provisioned-concurrency" +const InitCachingInitTelemetrySource string = "snap-start" - Init(i *Start, invokeTimeoutMs int64) +func InferTelemetryInitSource(initCachingEnabled bool, sandboxType SandboxType) string { + initSource := OnDemandInitTelemetrySource - Invoke(responseWriter http.ResponseWriter, invoke *Invoke) error + // ToDo: Unify this selection of SandboxType by using the START message + // after having a roadmap on the combination of INIT modes + if initCachingEnabled { + initSource = InitCachingInitTelemetrySource + } else if sandboxType == SandboxPreWarmed { + initSource = ProvisionedConcurrencyInitTelemetrySource + } - Shutdown(shutdown *Shutdown) *statejson.InternalStateDescription + return initSource } - -type InternalStateGetter func() statejson.InternalStateDescription diff --git a/lambda/interop/model_test.go b/lambda/interop/model_test.go new file mode 100644 index 0000000..9ad4d17 --- /dev/null +++ b/lambda/interop/model_test.go @@ -0,0 +1,27 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMergeSubscriptionMetrics(t *testing.T) { + logsAPIMetrics := map[string]int{ + "server_error": 1, + "client_error": 2, + } + + telemetryAPIMetrics := map[string]int{ + "server_error": 1, + "success": 5, + } + + metrics := MergeSubscriptionMetrics(logsAPIMetrics, telemetryAPIMetrics) + assert.Equal(t, 5, metrics["success"]) + assert.Equal(t, 2, metrics["server_error"]) + assert.Equal(t, 2, metrics["client_error"]) +} diff --git a/lambda/interop/sandbox_model.go b/lambda/interop/sandbox_model.go index dddfcf2..b5d15b0 100644 --- a/lambda/interop/sandbox_model.go +++ b/lambda/interop/sandbox_model.go @@ -3,20 +3,183 @@ package interop -// Init represents an init message and is currently only used in standalone +import ( + "time" + + "go.amzn.com/lambda/fatalerror" +) + +// Init represents an init message +// In Rapid Shim, this is a START GirD message +// In Rapid Daemon, this is an INIT GirP message type Init struct { InvokeID string Handler string AwsKey string AwsSecret string AwsSession string + CredentialsExpiry time.Time SuppressInit bool + InvokeTimeoutMs int64 // timeout duration of whole invoke + InitTimeoutMs int64 // timeout duration for init only XRayDaemonAddress string // only in standalone FunctionName string // only in standalone FunctionVersion string // only in standalone - CorrelationID string // internal use only - // TODO: define new Init type that has the Start fields as well as env vars below. // In standalone mode, these env vars come from test/init but from environment otherwise. CustomerEnvironmentVariables map[string]string - SandboxType + SandboxType SandboxType + // there is no dynamic config at the moment for the runtime domain + OperatorDomainExtraConfig DynamicDomainConfig + RuntimeInfo RuntimeInfo + Bootstrap Bootstrap + EnvironmentVariables EnvironmentVariables // contains env vars for agents and runtime procs +} + +// InitStarted contains metadata about the initialized sandbox +// In Rapid Shim, this translates to a RUNNING GirD message to Slicer +// In Rapid Daemon, this is followed by a SANDBOX GirP message to MM +type InitStarted struct { + WaitStartTimeNs int64 + WaitEndTimeNs int64 + PreLoadTimeNs int64 + PostLoadTimeNs int64 + ExtensionsEnabled bool + Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent +} + +// InitSuccess indicates that runtime/extensions initialization completed successfully +// In Rapid Shim, this translates to a DONE GirD message to Slicer +// In Rapid Daemon, this is followed by a DONEDONE GirP message to MM +type InitSuccess struct { + NumActiveExtensions int // indicates number of active extensions + ExtensionNames string // file names of extensions in /opt/extensions + RuntimeRelease string + LogsAPIMetrics TelemetrySubscriptionMetrics // used if telemetry API enabled + Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent +} + +// InitFailure indicates that runtime/extensions initialization failed due to process exit or /error calls +// In Rapid Shim, this translates to either a DONE or a DONEFAIL GirD message to Slicer (depending on extensions mode) +// However, even on failure, the next invoke is expected to work with a suppressed init - i.e. we init again as aprt of the invoke +type InitFailure struct { + ResetReceived bool // indicates if failure happened due to a reset received + RequestReset bool // Indicates whether reset should be requested on init failure + ErrorType fatalerror.ErrorType + ErrorMessage error + NumActiveExtensions int + RuntimeRelease string // value of the User Agent HTTP header provided by runtime + LogsAPIMetrics TelemetrySubscriptionMetrics + Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent +} + +// ResponseMetrics groups metrics related to the response stream +type ResponseMetrics struct { + RuntimeTimeThrottledMs int64 + RuntimeProducedBytes int64 + RuntimeOutboundThroughputBps int64 +} + +// InvokeMetrics groups metrics related to the invoke phase +type InvokeMetrics struct { + InvokeRequestReadTimeNs int64 + InvokeRequestSizeBytes int64 + RuntimeReadyTime int64 +} + +// InvokeSuccess is the success response to invoke phase end +type InvokeSuccess struct { + RuntimeRelease string // value of the User Agent HTTP header provided by runtime + NumActiveExtensions int + ExtensionNames string + InvokeCompletionTimeNs int64 + InvokeReceivedTime int64 + LogsAPIMetrics TelemetrySubscriptionMetrics + ResponseMetrics ResponseMetrics + InvokeMetrics InvokeMetrics +} + +// InvokeFailure is the failure response to invoke phase end +type InvokeFailure struct { + ResetReceived bool // indicates if failure happened due to a reset received + RequestReset bool // indicates if reset must be requested after the failure + ErrorType fatalerror.ErrorType + ErrorMessage error + RuntimeRelease string // value of the User Agent HTTP header provided by runtime + NumActiveExtensions int + InvokeReceivedTime int64 + LogsAPIMetrics TelemetrySubscriptionMetrics + ResponseMetrics ResponseMetrics + InvokeMetrics InvokeMetrics + ExtensionNames string + DefaultErrorResponse *ErrorResponse // error resp constructed by platform during fn errors +} + +// ResetSuccess is the success response to reset request +type ResetSuccess struct { + ExtensionsResetMs int64 + ErrorType fatalerror.ErrorType + ResponseMetrics ResponseMetrics +} + +// ResetFailure is the failure response to reset request +type ResetFailure struct { + ExtensionsResetMs int64 + ErrorType fatalerror.ErrorType + ResponseMetrics ResponseMetrics +} + +// ShutdownSuccess is the response to a shutdown request +type ShutdownSuccess struct { + ErrorType fatalerror.ErrorType +} + +// SandboxInfoFromInit captures data from init request that +// is required during invoke (e.g. for suppressed init) +type SandboxInfoFromInit struct { + EnvironmentVariables EnvironmentVariables // contains agent env vars (creds, customer, platform) + SandboxType SandboxType // indicating Pre-Warmed, On-Demand etc + RuntimeBootstrap Bootstrap // contains the runtime bootstrap binary path, Cwd, Args, Env, Cmd +} + +// RapidContext expose methods for functionality of the Rapid Core library +type RapidContext interface { + HandleInit(i *Init, started chan<- InitStarted, success chan<- InitSuccess, failure chan<- InitFailure) + HandleInvoke(i *Invoke, sbMetadata SandboxInfoFromInit) (InvokeSuccess, *InvokeFailure) + HandleReset(reset *Reset, invokeReceivedTime int64, InvokeResponseMetrics *InvokeResponseMetrics) (ResetSuccess, *ResetFailure) + HandleShutdown(shutdown *Shutdown) ShutdownSuccess + HandleRestore(restore *Restore) error + Clear() +} + +// SandboxContext represents the sandbox lifecycle context +type SandboxContext interface { + Init(i *Init, timeoutMs int64) (InitStarted, InitContext) + Reset(reset *Reset) (ResetSuccess, *ResetFailure) + Shutdown(shutdown *Shutdown) ShutdownSuccess + Restore(restore *Restore) error + + // TODO: refactor this + // invokeReceivedTime and InvokeResponseMetrics are needed to compute the runtimeDone metrics + // in case of a Reset during an invoke (reset.reason=failure or reset.reason=timeout). + // Ideally: + // - the InvokeContext will have a Reset method to deal with Reset during an invoke and will hold invokeReceivedTime and InvokeResponseMetrics + // - the SandboxContext will have its own Reset/Spindown method + SetInvokeReceivedTime(invokeReceivedTime int64) + SetInvokeResponseMetrics(metrics *InvokeResponseMetrics) +} + +// InitContext represents the lifecycle of a sandbox initialization +type InitContext interface { + Wait() (InitSuccess, *InitFailure) + Reserve() InvokeContext +} + +// InvokeContext represents the lifecycle of a sandbox reservation +type InvokeContext interface { + SendRequest(i *Invoke) + Wait() (InvokeSuccess, *InvokeFailure) +} + +// Restored message is sent to Slicer to inform Runtime Restore Hook execution was successful +type Restored struct { } diff --git a/lambda/logging/doc.go b/lambda/logging/doc.go index 92637a1..a1f7e95 100644 --- a/lambda/logging/doc.go +++ b/lambda/logging/doc.go @@ -2,24 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 /* - RAPID emits or proxies the following sources of logging: -1. Internal logs: RAPID's own application logs into stderr for operational use, visible only internally -2. Function stream-based logs: Runtime's stdout and stderr, read as newline separated lines -3. Function message-based logs: Stock runtimes communicate using a custom TLV protocol over a Unix pipe -4. Extension stream-based logs: Extension's stdout and stderr, read as newline separated lines -5. Platform logs: Logs that RAPID generates, but is visible in customer's logs. - - -It has the following log sinks, which may further be egressed to other sinks (e.g. CloudWatch) by external telemetry agents: - -1. Internal Log File (stderr): stderr is redirected to a file specified by Sandbox Factory via env-vars, and accessible via StreamQuery -2. Stdout: stream-based function logs are output to RAPID's stdout process, and read by a telemetry agent -3. Telemetry API MSG-verb events: function messages-based logs are written using GirP protocol into the console socket specified by Sandbox Factory env-vars -4. Telemetry API LOGX-verb events: extension stream-based logs are written using GirP protocol into the console socket specified by Sandbox Factory env-vars -5. Telemetry API LOGP-verb events: platform logs are written using GirP protocol into the console socket specified by Sandbox Factory env-vars -6. Tail logs: a truncated version of function stream-based and message-based logs are written along with the invocation response to the frontend when 'debug logging' is enabled - + 1. Internal logs: RAPID's own application logs into stderr for operational use, visible only internally + 2. Function stream-based logs: Runtime's stdout and stderr, read as newline separated lines + 3. Function message-based logs: Stock runtimes communicate using a custom TLV protocol over a Unix pipe + 4. Extension stream-based logs: Extension's stdout and stderr, read as newline separated lines + 5. Platform logs: Logs that RAPID generates, but is visible either in customer's logs or via Logs API + (e.g. EXTENSION, RUNTIME, RUNTIMEDONE, IMAGE) */ package logging diff --git a/lambda/logging/internal_log_test.go b/lambda/logging/internal_log_test.go index b94ac88..3ec537f 100644 --- a/lambda/logging/internal_log_test.go +++ b/lambda/logging/internal_log_test.go @@ -8,7 +8,7 @@ import ( "fmt" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "io/ioutil" + "io" "log" "testing" ) @@ -67,14 +67,14 @@ func TestInternalFormatter(t *testing.T) { } func BenchmarkLogPrint(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { log.Print(1, "two", true) } } func BenchmarkLogrusPrint(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { logrus.Print(1, "two", true) } @@ -83,21 +83,21 @@ func BenchmarkLogrusPrint(b *testing.B) { func BenchmarkLogrusPrintInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) for n := 0; n < b.N; n++ { l.Print(1, "two", true) } } func BenchmarkLogPrintf(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { log.Printf("field:%v,field:%v,field:%v", 1, "two", true) } } func BenchmarkLogrusPrintf(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { logrus.Printf("field:%v,field:%v,field:%v", 1, "two", true) } @@ -106,14 +106,14 @@ func BenchmarkLogrusPrintf(b *testing.B) { func BenchmarkLogrusPrintfInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) for n := 0; n < b.N; n++ { l.Printf("field:%v,field:%v,field:%v", 1, "two", true) } } func BenchmarkLogrusDebugLogLevelDisabled(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) logrus.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { logrus.Debug(1, "two", true) @@ -122,7 +122,7 @@ func BenchmarkLogrusDebugLogLevelDisabled(b *testing.B) { func BenchmarkLogrusDebugLogLevelDisabledInternalFormatter(b *testing.B) { var l = logrus.New() - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) l.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { l.Debug(1, "two", true) @@ -130,7 +130,7 @@ func BenchmarkLogrusDebugLogLevelDisabledInternalFormatter(b *testing.B) { } func BenchmarkLogrusDebugLogLevelEnabled(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) logrus.SetLevel(logrus.DebugLevel) for n := 0; n < b.N; n++ { logrus.Debug(1, "two", true) @@ -140,7 +140,7 @@ func BenchmarkLogrusDebugLogLevelEnabled(b *testing.B) { func BenchmarkLogrusDebugLogLevelEnabledInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) l.SetLevel(logrus.DebugLevel) for n := 0; n < b.N; n++ { l.Debug(1, "two", true) @@ -148,7 +148,7 @@ func BenchmarkLogrusDebugLogLevelEnabledInternalFormatter(b *testing.B) { } func BenchmarkLogrusDebugWithFieldLogLevelDisabled(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) logrus.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { logrus.WithField("field", "value").Debug(1, "two", true) @@ -158,7 +158,7 @@ func BenchmarkLogrusDebugWithFieldLogLevelDisabled(b *testing.B) { func BenchmarkLogrusDebugWithFieldLogLevelDisabledInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) l.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { l.WithField("field", "value").Debug(1, "two", true) diff --git a/lambda/logging/platform_log.go b/lambda/logging/platform_log.go deleted file mode 100644 index 5154f93..0000000 --- a/lambda/logging/platform_log.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "fmt" - "io" - "log" - "strings" -) - -// TODO PlatformLogger interface has this LogExtensionInitEvent() method so it's easier to assert against it in standalone tests; -// TODO However, this makes interface harder to maintain (you are supposed to add new method to PlatformLogger for each event type) -// TODO We need to remove those methods and make PlatformLogger just a log.Logger interface - -// PlatformLogger is a logger that logs platform lines to customers' logs -type PlatformLogger interface { - Printf(fmt string, args ...interface{}) - LogExtensionInitEvent(agentName, state, errorType string, subscriptions []string) -} - -// FormattedPlatformLogger formats and logs platform lines to customers' logs via Telemetry API -type FormattedPlatformLogger struct { - logger *log.Logger -} - -// NewPlatformLogger is a logger for logging Platform log lines into customers' logs -func NewPlatformLogger(output, tailLogWriter io.Writer) *FormattedPlatformLogger { - prefix, flags := "", 0 - return &FormattedPlatformLogger{ - logger: log.New(io.MultiWriter(output, tailLogWriter), prefix, flags), - } -} - -// LogExtensionInitEvent formats and logs a line containing agent info -func (l *FormattedPlatformLogger) LogExtensionInitEvent(agentName, state, errorType string, subscriptions []string) { - format := "EXTENSION\tName: %s\tState: %s\tEvents: [%s]" - line := fmt.Sprintf(format, agentName, state, strings.Join(subscriptions, ",")) - if len(errorType) > 0 { - line += fmt.Sprintf("\tError Type: %s", errorType) - } - l.logger.Println(line) -} - -func (l *FormattedPlatformLogger) Printf(fmt string, args ...interface{}) { - fmt += "\n" // we append newline to the logline because that's how they are separated on recepient - l.logger.Printf(fmt, args...) -} - -func SupernovaInvalidTaskConfigRepr(err error) func(error) string { - return func(unused error) string { - return fmt.Sprintf("IMAGE\tInvalid task config: %s", err) - } -} - -func SupernovaLaunchErrorRepr(entrypoint []string, cmd []string, workingDir string) func(error) string { - return func(err error) string { - return fmt.Sprintf("IMAGE\tLaunch error: %s\tEntrypoint: [%s]\tCmd: [%s]\tWorkingDir: [%s]", - err, - strings.Join(entrypoint, ","), - strings.Join(cmd, ","), - workingDir) - } -} diff --git a/lambda/logging/platform_log_test.go b/lambda/logging/platform_log_test.go deleted file mode 100644 index 8b01778..0000000 --- a/lambda/logging/platform_log_test.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestPlatformLogExtensionLine(t *testing.T) { - var buf bytes.Buffer - var tailLogBuf bytes.Buffer - logger := NewPlatformLogger(&buf, &tailLogBuf) - - logger.LogExtensionInitEvent("agentName", "Registered", "", []string{"INVOKE", "SHUTDOWN"}) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\n", buf.String()) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\n", tailLogBuf.String()) -} - -func TestPlatformLogExtensionLineWithError(t *testing.T) { - var buf bytes.Buffer - var tailLogBuf bytes.Buffer - logger := NewPlatformLogger(&buf, &tailLogBuf) - - errorType := "Extension.FooBar" - logger.LogExtensionInitEvent("agentName", "Registered", errorType, []string{"INVOKE", "SHUTDOWN"}) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\tError Type: "+errorType+"\n", buf.String()) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\tError Type: "+errorType+"\n", tailLogBuf.String()) -} - -func TestPlatformLogPrintf(t *testing.T) { - var buf bytes.Buffer - var tailLogBuf bytes.Buffer - logger := NewPlatformLogger(&buf, &tailLogBuf) - - logger.Printf("bebe %s %d", "as", 12) - require.Equal(t, "bebe as 12\n", buf.String()) - require.Equal(t, "bebe as 12\n", tailLogBuf.String()) -} diff --git a/lambda/logging/taillog.go b/lambda/logging/taillog.go deleted file mode 100644 index 9fe5352..0000000 --- a/lambda/logging/taillog.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "io" - "sync" -) - -// TailLogWriter writes tail/debug log to provided io.Writer -type TailLogWriter struct { - out io.Writer - enabled bool - mutex sync.Mutex -} - -// Enable enables log writer. -func (lw *TailLogWriter) Enable() { - lw.mutex.Lock() - defer lw.mutex.Unlock() - - lw.enabled = true -} - -// Disable disables log writer. -func (lw *TailLogWriter) Disable() { - lw.mutex.Lock() - defer lw.mutex.Unlock() - - lw.enabled = false -} - -// Writer wraps the basic io.Write method -func (lw *TailLogWriter) Write(p []byte) (n int, err error) { - lw.mutex.Lock() - defer lw.mutex.Unlock() - - if lw.enabled { - return lw.out.Write(p) - } - // Else returns a successful write so that MultiWriter won't stop - return len(p), nil -} - -// NewTailLogWriter returns a new invoke tail log writer, default output is discarded until output is configured. -func NewTailLogWriter(w io.Writer) *TailLogWriter { - return &TailLogWriter{ - out: w, - enabled: false, - } -} diff --git a/lambda/logging/taillog_test.go b/lambda/logging/taillog_test.go deleted file mode 100644 index 2bc444c..0000000 --- a/lambda/logging/taillog_test.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDisableDebugLog(t *testing.T) { - buf := new(bytes.Buffer) - tailLogWriter := NewTailLogWriter(buf) - tailLogWriter.Disable() - - tailLogWriter.Write([]byte("hello_world")) - assert.Len(t, buf.String(), 0) -} - -func TestEnableDebugLog(t *testing.T) { - buf := new(bytes.Buffer) - tailLogWriter := NewTailLogWriter(buf) - tailLogWriter.Enable() - - tailLogWriter.Write([]byte("hello_world")) - assert.Equal(t, "hello_world", buf.String()) -} diff --git a/lambda/rapi/handler/agentiniterror_test.go b/lambda/rapi/handler/agentiniterror_test.go index 571031e..50b9143 100644 --- a/lambda/rapi/handler/agentiniterror_test.go +++ b/lambda/rapi/handler/agentiniterror_test.go @@ -6,7 +6,7 @@ package handler import ( "context" "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -60,7 +60,7 @@ func TestAgentInitErrorMissingErrorHeader(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, errAgentMissingHeader, errorResponse.ErrorType) } @@ -77,7 +77,7 @@ func TestAgentInitErrorUnknownAgent(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, errAgentIdentifierUnknown, errorResponse.ErrorType) } @@ -97,7 +97,7 @@ func TestAgentInitErrorAgentInvalidState(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, errAgentInvalidState, errorResponse.ErrorType) } @@ -118,7 +118,7 @@ func TestAgentInitErrorRequestAccepted(t *testing.T) { assert.Equal(t, http.StatusAccepted, responseRecorder.Code) var response model.StatusResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, "OK", response.Status) diff --git a/lambda/rapi/handler/agentnext_test.go b/lambda/rapi/handler/agentnext_test.go index ef14e49..003c4b6 100644 --- a/lambda/rapi/handler/agentnext_test.go +++ b/lambda/rapi/handler/agentnext_test.go @@ -7,7 +7,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -52,7 +52,7 @@ func TestRenderAgentInvokeUnknownAgent(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, errAgentIdentifierUnknown, errorResponse.ErrorType) @@ -75,7 +75,7 @@ func TestRenderAgentInvokeInvalidAgentState(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, errAgentInvalidState, errorResponse.ErrorType) @@ -118,7 +118,7 @@ func TestRenderAgentInvokeNextHappy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentInvokeEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -167,7 +167,7 @@ func TestRenderAgentInternalInvokeNextHappy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentInvokeEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -212,7 +212,7 @@ func TestRenderAgentInternalShutdownEvent(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentShutdownEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -254,7 +254,7 @@ func TestRenderAgentExternalShutdownEvent(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentShutdownEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -297,7 +297,7 @@ func TestRenderAgentInvokeNextHappyEmptyTraceID(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentInvokeEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Nil(t, response.Tracing) diff --git a/lambda/rapi/handler/agentregister.go b/lambda/rapi/handler/agentregister.go index 776ac28..8882965 100644 --- a/lambda/rapi/handler/agentregister.go +++ b/lambda/rapi/handler/agentregister.go @@ -6,10 +6,9 @@ package handler import ( "encoding/json" "errors" - "io/ioutil" + "io" "net/http" - "github.com/go-chi/render" log "github.com/sirupsen/logrus" "go.amzn.com/lambda/core" "go.amzn.com/lambda/rapi/model" @@ -26,7 +25,7 @@ type RegisterRequest struct { } func parseRegister(request *http.Request) (*RegisterRequest, error) { - body, err := ioutil.ReadAll(request.Body) + body, err := io.ReadAll(request.Body) if err != nil { return nil, err } @@ -70,7 +69,6 @@ func (h *agentRegisterHandler) ServeHTTP(writer http.ResponseWriter, request *ht } func (h *agentRegisterHandler) renderResponse(agentID string, writer http.ResponseWriter, request *http.Request) { - render.Status(request, http.StatusOK) writer.Header().Set(LambdaAgentIdentifier, agentID) metadata := h.registrationService.GetFunctionMetadata() @@ -81,7 +79,10 @@ func (h *agentRegisterHandler) renderResponse(agentID string, writer http.Respon Handler: metadata.Handler, } - render.JSON(writer, request, resp) + if err := rendering.RenderJSON(http.StatusOK, writer, request, resp); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(writer, err.Error(), http.StatusInternalServerError) + } } func (h *agentRegisterHandler) registerExternalAgent(agent *core.ExternalAgent, registerRequest *RegisterRequest, writer http.ResponseWriter, request *http.Request) { diff --git a/lambda/rapi/handler/agentregister_test.go b/lambda/rapi/handler/agentregister_test.go index 185f249..35456ee 100644 --- a/lambda/rapi/handler/agentregister_test.go +++ b/lambda/rapi/handler/agentregister_test.go @@ -7,7 +7,6 @@ import ( "bytes" "encoding/json" "io" - "io/ioutil" "net/http" "net/http/httptest" "testing" @@ -41,7 +40,7 @@ func TestRenderAgentRegisterInvalidAgentName(t *testing.T) { require.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) require.Equal(t, http.StatusForbidden, responseRecorder.Code) require.Equal(t, errAgentNameInvalid, errorResponse.ErrorType) @@ -63,7 +62,7 @@ func TestRenderAgentRegisterRegistrationClosed(t *testing.T) { require.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) require.Equal(t, http.StatusForbidden, responseRecorder.Code) require.Equal(t, errAgentRegistrationClosed, errorResponse.ErrorType) @@ -88,7 +87,7 @@ func TestRenderAgentRegisterInvalidAgentState(t *testing.T) { require.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) require.Equal(t, http.StatusForbidden, responseRecorder.Code) require.Equal(t, errAgentInvalidState, errorResponse.ErrorType) @@ -311,7 +310,7 @@ func TestRenderAgentResponse(t *testing.T) { require.Equal(t, http.StatusOK, responseRecorder.Code) registerResponse := ExtensionRegisterResponseWithConfig{} - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, ®isterResponse) assert.Equal(t, tt.expectedRegistrationResponse.FunctionName, registerResponse.FunctionName) assert.Equal(t, tt.expectedRegistrationResponse.FunctionVersion, registerResponse.FunctionVersion) diff --git a/lambda/rapi/handler/constants.go b/lambda/rapi/handler/constants.go index 01553f3..5912d71 100644 --- a/lambda/rapi/handler/constants.go +++ b/lambda/rapi/handler/constants.go @@ -20,7 +20,6 @@ const ( errAgentMissingHeader string = "Extension.MissingHeader" errTooManyExtensions string = "Extension.TooManyExtensions" errInvalidEventType string = "Extension.InvalidEventType" - errLogsSubscriptionClosed string = "Logs.SubscriptionClosed" errInvalidRequestFormat string = "InvalidRequestFormat" StateTransitionFailedForExtensionMessageFormat string = "State transition from %s to %s failed for extension %s. Error: %s" diff --git a/lambda/rapi/handler/credentials_test.go b/lambda/rapi/handler/credentials_test.go index fa4a2bd..d5a1090 100644 --- a/lambda/rapi/handler/credentials_test.go +++ b/lambda/rapi/handler/credentials_test.go @@ -21,13 +21,11 @@ const InitCachingAwsKey = "sampleAwsKey" const InitCachingAwsSecret = "sampleAwsSecret" const InitCachingAwsSessionToken = "sampleAwsSessionToken" -func getRequestContext(isServiceBlocked bool) (http.Handler, *http.Request, *httptest.ResponseRecorder) { +func getRequestContext() (http.Handler, *http.Request, *httptest.ResponseRecorder) { flowTest := testdata.NewFlowTest() - if isServiceBlocked { - flowTest.ConfigureForBlockedInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) - } else { - flowTest.ConfigureForInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) - } + + flowTest.ConfigureForInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) + handler := NewCredentialsHandler(flowTest.CredentialsService) responseRecorder := httptest.NewRecorder() appCtx := flowTest.AppCtx @@ -38,14 +36,14 @@ func getRequestContext(isServiceBlocked bool) (http.Handler, *http.Request, *htt } func TestEmptyAuthorizationHeader(t *testing.T) { - handler, request, responseRecorder := getRequestContext(false) + handler, request, responseRecorder := getRequestContext() handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusNotFound, responseRecorder.Code) } func TestArbitraryAuthorizationHeader(t *testing.T) { - handler, request, responseRecorder := getRequestContext(false) + handler, request, responseRecorder := getRequestContext() request.Header.Set("Authorization", "randomAuthToken") handler.ServeHTTP(responseRecorder, request) @@ -53,7 +51,7 @@ func TestArbitraryAuthorizationHeader(t *testing.T) { } func TestSuccessfulGet(t *testing.T) { - handler, request, responseRecorder := getRequestContext(false) + handler, request, responseRecorder := getRequestContext() request.Header.Set("Authorization", InitCachingToken) handler.ServeHTTP(responseRecorder, request) @@ -67,25 +65,6 @@ func TestSuccessfulGet(t *testing.T) { expirationTime, err := time.Parse(time.RFC3339, responseMap["Expiration"]) assert.NoError(t, err) durationUntilExpiration := time.Until(expirationTime) - assert.True(t, durationUntilExpiration.Minutes() <= 16 && durationUntilExpiration.Minutes() > 15 && durationUntilExpiration.Hours() < 1) + assert.True(t, durationUntilExpiration.Minutes() <= 30 && durationUntilExpiration.Minutes() > 29 && durationUntilExpiration.Hours() < 1) log.Println(responseRecorder.Body.String()) } - -func TestBlockedGet(t *testing.T) { - handler, request, responseRecorder := getRequestContext(true) - request.Header.Set("Authorization", InitCachingToken) - - timeout := time.After(1 * time.Second) - done := make(chan bool) - - go func() { - handler.ServeHTTP(responseRecorder, request) - done <- true - }() - - select { - case <-done: - t.Fatal("Endpoint should be blocked!") - case <-timeout: - } -} diff --git a/lambda/rapi/handler/initerror.go b/lambda/rapi/handler/initerror.go index 4015a11..d28e2d4 100644 --- a/lambda/rapi/handler/initerror.go +++ b/lambda/rapi/handler/initerror.go @@ -5,11 +5,12 @@ package handler import ( "encoding/json" - "io/ioutil" + "io" "net/http" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" "go.amzn.com/lambda/core" "go.amzn.com/lambda/rapi/rendering" @@ -19,6 +20,7 @@ import ( type initErrorHandler struct { registrationService core.RegistrationService + eventsAPI telemetry.EventsAPI } func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -30,6 +32,10 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R } runtime := h.registrationService.GetRuntime() + + // the previousStateName is needed to define if the init/error is called for INIT or RESTORE + previousStateName := runtime.GetState().Name() + if err := runtime.InitError(); err != nil { log.Warn(err) rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, @@ -39,18 +45,24 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R errorType := request.Header.Get("Lambda-Runtime-Function-Error-Type") - errorBody, err := ioutil.ReadAll(request.Body) + errorBody, err := io.ReadAll(request.Body) if err != nil { log.WithError(err).Warn("Failed to read error body") } + if previousStateName == core.RuntimeRestoringStateName { + h.sendRestoreRuntimeDoneLogEvent() + } else { + h.sendInitRuntimeDoneLogEvent(appCtx) + } + response := &interop.ErrorResponse{ ErrorType: errorType, Payload: errorBody, ContentType: determineJSONContentType(errorBody), } - if err := server.SendErrorResponse(server.GetCurrentInvokeID(), response); err != nil { + if err := server.SendInitErrorResponse(server.GetCurrentInvokeID(), response); err != nil { rendering.RenderInteropError(writer, request, err) return } @@ -62,9 +74,10 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R // NewInitErrorHandler returns a new instance of http handler // for serving /runtime/init/error. -func NewInitErrorHandler(registrationService core.RegistrationService) http.Handler { +func NewInitErrorHandler(registrationService core.RegistrationService, eventsAPI telemetry.EventsAPI) http.Handler { return &initErrorHandler{ registrationService: registrationService, + eventsAPI: eventsAPI, } } @@ -74,3 +87,24 @@ func determineJSONContentType(body []byte) string { } return "application/octet-stream" } + +func (h *initErrorHandler) sendInitRuntimeDoneLogEvent(appCtx appctx.ApplicationContext) { + // ToDo: Convert this to an enum for the whole package to increase readability. + initCachingEnabled := appctx.LoadInitType(appCtx) == appctx.InitCaching + + initSource := interop.InferTelemetryInitSource(initCachingEnabled, appctx.LoadSandboxType(appCtx)) + runtimeDoneData := &telemetry.InitRuntimeDoneData{ + InitSource: initSource, + Status: telemetry.RuntimeDoneFailure, + } + + if err := h.eventsAPI.SendInitRuntimeDone(runtimeDoneData); err != nil { + log.Errorf("Failed to send INITRD: %s", err) + } +} + +func (h *initErrorHandler) sendRestoreRuntimeDoneLogEvent() { + if err := h.eventsAPI.SendRestoreRuntimeDone(telemetry.RuntimeDoneFailure); err != nil { + log.Errorf("Failed to send RESTRD: %s", err) + } +} diff --git a/lambda/rapi/handler/initerror_test.go b/lambda/rapi/handler/initerror_test.go index c2d3d89..c9a5a83 100644 --- a/lambda/rapi/handler/initerror_test.go +++ b/lambda/rapi/handler/initerror_test.go @@ -27,7 +27,7 @@ func runTestInitErrorHandler(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - handler := NewInitErrorHandler(flowTest.RegistrationService) + handler := NewInitErrorHandler(flowTest.RegistrationService, flowTest.EventsAPI) responseRecorder := httptest.NewRecorder() appCtx := flowTest.AppCtx @@ -49,7 +49,7 @@ func runTestInitErrorHandler(t *testing.T) { require.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", responseRecorder.Code, http.StatusAccepted) require.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) - require.Equal(t, "application/json; charset=utf-8", responseRecorder.Header().Get("Content-Type")) + require.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) // Validate init error persisted in the application context. errorResponse := flowTest.InteropServer.ErrorResponse diff --git a/lambda/rapi/handler/invocationerror.go b/lambda/rapi/handler/invocationerror.go index d60b5d6..170c0cb 100644 --- a/lambda/rapi/handler/invocationerror.go +++ b/lambda/rapi/handler/invocationerror.go @@ -6,7 +6,7 @@ package handler import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "go.amzn.com/lambda/interop" @@ -25,6 +25,11 @@ const errorWithCauseContentType = "application/vnd.aws.lambda.error.cause+json" const xrayErrorCauseHeaderName = "Lambda-Runtime-Function-XRay-Error-Cause" const invalidErrorBodyMessage = "Invalid error body" +const ( + contentTypeHeader = "Content-Type" + functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" +) + type invocationErrorHandler struct { registrationService core.RegistrationService } @@ -52,7 +57,7 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * var contentType string var err error - switch request.Header.Get("Content-Type") { + switch request.Header.Get(contentTypeHeader) { case errorWithCauseContentType: errorBody, errorCause, err = h.getErrorBodyForErrorCauseContentType(request) contentType = "application/json" @@ -62,18 +67,20 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * default: errorBody, err = h.getErrorBody(request) errorCause = h.getValidatedErrorCause(request.Header) - contentType = request.Header.Get("Content-Type") + contentType = request.Header.Get(contentTypeHeader) } + functionResponseMode := request.Header.Get(functionResponseModeHeader) if err != nil { log.WithError(err).Warn("Failed to parse error body") } response := &interop.ErrorResponse{ - ErrorType: errorType, - Payload: errorBody, - ErrorCause: errorCause, - ContentType: contentType, + ErrorType: errorType, + Payload: errorBody, + ErrorCause: errorCause, + ContentType: contentType, + FunctionResponseMode: functionResponseMode, } if err := server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), response); err != nil { @@ -95,7 +102,7 @@ func (h *invocationErrorHandler) getErrorType(headers http.Header) string { } func (h *invocationErrorHandler) getErrorBody(request *http.Request) ([]byte, error) { - errorBody, err := ioutil.ReadAll(request.Body) + errorBody, err := io.ReadAll(request.Body) if err != nil { return nil, fmt.Errorf("error reading request body: %s", err) } @@ -120,7 +127,7 @@ func (h *invocationErrorHandler) getValidatedErrorCause(headers http.Header) jso } func (h *invocationErrorHandler) getErrorBodyForErrorCauseContentType(request *http.Request) ([]byte, json.RawMessage, error) { - errorBody, err := ioutil.ReadAll(request.Body) + errorBody, err := io.ReadAll(request.Body) if err != nil { return nil, nil, fmt.Errorf("error reading request body: %s", err) } diff --git a/lambda/rapi/handler/invocationerror_test.go b/lambda/rapi/handler/invocationerror_test.go index 6defa14..2f177fe 100644 --- a/lambda/rapi/handler/invocationerror_test.go +++ b/lambda/rapi/handler/invocationerror_test.go @@ -77,7 +77,7 @@ func runTestInvocationErrorHandler(t *testing.T) { assert.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", responseRecorder.Code, http.StatusAccepted) assert.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) - assert.Equal(t, "application/json; charset=utf-8", responseRecorder.Header().Get("Content-Type")) + assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) errorResponse := flowTest.InteropServer.ErrorResponse assert.NotNil(t, errorResponse) @@ -268,7 +268,8 @@ func TestInvocationResponsePayloadIsDefaultErrorMessageWhenRequestParsingFailsFo invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} request := httptest.NewRequest("POST", "/", bytes.NewReader(invalidRequestBody)) request = addInvocationID(request, invoke.ID) - request.Header.Set("Content-Type", errorWithCauseContentType) + request.Header.Set(contentTypeHeader, errorWithCauseContentType) + request.Header.Set(functionResponseModeHeader, "function-response-mode") // Corresponding invoke must be placed into appCtx. flowTest.ConfigureForInvoke(context.Background(), invoke) @@ -280,6 +281,7 @@ func TestInvocationResponsePayloadIsDefaultErrorMessageWhenRequestParsingFailsFo assert.NotNil(t, errorResponse) assert.Nil(t, flowTest.InteropServer.Response) assert.Equal(t, "application/octet-stream", flowTest.InteropServer.ResponseContentType) + assert.Equal(t, "function-response-mode", flowTest.InteropServer.FunctionResponseMode) invokeResponsePayload := errorResponse.Payload diff --git a/lambda/rapi/handler/invocationnext_test.go b/lambda/rapi/handler/invocationnext_test.go index beebd97..5bddb86 100644 --- a/lambda/rapi/handler/invocationnext_test.go +++ b/lambda/rapi/handler/invocationnext_test.go @@ -92,7 +92,7 @@ func TestRenderInvoke(t *testing.T) { assert.Equal(t, invokePayload, responseRecorder.Body.String()) } -//Cgo calls removed due to crashes while spawning threads under memory pressure. +// Cgo calls removed due to crashes while spawning threads under memory pressure. func TestRenderInvokeDoesNotCallCgo(t *testing.T) { cgoCallsBefore := runtime.NumCgoCall() TestRenderInvoke(t) diff --git a/lambda/rapi/handler/invocationresponse.go b/lambda/rapi/handler/invocationresponse.go index 7c15342..7e47d2e 100644 --- a/lambda/rapi/handler/invocationresponse.go +++ b/lambda/rapi/handler/invocationresponse.go @@ -15,7 +15,10 @@ import ( log "github.com/sirupsen/logrus" ) -const contentTypeOverrideHeaderName = "Content-Type" +const ( + StreamingFunctionResponseMode = "streaming" + ErrInvalidResponseModeHeader = "Runtime.InvalidResponseModeHeader" +) type invocationResponseHandler struct { registrationService core.RegistrationService @@ -39,9 +42,23 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques invokeID := chi.URLParam(request, "awsrequestid") - responseContentType := request.Header.Get(contentTypeOverrideHeaderName) + headers := map[string]string{contentTypeHeader: request.Header.Get(contentTypeHeader)} + if functionResponseMode := request.Header.Get(functionResponseModeHeader); functionResponseMode != "" { + switch functionResponseMode { + case StreamingFunctionResponseMode: + headers[functionResponseModeHeader] = functionResponseMode + default: + errorResponse := &interop.ErrorResponse{ + ErrorType: ErrInvalidResponseModeHeader, + ContentType: request.Header.Get(contentTypeHeader), + } + _ = server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), errorResponse) + rendering.RenderInvalidFunctionResponseMode(writer, request) + return + } + } - if err := server.SendResponse(invokeID, responseContentType, request.Body); err != nil { + if err := server.SendResponse(invokeID, headers, request.Body, request.Trailer, &interop.CancellableRequest{Request: request}); err != nil { switch err := err.(type) { case *interop.ErrorResponseTooLarge: if server.SendErrorResponse(invokeID, err.AsInteropError()) != nil { @@ -66,6 +83,19 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques rendering.RenderRequestEntityTooLarge(writer, request) return + + case *interop.ErrTruncatedResponse: + if err := runtime.ResponseSent(); err != nil { + log.Panic(err) + } + + rendering.RenderTruncatedHTTPRequestError(writer, request) + return + + case *interop.ErrInternalPlatformError: + rendering.RenderInternalServerError(writer, request) + return + default: rendering.RenderInteropError(writer, request, err) return diff --git a/lambda/rapi/handler/invocationresponse_test.go b/lambda/rapi/handler/invocationresponse_test.go index e40a5bf..7c0b220 100644 --- a/lambda/rapi/handler/invocationresponse_test.go +++ b/lambda/rapi/handler/invocationresponse_test.go @@ -8,7 +8,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -55,7 +55,7 @@ func TestResponseTooLarge(t *testing.T) { responseRecorder.Code, http.StatusRequestEntityTooLarge) expectedAPIResponse := fmt.Sprintf("{\"errorMessage\":\"Exceeded maximum allowed payload size (%d bytes).\",\"errorType\":\"RequestEntityTooLarge\"}\n", interop.MaxPayloadSize) - body, err := ioutil.ReadAll(responseRecorder.Body) + body, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) test.AssertJsonsEqual(t, []byte(expectedAPIResponse), body) @@ -98,7 +98,8 @@ func TestResponseAccepted(t *testing.T) { request := httptest.NewRequest("", "/", bytes.NewReader(responseBody)) request = addInvocationID(request, invoke.ID) - request.Header.Set(contentTypeOverrideHeaderName, "application/json") + request.Header.Set(contentTypeHeader, "application/json") + request.Header.Set(functionResponseModeHeader, "streaming") handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) // Assertions @@ -106,7 +107,7 @@ func TestResponseAccepted(t *testing.T) { responseRecorder.Code, http.StatusAccepted) expectedAPIResponse := "{\"status\":\"OK\"}\n" - body, err := ioutil.ReadAll(responseRecorder.Body) + body, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) test.AssertJsonsEqual(t, []byte(expectedAPIResponse), body) @@ -114,6 +115,92 @@ func TestResponseAccepted(t *testing.T) { assert.NotNil(t, response) assert.Nil(t, flowTest.InteropServer.ErrorResponse) assert.Equal(t, "application/json", flowTest.InteropServer.ResponseContentType) + assert.Equal(t, "streaming", flowTest.InteropServer.FunctionResponseMode) assert.Equal(t, responseBody, response, "Persisted response data in app context must match the submitted.") } + +func TestResponseWithDifferentFunctionResponseModes(t *testing.T) { + type testCase struct { + providedFunctionResponseMode string + expectedFunctionResponseMode string + expectedAPIResponse string + expectedStatusCode int + expectedErrorResponse bool + } + testCases := []testCase{ + { + providedFunctionResponseMode: "", + expectedFunctionResponseMode: "", + expectedAPIResponse: "{\"status\":\"OK\"}\n", + expectedStatusCode: http.StatusAccepted, + expectedErrorResponse: false, + }, + { + providedFunctionResponseMode: "streaming", + expectedFunctionResponseMode: "streaming", + expectedAPIResponse: "{\"status\":\"OK\"}\n", + expectedStatusCode: http.StatusAccepted, + expectedErrorResponse: false, + }, + { + providedFunctionResponseMode: "invalid-mode", + expectedFunctionResponseMode: "", + expectedAPIResponse: "{\"errorMessage\":\"Invalid function response mode\", \"errorType\":\"InvalidFunctionResponseMode\"}\n", + expectedStatusCode: http.StatusBadRequest, + expectedErrorResponse: true, + }, + } + + for _, testCase := range testCases { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + flowTest.Runtime.Ready() + handler := NewInvocationResponseHandler(flowTest.RegistrationService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + // Invoke that we are sending response for must be placed into appCtx. + invoke := &interop.Invoke{ + ID: "InvocationID1", + InvokedFunctionArn: "arn::dummy1", + CognitoIdentityID: "CognitoidentityID1", + CognitoIdentityPoolID: "CognitoidentityPollID1", + DeadlineNs: "deadlinens1", + ClientContext: "clientcontext1", + ContentType: "application/json", + Payload: strings.NewReader(`{"message": "hello"}`), + } + + flowTest.ConfigureForInvoke(context.Background(), invoke) + + // Invocation response submitted by runtime. + var responseBody = []byte("{'foo': 'bar'}") + + request := httptest.NewRequest("", "/", bytes.NewReader(responseBody)) + request = addInvocationID(request, invoke.ID) + request.Header.Set(functionResponseModeHeader, testCase.providedFunctionResponseMode) + handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) + + // Assertions + assert.Equal(t, testCase.expectedStatusCode, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", + responseRecorder.Code, testCase.expectedStatusCode) + + body, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + test.AssertJsonsEqual(t, []byte(testCase.expectedAPIResponse), body) + + if testCase.expectedErrorResponse { + assert.NotNil(t, flowTest.InteropServer.ErrorResponse) + assert.Nil(t, flowTest.InteropServer.Response) + assert.Equal(t, "Runtime.InvalidResponseModeHeader", flowTest.InteropServer.ErrorResponse.ErrorType) + } else { + assert.NotNil(t, flowTest.InteropServer.Response) + assert.Nil(t, flowTest.InteropServer.ErrorResponse) + assert.Equal(t, responseBody, flowTest.InteropServer.Response, + "Persisted response data in app context must match the submitted.") + } + + assert.Equal(t, testCase.expectedFunctionResponseMode, flowTest.InteropServer.FunctionResponseMode) + } +} diff --git a/lambda/rapi/handler/restorenext.go b/lambda/rapi/handler/restorenext.go new file mode 100644 index 0000000..ecff059 --- /dev/null +++ b/lambda/rapi/handler/restorenext.go @@ -0,0 +1,40 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "net/http" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/rapi/rendering" +) + +type restoreNextHandler struct { + registrationService core.RegistrationService + renderingService *rendering.EventRenderingService +} + +func (h *restoreNextHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + runtime := h.registrationService.GetRuntime() + err := runtime.RestoreReady() + if err != nil { + log.Warn(err) + rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, runtime.GetState().Name(), core.RuntimeReadyStateName, err) + return + } + err = h.renderingService.RenderRuntimeEvent(writer, request) + if err != nil { + log.Error(err) + rendering.RenderInternalServerError(writer, request) + return + } +} + +func NewRestoreNextHandler(registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { + return &restoreNextHandler{ + registrationService: registrationService, + renderingService: renderingService, + } +} diff --git a/lambda/rapi/handler/restorenext_test.go b/lambda/rapi/handler/restorenext_test.go new file mode 100644 index 0000000..7018d98 --- /dev/null +++ b/lambda/rapi/handler/restorenext_test.go @@ -0,0 +1,87 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" +) + +func TestRenderRestoreNext(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + handler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + flowTest.ConfigureForRestore() + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusOK, responseRecorder.Code) +} + +func TestBrokenRenderer(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + handler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + flowTest.ConfigureForRestore() + flowTest.RenderingService.SetRenderer(&mockBrokenRenderer{}) + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) + + assert.JSONEq(t, `{"errorMessage":"Internal Server Error","errorType":"InternalServerError"}`, responseRecorder.Body.String()) +} + +func TestRenderRestoreAfterInvoke(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + deadlineNs := 12345 + invokePayload := "Payload" + invoke := &interop.Invoke{ + TraceID: "Root=RootID;Parent=LambdaFrontend;Sampled=1", + ID: "ID", + InvokedFunctionArn: "InvokedFunctionArn", + CognitoIdentityID: "CognitoIdentityId1", + CognitoIdentityPoolID: "CognitoIdentityPoolId1", + ClientContext: "ClientContext", + DeadlineNs: strconv.Itoa(deadlineNs), + ContentType: "image/png", + Payload: strings.NewReader(invokePayload), + } + + ctx := telemetry.NewTraceContext(context.Background(), "RootID", "InvocationSubegmentID") + flowTest.ConfigureForInvoke(ctx, invoke) + + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + restoreHandler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + restoreRequest := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + responseRecorder = httptest.NewRecorder() + restoreHandler.ServeHTTP(responseRecorder, restoreRequest) + + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) +} diff --git a/lambda/rapi/handler/runtimelogs.go b/lambda/rapi/handler/runtimelogs.go index 9b4e406..99941b0 100644 --- a/lambda/rapi/handler/runtimelogs.go +++ b/lambda/rapi/handler/runtimelogs.go @@ -7,7 +7,7 @@ import ( "bytes" "errors" "fmt" - "io/ioutil" + "io" "net/http" "go.amzn.com/lambda/core" @@ -20,8 +20,8 @@ import ( ) type runtimeLogsHandler struct { - registrationService core.RegistrationService - logsSubscriptionAPI telemetry.LogsSubscriptionAPI + registrationService core.RegistrationService + telemetrySubscription telemetry.SubscriptionAPI } func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -31,10 +31,10 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http switch err := err.(type) { case *ErrAgentIdentifierUnknown: rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension "+err.agentID.String()) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } return } @@ -45,21 +45,21 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http if err != nil { log.Error(err) rendering.RenderInternalServerError(writer, request) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) return } - respBody, status, headers, err := h.logsSubscriptionAPI.Subscribe(agentName, bytes.NewReader(body), request.Header) + respBody, status, headers, err := h.telemetrySubscription.Subscribe(agentName, bytes.NewReader(body), request.Header) if err != nil { log.Errorf("Telemetry API error: %s", err) switch err { case logsapi.ErrTelemetryServiceOff: rendering.RenderForbiddenWithTypeMsg(writer, request, - errLogsSubscriptionClosed, "Logs API subscription is closed already") - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.GetServiceClosedErrorType(), h.telemetrySubscription.GetServiceClosedErrorMessage()) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } return } @@ -67,11 +67,11 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http rendering.RenderRuntimeLogsResponse(writer, respBody, status, headers) switch status / 100 { case 2: // 2xx - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeSuccess, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeSuccess, 1) case 4: // 4xx - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) case 5: // 5xx - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } } @@ -114,7 +114,7 @@ func (h *runtimeLogsHandler) getAgentName(agentID uuid.UUID) (string, bool) { } func (h *runtimeLogsHandler) getBody(writer http.ResponseWriter, request *http.Request) ([]byte, error) { - body, err := ioutil.ReadAll(request.Body) + body, err := io.ReadAll(request.Body) if err != nil { return nil, fmt.Errorf("Failed to read error body: %s", err) } @@ -122,11 +122,11 @@ func (h *runtimeLogsHandler) getBody(writer http.ResponseWriter, request *http.R return body, nil } -// NewRuntimeLogsHandler returns a new instance of http handler +// NewRuntimeTelemetrySubscriptionHandler returns a new instance of http handler // for serving /runtime/logs -func NewRuntimeLogsHandler(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.LogsSubscriptionAPI) http.Handler { +func NewRuntimeTelemetrySubscriptionHandler(registrationService core.RegistrationService, telemetrySubscription telemetry.SubscriptionAPI) http.Handler { return &runtimeLogsHandler{ - registrationService: registrationService, - logsSubscriptionAPI: logsSubscriptionAPI, + registrationService: registrationService, + telemetrySubscription: telemetrySubscription, } } diff --git a/lambda/rapi/handler/runtimelogs_stub.go b/lambda/rapi/handler/runtimelogs_stub.go index 0ce472e..f540e9b 100644 --- a/lambda/rapi/handler/runtimelogs_stub.go +++ b/lambda/rapi/handler/runtimelogs_stub.go @@ -6,27 +6,48 @@ package handler import ( "net/http" + log "github.com/sirupsen/logrus" "go.amzn.com/lambda/rapi/model" - - "github.com/go-chi/render" + "go.amzn.com/lambda/rapi/rendering" ) const ( - telemetryAPIDisabledErrorType = "Logs.NotSupported" + logsAPIDisabledErrorType = "Logs.NotSupported" + telemetryAPIDisabledErrorType = "Telemetry.NotSupported" ) -type runtimeLogsStubHandler struct{} +type runtimeLogsStubAPIHandler struct{} -func (h *runtimeLogsStubHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - render.Status(request, http.StatusAccepted) - render.JSON(writer, request, &model.ErrorResponse{ - ErrorType: telemetryAPIDisabledErrorType, +func (h *runtimeLogsStubAPIHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if err := rendering.RenderJSON(http.StatusAccepted, writer, request, &model.ErrorResponse{ + ErrorType: logsAPIDisabledErrorType, ErrorMessage: "Logs API is not supported", - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(writer, err.Error(), http.StatusInternalServerError) + } +} + +// NewRuntimeLogsAPIStubHandler returns a new instance of http handler +// for serving /runtime/logs when a telemetry service implementation is absent +func NewRuntimeLogsAPIStubHandler() http.Handler { + return &runtimeLogsStubAPIHandler{} +} + +type runtimeTelemetryAPIStubHandler struct{} + +func (h *runtimeTelemetryAPIStubHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if err := rendering.RenderJSON(http.StatusAccepted, writer, request, &model.ErrorResponse{ + ErrorType: telemetryAPIDisabledErrorType, + ErrorMessage: "Telemetry API is not supported", + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(writer, err.Error(), http.StatusInternalServerError) + } } -// NewRuntimeLogsStubHandler returns a new instance of http handler +// NewRuntimeTelemetryAPIStubHandler returns a new instance of http handler // for serving /runtime/logs when a telemetry service implementation is absent -func NewRuntimeLogsStubHandler() http.Handler { - return &runtimeLogsStubHandler{} +func NewRuntimeTelemetryAPIStubHandler() http.Handler { + return &runtimeTelemetryAPIStubHandler{} } diff --git a/lambda/rapi/handler/runtimelogs_stub_test.go b/lambda/rapi/handler/runtimelogs_stub_test.go index 4826d12..5b27983 100644 --- a/lambda/rapi/handler/runtimelogs_stub_test.go +++ b/lambda/rapi/handler/runtimelogs_stub_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" ) -func TestSuccessfulRuntimeLogsStub202Response(t *testing.T) { - handler := NewRuntimeLogsStubHandler() +func TestSuccessfulRuntimeLogsAPIStub202Response(t *testing.T) { + handler := NewRuntimeLogsAPIStubHandler() requestBody := []byte(`foobar`) request := httptest.NewRequest("PUT", "/logs", bytes.NewBuffer(requestBody)) responseRecorder := httptest.NewRecorder() @@ -23,3 +23,15 @@ func TestSuccessfulRuntimeLogsStub202Response(t *testing.T) { assert.Equal(t, http.StatusAccepted, responseRecorder.Code) assert.JSONEq(t, `{"errorMessage":"Logs API is not supported","errorType":"Logs.NotSupported"}`, responseRecorder.Body.String()) } + +func TestSuccessfulRuntimeTelemetryAPIStub202Response(t *testing.T) { + handler := NewRuntimeTelemetryAPIStubHandler() + requestBody := []byte(`foobar`) + request := httptest.NewRequest("PUT", "/telemetry", bytes.NewBuffer(requestBody)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.JSONEq(t, `{"errorMessage":"Telemetry API is not supported","errorType":"Telemetry.NotSupported"}`, responseRecorder.Body.String()) +} diff --git a/lambda/rapi/handler/runtimelogs_test.go b/lambda/rapi/handler/runtimelogs_test.go index b7db6df..892d61e 100644 --- a/lambda/rapi/handler/runtimelogs_test.go +++ b/lambda/rapi/handler/runtimelogs_test.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "testing" @@ -23,30 +22,45 @@ import ( "go.amzn.com/lambda/rapidcore/telemetry/logsapi" ) -type mockLogsSubscriptionAPI struct{ mock.Mock } +type mockSubscriptionAPI struct{ mock.Mock } -func (s *mockLogsSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { +func (s *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { args := s.Called(agentName, body, headers) return args.Get(0).([]byte), args.Int(1), args.Get(2).(map[string][]string), args.Error(3) } -func (s *mockLogsSubscriptionAPI) RecordCounterMetric(metricName string, count int) { +func (s *mockSubscriptionAPI) RecordCounterMetric(metricName string, count int) { s.Called(metricName, count) } -func (s *mockLogsSubscriptionAPI) FlushMetrics() interop.LogsAPIMetrics { +func (s *mockSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { args := s.Called() - return args.Get(0).(interop.LogsAPIMetrics) + return args.Get(0).(interop.TelemetrySubscriptionMetrics) } -func (s *mockLogsSubscriptionAPI) Clear() { +func (s *mockSubscriptionAPI) Clear() { s.Called() } -func (s *mockLogsSubscriptionAPI) TurnOff() { +func (s *mockSubscriptionAPI) TurnOff() { s.Called() } +func (s *mockSubscriptionAPI) GetEndpointURL() string { + args := s.Called() + return args.Get(0).(string) +} + +func (s *mockSubscriptionAPI) GetServiceClosedErrorMessage() string { + args := s.Called() + return args.Get(0).(string) +} + +func (s *mockSubscriptionAPI) GetServiceClosedErrorType() string { + args := s.Called() + return args.Get(0).(string) +} + func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} respBody, respStatus, respHeaders := []byte(`barbaz`), http.StatusNotFound, map[string][]string{"K": []string{"V1", "V2"}} @@ -60,11 +74,11 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return(respBody, respStatus, respHeaders, nil) - logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -77,10 +91,10 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - logsSubscriptionAPI.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) assert.Equal(t, respStatus, responseRecorder.Code) @@ -98,10 +112,10 @@ func TestErrorUnregisteredAgentID(t *testing.T) { core.NewInvokeFlowSynchronization(), ) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -114,16 +128,16 @@ func TestErrorUnregisteredAgentID(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) expectedErrorBody := fmt.Sprintf(`{"errorMessage":"Unknown extension %s","errorType":"Extension.UnknownExtensionIdentifier"}`+"\n", invalidAgentID.String()) - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json; charset=utf-8"}}) + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) } func TestErrorTelemetryAPICallFailure(t *testing.T) { @@ -139,11 +153,11 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - logsSubscriptionAPI.On("RecordCounterMetric", serverErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", serverErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -156,16 +170,16 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) expectedErrorBody := `{"errorMessage":"Internal Server Error","errorType":"InternalServerError"}` + "\n" - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json; charset=utf-8"}}) + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", serverErrMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", serverErrMetric, 1) } func TestRenderLogsSubscriptionClosed(t *testing.T) { @@ -181,11 +195,13 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Logs API subscription is closed already") + telemetrySubscription.On("GetServiceClosedErrorType").Return("Logs.SubscriptionClosed") - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -198,14 +214,58 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) expectedErrorBody := `{"errorMessage":"Logs API subscription is closed already","errorType":"Logs.SubscriptionClosed"}` + "\n" - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json; charset=utf-8"}}) + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) + + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + assert.Equal(t, expectedErrorBody, string(recordedBody)) + assert.Equal(t, expectedHeaders, responseRecorder.Header()) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) +} + +func TestRenderTelemetrySubscriptionClosed(t *testing.T) { + agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} + apiError := logsapi.ErrTelemetryServiceOff + clientErrMetric := logsapi.SubscribeClientErr + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization(), + core.NewInvokeFlowSynchronization(), + ) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Telemetry API subscription is closed already") + telemetrySubscription.On("GetServiceClosedErrorType").Return("Telemetry.SubscriptionClosed") + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + expectedErrorBody := `{"errorMessage":"Telemetry API subscription is closed already","errorType":"Telemetry.SubscriptionClosed"}` + "\n" + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) } diff --git a/lambda/rapi/middleware/middleware_test.go b/lambda/rapi/middleware/middleware_test.go index 7b37de9..a0b9134 100644 --- a/lambda/rapi/middleware/middleware_test.go +++ b/lambda/rapi/middleware/middleware_test.go @@ -7,7 +7,7 @@ import ( "bytes" "context" "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -58,7 +58,7 @@ func TestAgentUniqueIdentifierHeaderValidatorForbidden(t *testing.T) { responseRecorder := httptest.NewRecorder() router.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, handler.ErrAgentIdentifierMissing, errorResponse.ErrorType) @@ -66,7 +66,7 @@ func TestAgentUniqueIdentifierHeaderValidatorForbidden(t *testing.T) { request.Header.Set(handler.LambdaAgentIdentifier, "invalid-unique-identifier") router.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - respBody, _ = ioutil.ReadAll(responseRecorder.Body) + respBody, _ = io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, handler.ErrAgentIdentifierInvalid, errorResponse.ErrorType) } diff --git a/lambda/rapi/model/tracing.go b/lambda/rapi/model/tracing.go index af90e8f..83f97e8 100644 --- a/lambda/rapi/model/tracing.go +++ b/lambda/rapi/model/tracing.go @@ -3,14 +3,21 @@ package model +type TracingType string + const ( // XRayTracingType represents an X-Ray Tracing object type - XRayTracingType = "X-Amzn-Trace-Id" + XRayTracingType TracingType = "X-Amzn-Trace-Id" +) + +const ( + XRaySampled = "1" + XRayNonSampled = "0" ) // Tracing object returned as part of agent Invoke event type Tracing struct { - Type string `json:"type"` + Type TracingType `json:"type"` XRayTracing } diff --git a/lambda/rapi/rendering/doc.go b/lambda/rapi/rendering/doc.go index 4573638..bc359a1 100644 --- a/lambda/rapi/rendering/doc.go +++ b/lambda/rapi/rendering/doc.go @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 /* - Package rendering provides stateful event rendering service. State of the rendering service should be set from the main event dispatch thread @@ -17,6 +16,5 @@ Example of INVOKE event: [main] // release threads registered for INVOKE event [thread] // receives INVOKE event - */ package rendering diff --git a/lambda/rapi/rendering/render_json.go b/lambda/rapi/rendering/render_json.go new file mode 100644 index 0000000..8cea816 --- /dev/null +++ b/lambda/rapi/rendering/render_json.go @@ -0,0 +1,33 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rendering + +import ( + "bytes" + "encoding/json" + log "github.com/sirupsen/logrus" + "net/http" +) + +// RenderJSON: +// - marshals 'v' to JSON, automatically escaping HTML +// - sets the Content-Type as application/json +// - sets the HTTP response status code +// - returns an error if it occurred before writing to response +func RenderJSON(status int, w http.ResponseWriter, r *http.Request, v interface{}) error { + buf := &bytes.Buffer{} + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(true) + if err := enc.Encode(v); err != nil { + return err + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if _, err := w.Write(buf.Bytes()); err != nil { + log.WithError(err).Warn("Error while writing response body") + } + + return nil +} diff --git a/lambda/rapi/rendering/rendering.go b/lambda/rapi/rendering/rendering.go index c75d010..0edfb68 100644 --- a/lambda/rapi/rendering/rendering.go +++ b/lambda/rapi/rendering/rendering.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "strconv" "sync" @@ -19,7 +18,6 @@ import ( "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi/model" - "github.com/go-chi/render" "github.com/google/uuid" log "github.com/sirupsen/logrus" ) @@ -33,6 +31,8 @@ const ( ErrorTypeInvalidRequestID = "InvalidRequestID" // ErrorTypeRequestEntityTooLarge error type for payload too large ErrorTypeRequestEntityTooLarge = "RequestEntityTooLarge" + // ErrorTypeTruncatedHTTPRequest error type for truncated HTTP request + ErrorTypeTruncatedHTTPRequest = "TruncatedHTTPRequest" ) // ErrRenderingServiceStateNotSet returned when state not set @@ -100,6 +100,9 @@ type InvokeRenderer struct { metrics InvokeRendererMetrics } +type RestoreRenderer struct { +} + // NewAgentInvokeEvent forms a new AgentInvokeEvent from INVOKE request func NewAgentInvokeEvent(req *interop.Invoke) (*model.AgentInvokeEvent, error) { deadlineMono, err := strconv.ParseInt(req.DeadlineNs, 10, 64) @@ -145,7 +148,7 @@ func (s *InvokeRenderer) bufferInvokeRequest() error { if nil == s.requestBuffer { reader := io.LimitReader(s.invoke.Payload, interop.MaxPayloadSize) start := time.Now() - s.requestBuffer, err = ioutil.ReadAll(reader) + s.requestBuffer, err = io.ReadAll(reader) s.metrics = InvokeRendererMetrics{ ReadTime: time.Since(start), SizeBytes: len(s.requestBuffer), @@ -193,6 +196,15 @@ func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request return nil } +func (s *RestoreRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { + writer.WriteHeader(http.StatusOK) + return nil +} + +func (s *RestoreRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { + return nil +} + // NewInvokeRenderer returns new invoke event renderer func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, traceParser func(context.Context, *interop.Invoke) string) *InvokeRenderer { return &InvokeRenderer{ @@ -204,6 +216,10 @@ func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, traceParser } } +func NewRestoreRenderer() *RestoreRenderer { + return &RestoreRenderer{} +} + func (s *InvokeRenderer) GetMetrics() InvokeRendererMetrics { s.requestMutex.Lock() defer s.requestMutex.Unlock() @@ -283,46 +299,78 @@ func renderAgentInvokeHeaders(writer http.ResponseWriter, eventID uuid.UUID) { // RenderForbiddenWithTypeMsg method for rendering error response func RenderForbiddenWithTypeMsg(w http.ResponseWriter, r *http.Request, errorType string, format string, args ...interface{}) { - render.Status(r, http.StatusForbidden) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusForbidden, w, r, &model.ErrorResponse{ ErrorType: errorType, ErrorMessage: fmt.Sprintf(format, args...), - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderInternalServerError method for rendering error response func RenderInternalServerError(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusInternalServerError) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusInternalServerError, w, r, &model.ErrorResponse{ ErrorMessage: "Internal Server Error", ErrorType: ErrorTypeInternalServerError, - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderRequestEntityTooLarge method for rendering error response func RenderRequestEntityTooLarge(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusRequestEntityTooLarge) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusRequestEntityTooLarge, w, r, &model.ErrorResponse{ ErrorMessage: fmt.Sprintf("Exceeded maximum allowed payload size (%d bytes).", interop.MaxPayloadSize), ErrorType: ErrorTypeRequestEntityTooLarge, - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderTruncatedHTTPRequestError method for rendering error response +func RenderTruncatedHTTPRequestError(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "HTTP request detected as truncated", + ErrorType: ErrorTypeTruncatedHTTPRequest, + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderInvalidRequestID renders invalid request ID error response func RenderInvalidRequestID(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusBadRequest) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ ErrorMessage: "Invalid request ID", ErrorType: "InvalidRequestID", - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInvalidFunctionResponseMode renders invalid function response mode response +func RenderInvalidFunctionResponseMode(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "Invalid function response mode", + ErrorType: "InvalidFunctionResponseMode", + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderAccepted method for rendering accepted status response func RenderAccepted(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusAccepted) - render.JSON(w, r, &model.StatusResponse{ + if err := RenderJSON(http.StatusAccepted, w, r, &model.StatusResponse{ Status: "OK", - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderInteropError is a convenience method for interpreting interop errors diff --git a/lambda/rapi/router.go b/lambda/rapi/router.go index 1d2766a..5c2a56d 100644 --- a/lambda/rapi/router.go +++ b/lambda/rapi/router.go @@ -19,7 +19,7 @@ import ( // NewRouter returns a new instance of chi router implementing // Runtime API specification. -func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { +func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, eventsAPI telemetry.EventsAPI) http.Handler { router := chi.NewRouter() router.Use(middleware.AppCtxMiddleware(appCtx)) @@ -46,7 +46,11 @@ func NewRouter(appCtx appctx.ApplicationContext, registrationService core.Regist handler.NewInvocationErrorHandler(registrationService)).ServeHTTP) router.Post("/runtime/init/error", - handler.NewInitErrorHandler(registrationService).ServeHTTP) + handler.NewInitErrorHandler(registrationService, eventsAPI).ServeHTTP) + + if appctx.LoadInitType(appCtx) == appctx.InitCaching { + router.Get("/runtime/restore/next", handler.NewRestoreNextHandler(registrationService, renderingService).ServeHTTP) + } return router } @@ -80,14 +84,14 @@ func ExtensionsRouter(appCtx appctx.ApplicationContext, registrationService core // LogsAPIRouter returns a new instance of chi router implementing // Logs API specification. -func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.LogsSubscriptionAPI) http.Handler { +func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.SubscriptionAPI) http.Handler { router := chi.NewRouter() router.Use(middleware.AccessLogMiddleware()) router.Use(middleware.AllowIfExtensionsEnabled) router.Put("/logs", middleware.AgentUniqueIdentifierHeaderValidator( - handler.NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI)).ServeHTTP) + handler.NewRuntimeTelemetrySubscriptionHandler(registrationService, logsSubscriptionAPI)).ServeHTTP) return router } @@ -98,7 +102,32 @@ func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptio func LogsAPIStubRouter() http.Handler { router := chi.NewRouter() - router.Put("/logs", handler.NewRuntimeLogsStubHandler().ServeHTTP) + router.Put("/logs", handler.NewRuntimeLogsAPIStubHandler().ServeHTTP) + + return router +} + +// TelemetryRouter returns a new instance of chi router implementing +// Telemetry API specification. +func TelemetryAPIRouter(registrationService core.RegistrationService, telemetrySubscriptionAPI telemetry.SubscriptionAPI) http.Handler { + router := chi.NewRouter() + router.Use(middleware.AccessLogMiddleware()) + router.Use(middleware.AllowIfExtensionsEnabled) + + router.Put("/telemetry", + middleware.AgentUniqueIdentifierHeaderValidator( + handler.NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscriptionAPI)).ServeHTTP) + + return router +} + +// TelemetryStubRouter returns a new instance of chi router implementing +// a stub of Telemetry API that always returns a non-committal response to +// prevent customer code from crashing when Telemetry API is disabled locally +func TelemetryAPIStubRouter() http.Handler { + router := chi.NewRouter() + + router.Put("/telemetry", handler.NewRuntimeTelemetryAPIStubHandler().ServeHTTP) return router } diff --git a/lambda/rapi/router_test.go b/lambda/rapi/router_test.go index f1cbde8..73cbde1 100644 --- a/lambda/rapi/router_test.go +++ b/lambda/rapi/router_test.go @@ -60,7 +60,7 @@ func assertResponseErrorType(t *testing.T, expectedErrorType string, response *h // rendered as JSON, regardless of the value provided // in "Accept" header. // -// When using render.Render(...), chi rendering library +// When using render.Render(...), rendering function // would attempt to render response using content type // specified in the "Accept" header. // @@ -69,7 +69,7 @@ func assertResponseErrorType(t *testing.T, expectedErrorType string, response *h func TestAcceptXML(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := httptest.NewRecorder() request := httptest.NewRequest("POST", "/runtime/invocation/x-y-z/error", bytes.NewReader([]byte(""))) // Tell server that client side accepts "application/xml". @@ -90,7 +90,7 @@ func TestAcceptXML(t *testing.T) { func Test404PageNotFound(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/unsupported", bytes.NewReader([]byte("")))) assert.Equal(t, http.StatusNotFound, responseRecorder.Code) assert.Equal(t, "404 page not found\n", responseRecorder.Body.String()) @@ -99,7 +99,7 @@ func Test404PageNotFound(t *testing.T) { func Test405MethodNotAllowed(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("DELETE", "/runtime/invocation/ABC/error", bytes.NewReader([]byte("")))) assert.Equal(t, http.StatusMethodNotAllowed, responseRecorder.Code) } @@ -107,7 +107,7 @@ func Test405MethodNotAllowed(t *testing.T) { func TestInitErrorAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/init/error", bytes.NewReader([]byte("{}")))) assert.Equal(t, http.StatusAccepted, responseRecorder.Code) } @@ -115,7 +115,7 @@ func TestInitErrorAccepted(t *testing.T) { func TestInitErrorForbidden(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -126,7 +126,7 @@ func TestInitErrorForbidden(t *testing.T) { func TestInvokeResponseAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -137,7 +137,7 @@ func TestInvokeResponseAccepted(t *testing.T) { func TestInvokeErrorResponseAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -148,7 +148,7 @@ func TestInvokeErrorResponseAccepted(t *testing.T) { func TestInvokeNextTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -159,7 +159,7 @@ func TestInvokeNextTwice(t *testing.T) { func TestInvokeResponseInvalidRequestID(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -171,7 +171,7 @@ func TestInvokeResponseInvalidRequestID(t *testing.T) { func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -183,7 +183,7 @@ func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { func TestInvokeResponseTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -197,7 +197,7 @@ func TestInvokeResponseTwice(t *testing.T) { func TestInvokeErrorResponseTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -211,7 +211,7 @@ func TestInvokeErrorResponseTwice(t *testing.T) { func TestInvokeResponseAfterErrorResponse(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -225,7 +225,7 @@ func TestInvokeResponseAfterErrorResponse(t *testing.T) { func TestInvokeErrorResponseAfterResponse(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -239,7 +239,7 @@ func TestInvokeErrorResponseAfterResponse(t *testing.T) { func TestMoreThanOneInvoke(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) var responseRecorder *httptest.ResponseRecorder for _, id := range []string{"A", "B", "C"} { flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) @@ -250,12 +250,25 @@ func TestMoreThanOneInvoke(t *testing.T) { } } +func TestInitCachingAPIDisabledForPlainInit(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + var responseRecorder *httptest.ResponseRecorder + + responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/restore/next", nil)) + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) + + responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/credentials", nil)) + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) +} + func benchmarkInvoke(b *testing.B, payload []byte) { b.StopTimer() b.ReportAllocs() flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) for i := 0; i < b.N; i++ { id := uuid.New().String() flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) diff --git a/lambda/rapi/security_test.go b/lambda/rapi/security_test.go index 3f869d5..5312b43 100644 --- a/lambda/rapi/security_test.go +++ b/lambda/rapi/security_test.go @@ -20,7 +20,7 @@ func TestInvokeValidId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) @@ -53,7 +53,7 @@ func TestSecurityInvokeResponseBadRequestId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) @@ -100,7 +100,7 @@ func TestSecurityInvokeErrorBadRequestId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) diff --git a/lambda/rapi/server.go b/lambda/rapi/server.go index e2c6ad4..dd027f4 100644 --- a/lambda/rapi/server.go +++ b/lambda/rapi/server.go @@ -13,6 +13,7 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/rendering" "go.amzn.com/lambda/telemetry" @@ -23,6 +24,7 @@ const version20180601 = "/2018-06-01" const version20200101 = "/2020-01-01" const version20200815 = "/2020-08-15" const version20210423 = "/2021-04-23" +const version20220701 = "/2022-07-01" // Server is a Runtime API server type Server struct { @@ -33,6 +35,10 @@ type Server struct { exit chan error } +func SaveConnInContext(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, interop.HTTPConnKey, c) +} + // NewServer creates a new Runtime API Server // // Unlike net/http server's ListenAndServe, we separate Listen() @@ -44,28 +50,30 @@ func NewServer(host string, port int, appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, telemetryAPIEnabled bool, - logsSubscriptionAPI telemetry.LogsSubscriptionAPI, initCachingEnabled bool, credentialsService core.CredentialsService) *Server { + logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI, credentialsService core.CredentialsService, eventsAPI telemetry.EventsAPI) *Server { exitErrors := make(chan error, 1) router := chi.NewRouter() - router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService)) + router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService, eventsAPI)) router.Mount(version20200101, ExtensionsRouter(appCtx, registrationService, renderingService)) if telemetryAPIEnabled { router.Mount(version20200815, LogsAPIRouter(registrationService, logsSubscriptionAPI)) + router.Mount(version20220701, TelemetryAPIRouter(registrationService, telemetrySubscriptionAPI)) } else { router.Mount(version20200815, LogsAPIStubRouter()) + router.Mount(version20220701, TelemetryAPIStubRouter()) } - if initCachingEnabled { + if appctx.LoadInitType(appCtx) == appctx.InitCaching { router.Mount(version20210423, CredentialsAPIRouter(credentialsService)) } return &Server{ host: host, port: port, - server: &http.Server{Handler: router}, + server: &http.Server{Handler: router, ConnContext: SaveConnInContext}, listener: nil, exit: exitErrors, } diff --git a/lambda/rapi/server_test.go b/lambda/rapi/server_test.go index ce6e4e6..cf31fab 100644 --- a/lambda/rapi/server_test.go +++ b/lambda/rapi/server_test.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "testing" "time" @@ -52,7 +51,7 @@ func TestServerReturnsSuccessfulResponse(t *testing.T) { if err != nil { assert.FailNowf(t, "Failed to get response", err.Error()) } - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { assert.FailNowf(t, "Failed to read response body", err.Error()) } diff --git a/lambda/rapid/bootstrap.go b/lambda/rapid/bootstrap.go deleted file mode 100644 index e82ec6c..0000000 --- a/lambda/rapid/bootstrap.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "os" - - "go.amzn.com/lambda/fatalerror" -) - -type Bootstrap interface { - Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable - Env(e EnvironmentVariables) []string // returns the environment variables to be passed to the bootstrapped process - Cwd() (string, error) // returns the working directory of the bootstrap process - ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime - CachedFatalError(err error) (fatalerror.ErrorType, string, bool) -} diff --git a/lambda/rapid/exit.go b/lambda/rapid/exit.go index af5cb72..e45f3a4 100644 --- a/lambda/rapid/exit.go +++ b/lambda/rapid/exit.go @@ -5,8 +5,9 @@ package rapid import ( "fmt" - "os" + "time" + "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" @@ -15,102 +16,116 @@ import ( log "github.com/sirupsen/logrus" ) -func checkInteropError(format string, err error) { - if err == interop.ErrInvalidInvokeID || err == interop.ErrResponseSent { - log.Warnf(format, err) - } else { - log.Panicf(format, err) +func handleInvokeError(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { + invokeFailure := newInvokeFailureMsg(execCtx, invokeRequest, invokeMx, err) + resp := model.ErrorResponse{ + ErrorType: string(invokeFailure.ErrorType), + ErrorMessage: fmt.Sprintf("Error: %v", invokeFailure.ErrorMessage), } -} -func trySendDefaultErrorResponse(interopServer interop.Server, invokeID string, errorType fatalerror.ErrorType, err error) { - resp := model.ErrorResponse{ - ErrorType: string(errorType), - ErrorMessage: fmt.Sprintf("Error: %v", err), + if invokeRequest.ID != "" { + resp.ErrorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeRequest.ID, invokeFailure.ErrorMessage) } - if invokeID != "" { - resp.ErrorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeID, err) + // This is the default error response that gets sent back as the function response in failure cases + invokeFailure.DefaultErrorResponse = resp.AsInteropError() + + // Invoke with extensions disabled maintains behaviour parity with pre-extensions rapid + if !extensions.AreEnabled() { + invokeFailure.RequestReset = false + return invokeFailure } - if err := interopServer.SendErrorResponse(invokeID, resp.AsInteropError()); err != nil { - checkInteropError("Failed to send default error response: %s", err) + if err == errResetReceived { + // errResetReceived is returned when execution flow was interrupted by the Reset message, + // hence this error deserves special handling and we yield to main receive loop to handle it + invokeFailure.ResetReceived = true + return invokeFailure } + + invokeFailure.RequestReset = true + return invokeFailure } -func reportErrorAndExit(doneFailMsg *interop.DoneFail, invokeID string, interopServer interop.Server, err error) { - // This function maintains compatibility of exit sequence behaviour - // with Sandbox Factory in the absence of extensions - - // NOTE this check will prevent us from sending FAULT message in case - // response (positive or negative) has already been sent. This is done - // to maintain legacy behavior of RAPID. - // ALSO NOTE, this works in case of positive response because this will - // be followed by RAPID exit. - if !interopServer.IsResponseSent() { - trySendDefaultErrorResponse(interopServer, invokeID, doneFailMsg.ErrorType, err) +func newInvokeFailureMsg(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { + errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + errorType = fatalerror.Unknown } - if err := interopServer.CommitResponse(); err != nil { - checkInteropError("Failed to commit error response: %s", err) + invokeFailure := &interop.InvokeFailure{ + ErrorType: errorType, + ErrorMessage: err, + RequestReset: true, + ResetReceived: false, + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + InvokeReceivedTime: invokeRequest.InvokeReceivedTime, } - // old behavior: no DoneFails - doneMsg := &interop.Done{ - WaitForExit: true, - CorrelationID: doneFailMsg.CorrelationID, // required for standalone mode - Meta: doneFailMsg.Meta, + if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeFailure.ResponseMetrics.RuntimeTimeThrottledMs = invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) + invokeFailure.ResponseMetrics.RuntimeProducedBytes = invokeRequest.InvokeResponseMetrics.ProducedBytes + invokeFailure.ResponseMetrics.RuntimeOutboundThroughputBps = invokeRequest.InvokeResponseMetrics.OutboundThroughputBps } - if err := interopServer.SendDone(doneMsg); err != nil { - checkInteropError("Failed to send DONE during exit: %s", err) + if invokeMx != nil { + invokeFailure.InvokeMetrics.InvokeRequestReadTimeNs = invokeMx.rendererMetrics.ReadTime.Nanoseconds() + invokeFailure.InvokeMetrics.InvokeRequestSizeBytes = int64(invokeMx.rendererMetrics.SizeBytes) + invokeFailure.InvokeMetrics.RuntimeReadyTime = int64(invokeMx.runtimeReadyTime) + invokeFailure.ExtensionNames = execCtx.GetExtensionNames() } - os.Exit(1) -} - -func reportErrorAndRequestReset(doneFailMsg *interop.DoneFail, invokeID string, interopServer interop.Server, err error) { - if err == errResetReceived { - // errResetReceived is returned when execution flow was interrupted by the Reset message, - // hence this error deserves special handling and we yield to main receive loop to handle it - return + if execCtx.telemetryAPIEnabled { + invokeFailure.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } - trySendDefaultErrorResponse(interopServer, invokeID, doneFailMsg.ErrorType, err) - - if err := interopServer.CommitResponse(); err != nil { - checkInteropError("Failed to commit error response: %s", err) - } + return invokeFailure +} - if err := interopServer.SendDoneFail(doneFailMsg); err != nil { - checkInteropError("Failed to send DONEFAIL: %s", err) +func generateInitFailureMsg(execCtx *rapidContext, err error) interop.InitFailure { + errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + errorType = fatalerror.Unknown } -} -func handleInitError(doneFailMsg *interop.DoneFail, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error) { - if execCtx.standaloneMode { - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) - return + initFailureMsg := interop.InitFailure{ + RequestReset: true, + ErrorType: errorType, + ErrorMessage: err, + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + Ack: make(chan struct{}), } - if !execCtx.HasActiveExtensions() { - // we don't expect Slicer to send RESET during INIT, that's why we Exit here - reportErrorAndExit(doneFailMsg, invokeID, interopServer, err) + if execCtx.telemetryAPIEnabled { + initFailureMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) + return initFailureMsg } -func handleInvokeError(doneFailMsg *interop.DoneFail, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error) { - if execCtx.standaloneMode { - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) +func handleInitError(execCtx *rapidContext, invokeID string, err error, initFailureResponse chan<- interop.InitFailure) { + log.WithError(err).WithField("InvokeID", invokeID).Error("Init failed") + initFailureMsg := generateInitFailureMsg(execCtx, err) + + if err == errResetReceived { + // errResetReceived is returned when execution flow was interrupted by the Reset message, + // hence this error deserves special handling and we yield to main receive loop to handle it + initFailureMsg.ResetReceived = true + initFailureResponse <- initFailureMsg + <-initFailureMsg.Ack return } - // Invoke with extensions disabled maintains behaviour parity with pre-extensions rapid - if !extensions.AreEnabled() { - reportErrorAndExit(doneFailMsg, invokeID, interopServer, err) + if !execCtx.HasActiveExtensions() && !execCtx.standaloneMode { + // different behaviour when no extensions are present, + // for compatibility with previous implementations + initFailureMsg.RequestReset = false + } else { + initFailureMsg.RequestReset = true } - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) + initFailureResponse <- initFailureMsg + <-initFailureMsg.Ack } diff --git a/lambda/rapid/graceful_shutdown.go b/lambda/rapid/graceful_shutdown.go deleted file mode 100644 index 5ad1326..0000000 --- a/lambda/rapid/graceful_shutdown.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "syscall" - "time" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/rapi/rendering" - - log "github.com/sirupsen/logrus" -) - -func sigkillProcessGroup(pid int, sigkilledPids map[int]bool) map[int]bool { - pgid, err := syscall.Getpgid(pid) - if err == nil { - syscall.Kill(-pgid, 9) // Negative pid sends signal to all in process group - } else { - syscall.Kill(pid, 9) - } - sigkilledPids[pid] = true - - return sigkilledPids -} - -func awaitSigkilledProcessesToExit(exitPidChan chan int, processesExited, sigkilledPidsToAwait map[int]bool) { - for pid := range processesExited { - delete(sigkilledPidsToAwait, pid) - } - - for len(sigkilledPidsToAwait) != 0 { - pid := <-exitPidChan - _, found := sigkilledPidsToAwait[pid] - if !found { - log.Warnf("Unexpected process %d exited while waiting for sigkilled processes to exit", pid) - } else { - delete(sigkilledPidsToAwait, pid) - } - } -} - -func gracefulShutdown(execCtx *rapidContext, watchdog *core.Watchdog, profiler *metering.ExtensionsResetDurationProfiler, deadlineNs int64, killAgents bool, reason string) { - watchdog.Mute() - defer watchdog.Unmute() - - if execCtx.registrationService.CountAgents() == 0 { - // We do not spend any compute time on runtime graceful shutdown if there are no agents - if runtime := execCtx.registrationService.GetRuntime(); runtime != nil && runtime.Pid != 0 { - sigkilledPids := sigkillProcessGroup(runtime.Pid, map[int]bool{}) - if execCtx.standaloneMode { - processesExited := map[int]bool{} - awaitSigkilledProcessesToExit(execCtx.exitPidChan, processesExited, sigkilledPids) - } - } - return - } - - mono := metering.Monotime() - - availableNs := deadlineNs - mono - - if availableNs < 0 { - log.Warnf("Deadline is in the past: %v, %v, %v", mono, deadlineNs, availableNs) - availableNs = 0 - } - - profiler.AvailableNs = availableNs - - start := time.Now() - profiler.Start() - - runtimeDeadline := start.Add(time.Duration(float64(availableNs) * runtimeDeadlineShare)) - agentsDeadline := start.Add(time.Duration(availableNs)) - - sigkilledPids := make(map[int]bool) // Track process ids that were sent sigkill - processesExited := make(map[int]bool) // Don't send sigkill to processes that exit out of order - - processesExited, sigkilledPids = shutdownRuntime(execCtx, start, runtimeDeadline, processesExited, sigkilledPids) - processesExited, sigkilledPids = shutdownAgents(execCtx, start, profiler, agentsDeadline, killAgents, reason, processesExited, sigkilledPids) - if execCtx.standaloneMode { - awaitSigkilledProcessesToExit(execCtx.exitPidChan, processesExited, sigkilledPids) - } - - profiler.Stop() -} - -func shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time, processesExited, sigkilledPids map[int]bool) (map[int]bool, map[int]bool) { - // If runtime is started: - // 1. SIGTERM and wait until timeout - // 2. SIGKILL on timeout - - log.Debug("shutdown runtime") - runtime := execCtx.registrationService.GetRuntime() - if runtime == nil || runtime.Pid == 0 { - log.Warn("Runtime not started") - return processesExited, sigkilledPids - } - - syscall.Kill(runtime.Pid, syscall.SIGTERM) - - runtimeTimeout := deadline.Sub(start) - runtimeTimer := time.NewTimer(runtimeTimeout) - - for { - select { - case pid := <-execCtx.exitPidChan: - processesExited[pid] = true - if pid == runtime.Pid { - log.Info("runtime exited") - return processesExited, sigkilledPids - } - - log.Warnf("Process %d exited unexpectedly", pid) - case <-runtimeTimer.C: - log.Warnf("Timeout: no SIGCHLD from Runtime after %d ms; dispatching SIGKILL to runtime process group", int64(runtimeTimeout/time.Millisecond)) - sigkilledPids = sigkillProcessGroup(runtime.Pid, sigkilledPids) - return processesExited, sigkilledPids - } - } -} - -func shutdownAgents(execCtx *rapidContext, start time.Time, profiler *metering.ExtensionsResetDurationProfiler, deadline time.Time, killAgents bool, reason string, processesExited, sigkilledPids map[int]bool) (map[int]bool, map[int]bool) { - // For each external agent, if agent is launched: - // 1. Send Shutdown event if subscribed for it, else send SIGKILL to process group - // 2. Wait for all Shutdown-subscribed agents to exit with timeout - // 3. Send SIGKILL to process group for Shutdown-subscribed agents on timeout - - log.Debug("shutdown agents") - execCtx.renderingService.SetRenderer( - &rendering.ShutdownRenderer{ - AgentEvent: model.AgentShutdownEvent{ - AgentEvent: &model.AgentEvent{ - EventType: "SHUTDOWN", - DeadlineMs: deadline.UnixNano() / (1000 * 1000), - }, - ShutdownReason: reason, - }, - }) - - pidsToShutdown := make(map[int]*core.ExternalAgent) - for _, a := range execCtx.registrationService.GetExternalAgents() { - if a.Pid == 0 { - log.Warnf("Agent %s failed not launched; skipping shutdown", a) - continue - } - if a.IsSubscribed(core.ShutdownEvent) { - pidsToShutdown[a.Pid] = a - a.Release() - } else { - if !processesExited[a.Pid] { - sigkilledPids = sigkillProcessGroup(a.Pid, sigkilledPids) - } - } - } - profiler.NumAgentsRegisteredForShutdown = len(pidsToShutdown) - - var timerChan <-chan time.Time // default timerChan - if killAgents { - timerChan = time.NewTimer(deadline.Sub(start)).C // timerChan with deadline - } - - timeoutExceeded := false - for !timeoutExceeded && len(pidsToShutdown) != 0 { - select { - case pid := <-execCtx.exitPidChan: - processesExited[pid] = true - a, found := pidsToShutdown[pid] - if !found { - log.Warnf("Process %d exited unexpectedly", pid) - } else { - if err := a.Exited(); err != nil { - log.Warnf("%s failed to transition to EXITED: %s (current state: %s)", a.String(), err, a.GetState().Name()) - } - delete(pidsToShutdown, pid) - } - case <-timerChan: - timeoutExceeded = true - } - } - - if len(pidsToShutdown) != 0 { - for pid, agent := range pidsToShutdown { - if err := agent.ShutdownFailed(); err != nil { - log.Warnf("%s failed to transition to ShutdownFailed: %s (current state: %s)", agent, err, agent.GetState().Name()) - } - log.Warnf("Killing agent %s which failed to shutdown", agent) - if !processesExited[pid] { - sigkilledPids = sigkillProcessGroup(pid, sigkilledPids) - } - } - } - - return processesExited, sigkilledPids -} diff --git a/lambda/rapid/sandbox.go b/lambda/rapid/sandbox.go index a5614b0..9259514 100644 --- a/lambda/rapid/sandbox.go +++ b/lambda/rapid/sandbox.go @@ -7,49 +7,43 @@ import ( "context" "fmt" "io" + "sync" + "time" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/logging" "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi" "go.amzn.com/lambda/rapi/rendering" + supvmodel "go.amzn.com/lambda/supervisor/model" "go.amzn.com/lambda/telemetry" -) -type EnvironmentVariables interface { - AgentExecEnv() []string - RuntimeExecEnv() []string - SetHandler(handler string) - StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) - StoreEnvironmentVariablesFromInit(customerEnv map[string]string, - handler, awsKey, awsSecret, awsSession, funcName, funcVer string) - StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) -} + log "github.com/sirupsen/logrus" +) type Sandbox struct { - EnableTelemetryAPI bool - StandaloneMode bool - Bootstrap Bootstrap - InteropServer interop.Server - Tracer telemetry.Tracer - LogsSubscriptionAPI telemetry.LogsSubscriptionAPI - LogsEgressAPI telemetry.LogsEgressAPI - Environment EnvironmentVariables - DebugTailLogger *logging.TailLogWriter - PlatformLogger logging.PlatformLogger - RuntimeStdoutWriter io.Writer - RuntimeStderrWriter io.Writer - PreLoadTimeNs int64 - Handler string - SignalCtx context.Context - EventsAPI telemetry.EventsAPI - InitCachingEnabled bool + EnableTelemetryAPI bool + StandaloneMode bool + InteropServer interop.Server + Tracer telemetry.Tracer + LogsSubscriptionAPI telemetry.SubscriptionAPI + TelemetrySubscriptionAPI telemetry.SubscriptionAPI + LogsEgressAPI telemetry.StdLogsEgressAPI + RuntimeStdoutWriter io.Writer + RuntimeStderrWriter io.Writer + PreLoadTimeNs int64 + Handler string + SignalCtx context.Context + EventsAPI telemetry.EventsAPI + InitCachingEnabled bool + Supervisor supvmodel.Supervisor + RuntimeAPIHost string + RuntimeAPIPort int } // Start is a public version of start() that exports only configurable parameters -func Start(s *Sandbox) { +func Start(s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, string) { appCtx := appctx.NewApplicationContext() initFlow := core.NewInitFlowSynchronization() invokeFlow := core.NewInvokeFlowSynchronization() @@ -57,19 +51,18 @@ func Start(s *Sandbox) { renderingService := rendering.NewRenderingService() credentialsService := core.NewCredentialsService() - if s.StandaloneMode { - s.InteropServer.SetInternalStateGetter(registrationService.GetInternalStateDescriptor(appCtx)) - } - server := rapi.NewServer(RuntimeAPIHost, RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.InitCachingEnabled, credentialsService) - - postLoadTimeNs := metering.Monotime() + appctx.StoreInitType(appCtx, s.InitCachingEnabled) + server := rapi.NewServer(s.RuntimeAPIHost, s.RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.TelemetrySubscriptionAPI, credentialsService, s.EventsAPI) runtimeAPIAddr := fmt.Sprintf("%s:%d", server.Host(), server.Port()) - s.Environment.StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddr) + postLoadTimeNs := metering.Monotime() + + // TODO: pass this directly down to HTTP servers and handlers, instead of using + // global state to share the interop server implementation appctx.StoreInteropServer(appCtx, s.InteropServer) - start(s.SignalCtx, &rapidContext{ + execCtx := &rapidContext{ server: server, appCtx: appCtx, postLoadTimeNs: postLoadTimeNs, @@ -78,24 +71,86 @@ func Start(s *Sandbox) { invokeFlow: invokeFlow, registrationService: registrationService, renderingService: renderingService, - exitPidChan: make(chan int), - resetChan: make(chan *interop.Reset), credentialsService: credentialsService, - telemetryAPIEnabled: s.EnableTelemetryAPI, - logsSubscriptionAPI: s.LogsSubscriptionAPI, - logsEgressAPI: s.LogsEgressAPI, - bootstrap: s.Bootstrap, - interopServer: s.InteropServer, - xray: s.Tracer, - environment: s.Environment, - standaloneMode: s.StandaloneMode, - debugTailLogger: s.DebugTailLogger, - platformLogger: s.PlatformLogger, - runtimeStdoutWriter: s.RuntimeStdoutWriter, - runtimeStderrWriter: s.RuntimeStderrWriter, - preLoadTimeNs: s.PreLoadTimeNs, - eventsAPI: s.EventsAPI, - initCachingEnabled: s.InitCachingEnabled, - }) + telemetryAPIEnabled: s.EnableTelemetryAPI, + logsSubscriptionAPI: s.LogsSubscriptionAPI, + telemetrySubscriptionAPI: s.TelemetrySubscriptionAPI, + logsEgressAPI: s.LogsEgressAPI, + interopServer: s.InteropServer, + xray: s.Tracer, + standaloneMode: s.StandaloneMode, + preLoadTimeNs: s.PreLoadTimeNs, + eventsAPI: s.EventsAPI, + initCachingEnabled: s.InitCachingEnabled, + signalCtx: s.SignalCtx, + supervisor: s.Supervisor, + executionMutex: sync.Mutex{}, + shutdownContext: newShutdownContext(), + } + + // We call /ping on Supervisor before starting Rapid, since Rapid + // depends on Supervisor setting up networking dependencies + var startupErr error + for retries := 1; retries <= 5; retries++ { + if startupErr = s.Supervisor.Ping(); startupErr == nil { + break + } + // Retry timeout: 5s, same order-of-mag as test client PING retries + // TODO: revisit retry timeout, identify appropriate value for prod. + time.Sleep(1000 * time.Millisecond) + } + + if startupErr != nil { + log.Panicf("Application ping to Supervisor failed, terminating Rapid Startup: %s", startupErr) + } + + go start(s.SignalCtx, execCtx) + + return execCtx, registrationService.GetInternalStateDescriptor(appCtx), runtimeAPIAddr +} + +func (r *rapidContext) HandleInit(init *interop.Init, initStartedResponseChan chan<- interop.InitStarted, initSuccessResponseChan chan<- interop.InitSuccess, initFailureResponseChan chan<- interop.InitFailure) { + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + handleInit(r, init, initStartedResponseChan, initSuccessResponseChan, initFailureResponseChan) +} + +func (r *rapidContext) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + // Clear the context used by the last invok + r.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) + return handleInvoke(r, invoke, sbInfoFromInit) +} + +func (r *rapidContext) HandleReset(reset *interop.Reset, invokeReceivedTime int64, InvokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { + // In the event of a Reset during init/invoke, CancelFlows cancels execution + // flows and return with the errResetReceived err - this error is special-cased + // and not handled by the init/invoke (unexpected) error handling functions + r.registrationService.CancelFlows(errResetReceived) + + // Wait until invoke error handling has returned before continuing execution + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + + // Clear the context used by the last invoke, i.e. error message etc. + r.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) + return handleReset(r, reset, invokeReceivedTime, InvokeResponseMetrics) +} + +func (r *rapidContext) HandleShutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { + // Wait until invoke error handling has returned before continuing execution + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + // Shutdown doesn't cancel flows, so it can block forever + return handleShutdown(r, shutdown, standaloneShutdownReason) +} + +func (r *rapidContext) HandleRestore(restore *interop.Restore) error { + return handleRestore(r, restore) +} + +func (r *rapidContext) Clear() { + reinitialize(r) } diff --git a/lambda/rapid/shutdown.go b/lambda/rapid/shutdown.go new file mode 100644 index 0000000..fe23a9f --- /dev/null +++ b/lambda/rapid/shutdown.go @@ -0,0 +1,366 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package rapid implements synchronous even dispatch loop. +package rapid + +import ( + "fmt" + "sync" + "time" + + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/metering" + "go.amzn.com/lambda/rapi/model" + "go.amzn.com/lambda/rapi/rendering" + supvmodel "go.amzn.com/lambda/supervisor/model" + + log "github.com/sirupsen/logrus" +) + +const ( + // supervisor shutdown and kill operations block until the exit status of the + // interested process has been collected, or until the specified timeotuw + // expires (in which case the operation fails). + // Note that this timeout is mainly relevant when any of the domain + // processes are in uninterruptible sleep state (notable examples: syscall + // to read/write a newtorked driver) + // + // We set a non nil value for these timeouts so that RAPID doesn't block + // forever in one of the cases above. + supervisorBlockingMaxMillis = 9000 + runtimeDeadlineShare = 0.3 +) + +type shutdownContext struct { + // Adding a mutex around shuttingDown because there may be concurrent reads/writes. + // Because the code in shutdown() and the seperate go routine created in setupEventsWatcher() + // could be concurrently accessing the field shuttingDown. + shuttingDownMutex sync.Mutex + shuttingDown bool + agentsAwaitingExit map[string]*core.ExternalAgent + // Adding a mutex around runtimeDomainExited because there may be concurrent reads/writes. + // The first reason this can be caused is by different go routines reading/writing different keys. + // The second reason this can be caused is between the code shutting down the runtime/extensions and + // handleProcessExit in a separate go routine, reading and writing to the same key. Caused by + // unexpected exits. + runtimeDomainExitedMutex sync.Mutex + // used to synchronize on processes exits. We create the channel when a + // process is started and we close it upon exit notification from + // supervisor. Closing the channel is basically a persistent broadcast of process exit. + // We never write anything to the channels + runtimeDomainExited map[string]chan struct{} +} + +func newShutdownContext() *shutdownContext { + return &shutdownContext{ + shuttingDownMutex: sync.Mutex{}, + shuttingDown: false, + agentsAwaitingExit: make(map[string]*core.ExternalAgent), + runtimeDomainExited: make(map[string]chan struct{}), + runtimeDomainExitedMutex: sync.Mutex{}, + } +} + +func (s *shutdownContext) isShuttingDown() bool { + s.shuttingDownMutex.Lock() + defer s.shuttingDownMutex.Unlock() + return s.shuttingDown +} + +func (s *shutdownContext) setShuttingDown(value bool) { + s.shuttingDownMutex.Lock() + defer s.shuttingDownMutex.Unlock() + s.shuttingDown = value +} + +func (s *shutdownContext) handleProcessExit(termination supvmodel.ProcessTermination) { + + name := *termination.Name + agent, found := s.agentsAwaitingExit[name] + + // If it is an agent registered to receive a shutdown event. + if found { + log.Debugf("Handling termination for %s", name) + exitStatus := termination.Exited() + if exitStatus != nil && *exitStatus == 0 { + // If the agent exited by itself after receiving the shutdown event. + stateErr := agent.Exited() + if stateErr != nil { + log.Warnf("%s failed to transition to EXITED: %s (current state: %s)", agent.String(), stateErr, agent.GetState().Name()) + } + } else { + // If the agent did not exit by itself, had to be SIGKILLed (only in standalone mode). + stateErr := agent.ShutdownFailed() + if stateErr != nil { + log.Warnf("%s failed to transition to ShutdownFailed: %s (current state: %s)", agent, stateErr, agent.GetState().Name()) + } + } + } + + exitedChannel, found := s.getExitedChannel(name) + + if !found { + log.Panicf("Unable to find an exitedChannel for '%s', it should have been created just after it was execed.", name) + } + // we close the channel so that whoever is blocked on it + // or will try to block on it in the future unblocks immediately + close(exitedChannel) +} + +func (s *shutdownContext) getExitedChannel(name string) (chan struct{}, bool) { + s.runtimeDomainExitedMutex.Lock() + defer s.runtimeDomainExitedMutex.Unlock() + exitedChannel, found := s.runtimeDomainExited[name] + return exitedChannel, found +} + +func (s *shutdownContext) createExitedChannel(name string) { + s.runtimeDomainExitedMutex.Lock() + defer s.runtimeDomainExitedMutex.Unlock() + + _, found := s.runtimeDomainExited[name] + + if found { + log.Panicf("Tried to create an exited channel for '%s' but one already exists.", name) + } + s.runtimeDomainExited[name] = make(chan struct{}) +} + +// Blocks until all the processes in the runtime domain generation have exited. +// This helps us have a nice sync point on Shutdown where we know for sure that +// all the processes have exited and the state has been cleared. +// +// It is OK not to hold the lock because we know that this is called only during +// shutdown and nobody will start a new process during shutdown +func (s *shutdownContext) clearExitedChannel() { + s.runtimeDomainExitedMutex.Lock() + mapLen := len(s.runtimeDomainExited) + channels := make([]chan struct{}, 0, mapLen) + for _, v := range s.runtimeDomainExited { + channels = append(channels, v) + } + s.runtimeDomainExitedMutex.Unlock() + + for _, v := range channels { + <-v + } + + s.runtimeDomainExitedMutex.Lock() + s.runtimeDomainExited = make(map[string]chan struct{}, mapLen) + s.runtimeDomainExitedMutex.Unlock() +} + +func (s *shutdownContext) shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time) { + // If runtime is started: + // 1. SIGTERM and wait until timeout + // 2. SIGKILL on timeout + log.Debug("Shutting down the runtime.") + name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) + exitedChannel, found := s.getExitedChannel(name) + + if found { + + err := execCtx.supervisor.Terminate(&supvmodel.TerminateRequest{ + Domain: RuntimeDomain, + Name: name, + }) + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Termination signal to runtime") + } + + runtimeTimeout := deadline.Sub(start) + log.Tracef("The runtime timeout is %v.", runtimeTimeout) + runtimeTimer := time.NewTimer(runtimeTimeout) + select { + case <-runtimeTimer.C: + log.Warnf("Timeout: The runtime did not exit after %d ms; Killing it.", int64(runtimeTimeout/time.Millisecond)) + supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) + err = execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Kill signal to runtime") + } + case <-exitedChannel: + } + } else { + log.Warn("The runtime was not started.") + } + log.Debug("Shutdown the runtime.") +} + +func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, deadline time.Time, reason string) { + // For each external agent, if agent is launched: + // 1. Send Shutdown event if subscribed for it, else send SIGKILL to process group + // 2. Wait for all Shutdown-subscribed agents to exit with timeout + // 3. Send SIGKILL to process group for Shutdown-subscribed agents on timeout + + log.Debug("Shutting down the agents.") + execCtx.renderingService.SetRenderer( + &rendering.ShutdownRenderer{ + AgentEvent: model.AgentShutdownEvent{ + AgentEvent: &model.AgentEvent{ + EventType: "SHUTDOWN", + DeadlineMs: deadline.UnixNano() / (1000 * 1000), + }, + ShutdownReason: reason, + }, + }) + + var wg sync.WaitGroup + + // clear agentsAwaitingExit from last shutdownAgents + s.agentsAwaitingExit = make(map[string]*core.ExternalAgent) + + for _, a := range execCtx.registrationService.GetExternalAgents() { + name := fmt.Sprintf("extension-%s-%d", a.Name, execCtx.runtimeDomainGeneration) + exitedChannel, found := s.getExitedChannel(name) + supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) + + if !found { + log.Warnf("Agent %s failed to launch, therefore skipping shutting it down.", a) + continue + } + + wg.Add(1) + + if a.IsSubscribed(core.ShutdownEvent) { + log.Debugf("Agent %s is registered for the shutdown event.", a) + s.agentsAwaitingExit[name] = a + + go func(name string, agent *core.ExternalAgent) { + defer wg.Done() + + agent.Release() + + agentTimeout := deadline.Sub(start) + var agentTimeoutChan <-chan time.Time + if execCtx.standaloneMode { + agentTimeoutChan = time.NewTimer(agentTimeout).C + } + + select { + case <-agentTimeoutChan: + log.Warnf("Timeout: the agent %s did not exit after %d ms; Killing it.", name, int64(agentTimeout/time.Millisecond)) + err := execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Kill signal to runtime") + } + case <-exitedChannel: + } + }(name, a) + } else { + log.Debugf("Agent %s is not registered for the shutdown event, so just killing it.", a) + + go func(name string) { + defer wg.Done() + + execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + }(name) + } + } + + // Wait on the agents subscribed to the shutdown event to voluntary shutting down after receiving the shutdown event or be sigkilled. + // In addition to waiting on the agents not subscribed to the shutdown event being sigkilled. + wg.Wait() + log.Debug("Shutdown the agents.") +} + +func (s *shutdownContext) shutdown(execCtx *rapidContext, deadlineNs int64, reason string) (int64, bool, error) { + var err error + s.setShuttingDown(true) + defer s.setShuttingDown(false) + + // Fatal errors such as Runtime exit and Extension.Crash + // are ignored by the events watcher when shutting down + execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) + + runtimeDomainProfiler := &metering.ExtensionsResetDurationProfiler{} + supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) + + // We do not spend any compute time on runtime graceful shutdown if there are no agents + if execCtx.registrationService.CountAgents() == 0 { + name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) + + _, found := s.getExitedChannel(name) + + if found { + log.Debug("SIGKILLing the runtime as no agents are registered.") + err = execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Kill signal to runtime") + } + } else { + log.Debugf("Could not find runtime process %s in processes map. Already exited/never started", name) + } + } else { + mono := metering.Monotime() + availableNs := deadlineNs - mono + + if availableNs < 0 { + log.Warnf("Deadline is in the past: %v, %v, %v", mono, deadlineNs, availableNs) + availableNs = 0 + } + + start := time.Now() + + runtimeDeadline := start.Add(time.Duration(float64(availableNs) * runtimeDeadlineShare)) + agentsDeadline := start.Add(time.Duration(availableNs)) + + runtimeDomainProfiler.AvailableNs = availableNs + runtimeDomainProfiler.Start() + + s.shutdownRuntime(execCtx, start, runtimeDeadline) + s.shutdownAgents(execCtx, start, agentsDeadline, reason) + + runtimeDomainProfiler.NumAgentsRegisteredForShutdown = len(s.agentsAwaitingExit) + } + log.Info("Stopping runtime domain") + err = execCtx.supervisor.Stop(&supvmodel.StopRequest{ + Domain: RuntimeDomain, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + log.WithError(err).Error("Failed shutting runtime domain down") + } else { + log.Info("Waiting for runtime domain processes termination") + s.clearExitedChannel() + log.Info("Stopping operator domain") + err = execCtx.supervisor.Stop(&supvmodel.StopRequest{ + Domain: OperatorDomain, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + log.WithError(err).Error("Failed shutting operator domain down") + } + } + + runtimeDomainProfiler.Stop() + extensionsRestMs, timeout := runtimeDomainProfiler.CalculateExtensionsResetMs() + return extensionsRestMs, timeout, err +} diff --git a/lambda/rapid/start.go b/lambda/rapid/start.go index 087ef13..76337af 100644 --- a/lambda/rapid/start.go +++ b/lambda/rapid/start.go @@ -7,9 +7,11 @@ package rapid import ( "context" "errors" - "io" + "fmt" "os" + "path" "strings" + "sync" "time" "go.amzn.com/lambda/agents" @@ -18,11 +20,11 @@ import ( "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/logging" "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi" + "go.amzn.com/lambda/rapi/model" "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/runtimecmd" + supvmodel "go.amzn.com/lambda/supervisor/model" "go.amzn.com/lambda/telemetry" "github.com/google/uuid" @@ -31,11 +33,11 @@ import ( ) const ( - RuntimeAPIHost = "127.0.0.1" - RuntimeAPIPort = 9001 + RuntimeDomain = "runtime" + OperatorDomain = "operator" defaultAgentLocation = "/opt/extensions" - runtimeDeadlineShare = 0.3 disableExtensionsFile = "/opt/disable-extensions-jwigqn8j" + runtimeProcessName = "runtime" ) const ( @@ -47,37 +49,38 @@ const ( var errResetReceived = errors.New("errResetReceived") type rapidContext struct { - bootstrap Bootstrap - interopServer interop.Server - server *rapi.Server - appCtx appctx.ApplicationContext - preLoadTimeNs int64 - postLoadTimeNs int64 - startRequest *interop.Start - initDone bool - initFlow core.InitFlowSynchronization - invokeFlow core.InvokeFlowSynchronization - registrationService core.RegistrationService - renderingService *rendering.EventRenderingService - telemetryAPIEnabled bool - logsSubscriptionAPI telemetry.LogsSubscriptionAPI - logsEgressAPI telemetry.LogsEgressAPI - xray telemetry.Tracer - exitPidChan chan int - resetChan chan *interop.Reset - environment EnvironmentVariables - standaloneMode bool - debugTailLogger *logging.TailLogWriter - platformLogger logging.PlatformLogger - runtimeStdoutWriter io.Writer - runtimeStderrWriter io.Writer - eventsAPI telemetry.EventsAPI - initCachingEnabled bool - credentialsService core.CredentialsService + interopServer interop.Server + server *rapi.Server + appCtx appctx.ApplicationContext + preLoadTimeNs int64 + postLoadTimeNs int64 + initDone bool + supervisor supvmodel.Supervisor + runtimeDomainGeneration uint32 + initFlow core.InitFlowSynchronization + invokeFlow core.InvokeFlowSynchronization + registrationService core.RegistrationService + renderingService *rendering.EventRenderingService + telemetryAPIEnabled bool + logsSubscriptionAPI telemetry.SubscriptionAPI + telemetrySubscriptionAPI telemetry.SubscriptionAPI + logsEgressAPI telemetry.StdLogsEgressAPI + xray telemetry.Tracer + standaloneMode bool + eventsAPI telemetry.EventsAPI + initCachingEnabled bool + credentialsService core.CredentialsService + signalCtx context.Context + executionMutex sync.Mutex + shutdownContext *shutdownContext } +// Validate interface compliance +var _ interop.RapidContext = (*rapidContext)(nil) + type invokeMetrics struct { - rendererMetrics rendering.InvokeRendererMetrics + rendererMetrics rendering.InvokeRendererMetrics + runtimeReadyTime int64 } @@ -102,7 +105,7 @@ func (c *rapidContext) GetExtensionNames() string { func logAgentsInitStatus(execCtx *rapidContext) { for _, agent := range execCtx.registrationService.AgentsInfo() { - execCtx.platformLogger.LogExtensionInitEvent(agent.Name, agent.State, agent.ErrorType, agent.Subscriptions) + execCtx.eventsAPI.SendExtensionInit(agent.Name, agent.State, agent.ErrorType, agent.Subscriptions) } } @@ -113,8 +116,7 @@ func agentLaunchError(agent *core.ExternalAgent, appCtx appctx.ApplicationContex appctx.StoreFirstFatalError(appCtx, fatalerror.AgentLaunchError) } -func doInitExtensions(execCtx *rapidContext, watchdog *core.Watchdog) error { - agentPaths := agents.ListExternalAgentPaths(defaultAgentLocation) +func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, env interop.EnvironmentVariables) error { initFlow := execCtx.registrationService.InitFlow() // we don't bring it into the loop below because we don't want unnecessary broadcasts on agent gate @@ -123,38 +125,42 @@ func doInitExtensions(execCtx *rapidContext, watchdog *core.Watchdog) error { } for _, agentPath := range agentPaths { - env := execCtx.environment.AgentExecEnv() - - agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() + // Using path.Base(agentPath) not agentName because the agent name is contact, as standalone can get the internal state. + agent, err := execCtx.registrationService.CreateExternalAgent(path.Base(agentPath)) if err != nil { return err } - // Compose debug log writer with all log sinks. Debug log writer w - // will not write logs when disabled by invoke parameter - agentStdoutWriter = io.MultiWriter(execCtx.debugTailLogger, agentStdoutWriter) - agentStderrWriter = io.MultiWriter(execCtx.debugTailLogger, agentStderrWriter) + if execCtx.registrationService.CountAgents() > core.MaxAgentsAllowed { + agentLaunchError(agent, execCtx.appCtx, core.ErrTooManyExtensions) + return core.ErrTooManyExtensions + } - agentProc := agents.NewExternalAgentProcess(agentPath, env, agentStdoutWriter, agentStderrWriter) + env := env.AgentExecEnv() - agent, err := execCtx.registrationService.CreateExternalAgent(agentProc.Name()) + agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() if err != nil { return err } + agentName := fmt.Sprintf("extension-%s-%d", path.Base(agentPath), execCtx.runtimeDomainGeneration) - if execCtx.registrationService.CountAgents() > core.MaxAgentsAllowed { - agentLaunchError(agent, execCtx.appCtx, core.ErrTooManyExtensions) - return core.ErrTooManyExtensions - } + err = execCtx.supervisor.Exec(&supvmodel.ExecRequest{ + Domain: domain, + Name: agentName, + Path: agentPath, + Env: &env, + StdoutWriter: agentStdoutWriter, + StderrWriter: agentStderrWriter, + }) - if err := agentProc.Start(); err != nil { + if err != nil { agentLaunchError(agent, execCtx.appCtx, err) return err } - agent.Pid = watchdog.GoWait(&agentProc, fatalerror.AgentCrash) + execCtx.shutdownContext.createExitedChannel(agentName) } if err := initFlow.AwaitExternalAgentsRegistered(); err != nil { @@ -164,20 +170,154 @@ func doInitExtensions(execCtx *rapidContext, watchdog *core.Watchdog) error { return nil } -func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) error { +func doRuntimeBootstrap(execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit) ([]string, map[string]string, string, []*os.File, error) { + env := sbInfoFromInit.EnvironmentVariables + runtimeBootstrap := sbInfoFromInit.RuntimeBootstrap + bootstrapCmd, err := runtimeBootstrap.Cmd() + if err != nil { + if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) + execCtx.eventsAPI.SendImageErrorLog(formattedLog) + } else { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) + } + return []string{}, map[string]string{}, "", []*os.File{}, err + } + + bootstrapEnv := runtimeBootstrap.Env(env) + bootstrapCwd, err := runtimeBootstrap.Cwd() + if err != nil { + if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) + execCtx.eventsAPI.SendImageErrorLog(formattedLog) + } else { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidWorkingDir) + } + return []string{}, map[string]string{}, "", []*os.File{}, err + } + + bootstrapExtraFiles := runtimeBootstrap.ExtraFiles() + + return bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, nil +} + +func (c *rapidContext) setupEventsWatcher(events <-chan supvmodel.Event) { + go func() { + for event := range events { + var err error = nil + log.Debugf("The events handler received the event %+v.", event) + if loss := event.Event.EventLoss(); loss != nil { + log.Panicf("Lost %d events from supervisor", *loss) + } + termination := event.Event.ProcessTerminated() + + // If we are not shutting down then we care if an unexpected exit happens. + if !c.shutdownContext.isShuttingDown() { + runtimeProcessName := fmt.Sprintf("%s-%d", runtimeProcessName, c.runtimeDomainGeneration) + + // If event from the runtime. + if *termination.Name == runtimeProcessName { + if termination.Success() { + err = fmt.Errorf("Runtime exited without providing a reason") + } else { + err = fmt.Errorf("Runtime exited with error: %s", termination.String()) + } + appctx.StoreFirstFatalError(c.appCtx, fatalerror.RuntimeExit) + } else { + if termination.Success() { + err = fmt.Errorf("exit code 0") + } else { + err = fmt.Errorf(termination.String()) + } + + appctx.StoreFirstFatalError(c.appCtx, fatalerror.AgentCrash) + } + + log.Warnf("Process %s exited: %+v", *termination.Name, termination) + } + + // At the moment we only get termination events. + // When their are other event types then we would need to be selective, + // about what we send to handleShutdownEvent(). + c.shutdownContext.handleProcessExit(*termination) + c.registrationService.CancelFlows(err) + } + }() +} + +func doOperatorDomainInit(ctx context.Context, execCtx *rapidContext, operatorDomainExtraConfig interop.DynamicDomainConfig) error { + events, err := execCtx.supervisor.Events() + if err != nil { + log.WithError(err).Panic("Could not get events stream from supervsior") + } + execCtx.setupEventsWatcher(events) + + log.Info("Configuring and starting Operator Domain") + conf := operatorDomainExtraConfig + err = execCtx.supervisor.Configure(&supvmodel.ConfigureRequest{ + Domain: OperatorDomain, + AdditionalStartHooks: conf.AdditionalStartHooks, + Mounts: conf.Mounts, + }) + + if err != nil { + log.WithError(err).Error("Failed to configure operator domain") + return err + } + + err = execCtx.supervisor.Start(&supvmodel.StartRequest{ + Domain: OperatorDomain, + }) + + if err != nil { + log.WithError(err).Error("Failed to start operator domain") + return err + } + + // we configure the runtime domain only once and not at + // every init phase (e.g., suppressed or reset). + err = execCtx.supervisor.Configure(&supvmodel.ConfigureRequest{ + Domain: RuntimeDomain, + }) + + if err != nil { + log.WithError(err).Error("Failed to configure operator domain") + return err + } + + return nil + +} + +func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit) error { execCtx.xray.RecordInitStartTime() defer execCtx.xray.RecordInitEndTime() - if extensions.AreEnabled() { - defer func() { + defer func() { + if extensions.AreEnabled() { logAgentsInitStatus(execCtx) - }() + } + }() + + log.Info("Starting runtime domain") + err := execCtx.supervisor.Start(&supvmodel.StartRequest{ + Domain: RuntimeDomain, + }) + if err != nil { + log.WithError(err).Panic("Failed configuring runtime domain") + } + execCtx.runtimeDomainGeneration++ - if err := doInitExtensions(execCtx, watchdog); err != nil { + if extensions.AreEnabled() { + runtimeExtensions := agents.ListExternalAgentPaths(defaultAgentLocation, + execCtx.supervisor.RuntimeConfig.RootPath) + if err := doInitExtensions(RuntimeDomain, runtimeExtensions, execCtx, sbInfoFromInit.EnvironmentVariables); err != nil { return err } } + appctx.StoreSandboxType(execCtx.appCtx, sbInfoFromInit.SandboxType) + initFlow := execCtx.registrationService.InitFlow() // Runtime state machine @@ -188,56 +328,66 @@ func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) // runtime is implicitly subscribed for certain lifecycle events. log.Debug("Preregister runtime") registrationService := execCtx.registrationService - if err := registrationService.PreregisterRuntime(runtime); err != nil { + err = registrationService.PreregisterRuntime(runtime) + + if err != nil { return err } - bootstrap := execCtx.bootstrap - bootstrapCmd, err := bootstrap.Cmd() + bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, err := doRuntimeBootstrap(execCtx, sbInfoFromInit) + if err != nil { - if fatalError, formattedLog, hasError := bootstrap.CachedFatalError(err); hasError { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.platformLogger.Printf("%s", formattedLog) - } else { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) - } return err } - bootstrapEnv := bootstrap.Env(execCtx.environment) - bootstrapCwd, err := bootstrap.Cwd() + runtimeStdoutWriter, runtimeStderrWriter, err := execCtx.logsEgressAPI.GetRuntimeSockets() + if err != nil { - if fatalError, formattedLog, hasError := bootstrap.CachedFatalError(err); hasError { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.platformLogger.Printf("%s", formattedLog) - } else { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidWorkingDir) - } return err } - bootstrapExtraFiles := bootstrap.ExtraFiles() - runtimeCmd := runtimecmd.NewCustomRuntimeCmd(ctx, bootstrapCmd, bootstrapCwd, bootstrapEnv, execCtx.runtimeStdoutWriter, execCtx.runtimeStderrWriter, bootstrapExtraFiles) - log.Debug("Start runtime") - err = runtimeCmd.Start() + checkCredentials(execCtx, bootstrapEnv) + name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) + err = execCtx.supervisor.Exec(&supvmodel.ExecRequest{ + Domain: RuntimeDomain, + Name: name, + Cwd: &bootstrapCwd, + Path: bootstrapCmd[0], + Args: bootstrapCmd[1:], + Env: &bootstrapEnv, + StdoutWriter: runtimeStdoutWriter, + StderrWriter: runtimeStderrWriter, + ExtraFiles: &bootstrapExtraFiles, + }) + + runtimeDoneStatus := telemetry.RuntimeDoneSuccess + + defer func() { + sendInitRuntimeDoneLogEvent(execCtx, sbInfoFromInit.SandboxType, runtimeDoneStatus) + }() + if err != nil { - if fatalError, formattedLog, hasError := bootstrap.CachedFatalError(err); hasError { + if fatalError, formattedLog, hasError := sbInfoFromInit.RuntimeBootstrap.CachedFatalError(err); hasError { appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.platformLogger.Printf("%s", formattedLog) + execCtx.eventsAPI.SendImageErrorLog(formattedLog) } else { appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) } + runtimeDoneStatus = telemetry.RuntimeDoneFailure return err } - registrationService.GetRuntime().Pid = watchdog.GoWait(runtimeCmd, fatalerror.RuntimeExit) + execCtx.shutdownContext.createExitedChannel(name) - if err := initFlow.AwaitRuntimeReady(); err != nil { + if err := initFlow.AwaitRuntimeRestoreReady(); err != nil { + runtimeDoneStatus = telemetry.RuntimeDoneFailure return err } + runtimeDoneStatus = telemetry.RuntimeDoneSuccess + // Registration phase finished for agents - no more agents can be registered with the system registrationService.TurnOff() if extensions.AreEnabled() { @@ -253,22 +403,17 @@ func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) // Logs API subscription phase finished for agents - no more agents can be subscribed to the Logs API if execCtx.telemetryAPIEnabled { execCtx.logsSubscriptionAPI.TurnOff() + execCtx.telemetrySubscriptionAPI.TurnOff() } execCtx.initDone = true + return nil } -func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke, mx *invokeMetrics) error { +func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop.Invoke, mx *invokeMetrics, sbInfoFromInit interop.SandboxInfoFromInit) error { execCtx.eventsAPI.SetCurrentRequestID(invokeRequest.ID) appCtx := execCtx.appCtx - appctx.StoreErrorResponse(appCtx, nil) - - if invokeRequest.NeedDebugLogs { - execCtx.debugTailLogger.Enable() - } else { - execCtx.debugTailLogger.Disable() - } xray := execCtx.xray xray.Configure(invokeRequest) @@ -277,11 +422,11 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo if !execCtx.initDone { // do inline init if err := xray.CaptureInitSubsegment(ctx, func(ctx context.Context) error { - return doInit(ctx, execCtx, watchdog) + return doRuntimeDomainInit(ctx, execCtx, sbInfoFromInit) }); err != nil { return err } - } else if execCtx.startRequest.SandboxType != interop.SandboxPreWarmed { + } else if sbInfoFromInit.SandboxType != interop.SandboxPreWarmed { xray.SendInitSubsegmentWithRecordedTimesOnce(ctx) } @@ -317,16 +462,20 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo if extensions.AreEnabled() { log.Debug("Release agents conditions") for _, agent := range extAgents { + //TODO handle Supervisors listening channel agent.Release() } for _, agent := range intAgents { + //TODO handle Supervisors listening channel agent.Release() } } log.Debug("Release runtime condition") + //TODO handle Supervisors listening channel runtime.Release() log.Debug("Await runtime response") + //TODO handle Supervisors listening channel return invokeFlow.AwaitRuntimeResponse() })); err != nil { return err @@ -335,12 +484,21 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo // Runtime overhead if err := xray.CaptureOverheadSubsegment(ctx, func(ctx context.Context) error { log.Debug("Await runtime ready") + //TODO handle Supervisors listening channel return invokeFlow.AwaitRuntimeReady() }); err != nil { return err } mx.runtimeReadyTime = metering.Monotime() - if err := execCtx.eventsAPI.SendRuntimeDone("success"); err != nil { + + runtimeDoneEventData := telemetry.InvokeRuntimeDoneData{ + Status: telemetry.RuntimeDoneSuccess, + Metrics: telemetry.GetRuntimeDoneInvokeMetrics(invokeRequest.InvokeReceivedTime, invokeRequest.InvokeResponseMetrics, mx.runtimeReadyTime), + InternalMetrics: invokeRequest.InvokeResponseMetrics, + Tracing: telemetry.BuildTracingCtx(model.XRayTracingType, invokeRequest.TraceID, invokeRequest.LambdaSegmentID), + Spans: telemetry.GetRuntimeDoneSpans(invokeRequest.InvokeReceivedTime, invokeRequest.InvokeResponseMetrics), + } + if err := execCtx.eventsAPI.SendRuntimeDone(runtimeDoneEventData); err != nil { log.Errorf("Failed to send RUNDONE: %s", err) } @@ -348,6 +506,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo if execCtx.HasActiveExtensions() { execCtx.interopServer.SendRuntimeReady() log.Debug("Await agents ready") + //TODO handle Supervisors listening channel if err := invokeFlow.AwaitAgentsReady(); err != nil { log.Warnf("AwaitAgentsReady() = %s", err) return err @@ -364,177 +523,148 @@ func extensionsDisabledByLayer() bool { return err == nil } -// acceptStartRequest is a second initialization phase, performed after receiving START +// acceptInitRequest is a second initialization phase, performed after receiving START // initialized entities: _HANDLER, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN -func (c *rapidContext) acceptStartRequest(startRequest *interop.Start) { - c.startRequest = startRequest - c.environment.StoreEnvironmentVariablesFromInit( - startRequest.CustomerEnvironmentVariables, - startRequest.Handler, - startRequest.AwsKey, - startRequest.AwsSecret, - startRequest.AwsSession, - startRequest.FunctionName, - startRequest.FunctionVersion) +func (c *rapidContext) acceptInitRequest(initRequest *interop.Init) *interop.Init { + initRequest.EnvironmentVariables.StoreEnvironmentVariablesFromInit( + initRequest.CustomerEnvironmentVariables, + initRequest.Handler, + initRequest.AwsKey, + initRequest.AwsSecret, + initRequest.AwsSession, + initRequest.FunctionName, + initRequest.FunctionVersion) c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - FunctionName: startRequest.FunctionName, - FunctionVersion: startRequest.FunctionVersion, - Handler: startRequest.Handler, + FunctionName: initRequest.FunctionName, + FunctionVersion: initRequest.FunctionVersion, + Handler: initRequest.Handler, + RuntimeInfo: initRequest.RuntimeInfo, }) if extensionsDisabledByLayer() { extensions.Disable() } + + return initRequest } -func (c *rapidContext) acceptStartRequestForInitCaching(startRequest *interop.Start) error { +func (c *rapidContext) acceptInitRequestForInitCaching(initRequest *interop.Init) (*interop.Init, error) { log.Info("Configure environment for Init Caching.") - c.startRequest = startRequest randomUUID, err := uuid.NewRandom() if err != nil { - return err + return initRequest, err } initCachingToken := randomUUID.String() - c.environment.StoreEnvironmentVariablesFromInitForInitCaching( - RuntimeAPIHost, - RuntimeAPIPort, - startRequest.CustomerEnvironmentVariables, - startRequest.Handler, - startRequest.FunctionName, - startRequest.FunctionVersion, + initRequest.EnvironmentVariables.StoreEnvironmentVariablesFromInitForInitCaching( + c.server.Host(), + c.server.Port(), + initRequest.CustomerEnvironmentVariables, + initRequest.Handler, + initRequest.FunctionName, + initRequest.FunctionVersion, initCachingToken) c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - FunctionName: startRequest.FunctionName, - FunctionVersion: startRequest.FunctionVersion, - Handler: startRequest.Handler, + FunctionName: initRequest.FunctionName, + FunctionVersion: initRequest.FunctionVersion, + Handler: initRequest.Handler, }) - c.credentialsService.SetCredentials(initCachingToken, startRequest.AwsKey, startRequest.AwsSecret, startRequest.AwsSession) + c.credentialsService.SetCredentials(initCachingToken, initRequest.AwsKey, initRequest.AwsSecret, initRequest.AwsSession, initRequest.CredentialsExpiry) if extensionsDisabledByLayer() { extensions.Disable() } - return nil + return initRequest, nil } -func handleStart(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, startRequest *interop.Start) { +func handleInit(execCtx *rapidContext, initRequest *interop.Init, + initStartedResponse chan<- interop.InitStarted, + initSuccessResponse chan<- interop.InitSuccess, + initFailureResponse chan<- interop.InitFailure) { + ctx := execCtx.signalCtx + if execCtx.initCachingEnabled { - if err := execCtx.acceptStartRequestForInitCaching(startRequest); err != nil { - handleStartError(execCtx, startRequest.InvokeID, startRequest.CorrelationID, err) + var err error + if initRequest, err = execCtx.acceptInitRequestForInitCaching(initRequest); err != nil { + // TODO: call handleInitError only after sending the RUNNING, since + // Slicer will fail receiving DONEFAIL here as it is expecting RUNNING + handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) return } - - execCtx.credentialsService.UnblockService() - defer execCtx.credentialsService.BlockService() } else { - execCtx.acceptStartRequest(startRequest) + initRequest = execCtx.acceptInitRequest(initRequest) } - interopServer, appCtx := execCtx.interopServer, execCtx.appCtx - - if err := interopServer.SendRunning(&interop.Running{ + initStartedMsg := interop.InitStarted{ PreLoadTimeNs: execCtx.preLoadTimeNs, PostLoadTimeNs: execCtx.postLoadTimeNs, WaitStartTimeNs: execCtx.postLoadTimeNs, WaitEndTimeNs: metering.Monotime(), ExtensionsEnabled: extensions.AreEnabled(), - }); err != nil { - log.Panic(err) + Ack: make(chan struct{}), } - if !startRequest.SuppressInit { - if err := doInit(ctx, execCtx, watchdog); err != nil { - handleStartError(execCtx, startRequest.InvokeID, startRequest.CorrelationID, err) - return - } - } + initStartedResponse <- initStartedMsg + <-initStartedMsg.Ack - doneMsg := &interop.Done{ - CorrelationID: startRequest.CorrelationID, - Meta: interop.DoneMetadata{ - RuntimeRelease: appctx.GetRuntimeRelease(appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - ExtensionNames: execCtx.GetExtensionNames(), - }, - } - if execCtx.telemetryAPIEnabled { - doneMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() - } - if err := interopServer.SendDone(doneMsg); err != nil { - log.Panic(err) - } - - if err := interopServer.StartAcceptingDirectInvokes(); err != nil { - log.Panic(err) - } -} - -func handleStartError(execCtx *rapidContext, invokeID string, correlationID string, err error) { - log.WithError(err).WithField("InvokeID", invokeID).Error("Init failed") - doneFailMsg := generateDoneFail(execCtx, correlationID, nil, 0) - handleInitError(doneFailMsg, execCtx, invokeID, execCtx.interopServer, err) -} - -func generateDoneFail(execCtx *rapidContext, correlationID string, invokeMx *invokeMetrics, invokeReceivedTime int64) *interop.DoneFail { - errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) - if !found { - errorType = fatalerror.Unknown + // Operator domain init happens only once, it's never suppressed, + // and it's terminal in case of failures + if err := doOperatorDomainInit(ctx, execCtx, initRequest.OperatorDomainExtraConfig); err != nil { + // TODO: I believe we need to handle this specially, because we want + // to consider any failure here as terminal + handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) + return } - doneFailMsg := &interop.DoneFail{ - ErrorType: errorType, - CorrelationID: correlationID, // required for standalone mode - Meta: interop.DoneMetadata{ - RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - InvokeReceivedTime: invokeReceivedTime, - }, + if !initRequest.SuppressInit { + // doRuntimeDomainInit() is used in both init/invoke, so the signature requires sbInfo arg + sbInfo := interop.SandboxInfoFromInit{ + EnvironmentVariables: initRequest.EnvironmentVariables, + SandboxType: initRequest.SandboxType, + RuntimeBootstrap: initRequest.Bootstrap, + } + if err := doRuntimeDomainInit(ctx, execCtx, sbInfo); err != nil { + handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) + return + } } - if invokeMx != nil { - doneFailMsg.Meta.InvokeRequestReadTimeNs = invokeMx.rendererMetrics.ReadTime.Nanoseconds() - doneFailMsg.Meta.InvokeRequestSizeBytes = int64(invokeMx.rendererMetrics.SizeBytes) - doneFailMsg.Meta.RuntimeReadyTime = int64(invokeMx.runtimeReadyTime) - doneFailMsg.Meta.ExtensionNames = execCtx.GetExtensionNames() + initSuccessMsg := interop.InitSuccess{ + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + ExtensionNames: execCtx.GetExtensionNames(), + Ack: make(chan struct{}), } if execCtx.telemetryAPIEnabled { - doneFailMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() + initSuccessMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } - return doneFailMsg + initSuccessResponse <- initSuccessMsg + <-initSuccessMsg.Ack } -func handleInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke) { - interopServer, appCtx := execCtx.interopServer, execCtx.appCtx - +func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { + ctx := execCtx.signalCtx invokeMx := invokeMetrics{} - if invokeRequest.ResyncState.IsResyncReceived { - err := execCtx.credentialsService.UpdateCredentials(invokeRequest.ResyncState.AwsKey, invokeRequest.ResyncState.AwsSecret, invokeRequest.ResyncState.AwsSession) - execCtx.credentialsService.UnblockService() - - if err != nil { - log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Resync for Invoke failed") - doneFailMsg := generateDoneFail(execCtx, invokeRequest.CorrelationID, &invokeMx, invokeRequest.InvokeReceivedTime) - handleInvokeError(doneFailMsg, execCtx, invokeRequest.ID, interopServer, err) - } - } - - if err := doInvoke(ctx, execCtx, watchdog, invokeRequest, &invokeMx); err != nil { + if err := doInvoke(ctx, execCtx, invokeRequest, &invokeMx, sbInfoFromInit); err != nil { log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Invoke failed") - doneFailMsg := generateDoneFail(execCtx, invokeRequest.CorrelationID, &invokeMx, invokeRequest.InvokeReceivedTime) - handleInvokeError(doneFailMsg, execCtx, invokeRequest.ID, interopServer, err) - return - } + invokeFailure := handleInvokeError(execCtx, invokeRequest, &invokeMx, err) - if err := execCtx.interopServer.CommitResponse(); err != nil { - log.Panic(err) + if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeFailure.ResponseMetrics = interop.ResponseMetrics{ + RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), + RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, + RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, + } + } + return interop.InvokeSuccess{}, invokeFailure } var invokeCompletionTimeNs int64 @@ -542,30 +672,35 @@ func handleInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Wat invokeCompletionTimeNs = time.Now().UnixNano() - responseTimeNs } - doneMsg := &interop.Done{ - CorrelationID: invokeRequest.CorrelationID, - Meta: interop.DoneMetadata{ - RuntimeRelease: appctx.GetRuntimeRelease(appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - ExtensionNames: execCtx.GetExtensionNames(), + invokeSuccessMsg := interop.InvokeSuccess{ + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + ExtensionNames: execCtx.GetExtensionNames(), + InvokeMetrics: interop.InvokeMetrics{ InvokeRequestReadTimeNs: invokeMx.rendererMetrics.ReadTime.Nanoseconds(), InvokeRequestSizeBytes: int64(invokeMx.rendererMetrics.SizeBytes), - InvokeCompletionTimeNs: invokeCompletionTimeNs, - InvokeReceivedTime: invokeRequest.InvokeReceivedTime, RuntimeReadyTime: invokeMx.runtimeReadyTime, }, + InvokeCompletionTimeNs: invokeCompletionTimeNs, + InvokeReceivedTime: invokeRequest.InvokeReceivedTime, } - if execCtx.telemetryAPIEnabled { - doneMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() + + if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeSuccessMsg.ResponseMetrics = interop.ResponseMetrics{ + RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), + RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, + RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, + } } - if err := interopServer.SendDone(doneMsg); err != nil { - log.Panic(err) + if execCtx.telemetryAPIEnabled { + invokeSuccessMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } + + return invokeSuccessMsg, nil } -func reinitialize(execCtx *rapidContext, watchdog *core.Watchdog) { - execCtx.interopServer.Clear() +func reinitialize(execCtx *rapidContext) { execCtx.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) execCtx.appCtx.Delete(appctx.AppCtxRuntimeReleaseKey) execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) @@ -576,90 +711,125 @@ func reinitialize(execCtx *rapidContext, watchdog *core.Watchdog) { execCtx.invokeFlow.Clear() if execCtx.telemetryAPIEnabled { execCtx.logsSubscriptionAPI.Clear() + execCtx.telemetrySubscriptionAPI.Clear() } - watchdog.Clear() -} - -func blockForever() { - select {} } // handle notification of reset -func handleReset(execCtx *rapidContext, watchdog *core.Watchdog, reset *interop.Reset) { - log.Warnf("Reset initiated: %s", reset.Reason) - if execCtx.initCachingEnabled { - execCtx.credentialsService.UnblockService() +func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { + log.Warnf("Reset initiated: %s", resetEvent.Reason) + + // Only send RuntimeDone event if we get a reset during an Invoke + if resetEvent.Reason == "failure" || resetEvent.Reason == "timeout" { + runtimeDoneEventData := telemetry.InvokeRuntimeDoneData{ + Status: resetEvent.Reason, + InternalMetrics: invokeResponseMetrics, + Metrics: telemetry.GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, metering.Monotime()), + Tracing: telemetry.BuildTracingCtx(model.XRayTracingType, resetEvent.TraceID, resetEvent.LambdaSegmentID), + Spans: telemetry.GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics), + } + if err := execCtx.eventsAPI.SendRuntimeDone(runtimeDoneEventData); err != nil { + log.Errorf("Failed to send RUNDONE: %s", err) + } } - if err := execCtx.eventsAPI.SendRuntimeDone(reset.Reason); err != nil { - log.Errorf("Failed to send RUNDONE: %s", err) + extensionsResetMs, resetTimeout, _ := execCtx.shutdownContext.shutdown(execCtx, resetEvent.DeadlineNs, resetEvent.Reason) + + log.Info("Starting runtime domain") + err := execCtx.supervisor.Start(&supvmodel.StartRequest{ + Domain: RuntimeDomain, + }) + if err != nil { + log.WithError(err).Panic("Failed booting runtime domain") } + execCtx.runtimeDomainGeneration++ - profiler := metering.ExtensionsResetDurationProfiler{} - gracefulShutdown(execCtx, watchdog, &profiler, reset.DeadlineNs, execCtx.standaloneMode, reset.Reason) + // Only used by standalone for more indepth assertions. + var fatalErrorType fatalerror.ErrorType - extensionsResetMs, resetTimeout := profiler.CalculateExtensionsResetMs() + if execCtx.standaloneMode { + fatalErrorType, _ = appctx.LoadFirstFatalError(execCtx.appCtx) + } - meta := interop.DoneMetadata{ - ExtensionsResetMs: extensionsResetMs, + var responseMetrics interop.ResponseMetrics + if resetEvent.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(resetEvent.InvokeResponseMetrics) { + responseMetrics.RuntimeTimeThrottledMs = resetEvent.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) + responseMetrics.RuntimeProducedBytes = resetEvent.InvokeResponseMetrics.ProducedBytes + responseMetrics.RuntimeOutboundThroughputBps = resetEvent.InvokeResponseMetrics.OutboundThroughputBps } - if !execCtx.standaloneMode { - // GIRP interopServer implementation sends GIRP RSTFAIL/RSTDONE - if resetTimeout { - // TODO: DoneFail must contain a reset timeout ErrorType for rapid local to distinguish errors - doneFail := &interop.DoneFail{CorrelationID: reset.CorrelationID, Meta: meta} - if err := execCtx.interopServer.SendDoneFail(doneFail); err != nil { - log.Panicf("Failed to SendDoneFail: %s", err) - } - } else { - done := &interop.Done{CorrelationID: reset.CorrelationID, Meta: meta} - if err := execCtx.interopServer.SendDone(done); err != nil { - log.Panicf("Failed to SendDone: %s", err) - } + if resetTimeout { + return interop.ResetSuccess{}, &interop.ResetFailure{ + ExtensionsResetMs: extensionsResetMs, + ErrorType: fatalErrorType, + ResponseMetrics: responseMetrics, } - - os.Exit(0) } - reinitialize(execCtx, watchdog) + return interop.ResetSuccess{ + ExtensionsResetMs: extensionsResetMs, + ErrorType: fatalErrorType, + ResponseMetrics: responseMetrics, + }, nil +} + +// handle notification of shutdown +func handleShutdown(execCtx *rapidContext, shutdownEvent *interop.Shutdown, reason string) interop.ShutdownSuccess { + log.Warnf("Shutdown initiated: %s", reason) + // TODO Handle shutdown error + _, _, _ = execCtx.shutdownContext.shutdown(execCtx, shutdownEvent.DeadlineNs, reason) - fatalErrorType, _ := appctx.LoadFirstFatalError(execCtx.appCtx) + // Only used by standalone for more indepth assertions. + var fatalErrorType fatalerror.ErrorType - if resetTimeout { - doneFail := &interop.DoneFail{CorrelationID: reset.CorrelationID, ErrorType: fatalErrorType, Meta: meta} - if err := execCtx.interopServer.SendDoneFail(doneFail); err != nil { - log.Panicf("Failed to SendDoneFail: %s", err) - } - } else { - done := &interop.Done{CorrelationID: reset.CorrelationID, ErrorType: fatalErrorType, Meta: meta} - if err := execCtx.interopServer.SendDone(done); err != nil { - log.Panicf("Failed to SendDone: %s", err) - } + if execCtx.standaloneMode { + fatalErrorType, _ = appctx.LoadFirstFatalError(execCtx.appCtx) } + + return interop.ShutdownSuccess{ErrorType: fatalErrorType} } -// handle notification of shutdown -func handleShutdown(execCtx *rapidContext, watchdog *core.Watchdog, shutdown *interop.Shutdown, reason string) { - log.Warnf("Shutdown initiated") +func handleRestore(execCtx *rapidContext, restore *interop.Restore) error { + err := execCtx.credentialsService.UpdateCredentials(restore.AwsKey, restore.AwsSecret, restore.AwsSession, restore.CredentialsExpiry) + restoreStatus := telemetry.RuntimeDoneSuccess + + defer func() { + sendRestoreRuntimeDoneLogEvent(execCtx, restoreStatus) + }() + + if err != nil { + return fmt.Errorf("error when updating credentials: %s", err) + } + renderer := rendering.NewRestoreRenderer() + execCtx.renderingService.SetRenderer(renderer) + + registrationService := execCtx.registrationService + runtime := registrationService.GetRuntime() + + // If runtime has not called /restore/next then just return + // instead of releasing the Runtime since there is no need to release. + // Then the runtime should be released only during Invoke + if runtime.GetState() != runtime.RuntimeRestoreReadyState { + restoreStatus = telemetry.RuntimeDoneSuccess + log.Infof("Runtime is in state: %s just returning", runtime.GetState().Name()) + return nil + } - gracefulShutdown(execCtx, watchdog, &metering.ExtensionsResetDurationProfiler{}, shutdown.DeadlineNs, true, reason) + runtime.Release() - fatalErrorType, _ := appctx.LoadFirstFatalError(execCtx.appCtx) + initFlow := execCtx.initFlow + err = initFlow.AwaitRuntimeReady() - if err := execCtx.interopServer.SendDone(&interop.Done{CorrelationID: shutdown.CorrelationID, ErrorType: fatalErrorType}); err != nil { - log.Panicf("Failed to SendDone: %s", err) + if err != nil { + restoreStatus = telemetry.RuntimeDoneFailure + } else { + restoreStatus = telemetry.RuntimeDoneSuccess } - // Shutdown induces a terminal state and no further messages will be processed - blockForever() + return err } func start(signalCtx context.Context, execCtx *rapidContext) { - watchdog := core.NewWatchdog(execCtx.registrationService.InitFlow(), execCtx.invokeFlow, execCtx.exitPidChan, execCtx.appCtx) - - interopServer := execCtx.interopServer - // Start Runtime API Server err := execCtx.server.Listen() if err != nil { @@ -670,30 +840,40 @@ func start(signalCtx context.Context, execCtx *rapidContext) { // Note, most of initialization code should run before blocking to receive START, // code before START runs in parallel with code downloads. +} - go func() { - for { - reset := <-interopServer.ResetChan() - // In the event of a Reset during init/invoke, CancelFlows cancels execution - // flows and return with the errResetReceived err - this error is special-cased - // and not handled by the init/invoke (unexpected) error handling functions - watchdog.CancelFlows(errResetReceived) - execCtx.resetChan <- reset - } - }() +func sendRestoreRuntimeDoneLogEvent(execCtx *rapidContext, status string) { + if err := execCtx.eventsAPI.SendRestoreRuntimeDone(status); err != nil { + log.Errorf("Failed to send RESTRD: %s", err) + } +} + +func sendInitRuntimeDoneLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, status string) { + initSource := interop.InferTelemetryInitSource(execCtx.initCachingEnabled, sandboxType) + + runtimeDoneData := &telemetry.InitRuntimeDoneData{ + InitSource: initSource, + Status: status, + } - for { - select { - case start := <-interopServer.StartChan(): - handleStart(signalCtx, execCtx, watchdog, start) - case invoke := <-interopServer.InvokeChan(): - handleInvoke(signalCtx, execCtx, watchdog, invoke) - case err := <-interopServer.TransportErrorChan(): - log.Panicf("Transport error emitted by interop server: %s", err) - case reset := <-execCtx.resetChan: - handleReset(execCtx, watchdog, reset) - case shutdown := <-interopServer.ShutdownChan(): // only in standalone mode - handleShutdown(execCtx, watchdog, shutdown, standaloneShutdownReason) + if err := execCtx.eventsAPI.SendInitRuntimeDone(runtimeDoneData); err != nil { + log.Errorf("Failed to send INITRD: %s", err) + } +} + +// This function will log a line if AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, or AWS_SESSION_TOKEN is missing +// This is expected to happen in cases when credentials provider is not needed +func checkCredentials(execCtx *rapidContext, bootstrapEnv map[string]string) { + credentialsKeys := []string{"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"} + missingCreds := []string{} + + for _, credEnvVar := range credentialsKeys { + if val, keyExists := bootstrapEnv[credEnvVar]; !keyExists || val == "" { + missingCreds = append(missingCreds, credEnvVar) } } + + if len(missingCreds) > 0 { + log.Infof("Starting runtime without %s , Expected?: %t", strings.Join(missingCreds[:], ", "), execCtx.initCachingEnabled) + } } diff --git a/lambda/rapid/start_test.go b/lambda/rapid/start_test.go index 2363705..ffb446f 100644 --- a/lambda/rapid/start_test.go +++ b/lambda/rapid/start_test.go @@ -7,7 +7,7 @@ import ( "context" "fmt" "go.amzn.com/lambda/core" - "io/ioutil" + "io" "net/http" "regexp" "strconv" @@ -142,7 +142,7 @@ func TestListen(t *testing.T) { ctx := context.Background() telemetryAPIEnabled := true - server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.LogsSubscriptionAPI, false, flowTest.CredentialsService) + server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.TelemetrySubscription, flowTest.TelemetrySubscription, flowTest.CredentialsService, flowTest.EventsAPI) err := server.Listen() assert.NoError(t, err) @@ -161,7 +161,7 @@ func TestListen(t *testing.T) { resp, err1 := http.Get(fmt.Sprintf("http://%s:%d/2018-06-01/runtime/invocation/next", server.Host(), server.Port())) assert.Nil(t, err1) - body, err2 := ioutil.ReadAll(resp.Body) + body, err2 := io.ReadAll(resp.Body) assert.Nil(t, err2) assert.Equal(t, "MyTest", string(body)) @@ -171,3 +171,31 @@ func TestListen(t *testing.T) { <-done } + +func TestInferSandboxInitTypeOnDemand(t *testing.T) { + initCachingEnabled := false + sandboxType := interop.SandboxClassic + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "on-demand", initSource) +} + +func TestInferSandboxInitTypeProvisionedConcurrency(t *testing.T) { + initCachingEnabled := false + sandboxType := interop.SandboxPreWarmed + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "provisioned-concurrency", initSource) +} + +func TestInferSandboxInitTypeInitCaching(t *testing.T) { + initCachingEnabled := true + sandboxType := interop.SandboxClassic + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "snap-start", initSource) +} + +func TestInferSandboxInitTypeInitCachingWithPC(t *testing.T) { + initCachingEnabled := true + sandboxType := interop.SandboxPreWarmed + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "snap-start", initSource) +} diff --git a/lambda/rapidcore/bootstrap.go b/lambda/rapidcore/bootstrap.go index 9faf518..165f532 100644 --- a/lambda/rapidcore/bootstrap.go +++ b/lambda/rapidcore/bootstrap.go @@ -6,11 +6,12 @@ package rapidcore import ( "fmt" "os" + "path" "path/filepath" + "strings" "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/logging" - "go.amzn.com/lambda/rapid" + "go.amzn.com/lambda/interop" log "github.com/sirupsen/logrus" ) @@ -21,6 +22,7 @@ type BootstrapError func() (fatalerror.ErrorType, LogFormatter) // Bootstrap represents a list of executable bootstrap // candidates in order of priority and exec metadata type Bootstrap struct { + runtimeDomainRoot string orderedLookupPaths []string validCmd []string workingDir string @@ -29,8 +31,11 @@ type Bootstrap struct { bootstrapError BootstrapError } +// Validate interface compliance +var _ interop.Bootstrap = (*Bootstrap)(nil) + // NewBootstrap returns an instance of bootstrap defined by given params -func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string) *Bootstrap { +func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string, runtimeDomainRoot string) *Bootstrap { var orderedLookupBootstrapPaths []string for _, args := range cmdCandidates { // Empty args is an error, but we want to detect it later (in Cmd() call) when we are able to report a descriptive error @@ -44,23 +49,32 @@ func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string) *Bootstrap currentWorkingDir = "/" } + if runtimeDomainRoot == "" { + runtimeDomainRoot = "/" + } + return &Bootstrap{ orderedLookupPaths: orderedLookupBootstrapPaths, workingDir: currentWorkingDir, cmdCandidates: cmdCandidates, + runtimeDomainRoot: runtimeDomainRoot, } } -func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string) *Bootstrap { +func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string, runtimeDomainRoot string) *Bootstrap { if currentWorkingDir == "" { // use the root directory as the default working directory currentWorkingDir = "/" } + if runtimeDomainRoot == "" { + runtimeDomainRoot = "/" + } // a single candidate command makes it automatically valid return &Bootstrap{ - validCmd: cmd, - workingDir: currentWorkingDir, + validCmd: cmd, + workingDir: currentWorkingDir, + runtimeDomainRoot: runtimeDomainRoot, } } @@ -68,16 +82,28 @@ func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string) *Bootstrap { // actual bootstrap, given a list of possible files func (b *Bootstrap) locateBootstrap() error { for i, bootstrapCandidate := range b.orderedLookupPaths { - if file, err := os.Stat(bootstrapCandidate); !os.IsNotExist(err) && !file.IsDir() { - b.validCmd = b.cmdCandidates[i] - return nil + // validate path relatively to the domain's root + candidatPath := path.Join(b.runtimeDomainRoot, bootstrapCandidate) + file, err := os.Stat(candidatPath) + if err != nil { + if !os.IsNotExist(err) { + log.WithError(err).Warnf("Could not validate %s. Ignoring it.", bootstrapCandidate) + } + continue } + if file.IsDir() { + log.Warnf("%s is a directory. Ignoring it", bootstrapCandidate) + continue + } + b.validCmd = b.cmdCandidates[i] + return nil } log.WithField("bootstrapPathsChecked", b.orderedLookupPaths).Warn("Couldn't find valid bootstrap(s)") return fmt.Errorf("Couldn't find valid bootstrap(s): %s", b.orderedLookupPaths) } -// Cmd returns the args of bootstrap, where args[0] +// Cmd returns the args of bootstrap, relative to the +// chroot idenfied by `root`, where args[0] // is the path to executable func (b *Bootstrap) Cmd() ([]string, error) { if len(b.validCmd) > 0 { @@ -94,16 +120,21 @@ func (b *Bootstrap) Cmd() ([]string, error) { // Env returns the environment variables available to // the bootstrap process -func (b *Bootstrap) Env(e rapid.EnvironmentVariables) []string { +func (b *Bootstrap) Env(e interop.EnvironmentVariables) map[string]string { return e.RuntimeExecEnv() } // Cwd returns the working directory of the bootstrap process +// The path is validated against the chroot identified by `root` func (b *Bootstrap) Cwd() (string, error) { if !filepath.IsAbs(b.workingDir) { return "", fmt.Errorf("the working directory '%s' is invalid, it needs to be an absolute path", b.workingDir) - } else if _, err := os.Stat(b.workingDir); os.IsNotExist(err) { - return "", fmt.Errorf("the working directory doesn't exist: %s", b.workingDir) + } + + // evaluate the path relatively to the domain's mnt namespace root + domainPath := path.Join(b.runtimeDomainRoot, b.workingDir) + if _, err := os.Stat(domainPath); os.IsNotExist(err) { + return "", fmt.Errorf("the working directory doesn't exist: %s", domainPath) } return b.workingDir, nil @@ -140,19 +171,35 @@ func (b *Bootstrap) SetCachedFatalError(bootstrapErrFn BootstrapError) { // BootstrapErrInvalidLCISTaskConfig represents an error while parsing LCIS task config func BootstrapErrInvalidLCISTaskConfig(err error) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidTaskConfig, logging.SupernovaInvalidTaskConfigRepr(err) + return fatalerror.InvalidTaskConfig, SupernovaInvalidTaskConfigRepr(err) } } // BootstrapErrInvalidLCISEntrypoint represents an invalid LCIS entrypoint error func BootstrapErrInvalidLCISEntrypoint(entrypoint []string, cmd []string, workingdir string) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidEntrypoint, logging.SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) + return fatalerror.InvalidEntrypoint, SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) } } func BootstrapErrInvalidLCISWorkingDir(entrypoint []string, cmd []string, workingdir string) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidWorkingDir, logging.SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) + return fatalerror.InvalidWorkingDir, SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) + } +} + +func SupernovaInvalidTaskConfigRepr(err error) func(error) string { + return func(unused error) string { + return fmt.Sprintf("IMAGE\tInvalid task config: %s", err) + } +} + +func SupernovaLaunchErrorRepr(entrypoint []string, cmd []string, workingDir string) func(error) string { + return func(err error) string { + return fmt.Sprintf("IMAGE\tLaunch error: %s\tEntrypoint: [%s]\tCmd: [%s]\tWorkingDir: [%s]", + err, + strings.Join(entrypoint, ","), + strings.Join(cmd, ","), + workingDir) } } diff --git a/lambda/rapidcore/bootstrap_test.go b/lambda/rapidcore/bootstrap_test.go index 4700130..b43520d 100644 --- a/lambda/rapidcore/bootstrap_test.go +++ b/lambda/rapidcore/bootstrap_test.go @@ -4,8 +4,10 @@ package rapidcore import ( - "io/ioutil" "os" + "path" + "path/filepath" + "reflect" "testing" "go.amzn.com/lambda/rapidcore/env" @@ -14,11 +16,11 @@ import ( ) func TestBootstrap(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") assert.NoError(t, err) defer os.RemoveAll(tmpDir) - tmpFile, err := ioutil.TempFile("", "lcis-test-bootstrap") + tmpFile, err := os.CreateTemp("", "lcis-test-bootstrap") assert.NoError(t, err) defer os.Remove(tmpFile.Name()) @@ -38,11 +40,57 @@ func TestBootstrap(t *testing.T) { environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") // Test - b := NewBootstrap(cmdCandidates, cwd) + b := NewBootstrap(cmdCandidates, cwd, "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, cwd, bCwd) - assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) + + cmd, err := b.Cmd() + assert.NoError(t, err) + assert.Equal(t, file, cmd) +} + +// When running bootstraps in separate mount namespaces +// we want to verify and discover paths relative to +// a root different from "/" +func TestBootstrapChroot(t *testing.T) { + tmpRoot, err := os.MkdirTemp(os.TempDir(), "domain-root") + assert.NoError(t, err) + defer os.RemoveAll(tmpRoot) + tmpDir, err := os.MkdirTemp(tmpRoot, "lcis-test-invalid-bootstrap") + assert.NoError(t, err) + defer os.RemoveAll(tmpDir) + + tmpFile, err := os.CreateTemp(tmpRoot, "lcis-test-bootstrap") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // Setup cmd candidates + nonExistent := []string{"/foo/bar/baz"} + baseName := filepath.Base(tmpDir) + dir := []string{"/" + baseName, "--arg1", "foo"} + baseName = filepath.Base(tmpFile.Name()) + file := []string{"/" + baseName, "--arg1 s", "foo"} + cmdCandidates := [][]string{nonExistent, dir, file} + + // Setup working dir + cwd, err := os.MkdirTemp(tmpRoot, "cwd") + assert.NoError(t, err) + defer os.RemoveAll(cwd) + + // Setup environment + environment := env.NewEnvironment() + environment.StoreRuntimeAPIEnvironmentVariable("host:port") + environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") + + // Test + baseName = filepath.Base(cwd) + b := NewBootstrap(cmdCandidates, "/"+baseName, tmpRoot) + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, cwd, path.Join(tmpRoot, bCwd)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) cmd, err := b.Cmd() assert.NoError(t, err) @@ -53,17 +101,24 @@ func TestBootstrapEmptyCandidate(t *testing.T) { // we expect newBootstrap to succeed and bootstrap.Cmd() to fail. // We want to postpone the failure to be able to propagate error description to slicer and write it to customer log invalidBootstrapCandidate := []string{} - bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/") + bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/", "") + _, err := bs.Cmd() + assert.Error(t, err) +} + +func TestBootstrapChrootNonExistingRoot(t *testing.T) { + invalidBootstrapCandidate := []string{"/bin/bash", "-c"} + bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/", "/does_not_exist") _, err := bs.Cmd() assert.Error(t, err) } func TestBootstrapSingleCmd(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") assert.NoError(t, err) defer os.RemoveAll(tmpDir) - tmpFile, err := ioutil.TempFile("", "lcis-test-bootstrap") + tmpFile, err := os.CreateTemp("", "lcis-test-bootstrap") assert.NoError(t, err) defer os.Remove(tmpFile.Name()) @@ -81,11 +136,11 @@ func TestBootstrapSingleCmd(t *testing.T) { environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") // Test - b := NewBootstrapSingleCmd(cmdCandidate, cwd) + b := NewBootstrapSingleCmd(cmdCandidate, cwd, "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, cwd, bCwd) - assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) cmd, err := b.Cmd() assert.NoError(t, err) @@ -93,7 +148,7 @@ func TestBootstrapSingleCmd(t *testing.T) { } func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") assert.NoError(t, err) defer os.RemoveAll(tmpDir) @@ -111,11 +166,11 @@ func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") // Test - b := NewBootstrapSingleCmd(cmdCandidate, cwd) + b := NewBootstrapSingleCmd(cmdCandidate, cwd, "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, cwd, bCwd) - assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) // No validations run against single candidates cmd, err := b.Cmd() @@ -125,100 +180,100 @@ func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { // Test our ability to locate bootstrap files in the file system func TestFindCustomRuntimeIfExists(t *testing.T) { - tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp-") + tmpFile, err := os.CreateTemp(os.TempDir(), "tmp-") if err != nil { t.Fatal("Cannot create temporary file", err) } defer os.Remove(tmpFile.Name()) - tmpFile2, err := ioutil.TempFile(os.TempDir(), "tmp-") + tmpFile2, err := os.CreateTemp(os.TempDir(), "tmp-") if err != nil { t.Fatal("Cannot create temporary file", err) } defer os.Remove(tmpFile2.Name()) // one bootstrap argument was given and it exists - bootstrap := NewBootstrap([][]string{[]string{tmpFile.Name()}}, "/") + bootstrap := NewBootstrap([][]string{{tmpFile.Name()}}, "/", "") cmd, err := bootstrap.Cmd() assert.NoError(t, err) assert.Equal(t, []string{tmpFile.Name()}, cmd) assert.Nil(t, err) // two bootstrap arguments given, both exist but first one is returned - bootstrap = NewBootstrap([][]string{[]string{tmpFile.Name()}, []string{tmpFile2.Name()}}, "/") + bootstrap = NewBootstrap([][]string{{tmpFile.Name()}, {tmpFile2.Name()}}, "/", "") cmd, err = bootstrap.Cmd() assert.NoError(t, err) assert.Equal(t, []string{tmpFile.Name()}, cmd) assert.Nil(t, err) // two bootstrap arguments given, first one does not exist, second exists and is returned - bootstrap = NewBootstrap([][]string{[]string{"mk"}, []string{tmpFile2.Name()}}, "/") + bootstrap = NewBootstrap([][]string{{"mk"}, {tmpFile2.Name()}}, "/", "") cmd, err = bootstrap.Cmd() assert.NoError(t, err) assert.Equal(t, []string{tmpFile2.Name()}, cmd) assert.Nil(t, err) // two bootstrap arguments given, none exists - bootstrap = NewBootstrap([][]string{[]string{"mk"}, []string{"mk2"}}, "/") + bootstrap = NewBootstrap([][]string{{"mk"}, {"mk2"}}, "/", "") cmd, err = bootstrap.Cmd() assert.EqualError(t, err, "Couldn't find valid bootstrap(s): [mk mk2]") assert.Equal(t, []string{}, cmd) } func TestCwdIsAbsolute(t *testing.T) { - tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp-") + tmpFile, err := os.CreateTemp(os.TempDir(), "tmp-") if err != nil { t.Fatal("Cannot create temporary file", err) } defer os.Remove(tmpFile.Name()) - cmdCandidates := [][]string{[]string{tmpFile.Name()}} + cmdCandidates := [][]string{{tmpFile.Name()}} // no errors when currentWorkingDir is absolute - bootstrap := NewBootstrap(cmdCandidates, "/tmp") + bootstrap := NewBootstrap(cmdCandidates, "/tmp", "") cwd, err := bootstrap.Cwd() assert.Nil(t, err) assert.Equal(t, "/tmp", cwd) - bootstrap = NewBootstrap(cmdCandidates, "tmp") + bootstrap = NewBootstrap(cmdCandidates, "tmp", "") _, err = bootstrap.Cwd() assert.EqualError(t, err, "the working directory 'tmp' is invalid, it needs to be an absolute path") - bootstrap = NewBootstrap(cmdCandidates, "./") + bootstrap = NewBootstrap(cmdCandidates, "./", "") _, err = bootstrap.Cwd() assert.EqualError(t, err, "the working directory './' is invalid, it needs to be an absolute path") } func TestBootstrapMissingWorkingDirectory(t *testing.T) { - tmpFile, err := ioutil.TempFile(os.TempDir(), "cwd-test-bootstrap") + tmpFile, err := os.CreateTemp(os.TempDir(), "cwd-test-bootstrap") assert.NoError(t, err) defer os.Remove(tmpFile.Name()) - tmpDir, err := ioutil.TempDir("", "cwd-test") + tmpDir, err := os.MkdirTemp("", "cwd-test") assert.NoError(t, err) defer os.RemoveAll(tmpDir) // cwd argument exists - bootstrap := NewBootstrap([][]string{[]string{tmpFile.Name()}}, tmpDir) + bootstrap := NewBootstrap([][]string{{tmpFile.Name()}}, tmpDir, "") cwd, err := bootstrap.Cwd() assert.Equal(t, cwd, tmpDir) assert.NoError(t, err) // cwd argument doesn't exist - bootstrap = NewBootstrap([][]string{[]string{tmpFile.Name()}}, "/foo") + bootstrap = NewBootstrap([][]string{{tmpFile.Name()}}, "/foo", "") _, err = bootstrap.Cwd() assert.EqualError(t, err, "the working directory doesn't exist: /foo") } func TestDefaultWorkeringDirectory(t *testing.T) { - bootstrap := NewBootstrap([][]string{[]string{}}, "") + bootstrap := NewBootstrap([][]string{{}}, "", "") cwd, err := bootstrap.Cwd() assert.NoError(t, err) assert.Equal(t, "/", cwd) } func TestBootstrapSingleCmdDefaultWorkingDir(t *testing.T) { - b := NewBootstrapSingleCmd([]string{}, "") + b := NewBootstrapSingleCmd([]string{}, "", "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, "/", bCwd) diff --git a/lambda/rapidcore/env/environment.go b/lambda/rapidcore/env/environment.go index 699abda..be0584c 100644 --- a/lambda/rapidcore/env/environment.go +++ b/lambda/rapidcore/env/environment.go @@ -149,26 +149,24 @@ func (e *Environment) mergeCustomerEnvironmentVariables(envVars map[string]strin // RuntimeExecEnv returns the key=value strings of all environment variables // passed to runtime process on exec() -func (e *Environment) RuntimeExecEnv() []string { +func (e *Environment) RuntimeExecEnv() map[string]string { if !e.initEnvVarsSet || !e.runtimeAPISet { log.Fatal("credentials, customer and runtime API address must be set") } - return asEnviron(mapUnion(e.Customer, e.PlatformUnreserved, e.Credentials, e.Runtime, e.Platform)) + return mapUnion(e.Customer, e.PlatformUnreserved, e.Credentials, e.Runtime, e.Platform) } // AgentExecEnv returns the key=value strings of all environment variables // passed to agent process on exec() -func (e *Environment) AgentExecEnv() []string { +func (e *Environment) AgentExecEnv() map[string]string { if !e.initEnvVarsSet || !e.runtimeAPISet { log.Fatal("credentials, customer and runtime API address must be set") } excludedKeys := extensionExcludedKeys() excludeCondition := func(key string) bool { return excludedKeys[key] || strings.HasPrefix(key, "_") } - environ := asEnviron(mapExclude(mapUnion(e.Customer, e.Credentials, e.Platform), excludeCondition)) - - return environ + return mapExclude(mapUnion(e.Customer, e.Credentials, e.Platform), excludeCondition) } // RAPIDInternalConfig returns the rapid config parsed from environment vars @@ -249,14 +247,6 @@ func mapUnion(maps ...map[string]string) map[string]string { return union } -func asEnviron(m map[string]string) []string { - keySepValArray := []string{} - for key, val := range m { - keySepValArray = append(keySepValArray, key+"="+val) - } - return keySepValArray -} - func mapExclude(m map[string]string, excludeCondition func(string) bool) map[string]string { res := map[string]string{} for key, val := range m { diff --git a/lambda/rapidcore/env/environment_test.go b/lambda/rapidcore/env/environment_test.go index cdfef24..ed3043c 100644 --- a/lambda/rapidcore/env/environment_test.go +++ b/lambda/rapidcore/env/environment_test.go @@ -6,12 +6,21 @@ package env import ( "fmt" "os" - "strings" "testing" "github.com/stretchr/testify/assert" ) +func envToSlice(env map[string]string) []string { + ret := make([]string, len(env)) + i := 0 + for key, val := range env { + ret[i] = key + "=" + val + i++ + } + return ret +} + func TestRAPIDInternalConfig(t *testing.T) { os.Clearenv() os.Setenv("_LAMBDA_SB_ID", "sbid") @@ -121,34 +130,35 @@ func TestRuntimeExecEnvironmentVariables(t *testing.T) { rapidEnvVars := env.RuntimeExecEnv() var rapidEnvKeys []string - for _, keyval := range rapidEnvVars { - key := strings.Split(keyval, "=")[0] + for key := range rapidEnvVars { rapidEnvKeys = append(rapidEnvKeys, key) } + rapidEnvVarsSlice := envToSlice(rapidEnvVars) + for key := range env.RAPID { assert.NotContains(t, rapidEnvKeys, key) } for key, val := range env.Runtime { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.Platform { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.PlatformUnreserved { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) assert.NotContains(t, env.Customer, key) } for key, val := range env.Credentials { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.Customer { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) assert.NotContains(t, env.PlatformUnreserved, key) } } @@ -191,7 +201,7 @@ func TestRuntimeExecEnvironmentVariablesPriority(t *testing.T) { assert.Equal(t, len(predefinedInternalEnvVarKeys()), len(env.RAPID)) assert.Equal(t, len(predefinedRuntimeEnvVarKeys()), len(env.Runtime)) - rapidEnvVars := env.RuntimeExecEnv() + rapidEnvVars := envToSlice(env.RuntimeExecEnv()) // Customer env vars cannot override platform/internal ones assert.NotContains(t, rapidEnvVars, conflictPlatformKeyFromInit+"="+customerEnvVal) @@ -224,7 +234,7 @@ func TestCustomerEnvironmentVariablesFromInitCanOverrideEnvironmentVariablesFrom assert.Equal(t, env.Customer["LCIS_ARG1"], lcisCLIArgEnvVal) assert.Equal(t, env.Customer["MY_FOOBAR_ENV_1"], customerEnvVal) - rapidEnvVars := env.RuntimeExecEnv() + rapidEnvVars := envToSlice(env.RuntimeExecEnv()) assert.Contains(t, rapidEnvVars, "LCIS_ARG1="+lcisCLIArgEnvVal) assert.Contains(t, rapidEnvVars, "MY_FOOBAR_ENV_1="+customerEnvVal) @@ -250,17 +260,18 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { agentEnvVars := env.AgentExecEnv() var agentEnvKeys []string - for _, keyval := range agentEnvVars { - key := strings.Split(keyval, "=")[0] + for key := range agentEnvVars { agentEnvKeys = append(agentEnvKeys, key) } + agentEnvVarsSlice := envToSlice(agentEnvVars) + for key := range env.RAPID { assert.NotContains(t, agentEnvKeys, key) } for key, val := range env.Runtime { - assert.NotContains(t, agentEnvKeys, key+"="+val) + assert.NotContains(t, agentEnvVarsSlice, key+"="+val) } for key := range env.Platform { @@ -272,10 +283,10 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { } for key, val := range env.Credentials { - assert.Contains(t, agentEnvVars, key+"="+val) + assert.Contains(t, agentEnvVarsSlice, key+"="+val) } - assert.Contains(t, agentEnvVars, runtimeAPIAddressKey+"="+env.Platform[runtimeAPIAddressKey]) + assert.Contains(t, agentEnvVarsSlice, runtimeAPIAddressKey+"="+env.Platform[runtimeAPIAddressKey]) } func TestStoreEnvironmentVariablesFromInitCaching(t *testing.T) { diff --git a/lambda/rapidcore/errors.go b/lambda/rapidcore/errors.go index 06a4830..7f35ca8 100644 --- a/lambda/rapidcore/errors.go +++ b/lambda/rapidcore/errors.go @@ -5,9 +5,9 @@ package rapidcore import "errors" -var ErrInitAlreadyDone = errors.New("InitAlreadyDone") var ErrInitDoneFailed = errors.New("InitDoneFailed") -var ErrInitError = errors.New("InitError") +var ErrInitNotStarted = errors.New("InitNotStarted") +var ErrInitResetReceived = errors.New("InitResetReceived") var ErrNotReserved = errors.New("NotReserved") var ErrAlreadyReserved = errors.New("AlreadyReserved") @@ -23,5 +23,3 @@ var ErrReleaseReservationDone = errors.New("ReleaseReservationDone") var ErrInternalServerError = errors.New("InternalServerError") var ErrInvokeTimeout = errors.New("InvokeTimeout") - -var ErrTerminated = errors.New("SandboxTerminated") // sent to signal a process exit diff --git a/lambda/rapidcore/sandbox.go b/lambda/rapidcore/sandbox.go deleted file mode 100644 index 7d5a8a9..0000000 --- a/lambda/rapidcore/sandbox.go +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "context" - "io" - "io/ioutil" - "net/http" - "os" - "os/signal" - "syscall" - - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/logging" - "go.amzn.com/lambda/rapid" - "go.amzn.com/lambda/rapidcore/env" - "go.amzn.com/lambda/telemetry" - - log "github.com/sirupsen/logrus" -) - -const ( - defaultSigtermResetTimeoutMs = int64(2000) -) - -type Sandbox interface { - Init(i *interop.Init, invokeTimeoutMs int64) - Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error - InteropServer() InteropServer -} - -type ReserveResponse struct { - Token interop.Token - InternalState *statejson.InternalStateDescription -} - -type InteropServer interface { - FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error - Reserve(id string, traceID, lambdaSegmentID string) (*ReserveResponse, error) - Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) - AwaitRelease() (*statejson.InternalStateDescription, error) - Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription - InternalState() (*statejson.InternalStateDescription, error) - CurrentToken() *interop.Token -} - -type SandboxBuilder struct { - sandbox *rapid.Sandbox - defaultInteropServer *Server - useCustomInteropServer bool - shutdownFuncs []context.CancelFunc - debugTailLogWriter io.Writer - platformLogWriter io.Writer -} - -type logSink int - -const ( - RuntimeLogSink logSink = iota - ExtensionLogSink -) - -func NewSandboxBuilder(bootstrap *Bootstrap) *SandboxBuilder { - defaultInteropServer := NewServer(context.Background()) - signalCtx, cancelSignalCtx := context.WithCancel(context.Background()) - logsEgressAPI := &telemetry.NoOpLogsEgressAPI{} - runtimeStdoutWriter, runtimeStderrWriter, _ := logsEgressAPI.GetRuntimeSockets() - - b := &SandboxBuilder{ - sandbox: &rapid.Sandbox{ - Bootstrap: bootstrap, - PreLoadTimeNs: 0, // TODO - StandaloneMode: true, - RuntimeStdoutWriter: runtimeStdoutWriter, - RuntimeStderrWriter: runtimeStderrWriter, - LogsEgressAPI: logsEgressAPI, - EnableTelemetryAPI: false, - Environment: env.NewEnvironment(), - Tracer: telemetry.NewNoOpTracer(), - SignalCtx: signalCtx, - EventsAPI: &telemetry.NoOpEventsAPI{}, - InitCachingEnabled: false, - }, - defaultInteropServer: defaultInteropServer, - shutdownFuncs: []context.CancelFunc{}, - debugTailLogWriter: ioutil.Discard, - platformLogWriter: ioutil.Discard, - } - - b.AddShutdownFunc(context.CancelFunc(func() { - log.Info("Shutting down...") - defaultInteropServer.Reset("SandboxTerminated", defaultSigtermResetTimeoutMs) - cancelSignalCtx() - })) - - return b -} - -func (b *SandboxBuilder) SetInteropServer(interopServer interop.Server) *SandboxBuilder { - b.sandbox.InteropServer = interopServer - b.useCustomInteropServer = true - return b -} - -func (b *SandboxBuilder) SetEventsAPI(eventsAPI telemetry.EventsAPI) *SandboxBuilder { - b.sandbox.EventsAPI = eventsAPI - return b -} - -func (b *SandboxBuilder) SetTracer(tracer telemetry.Tracer) *SandboxBuilder { - b.sandbox.Tracer = tracer - return b -} - -func (b *SandboxBuilder) DisableStandaloneMode() *SandboxBuilder { - b.sandbox.StandaloneMode = false - return b -} - -func (b *SandboxBuilder) SetExtensionsFlag(extensionsEnabled bool) *SandboxBuilder { - if extensionsEnabled { - extensions.Enable() - } else { - extensions.Disable() - } - return b -} - -func (b *SandboxBuilder) SetInitCachingFlag(initCachingEnabled bool) *SandboxBuilder { - b.sandbox.InitCachingEnabled = initCachingEnabled - return b -} - -func (b *SandboxBuilder) SetPreLoadTimeNs(preLoadTimeNs int64) *SandboxBuilder { - b.sandbox.PreLoadTimeNs = preLoadTimeNs - return b -} - -func (b *SandboxBuilder) SetEnvironmentVariables(environment *env.Environment) *SandboxBuilder { - b.sandbox.Environment = environment - return b -} - -func (b *SandboxBuilder) SetPlatformLogOutput(w io.Writer) *SandboxBuilder { - b.platformLogWriter = w - return b -} - -func (b *SandboxBuilder) SetTailLogOutput(w io.Writer) *SandboxBuilder { - b.debugTailLogWriter = w - return b -} - -func (b *SandboxBuilder) SetLogsSubscriptionAPI(logsSubscriptionAPI telemetry.LogsSubscriptionAPI) *SandboxBuilder { - b.sandbox.EnableTelemetryAPI = true - b.sandbox.LogsSubscriptionAPI = logsSubscriptionAPI - return b -} - -func (b *SandboxBuilder) SetLogsEgressAPI(logsEgressAPI telemetry.LogsEgressAPI) *SandboxBuilder { - runtimeStdoutWriter, runtimeStderrWriter, err := logsEgressAPI.GetRuntimeSockets() - - if err != nil { - log.WithError(err).Fatal("failed to get the Runtime sockets from the logs egress API") - } - - b.sandbox.LogsEgressAPI = logsEgressAPI - b.sandbox.RuntimeStdoutWriter = runtimeStdoutWriter - b.sandbox.RuntimeStderrWriter = runtimeStderrWriter - return b -} - -func (b *SandboxBuilder) SetHandler(handler string) *SandboxBuilder { - b.sandbox.Handler = handler - return b -} - -func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc context.CancelFunc) *SandboxBuilder { - b.shutdownFuncs = append(b.shutdownFuncs, shutdownFunc) - return b -} - -func (b *SandboxBuilder) setupLoggingWithDebugLogs() { - // Compose debug log writer with all log sinks. Debug log writer w - // will not write logs when disabled by invoke parameter - b.sandbox.DebugTailLogger = logging.NewTailLogWriter(b.debugTailLogWriter) - b.sandbox.PlatformLogger = logging.NewPlatformLogger(b.platformLogWriter, b.sandbox.DebugTailLogger) - b.sandbox.RuntimeStdoutWriter = io.MultiWriter(b.sandbox.DebugTailLogger, b.sandbox.RuntimeStdoutWriter) - b.sandbox.RuntimeStderrWriter = io.MultiWriter(b.sandbox.DebugTailLogger, b.sandbox.RuntimeStderrWriter) -} - -func (b *SandboxBuilder) Create() { - if len(b.sandbox.Handler) > 0 { - b.sandbox.Environment.SetHandler(b.sandbox.Handler) - } - - if !b.useCustomInteropServer { - b.sandbox.InteropServer = b.defaultInteropServer - } - - b.setupLoggingWithDebugLogs() - - go signalHandler(b.shutdownFuncs) - - rapid.Start(b.sandbox) -} - -func (b *SandboxBuilder) Init(i *interop.Init, timeoutMs int64) { - b.sandbox.InteropServer.Init(&interop.Start{ - Handler: i.Handler, - CorrelationID: i.CorrelationID, - AwsKey: i.AwsKey, - AwsSecret: i.AwsSecret, - AwsSession: i.AwsSession, - XRayDaemonAddress: i.XRayDaemonAddress, - FunctionName: i.FunctionName, - FunctionVersion: i.FunctionVersion, - CustomerEnvironmentVariables: i.CustomerEnvironmentVariables, - }, timeoutMs) -} - -func (b *SandboxBuilder) Invoke(w http.ResponseWriter, i *interop.Invoke) error { - return b.sandbox.InteropServer.Invoke(w, i) -} - -func (b *SandboxBuilder) InteropServer() InteropServer { - return b.defaultInteropServer -} - -// SetLogLevel sets the log level for internal logging. Needs to be called very -// early during startup to configure logs emitted during initialization -func SetLogLevel(logLevel string) { - level, err := log.ParseLevel(logLevel) - if err != nil { - log.WithError(err).Fatal("Failed to set log level. Valid log levels are:", log.AllLevels) - } - - log.SetLevel(level) - log.SetFormatter(&logging.InternalFormatter{}) -} - -func SetInternalLogOutput(w io.Writer) { - logging.SetOutput(w) -} - -// Trap SIGINT and SIGTERM signals and call shutdown function -func signalHandler(shutdownFuncs []context.CancelFunc) { - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) - sigReceived := <-sig - log.WithField("signal", sigReceived.String()).Info("Received signal") - for _, shutdownFunc := range shutdownFuncs { - shutdownFunc() - } -} diff --git a/lambda/rapidcore/sandbox_api.go b/lambda/rapidcore/sandbox_api.go new file mode 100644 index 0000000..0c7052e --- /dev/null +++ b/lambda/rapidcore/sandbox_api.go @@ -0,0 +1,147 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "go.amzn.com/lambda/interop" +) + +// SandboxContext and other structs form the implementation of the SandboxAPI +// interface defined in interop/sandbox_model.go, using the implementation of +// Init, Invoke and Reset handlers in rapid/sandbox.go +type SandboxContext struct { + rapidCtx interop.RapidContext + handler string + runtimeAPIAddress string + + InvokeReceivedTime int64 + InvokeResponseMetrics *interop.InvokeResponseMetrics +} + +type initContext struct { + initSuccessChan chan interop.InitSuccess + initFailureChan chan interop.InitFailure + rapidCtx interop.RapidContext + sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke +} + +type invokeContext struct { + rapidCtx interop.RapidContext + invokeRequestChan chan *interop.Invoke + invokeSuccessChan chan interop.InvokeSuccess + invokeFailureChan chan interop.InvokeFailure +} + +// Validate interface compliance +var _ interop.SandboxContext = (*SandboxContext)(nil) +var _ interop.InitContext = (*initContext)(nil) +var _ interop.InvokeContext = (*invokeContext)(nil) + +func (s SandboxContext) Init(init *interop.Init, timeoutMs int64) (interop.InitStarted, interop.InitContext) { + initStartedResponseChan := make(chan interop.InitStarted) + initSuccessResponseChan := make(chan interop.InitSuccess) + initFailureResponseChan := make(chan interop.InitFailure) + + if len(s.handler) > 0 { + init.EnvironmentVariables.SetHandler(s.handler) + } + + init.EnvironmentVariables.StoreRuntimeAPIEnvironmentVariable(s.runtimeAPIAddress) + + go s.rapidCtx.HandleInit(init, initStartedResponseChan, initSuccessResponseChan, initFailureResponseChan) + initStarted := <-initStartedResponseChan + + sbMetadata := interop.SandboxInfoFromInit{ + EnvironmentVariables: init.EnvironmentVariables, + SandboxType: init.SandboxType, + RuntimeBootstrap: init.Bootstrap, + } + return initStarted, newInitContext(s.rapidCtx, sbMetadata, initSuccessResponseChan, initFailureResponseChan) +} + +func (s SandboxContext) Reset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { + defer s.rapidCtx.Clear() + return s.rapidCtx.HandleReset(reset, s.InvokeReceivedTime, s.InvokeResponseMetrics) +} + +func (s SandboxContext) Shutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { + return s.rapidCtx.HandleShutdown(shutdown) +} + +func (s SandboxContext) Restore(restore *interop.Restore) error { + return s.rapidCtx.HandleRestore(restore) +} + +func (s *SandboxContext) SetInvokeReceivedTime(invokeReceivedTime int64) { + s.InvokeReceivedTime = invokeReceivedTime +} + +func (s *SandboxContext) SetInvokeResponseMetrics(metrics *interop.InvokeResponseMetrics) { + s.InvokeResponseMetrics = metrics +} + +func newInitContext(r interop.RapidContext, sbMetadata interop.SandboxInfoFromInit, + initSuccessChan chan interop.InitSuccess, initFailureChan chan interop.InitFailure) initContext { + return initContext{ + initSuccessChan: initSuccessChan, + initFailureChan: initFailureChan, + rapidCtx: r, + sbInfoFromInit: sbMetadata, + } +} + +func (i initContext) Wait() (interop.InitSuccess, *interop.InitFailure) { + select { + case initSuccess, isOpen := <-i.initSuccessChan: + if !isOpen { + // If init has already suceeded, we return quickly + return interop.InitSuccess{}, nil + } + return initSuccess, nil + case initFailure, isOpen := <-i.initFailureChan: + if !isOpen { + // If init has already failed, we return quickly for init to be suppressed + return interop.InitSuccess{}, &initFailure + } + return interop.InitSuccess{}, &initFailure + } +} + +func (i initContext) Reserve() interop.InvokeContext { + + invokeRequestChan := make(chan *interop.Invoke) + invokeSuccessChan := make(chan interop.InvokeSuccess) + invokeFailureChan := make(chan interop.InvokeFailure) + + go func() { + invoke := <-invokeRequestChan + // For suppressed inits, invoke needs the runtime and agent env vars + invokeSuccess, invokeFailure := i.rapidCtx.HandleInvoke(invoke, i.sbInfoFromInit) + if invokeFailure != nil { + invokeFailureChan <- *invokeFailure + } else { + invokeSuccessChan <- invokeSuccess + } + }() + + return invokeContext{ + rapidCtx: i.rapidCtx, + invokeRequestChan: invokeRequestChan, + invokeSuccessChan: invokeSuccessChan, + invokeFailureChan: invokeFailureChan, + } +} + +func (invCtx invokeContext) SendRequest(i *interop.Invoke) { + invCtx.invokeRequestChan <- i +} + +func (invCtx invokeContext) Wait() (interop.InvokeSuccess, *interop.InvokeFailure) { + select { + case invokeSuccess := <-invCtx.invokeSuccessChan: + return invokeSuccess, nil + case invokeFailure := <-invCtx.invokeFailureChan: + return interop.InvokeSuccess{}, &invokeFailure + } +} diff --git a/lambda/rapidcore/sandbox_builder.go b/lambda/rapidcore/sandbox_builder.go new file mode 100644 index 0000000..ce016a0 --- /dev/null +++ b/lambda/rapidcore/sandbox_builder.go @@ -0,0 +1,217 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "context" + "io" + "net" + "os" + "os/signal" + "strconv" + "syscall" + + "go.amzn.com/lambda/extensions" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/logging" + "go.amzn.com/lambda/rapid" + "go.amzn.com/lambda/supervisor" + supvmodel "go.amzn.com/lambda/supervisor/model" + "go.amzn.com/lambda/telemetry" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultSigtermResetTimeoutMs = int64(2000) +) + +type SandboxBuilder struct { + sandbox *rapid.Sandbox + sandboxContext interop.SandboxContext + lambdaInvokeAPI LambdaInvokeAPI + defaultInteropServer *Server + useCustomInteropServer bool + shutdownFuncs []context.CancelFunc + handler string +} + +type logSink int + +const ( + RuntimeLogSink logSink = iota + ExtensionLogSink +) + +func NewSandboxBuilder() *SandboxBuilder { + defaultInteropServer := NewServer(context.Background()) + signalCtx, cancelSignalCtx := context.WithCancel(context.Background()) + + b := &SandboxBuilder{ + sandbox: &rapid.Sandbox{ + PreLoadTimeNs: 0, // TODO + StandaloneMode: true, + LogsEgressAPI: &telemetry.NoOpLogsEgressAPI{}, + EnableTelemetryAPI: false, + Tracer: telemetry.NewNoOpTracer(), + SignalCtx: signalCtx, + EventsAPI: &telemetry.NoOpEventsAPI{}, + InitCachingEnabled: false, + Supervisor: supervisor.NewLocalSupervisor(), + RuntimeAPIHost: "127.0.0.1", + RuntimeAPIPort: 9001, + }, + defaultInteropServer: defaultInteropServer, + shutdownFuncs: []context.CancelFunc{}, + lambdaInvokeAPI: NewEmulatorAPI(defaultInteropServer), + } + + b.AddShutdownFunc(context.CancelFunc(func() { + log.Info("Shutting down...") + defaultInteropServer.Reset("SandboxTerminated", defaultSigtermResetTimeoutMs) + cancelSignalCtx() + })) + + return b +} + +func (b *SandboxBuilder) SetSupervisor(supervisor supvmodel.Supervisor) *SandboxBuilder { + b.sandbox.Supervisor = supervisor + return b +} + +func (b *SandboxBuilder) SetRuntimeAPIAddress(runtimeAPIAddress string) *SandboxBuilder { + host, port, err := net.SplitHostPort(runtimeAPIAddress) + if err != nil { + log.WithError(err).Warnf("Failed to parse RuntimeApiAddress: %s:", runtimeAPIAddress) + return b + } + + portInt, err := strconv.Atoi(port) + if err != nil { + log.WithError(err).Warnf("Failed to parse RuntimeApiPort: %s:", port) + return b + } + + b.sandbox.RuntimeAPIHost = host + b.sandbox.RuntimeAPIPort = portInt + return b +} + +func (b *SandboxBuilder) SetInteropServer(interopServer interop.Server) *SandboxBuilder { + b.sandbox.InteropServer = interopServer + b.useCustomInteropServer = true + return b +} + +func (b *SandboxBuilder) SetEventsAPI(eventsAPI telemetry.EventsAPI) *SandboxBuilder { + b.sandbox.EventsAPI = eventsAPI + return b +} + +func (b *SandboxBuilder) SetTracer(tracer telemetry.Tracer) *SandboxBuilder { + b.sandbox.Tracer = tracer + return b +} + +func (b *SandboxBuilder) DisableStandaloneMode() *SandboxBuilder { + b.sandbox.StandaloneMode = false + return b +} + +func (b *SandboxBuilder) SetExtensionsFlag(extensionsEnabled bool) *SandboxBuilder { + if extensionsEnabled { + extensions.Enable() + } else { + extensions.Disable() + } + return b +} + +func (b *SandboxBuilder) SetInitCachingFlag(initCachingEnabled bool) *SandboxBuilder { + b.sandbox.InitCachingEnabled = initCachingEnabled + return b +} + +func (b *SandboxBuilder) SetPreLoadTimeNs(preLoadTimeNs int64) *SandboxBuilder { + b.sandbox.PreLoadTimeNs = preLoadTimeNs + return b +} + +func (b *SandboxBuilder) SetTelemetrySubscription(logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI) *SandboxBuilder { + b.sandbox.EnableTelemetryAPI = true + b.sandbox.LogsSubscriptionAPI = logsSubscriptionAPI + b.sandbox.TelemetrySubscriptionAPI = telemetrySubscriptionAPI + return b +} + +func (b *SandboxBuilder) SetLogsEgressAPI(logsEgressAPI telemetry.StdLogsEgressAPI) *SandboxBuilder { + b.sandbox.LogsEgressAPI = logsEgressAPI + return b +} + +func (b *SandboxBuilder) SetHandler(handler string) *SandboxBuilder { + b.handler = handler + return b +} + +func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc context.CancelFunc) *SandboxBuilder { + b.shutdownFuncs = append(b.shutdownFuncs, shutdownFunc) + return b +} + +func (b *SandboxBuilder) Create() (interop.SandboxContext, interop.InternalStateGetter) { + if !b.useCustomInteropServer { + b.sandbox.InteropServer = b.defaultInteropServer + } + + go signalHandler(b.shutdownFuncs) + + rapidCtx, internalStateFn, runtimeAPIAddr := rapid.Start(b.sandbox) + + b.sandboxContext = &SandboxContext{ + rapidCtx: rapidCtx, + handler: b.handler, + runtimeAPIAddress: runtimeAPIAddr, + InvokeReceivedTime: int64(0), + InvokeResponseMetrics: nil, + } + + return b.sandboxContext, internalStateFn +} + +func (b *SandboxBuilder) DefaultInteropServer() *Server { + return b.defaultInteropServer +} + +func (b *SandboxBuilder) LambdaInvokeAPI() LambdaInvokeAPI { + return b.lambdaInvokeAPI +} + +// SetLogLevel sets the log level for internal logging. Needs to be called very +// early during startup to configure logs emitted during initialization +func SetLogLevel(logLevel string) { + level, err := log.ParseLevel(logLevel) + if err != nil { + log.WithError(err).Fatal("Failed to set log level. Valid log levels are:", log.AllLevels) + } + + log.SetLevel(level) + log.SetFormatter(&logging.InternalFormatter{}) +} + +func SetInternalLogOutput(w io.Writer) { + logging.SetOutput(w) +} + +// Trap SIGINT and SIGTERM signals and call shutdown function +func signalHandler(shutdownFuncs []context.CancelFunc) { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + sigReceived := <-sig + log.WithField("signal", sigReceived.String()).Info("Received signal") + for _, shutdownFunc := range shutdownFuncs { + shutdownFunc() + } +} diff --git a/lambda/rapidcore/sandbox_emulator_api.go b/lambda/rapidcore/sandbox_emulator_api.go new file mode 100644 index 0000000..6737631 --- /dev/null +++ b/lambda/rapidcore/sandbox_emulator_api.go @@ -0,0 +1,52 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "go.amzn.com/lambda/interop" + + "net/http" +) + +// LambdaInvokeAPI are the methods used by the Runtime Interface Emulator +type LambdaInvokeAPI interface { + Init(i *interop.Init, invokeTimeoutMs int64) + Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error +} + +// EmulatorAPI wraps the standalone interop server to provide a convenient interface +// for Rapid Standalone +type EmulatorAPI struct { + server *Server +} + +// Validate interface compliance +var _ LambdaInvokeAPI = (*EmulatorAPI)(nil) + +func NewEmulatorAPI(s *Server) *EmulatorAPI { + return &EmulatorAPI{s} +} + +// Init method is only used by the Runtime interface emulator +func (l *EmulatorAPI) Init(i *interop.Init, timeoutMs int64) { + l.server.Init(&interop.Init{ + Handler: i.Handler, + AwsKey: i.AwsKey, + AwsSecret: i.AwsSecret, + AwsSession: i.AwsSession, + XRayDaemonAddress: i.XRayDaemonAddress, + FunctionName: i.FunctionName, + FunctionVersion: i.FunctionVersion, + CustomerEnvironmentVariables: i.CustomerEnvironmentVariables, + RuntimeInfo: i.RuntimeInfo, + SandboxType: i.SandboxType, + Bootstrap: i.Bootstrap, + EnvironmentVariables: i.EnvironmentVariables, + }, timeoutMs) +} + +// Invoke method is only used by the Runtime interface emulator +func (l *EmulatorAPI) Invoke(w http.ResponseWriter, i *interop.Invoke) error { + return l.server.Invoke(w, i) +} diff --git a/lambda/rapidcore/server.go b/lambda/rapidcore/server.go index e3e01b6..e652130 100644 --- a/lambda/rapidcore/server.go +++ b/lambda/rapidcore/server.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "math" "net/http" "sync" @@ -17,6 +16,7 @@ import ( "go.amzn.com/lambda/core/directinvoke" "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" @@ -33,6 +33,12 @@ const ( resetDefaultTimeoutMs = 2000 ) +const ( + contentTypeHeader = "Content-Type" + errorTypeHeader = "Error-Type" + functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" +) + type rapidPhase int const ( @@ -46,7 +52,6 @@ type runtimeState int const ( runtimeNotStarted = iota - runtimeInitStarted runtimeInitError runtimeInitComplete runtimeInitFailed @@ -76,14 +81,11 @@ type InvokeContext struct { type Server struct { InternalStateGetter interop.InternalStateGetter - invokeChanOut chan *interop.Invoke - startChanOut chan *interop.Start - resetChanOut chan *interop.Reset - shutdownChanOut chan *interop.Shutdown - errorChanOut chan error + initChanOut chan *interop.Init + interruptedResponseChan chan *interop.Reset - sendRunningChan chan *interop.Running - sendResponseChan chan struct{} + sendRunningChan chan *interop.InitStarted + sendResponseChan chan *interop.InvokeResponseMetrics doneChan chan *interop.Done InitDoneChan chan DoneWithState @@ -100,12 +102,17 @@ type Server struct { rapidPhase rapidPhase runtimeState runtimeState -} -func (s *Server) StartAcceptingDirectInvokes() error { - return nil + sandboxContext interop.SandboxContext + initContext interop.InitContext + invoker interop.InvokeContext + initFailures chan interop.InitFailure + cachedInitErrorResponse *interop.ErrorResponse } +// Validate interface compliance +var _ interop.Server = (*Server)(nil) + func (s *Server) setRapidPhase(phase rapidPhase) { s.mutex.Lock() defer s.mutex.Unlock() @@ -185,6 +192,11 @@ func (s *Server) setNewInvokeContext(invokeID string, traceID, lambdaSegmentID s return resp, nil } +type ReserveResponse struct { + Token interop.Token + InternalState *statejson.InternalStateDescription +} + // Reserve allocates invoke context func (s *Server) Reserve(id string, traceID, lambdaSegmentID string) (*ReserveResponse, error) { invokeID := uuid.New().String() @@ -196,37 +208,28 @@ func (s *Server) Reserve(id string, traceID, lambdaSegmentID string) (*ReserveRe return nil, err } - resp.InternalState, err = s.waitInit() + // The two errors reserve returns in standalone mode are INIT timeout + // and INIT failure (two types of failure: runtime exit, /init/error). Both require suppressed + // initialization, so we succeed the reservation. + invCtx := s.initContext.Reserve() + s.invoker = invCtx + resp.InternalState, err = s.InternalState() + return resp, err } -func (s *Server) waitInit() (*statejson.InternalStateDescription, error) { - for { - select { - - case doneWithState, chanOpen := <-s.InitDoneChan: - if !chanOpen { - // init only happens once - return nil, ErrInitAlreadyDone - } - - close(s.InitDoneChan) // this was first call to reserve - - if s.getRuntimeState() == runtimeInitFailed { - return &doneWithState.State, ErrInitError - } - - if len(doneWithState.ErrorType) > 0 { - log.Errorf("INIT DONE failed: %s", doneWithState.ErrorType) - return &doneWithState.State, ErrInitDoneFailed - } - - return &doneWithState.State, nil - - case <-s.reservationContext.Done(): - return nil, ErrReserveReservationDone - } +func (s *Server) awaitInitCompletion() { + initSuccess, initFailure := s.initContext.Wait() + if initFailure != nil { + // In standalone, we don't have to block rapid start() goroutine until init failure is consumed + // because there is no channel back to the invoker until an invoke arrives via a Reserve() + initFailure.Ack <- struct{}{} + s.initFailures <- *initFailure + } else { + initSuccess.Ack <- struct{}{} } + // always closing the channel makes this method idempotent + close(s.initFailures) } func (s *Server) setReplyStream(w http.ResponseWriter, direct bool) (string, error) { @@ -263,6 +266,8 @@ func (s *Server) Release() error { s.reservationCancel() } + s.sandboxContext.SetInvokeReceivedTime(0) + s.sandboxContext.SetInvokeResponseMetrics(nil) s.invokeCtx = nil return nil } @@ -279,37 +284,18 @@ func (s *Server) GetCurrentInvokeID() string { return s.invokeCtx.Token.InvokeID } +// SetSandboxContext is used to set the sandbox context after intiialization of interop server. +// After refactoring all messages, this needs to be removed and made an struct parameter on initialization. +func (s *Server) SetSandboxContext(sbCtx interop.SandboxContext) { + s.sandboxContext = sbCtx +} + // SetInternalStateGetter is used to set callback which returnes internal state for /test/internalState request func (s *Server) SetInternalStateGetter(cb interop.InternalStateGetter) { s.InternalStateGetter = cb } -// StartChan returns Start emitter -func (s *Server) StartChan() <-chan *interop.Start { - return s.startChanOut -} - -// InvokeChan returns Invoke emitter -func (s *Server) InvokeChan() <-chan *interop.Invoke { - return s.invokeChanOut -} - -// ResetChan returns Reset emitter -func (s *Server) ResetChan() <-chan *interop.Reset { - return s.resetChanOut -} - -// ShutdownChan returns Shutdown emitter -func (s *Server) ShutdownChan() <-chan *interop.Shutdown { - return s.shutdownChanOut -} - -// InvalidMessageChan emits errors if there was something we could not parse -func (s *Server) TransportErrorChan() <-chan error { - return s.errorChanOut -} - -func (s *Server) sendResponseUnsafe(invokeID string, contentType string, status int, payload io.Reader) error { +func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[string]string, status int, payload io.Reader, trailers http.Header, request *interop.CancellableRequest, runtimeCalledResponse bool) error { if s.invokeCtx == nil || invokeID != s.invokeCtx.Token.InvokeID { return interop.ErrInvalidInvokeID } @@ -322,26 +308,15 @@ func (s *Server) sendResponseUnsafe(invokeID string, contentType string, status return fmt.Errorf("ReplyStream not available") } - // TODO: earlier, status was set to 500 if runtime called /invocation/error. However, the integration - // tests do not differentiate between /invocation/error and /invocation/response, but they check the error type: - // To identify user-errors, we should also allowlist custom errortypes and propagate them via headers. - - // s.invokeCtx.ReplyStream.WriteHeader(status) - var reportedErr error if s.invokeCtx.Direct { - if err := directinvoke.SendDirectInvokeResponse(map[string]string{"Content-Type": contentType}, payload, s.invokeCtx.ReplyStream); err != nil { + if err := directinvoke.SendDirectInvokeResponse(additionalHeaders, payload, trailers, s.invokeCtx.ReplyStream, s.interruptedResponseChan, s.sendResponseChan, request, runtimeCalledResponse); err != nil { // TODO: Do we need to drain the reader in case of a large payload and connection reuse? log.Errorf("Failed to write response to %s: %s", invokeID, err) - flusher, ok := s.invokeCtx.ReplyStream.(http.Flusher) - if !ok { - log.Error("Failed to flush response") - } - flusher.Flush() reportedErr = err } } else { - data, err := ioutil.ReadAll(payload) + data, err := io.ReadAll(payload) if err != nil { return fmt.Errorf("Failed to read response on %s: %s", invokeID, err) } @@ -352,73 +327,103 @@ func (s *Server) sendResponseUnsafe(invokeID string, contentType string, status } } - s.invokeCtx.ReplyStream.Header().Add("Content-Type", contentType) - if _, err := s.invokeCtx.ReplyStream.Write(data); err != nil { + startReadingResponseMonoTimeMs := metering.Monotime() + s.invokeCtx.ReplyStream.Header().Add(contentTypeHeader, additionalHeaders[contentTypeHeader]) + written, err := s.invokeCtx.ReplyStream.Write(data) + if err != nil { return fmt.Errorf("Failed to write response to %s: %s", invokeID, err) } + + s.sendResponseChan <- &interop.InvokeResponseMetrics{ + ProducedBytes: int64(written), + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: metering.Monotime(), + TimeShapedNs: int64(-1), + OutboundThroughputBps: int64(-1), + // FIXME: + // The runtime tells whether the function response mode is streaming or not. + // Ideally, we would want to use that value here. Since I'm just rebasing, I will leave + // as-is, but we should use that instead of relying on our memory to set this here + // because we "know" it's a streaming code path. + FunctionResponseMode: interop.FunctionResponseModeBuffered, + RuntimeCalledResponse: runtimeCalledResponse, + } } - s.sendResponseChan <- struct{}{} s.invokeCtx.ReplySent = true s.invokeCtx.Direct = false return reportedErr } -func (s *Server) SendResponse(invokeID string, contentType string, reader io.Reader) error { +func (s *Server) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { s.setRuntimeState(runtimeInvokeResponseSent) s.mutex.Lock() defer s.mutex.Unlock() - return s.sendResponseUnsafe(invokeID, contentType, http.StatusOK, reader) -} - -func (s *Server) CommitResponse() error { return nil } - -func (s *Server) SendRunning(run *interop.Running) error { - s.setRuntimeState(runtimeInitStarted) - s.sendRunningChan <- run - return nil + runtimeCalledResponse := true + return s.sendResponseUnsafe(invokeID, headers, http.StatusOK, reader, trailers, request, runtimeCalledResponse) } -func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorResponse) error { - switch s.getRapidPhase() { - case phaseInitializing: - s.setRuntimeState(runtimeInitError) - return nil - case phaseInvoking: - // This branch can also occur during a suppressed init error, which is reported as invoke error - s.setRuntimeState(runtimeInvokeError) - s.mutex.Lock() - defer s.mutex.Unlock() - return s.sendResponseUnsafe(invokeID, resp.ContentType, http.StatusInternalServerError, bytes.NewReader(resp.Payload)) - default: - panic("received unexpected error response outside invoke or init phases") +func (s *Server) SendInitErrorResponse(invokeID string, resp *interop.ErrorResponse) error { + log.Debugf("Sending Init Error Response: %s", resp.ErrorType) + if s.getRapidPhase() == phaseInvoking { + // This branch occurs during suppressed init + return s.SendErrorResponse(invokeID, resp) } -} -func (s *Server) SendDone(done *interop.Done) error { - s.doneChan <- done + // Handle an /init/error outside of the invoke phase + s.setCachedInitErrorResponse(resp) + s.setRuntimeState(runtimeInitError) return nil } -func (s *Server) SendDoneFail(doneFail *interop.DoneFail) error { - s.doneChan <- &interop.Done{ - ErrorType: doneFail.ErrorType, - CorrelationID: doneFail.CorrelationID, // filipovi: correlationID is required to dispatch message into correct channel - Meta: doneFail.Meta, +func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorResponse) error { + log.Debugf("Sending Error Response: %s", resp.ErrorType) + s.setRuntimeState(runtimeInvokeError) + s.mutex.Lock() + defer s.mutex.Unlock() + additionalHeaders := map[string]string{contentTypeHeader: resp.ContentType, errorTypeHeader: resp.ErrorType} + if functionResponseMode := resp.FunctionResponseMode; functionResponseMode != "" { + additionalHeaders[functionResponseModeHeader] = functionResponseMode } - return nil + runtimeCalledResponse := false // we are sending an error here, so runtime called /error or crashed/timeout + return s.sendResponseUnsafe(invokeID, additionalHeaders, http.StatusInternalServerError, bytes.NewReader(resp.Payload), nil, nil, runtimeCalledResponse) } func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) { // pass reset to rapid - s.resetChanOut <- &interop.Reset{ - Reason: reason, - DeadlineNs: deadlineNsFromTimeoutMs(timeoutMs), - CorrelationID: "resetCorrelationID", + reset := &interop.Reset{ + Reason: reason, + DeadlineNs: deadlineNsFromTimeoutMs(timeoutMs), } + go func() { + select { + case s.interruptedResponseChan <- reset: + <-s.interruptedResponseChan // wait for response streaming metrics being added to reset struct + s.sandboxContext.SetInvokeResponseMetrics(reset.InvokeResponseMetrics) + default: + } + + resetSuccess, resetFailure := s.sandboxContext.Reset(reset) + s.Clear() // clear server state to prepare for new invokes + s.setRapidPhase(phaseIdle) + s.setRuntimeState(runtimeNotStarted) + + var meta interop.DoneMetadata + if reset.InvokeResponseMetrics != nil { + meta.RuntimeTimeThrottledMs = reset.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) + meta.RuntimeProducedBytes = reset.InvokeResponseMetrics.ProducedBytes + meta.RuntimeOutboundThroughputBps = reset.InvokeResponseMetrics.OutboundThroughputBps + } + + if resetFailure != nil { + meta.ExtensionsResetMs = resetFailure.ExtensionsResetMs + s.ResetDoneChan <- &interop.Done{ErrorType: resetFailure.ErrorType, Meta: meta} + } else { + meta.ExtensionsResetMs = resetSuccess.ExtensionsResetMs + s.ResetDoneChan <- &interop.Done{ErrorType: resetSuccess.ErrorType, Meta: meta} + } + }() - // TODO do not block on reset, instead consume ResetDoneChan in waitForRelease handler, - // this will get us more aligned on async reset notification handling. done := <-s.ResetDoneChan s.Release() @@ -431,14 +436,11 @@ func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescript func NewServer(ctx context.Context) *Server { s := &Server{ - startChanOut: make(chan *interop.Start), - invokeChanOut: make(chan *interop.Invoke), - errorChanOut: make(chan error), - resetChanOut: make(chan *interop.Reset), - shutdownChanOut: make(chan *interop.Shutdown), - - sendRunningChan: make(chan *interop.Running), - sendResponseChan: make(chan struct{}), + initChanOut: make(chan *interop.Init), + interruptedResponseChan: make(chan *interop.Reset), + + sendRunningChan: make(chan *interop.InitStarted), + sendResponseChan: make(chan *interop.InvokeResponseMetrics), doneChan: make(chan *interop.Done), // These two channels are buffered, because they are depleted asynchronously (by reserve and waitUntilRelease) and we don't want to block in SendDone until they are called @@ -449,47 +451,9 @@ func NewServer(ctx context.Context) *Server { ShutdownDoneChan: make(chan *interop.Done), } - go s.dispatchDone() - return s } -func (s *Server) setInitDoneRuntimeState(done *interop.Done) { - if len(done.ErrorType) > 0 { - s.setRuntimeState(runtimeInitFailed) // donefail - } else { - s.setRuntimeState(runtimeInitComplete) // done - } -} - -// Note, the dispatch loop below has potential to block, when -// channel is not drained. E.g. if test assumes sandbox init -// completion before dispatching reset, then reset will block -// until init channel is drained. -func (s *Server) dispatchDone() { - for { - done := <-s.doneChan - log.Debug("Dispatching DONE:", done.CorrelationID) - internalState := s.InternalStateGetter() - s.setRapidPhase(phaseIdle) - if done.CorrelationID == "initCorrelationID" { - s.setInitDoneRuntimeState(done) - s.InitDoneChan <- DoneWithState{Done: done, State: internalState} - } else if done.CorrelationID == "invokeCorrelationID" { - s.setRuntimeState(runtimeInvokeComplete) - s.InvokeDoneChan <- DoneWithState{Done: done, State: internalState} - } else if done.CorrelationID == "resetCorrelationID" { - s.setRuntimeState(runtimeNotStarted) - s.ResetDoneChan <- done - } else if done.CorrelationID == "shutdownCorrelationID" { - s.setRuntimeState(runtimeNotStarted) - s.ShutdownDoneChan <- done - } else { - panic("Received DONE without correlation ID") - } - } -} - func drainChannel(c chan DoneWithState) { for { select { @@ -509,10 +473,6 @@ func (s *Server) Clear() { s.Release() } -func (s *Server) IsResponseSent() bool { - panic("unexpected call to unimplemented method in rapidcore: IsResponseSent()") -} - func (s *Server) SendRuntimeReady() error { // only called when extensions are enabled s.setRuntimeState(runtimeReady) @@ -524,16 +484,34 @@ func deadlineNsFromTimeoutMs(timeoutMs int64) int64 { return mono + timeoutMs*1000*1000 } -func (s *Server) Init(i *interop.Start, invokeTimeoutMs int64) { - s.SetInvokeTimeout(time.Duration(invokeTimeoutMs) * time.Millisecond) +func (s *Server) setInitFailuresChan() { + s.mutex.Lock() + defer s.mutex.Unlock() + s.initFailures = make(chan interop.InitFailure) +} + +func (s *Server) getInitFailuresChan() chan interop.InitFailure { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.initFailures +} - s.startChanOut <- i +func (s *Server) Init(i *interop.Init, invokeTimeoutMs int64) error { + s.SetInvokeTimeout(time.Duration(invokeTimeoutMs) * time.Millisecond) s.setRapidPhase(phaseInitializing) - <-s.sendRunningChan - log.Debug("Received RUNNING") + s.setInitFailuresChan() + initStarted, initCtx := s.sandboxContext.Init(i, invokeTimeoutMs) + initStarted.Ack <- struct{}{} + + s.initContext = initCtx + go s.awaitInitCompletion() + + log.Debugf("Received RUNNING: %v", initStarted) + return nil } func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error { + s.sandboxContext.SetInvokeReceivedTime(i.InvokeReceivedTime) invokeID, err := s.setReplyStream(w, direct) if err != nil { return err @@ -544,16 +522,55 @@ func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct boo i.ID = invokeID select { - case s.invokeChanOut <- i: - break case <-s.sendResponseChan: // we didn't pass invoke to rapid yet, but rapid already has written some response // It can happend if runtime/agent crashed even before we passed invoke to it return ErrInvokeResponseAlreadyWritten + default: } + go func() { + if s.invoker == nil { + // Reset occurred, do not send invoke request + s.InvokeDoneChan <- DoneWithState{State: s.InternalStateGetter()} + s.setRuntimeState(runtimeInvokeComplete) + return + } + s.invoker.SendRequest(i) + invokeSuccess, invokeFailure := s.invoker.Wait() + if invokeFailure != nil { + if invokeFailure.ResetReceived { + return + } + + // Rapid constructs a response body itself when invoke fails, with error type. + // These are on the handleInvokeError path, may occur during timeout resets, + // failure reset (proc exit). It is expected to be non-nil on all invoke failures. + if invokeFailure.DefaultErrorResponse == nil { + log.Panicf("default error response was nil for invoke failure, %v", invokeFailure) + } + + if cachedInitError := s.getCachedInitErrorResponse(); cachedInitError != nil { + // /init/error was called + s.trySendDefaultErrorResponse(cachedInitError) + } else { + // sent only if /error and /response not called + s.trySendDefaultErrorResponse(invokeFailure.DefaultErrorResponse) + } + doneFail := doneFailFromInvokeFailure(invokeFailure) + s.InvokeDoneChan <- DoneWithState{ + Done: &interop.Done{ErrorType: doneFail.ErrorType, Meta: doneFail.Meta}, + State: s.InternalStateGetter(), + } + } else { + done := doneFromInvokeSuccess(invokeSuccess) + s.InvokeDoneChan <- DoneWithState{Done: done, State: s.InternalStateGetter()} + } + }() + select { - case <-s.sendResponseChan: + case i.InvokeResponseMetrics = <-s.sendResponseChan: + s.sandboxContext.SetInvokeResponseMetrics(i.InvokeResponseMetrics) break case <-s.reservationContext.Done(): return ErrInvokeReservationDone @@ -562,6 +579,26 @@ func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct boo return nil } +func (s *Server) setCachedInitErrorResponse(errResp *interop.ErrorResponse) { + s.mutex.Lock() + defer s.mutex.Unlock() + s.cachedInitErrorResponse = errResp +} + +func (s *Server) getCachedInitErrorResponse() *interop.ErrorResponse { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.cachedInitErrorResponse +} + +func (s *Server) trySendDefaultErrorResponse(resp *interop.ErrorResponse) { + if err := s.SendErrorResponse(s.GetCurrentInvokeID(), resp); err != nil { + if err != interop.ErrResponseSent { + log.Panicf("Failed to send default error response: %s", err) + } + } +} + func (s *Server) CurrentToken() *interop.Token { s.mutex.Lock() defer s.mutex.Unlock() @@ -582,77 +619,158 @@ func (s *Server) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invo go func() { select { case <-time.After(s.GetInvokeTimeout()): + log.Debug("Invoke() timeout") timeoutChan <- ErrInvokeTimeout - s.Reset(autoresetReasonTimeout, resetDefaultTimeoutMs) case <-resetCtx.Done(): log.Debugf("execute finished, autoreset cancelled") } }() - reserveResp, err := s.Reserve(invoke.ID, "", "") - if err != nil { - switch err { - case ErrInitError: - // Simulate 'Suppressed Init' scenario - s.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) - reserveResp, err = s.Reserve("", "", "") - if err == ErrInitAlreadyDone { - break - } - return err - case ErrInitDoneFailed, ErrTerminated: - s.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) - return err - - case ErrInitAlreadyDone: - // This is a valid response (e.g. for 2nd and 3rd invokes) - // TODO: switch on ReserveResponse status instead of err, - // since these are valid values - if s.InternalStateGetter == nil { - responseWriter.Write([]byte("error: internal state callback not set")) - return ErrInternalServerError - } + initFailures := s.getInitFailuresChan() + if initFailures == nil { + return ErrInitNotStarted + } - default: - return err + releaseErrChan := make(chan error) + releaseSuccessChan := make(chan struct{}) + go func() { + // This thread can block in one of two method calls Reserve() & AwaitRelease(), + // corresponding to Init and Invoke phase. + // FastInvoke is intended to be 'async' response stream copying. + // When a timeout occurs, we send a 'Reset' with the timeout reason + // When a Reset is sent, the reset handler in rapid lib cancels existing flows, + // including init/invoke. This causes either initFailure/invokeFailure, and then + // the Reset is handled and processed. + // TODO: however, ideally Reserve() does not block on init, but FastInvoke does + // The logic would be almost identical, except that init failures could manifest + // through return values of FastInvoke and not Reserve() + + reserveResp, err := s.Reserve("", "", "") + if err != nil { + log.Infof("ReserveFailed: %s", err) } - } - invoke.DeadlineNs = fmt.Sprintf("%d", metering.Monotime()+reserveResp.Token.FunctionTimeout.Nanoseconds()) + invoke.DeadlineNs = fmt.Sprintf("%d", metering.Monotime()+reserveResp.Token.FunctionTimeout.Nanoseconds()) + go func() { + if initCompletionResp, err := s.awaitInitialized(); err != nil { + switch err { + case ErrInitResetReceived, ErrInitDoneFailed: + // For init failures, cache the response so they can be checked later + // We check if they have not already been set by a call to /init/error by runtime + if s.getCachedInitErrorResponse() == nil { + errType, errMsg := string(initCompletionResp.InitErrorType), initCompletionResp.InitErrorMessage.Error() + s.setCachedInitErrorResponse(&interop.ErrorResponse{ErrorType: errType, ErrorMessage: errMsg}) + } + } + } - invokeChan := make(chan error) - go func() { - if err := s.FastInvoke(responseWriter, invoke, false); err != nil { - invokeChan <- err + if err := s.FastInvoke(responseWriter, invoke, false); err != nil { + log.Debugf("FastInvoke() error: %s", err) + } + }() + + _, err = s.AwaitRelease() + if err != nil && err != ErrReleaseReservationDone { + log.Debugf("AwaitRelease() error: %s", err) + switch err { + case ErrReleaseReservationDone: // not an error, expected return value when Reset is called + if s.getCachedInitErrorResponse() != nil { + // For Init failures, AwaitRelease returns ErrReleaseReservationDone + // because the Reset calls Release & cancels the release context + // We rename the error to ErrInitDoneFailed + releaseErrChan <- ErrInitDoneFailed + } + case ErrInitDoneFailed, ErrInvokeDoneFailed: + // Reset when either init or invoke failrues occur, i.e. + // init/error, invocation/error, Runtime.ExitError, Extension.ExitError + s.Reset(autoresetReasonReleaseFail, resetDefaultTimeoutMs) + releaseErrChan <- err + default: + releaseErrChan <- err + } + return } - }() - releaseChan := make(chan error) - go func() { - _, err := s.AwaitRelease() - releaseChan <- err + releaseSuccessChan <- struct{}{} }() - // TODO: verify the order of channel receives. When timeouts happen, Reset() - // is called first, which also does Release() => this may signal a type - // Err<*>ReservationDone error to the non-timeout channels. This is currently - // handled by the http handler, which returns GatewayTimeout for reservation errors - // too. However, Timeouts should ideally be only represented by ErrInvokeTimeout. + var err error select { - case err = <-invokeChan: - case err = <-timeoutChan: - case err = <-releaseChan: - if err != nil { - s.Reset(autoresetReasonReleaseFail, resetDefaultTimeoutMs) + case timeoutErr := <-timeoutChan: + s.Reset(autoresetReasonTimeout, resetDefaultTimeoutMs) + select { + case releaseErr := <-releaseErrChan: // when AwaitRelease() has errors + log.Debugf("Invoke() release error on Execute() timeout: %s", releaseErr) + case <-releaseSuccessChan: // when AwaitRelease() finishes cleanly } + err = timeoutErr + case err = <-releaseErrChan: + log.Debug("Invoke() release error") + case <-releaseSuccessChan: + s.Release() + log.Debug("Invoke() success") } return err } +type initCompletionResponse struct { + InitErrorType fatalerror.ErrorType + InitErrorMessage error +} + +func (s *Server) awaitInitialized() (initCompletionResponse, error) { + initFailure, awaitingInitStatus := <-s.getInitFailuresChan() + resp := initCompletionResponse{} + + if initFailure.ResetReceived { + // Resets during Init are only received in standalone + // during an invoke timeout + s.setRuntimeState(runtimeInitFailed) + resp.InitErrorType = initFailure.ErrorType + resp.InitErrorMessage = initFailure.ErrorMessage + return resp, ErrInitResetReceived + } + + if awaitingInitStatus { + // channel not closed, received init failure + // Sandbox can be reserved even if init failed (due to function errors) + s.setRuntimeState(runtimeInitFailed) + resp.InitErrorType = initFailure.ErrorType + resp.InitErrorMessage = initFailure.ErrorMessage + return resp, ErrInitDoneFailed + } + + // not awaiting init status (channel closed) + return resp, nil +} + +// AwaitInitialized waits until init is complete. It must be idempotent, +// since it can be called twice when a caller wants to wait until init is complete +func (s *Server) AwaitInitialized() error { + if _, err := s.awaitInitialized(); err != nil { + if releaseErr := s.Release(); err != nil { + log.Infof("Error releasing after init failure %s: %s", err, releaseErr) + } + s.setRuntimeState(runtimeInitFailed) + return err + } + s.setRuntimeState(runtimeInitComplete) + return nil +} + func (s *Server) AwaitRelease() (*statejson.InternalStateDescription, error) { + defer func() { + s.setRapidPhase(phaseIdle) + s.setRuntimeState(runtimeInvokeComplete) + }() + select { case doneWithState := <-s.InvokeDoneChan: + if len(doneWithState.ErrorType) > 0 && string(doneWithState.ErrorType) == ErrInitDoneFailed.Error() { + return nil, ErrInitDoneFailed + } + if len(doneWithState.ErrorType) > 0 { log.Errorf("Invoke DONE failed: %s", doneWithState.ErrorType) return nil, ErrInvokeDoneFailed @@ -667,8 +785,13 @@ func (s *Server) AwaitRelease() (*statejson.InternalStateDescription, error) { } func (s *Server) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { - s.shutdownChanOut <- shutdown - <-s.ShutdownDoneChan + shutdownSuccess := s.sandboxContext.Shutdown(shutdown) + if len(shutdownSuccess.ErrorType) > 0 { + log.Errorf("Shutdown first fatal error: %s", shutdownSuccess.ErrorType) + } + + s.setRapidPhase(phaseIdle) + s.setRuntimeState(runtimeNotStarted) state := s.InternalStateGetter() return &state @@ -682,3 +805,49 @@ func (s *Server) InternalState() (*statejson.InternalStateDescription, error) { state := s.InternalStateGetter() return &state, nil } + +func (s *Server) Restore(restore *interop.Restore) error { + return s.sandboxContext.Restore(restore) +} + +func doneFromInvokeSuccess(successMsg interop.InvokeSuccess) *interop.Done { + return &interop.Done{ + Meta: interop.DoneMetadata{ + RuntimeRelease: successMsg.RuntimeRelease, + NumActiveExtensions: successMsg.NumActiveExtensions, + ExtensionNames: successMsg.ExtensionNames, + InvokeRequestReadTimeNs: successMsg.InvokeMetrics.InvokeRequestReadTimeNs, + InvokeRequestSizeBytes: successMsg.InvokeMetrics.InvokeRequestSizeBytes, + RuntimeReadyTime: successMsg.InvokeMetrics.RuntimeReadyTime, + + InvokeCompletionTimeNs: successMsg.InvokeCompletionTimeNs, + InvokeReceivedTime: successMsg.InvokeReceivedTime, + RuntimeTimeThrottledMs: successMsg.ResponseMetrics.RuntimeTimeThrottledMs, + RuntimeProducedBytes: successMsg.ResponseMetrics.RuntimeProducedBytes, + RuntimeOutboundThroughputBps: successMsg.ResponseMetrics.RuntimeOutboundThroughputBps, + LogsAPIMetrics: successMsg.LogsAPIMetrics, + }, + } +} + +func doneFailFromInvokeFailure(failureMsg *interop.InvokeFailure) *interop.DoneFail { + return &interop.DoneFail{ + ErrorType: failureMsg.ErrorType, + Meta: interop.DoneMetadata{ + RuntimeRelease: failureMsg.RuntimeRelease, + NumActiveExtensions: failureMsg.NumActiveExtensions, + InvokeReceivedTime: failureMsg.InvokeReceivedTime, + + RuntimeTimeThrottledMs: failureMsg.ResponseMetrics.RuntimeTimeThrottledMs, + RuntimeProducedBytes: failureMsg.ResponseMetrics.RuntimeProducedBytes, + RuntimeOutboundThroughputBps: failureMsg.ResponseMetrics.RuntimeOutboundThroughputBps, + + InvokeRequestReadTimeNs: failureMsg.InvokeMetrics.InvokeRequestReadTimeNs, + InvokeRequestSizeBytes: failureMsg.InvokeMetrics.InvokeRequestSizeBytes, + RuntimeReadyTime: failureMsg.InvokeMetrics.RuntimeReadyTime, + + ExtensionNames: failureMsg.ExtensionNames, + LogsAPIMetrics: failureMsg.LogsAPIMetrics, + }, + } +} diff --git a/lambda/rapidcore/server_test.go b/lambda/rapidcore/server_test.go index 416304c..88eea3f 100644 --- a/lambda/rapidcore/server_test.go +++ b/lambda/rapidcore/server_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore/env" ) func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { @@ -26,15 +27,68 @@ func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { } } +func sendInitStartedResponse(responseChannel chan<- interop.InitStarted, msg interop.InitStarted) { + msg.Ack = make(chan struct{}) + responseChannel <- msg + <-msg.Ack +} + +func sendInitSuccessResponse(responseChannel chan<- interop.InitSuccess, msg interop.InitSuccess) { + msg.Ack = make(chan struct{}) + responseChannel <- msg + <-msg.Ack +} + +func sendInitFailureResponse(responseChannel chan<- interop.InitFailure, msg interop.InitFailure) { + msg.Ack = make(chan struct{}) + responseChannel <- msg + <-msg.Ack +} + +type mockRapidCtx struct { + initHandler func(start chan<- interop.InitStarted, success chan<- interop.InitSuccess, fail chan<- interop.InitFailure) + invokeHandler func() (interop.InvokeSuccess, *interop.InvokeFailure) + resetHandler func() (interop.ResetSuccess, *interop.ResetFailure) +} + +func (r *mockRapidCtx) HandleInit(init *interop.Init, startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + r.initHandler(startResp, successResp, failureResp) +} + +func (r *mockRapidCtx) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { + return r.invokeHandler() +} + +func (r *mockRapidCtx) HandleReset(reset *interop.Reset, invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { + return r.resetHandler() +} + +func (r *mockRapidCtx) HandleShutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { + return interop.ShutdownSuccess{} +} + +func (r *mockRapidCtx) HandleRestore(restore *interop.Restore) error { + return nil +} + +func (r *mockRapidCtx) Clear() {} + func TestReserveDoesNotDeadlockWhenCalledMultipleTimes(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { <-srv.StartChan() }() - go srv.SendRunning(&interop.Running{}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - go srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"}) _, err := srv.Reserve("", "", "") // reserve successfully require.NoError(t, err) @@ -61,89 +115,120 @@ func TestInitSuccess(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) - }() + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) _, err := srv.Reserve("", "", "") require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) } func TestInitErrorBeforeReserve(t *testing.T) { + // Rapid thread sending init failure should not be blocked even if reserve hasn't arrived srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) initErrorResponseSent := make(chan error) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) initErrorResponseSent <- errors.New("initErrorResponseSent") - }() + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) if msg := waitForChanWithTimeout(initErrorResponseSent, 1*time.Second); msg == nil { require.Fail(t, "Timed out waiting for init error response to be sent") } resp, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) + require.NoError(t, err) require.True(t, len(resp.Token.InvokeID) > 0) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) + + _, err = srv.AwaitRelease() + require.Error(t, err, ErrReleaseReservationDone) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) } func TestInitErrorDuringReserve(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) - }() + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) resp, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) + require.NoError(t, err) require.True(t, len(resp.Token.InvokeID) > 0) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) + + _, err = srv.AwaitRelease() + require.Error(t, err, ErrReleaseReservationDone) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) } func TestInvokeSuccess(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) + releaseRuntimeInit := make(chan struct{}) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + <-releaseRuntimeInit + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } - <-srv.InvokeChan() - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "application/json", bytes.NewReader([]byte("response")))) + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader([]byte("response")), nil, nil)) require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + return interop.InvokeSuccess{}, nil + } - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) + releaseRuntimeInit <- struct{}{} _, err := srv.Reserve("", "", "") require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) // Reserve does not wait for init completion + + awaitInitErr := srv.AwaitInitialized() + require.NoError(t, awaitInitErr) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "response", responseRecorder.Body.String()) require.Equal(t, "application/json", responseRecorder.Result().Header.Get("Content-Type")) @@ -157,28 +242,35 @@ func TestInvokeError(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) - - <-srv.InvokeChan() + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }"), ContentType: "application/json"})) require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + return interop.InvokeSuccess{}, nil + } - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } + + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) _, err := srv.Reserve("", "", "") require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + awaitInitErr := srv.AwaitInitialized() + require.NoError(t, awaitInitErr) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) require.Equal(t, "application/json", responseRecorder.Result().Header.Get("Content-Type")) @@ -203,43 +295,49 @@ func TestInvokeWithSuppressedInitSuccess(t *testing.T) { srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) initErrorCompleted := make(chan error) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) initErrorCompleted <- errors.New("initErrorSequenceCompleted") + } - <-srv.ResetChan() - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"})) + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), nil, bytes.NewReader([]byte("response")), nil, nil)) + return interop.InvokeSuccess{}, nil + } - <-srv.InvokeChan() // run only after FastInvoke is called - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "", bytes.NewReader([]byte("response")))) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } + + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) if msg := waitForChanWithTimeout(initErrorCompleted, 1*time.Second); msg == nil { require.Fail(t, "Timed out waiting for init error sequence to be called") } - _, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + resp, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.True(t, len(resp.Token.InvokeID) > 0) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for suppressed init require.NoError(t, err) _, err = srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.NoError(t, err) responseRecorder := httptest.NewRecorder() successChan := make(chan error) go func() { directInvoke := false - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, directInvoke) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, directInvoke) require.NoError(t, invokeErr) successChan <- errors.New("invokeResponseWritten") }() @@ -261,39 +359,45 @@ func TestInvokeWithSuppressedInitErrorDueToInitError(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) + } + releaseChan := make(chan error) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "A.B"})) + releaseChan <- nil + return interop.InvokeSuccess{}, &interop.InvokeFailure{ErrorType: "A.B", RequestReset: true, DefaultErrorResponse: &interop.ErrorResponse{}} + } - <-srv.ResetChan() - srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"}) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } - <-srv.InvokeChan() - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - releaseChan <- nil - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "invokeCorrelationID", ErrorType: "A.B"})) - }() + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - _, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + resp, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.True(t, len(resp.Token.InvokeID) > 0) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for invoke with suppressed init require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) _, err = srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) require.Equal(t, phaseInvoking, srv.getRapidPhase()) @@ -310,39 +414,43 @@ func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "A.B"})) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) + } + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) + require.NoError(t, srv.SendRuntimeReady()) + return interop.InvokeSuccess{}, nil + } - <-srv.ResetChan() - srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"}) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } - <-srv.InvokeChan() - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) - require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) - _, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + resp, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.True(t, len(resp.Token.InvokeID) > 0) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for invoke with suppressed init require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) _, err = srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "{ 'errorType': 'B.C' }", responseRecorder.Body.String()) @@ -356,39 +464,43 @@ func TestMultipleInvokeSuccess(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) - }() - - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - - invokeFunc := func(i int) { - <-srv.InvokeChan() - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "", bytes.NewReader([]byte("response-"+fmt.Sprint(i))))) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + i := 0 + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), nil, bytes.NewReader([]byte("response-"+fmt.Sprint(i))), nil, nil)) require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) + i++ + return interop.InvokeSuccess{}, nil + } + + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil } - go func() { - for i := 0; i < 3; i++ { - invokeFunc(i) - } - }() + + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) for i := 0; i < 3; i++ { _, err := srv.Reserve("", "", "") - require.Contains(t, []error{nil, ErrInitAlreadyDone}, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) + require.NoError(t, err) + + awaitInitErr := srv.AwaitInitialized() + require.NoError(t, awaitInitErr) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "response-"+fmt.Sprint(i), responseRecorder.Body.String()) + require.Equal(t, phaseInvoking, srv.getRapidPhase()) _, err = srv.AwaitRelease() require.NoError(t, err) + require.Equal(t, phaseIdle, srv.getRapidPhase()) require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) } } diff --git a/lambda/rapidcore/standalone/directInvokeHandler.go b/lambda/rapidcore/standalone/directInvokeHandler.go index a485deb..1c7e7cb 100644 --- a/lambda/rapidcore/standalone/directInvokeHandler.go +++ b/lambda/rapidcore/standalone/directInvokeHandler.go @@ -4,13 +4,15 @@ package standalone import ( - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/core/directinvoke" "go.amzn.com/lambda/rapidcore" + "net/http" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/core/directinvoke" ) -func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { tok := s.CurrentToken() if tok == nil { log.Errorf("Attempt to call directInvoke without Reserve") @@ -24,6 +26,14 @@ func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.Int return } + if err := s.AwaitInitialized(); err != nil { + w.WriteHeader(DoneFailedHTTPCode) + if state, err := s.InternalState(); err == nil { + w.Write(state.AsJSON()) + } + return + } + if err := s.FastInvoke(w, invoke, true); err != nil { switch err { case rapidcore.ErrNotReserved: diff --git a/lambda/rapidcore/standalone/executeHandler.go b/lambda/rapidcore/standalone/executeHandler.go index 36c257a..9bac400 100644 --- a/lambda/rapidcore/standalone/executeHandler.go +++ b/lambda/rapidcore/standalone/executeHandler.go @@ -8,16 +8,17 @@ import ( log "github.com/sirupsen/logrus" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapidcore" ) -func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) { +func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.LambdaInvokeAPI) { invokePayload := &interop.Invoke{ - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: r.Body, - CorrelationID: "invokeCorrelationID", + TraceID: r.Header.Get("X-Amzn-Trace-Id"), + LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), + Payload: r.Body, + InvokeReceivedTime: metering.Monotime(), } // If we write to 'w' directly and waitUntilRelease fails, we won't be able to propagate error anymore @@ -38,17 +39,17 @@ func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) case rapidcore.ErrInvokeResponseAlreadyWritten: return - case rapidcore.ErrInvokeTimeout: + case rapidcore.ErrInvokeTimeout, rapidcore.ErrInitResetReceived: w.WriteHeader(http.StatusGatewayTimeout) // DONE failures: - case rapidcore.ErrTerminated, rapidcore.ErrInitDoneFailed, rapidcore.ErrInvokeDoneFailed: + case rapidcore.ErrInvokeDoneFailed: copyHeaders(invokeResp, w) w.WriteHeader(DoneFailedHTTPCode) w.Write(invokeResp.Body) return // Reservation canceled errors - case rapidcore.ErrReserveReservationDone, rapidcore.ErrInvokeReservationDone, rapidcore.ErrReleaseReservationDone: + case rapidcore.ErrReserveReservationDone, rapidcore.ErrInvokeReservationDone, rapidcore.ErrReleaseReservationDone, rapidcore.ErrInitNotStarted: w.WriteHeader(http.StatusGatewayTimeout) } diff --git a/lambda/rapidcore/standalone/initHandler.go b/lambda/rapidcore/standalone/initHandler.go index d006b81..d60ec6f 100644 --- a/lambda/rapidcore/standalone/initHandler.go +++ b/lambda/rapidcore/standalone/initHandler.go @@ -7,21 +7,33 @@ import ( "fmt" "net/http" "os" + "time" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" "go.amzn.com/lambda/rapidcore/env" ) +type RuntimeInfo struct { + ImageJSON string `json:"runtimeImageJSON,omitempty"` + Arn string `json:"runtimeArn,omitempty"` + Version string `json:"runtimeVersion,omitempty"` +} + // TODO: introduce suppress init flag type InitBody struct { - Handler string `json:"handler"` - FunctionName string `json:"functionName"` - FunctionVersion string `json:"functionVersion"` - InvokeTimeoutMs int64 `json:"invokeTimeoutMs"` + Handler string `json:"handler"` + FunctionName string `json:"functionName"` + FunctionVersion string `json:"functionVersion"` + InvokeTimeoutMs int64 `json:"invokeTimeoutMs"` + RuntimeInfo RuntimeInfo `json:"runtimeInfo"` Customer struct { Environment map[string]string `json:"environment"` } `json:"customer"` + AwsKey *string `json:"awskey"` + AwsSecret *string `json:"awssecret"` + AwsSession *string `json:"awssession"` + CredentialsExpiry time.Time `json:"credentialsExpiry"` + Throttled bool `json:"throttled"` } type InitRequest struct { @@ -44,7 +56,7 @@ func (c *InitBody) Validate() error { return nil } -func InitHandler(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) { +func InitHandler(w http.ResponseWriter, r *http.Request, sandbox InteropServer, bs interop.Bootstrap) { init := InitBody{} if lerr := readBodyAndUnmarshalJSON(r, &init); lerr != nil { lerr.Send(w, r) @@ -61,20 +73,54 @@ func InitHandler(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandb // logic consistent across standalone-mode and girp-mode os.Setenv(envKey, envVal) } - // TODO generate CorrelationID + + awsKey, awsSecret, awsSession := getCredentials(init) + + sandboxType := interop.SandboxClassic + + if init.Throttled { + sandboxType = interop.SandboxPreWarmed + } // pass to rapid sandbox.Init(&interop.Init{ Handler: init.Handler, - CorrelationID: "initCorrelationID", - AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), - AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), - AwsSession: os.Getenv("AWS_SESSION_TOKEN"), + AwsKey: awsKey, + AwsSecret: awsSecret, + AwsSession: awsSession, + CredentialsExpiry: init.CredentialsExpiry, XRayDaemonAddress: "0.0.0.0:0", // TODO FunctionName: init.FunctionName, FunctionVersion: init.FunctionVersion, - + RuntimeInfo: interop.RuntimeInfo{ + ImageJSON: init.RuntimeInfo.ImageJSON, + Arn: init.RuntimeInfo.Arn, + Version: init.RuntimeInfo.Version}, CustomerEnvironmentVariables: env.CustomerEnvironmentVariables(), + SandboxType: sandboxType, + Bootstrap: bs, + EnvironmentVariables: env.NewEnvironment(), }, init.InvokeTimeoutMs) +} + +func getCredentials(init InitBody) (string, string, string) { + // ToDo(guvfatih): I think instead of passing and getting these credentials values via environment variables + // we need to make StandaloneTests passing these via the Init request to be compliant with the existing protocol. + awsKey := os.Getenv("AWS_ACCESS_KEY_ID") + awsSecret := os.Getenv("AWS_SECRET_ACCESS_KEY") + awsSession := os.Getenv("AWS_SESSION_TOKEN") + + if init.AwsKey != nil { + awsKey = *init.AwsKey + } + + if init.AwsSecret != nil { + awsSecret = *init.AwsSecret + } + + if init.AwsSession != nil { + awsSession = *init.AwsSession + } + return awsKey, awsSecret, awsSession } diff --git a/lambda/rapidcore/standalone/internalStateHandler.go b/lambda/rapidcore/standalone/internalStateHandler.go index ff1335a..cb40c1c 100644 --- a/lambda/rapidcore/standalone/internalStateHandler.go +++ b/lambda/rapidcore/standalone/internalStateHandler.go @@ -5,11 +5,9 @@ package standalone import ( "net/http" - - "go.amzn.com/lambda/rapidcore" ) -func InternalStateHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func InternalStateHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { state, err := s.InternalState() if err != nil { http.Error(w, "internal state callback not set", http.StatusInternalServerError) diff --git a/lambda/rapidcore/standalone/invokeHandler.go b/lambda/rapidcore/standalone/invokeHandler.go index 0d89f1c..3e9768c 100644 --- a/lambda/rapidcore/standalone/invokeHandler.go +++ b/lambda/rapidcore/standalone/invokeHandler.go @@ -14,7 +14,7 @@ import ( log "github.com/sirupsen/logrus" ) -func InvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func InvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { tok := s.CurrentToken() if tok == nil { log.Errorf("Attempt to call directInvoke without Reserve") @@ -22,28 +22,20 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropSe return } - isResyncReceivedFlag := false - - awsKey := r.Header.Get("ResyncAwsKey") - awsSecret := r.Header.Get("ResyncAwsSecret") - awsSession := r.Header.Get("ResyncAwsSession") - - if len(awsKey) > 0 && len(awsSecret) > 0 && len(awsSession) > 0 { - isResyncReceivedFlag = true + invokePayload := &interop.Invoke{ + TraceID: r.Header.Get("X-Amzn-Trace-Id"), + LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), + Payload: r.Body, + DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), + InvokeReceivedTime: metering.Monotime(), } - invokePayload := &interop.Invoke{ - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: r.Body, - CorrelationID: "invokeCorrelationID", - DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), - ResyncState: interop.Resync{ - IsResyncReceived: isResyncReceivedFlag, - AwsKey: awsKey, - AwsSecret: awsSecret, - AwsSession: awsSession, - }, + if err := s.AwaitInitialized(); err != nil { + w.WriteHeader(DoneFailedHTTPCode) + if state, err := s.InternalState(); err == nil { + w.Write(state.AsJSON()) + } + return } if err := s.FastInvoke(w, invokePayload, false); err != nil { diff --git a/lambda/rapidcore/standalone/pingHandler.go b/lambda/rapidcore/standalone/pingHandler.go new file mode 100644 index 0000000..c6cb021 --- /dev/null +++ b/lambda/rapidcore/standalone/pingHandler.go @@ -0,0 +1,12 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package standalone + +import ( + "net/http" +) + +func PingHandler(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("pong")) +} diff --git a/lambda/rapidcore/standalone/reserveHandler.go b/lambda/rapidcore/standalone/reserveHandler.go index d3e0b9f..52b51cd 100644 --- a/lambda/rapidcore/standalone/reserveHandler.go +++ b/lambda/rapidcore/standalone/reserveHandler.go @@ -24,24 +24,14 @@ func tokenToHeaders(w http.ResponseWriter, token interop.Token) { w.Header().Set(directinvoke.VersionIDHeader, token.VersionID) } -func ReserveHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func ReserveHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { reserveResp, err := s.Reserve("", r.Header.Get("X-Amzn-Trace-Id"), r.Header.Get("X-Amzn-Segment-Id")) if err != nil { switch err { - case rapidcore.ErrInitAlreadyDone: - // init already happened before, just provide internal state and return - tokenToHeaders(w, reserveResp.Token) - InternalStateHandler(w, r, s) case rapidcore.ErrReserveReservationDone: // TODO use http.StatusBadGateway w.WriteHeader(http.StatusGatewayTimeout) - case rapidcore.ErrInitDoneFailed, rapidcore.ErrInitError: - w.WriteHeader(DoneFailedHTTPCode) - w.Write(reserveResp.InternalState.AsJSON()) - case rapidcore.ErrTerminated: - w.WriteHeader(DoneFailedHTTPCode) - w.Write(reserveResp.InternalState.AsJSON()) default: log.Errorf("Failed to reserve: %s", err) w.WriteHeader(400) diff --git a/lambda/rapidcore/standalone/resetHandler.go b/lambda/rapidcore/standalone/resetHandler.go index 1a719ff..4f2ca2e 100644 --- a/lambda/rapidcore/standalone/resetHandler.go +++ b/lambda/rapidcore/standalone/resetHandler.go @@ -5,8 +5,6 @@ package standalone import ( "net/http" - - "go.amzn.com/lambda/rapidcore" ) type resetAPIRequest struct { @@ -14,7 +12,7 @@ type resetAPIRequest struct { TimeoutMs int64 `json:"timeoutMs"` } -func ResetHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func ResetHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { reset := resetAPIRequest{} if lerr := readBodyAndUnmarshalJSON(r, &reset); lerr != nil { lerr.Send(w, r) diff --git a/lambda/rapidcore/standalone/restoreHandler.go b/lambda/rapidcore/standalone/restoreHandler.go new file mode 100644 index 0000000..190b6d8 --- /dev/null +++ b/lambda/rapidcore/standalone/restoreHandler.go @@ -0,0 +1,41 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package standalone + +import ( + "net/http" + "time" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" +) + +type RestoreBody struct { + AwsKey string `json:"awskey"` + AwsSecret string `json:"awssecret"` + AwsSession string `json:"awssession"` + CredentialsExpiry time.Time `json:"credentialsExpiry"` +} + +func RestoreHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { + restoreRequest := RestoreBody{} + if lerr := readBodyAndUnmarshalJSON(r, &restoreRequest); lerr != nil { + lerr.Send(w, r) + return + } + + restore := &interop.Restore{ + AwsKey: restoreRequest.AwsKey, + AwsSecret: restoreRequest.AwsSecret, + AwsSession: restoreRequest.AwsSession, + CredentialsExpiry: restoreRequest.CredentialsExpiry, + } + + err := s.Restore(restore) + + if err != nil { + log.Errorf("Failed to restore: %s", err) + w.WriteHeader(http.StatusBadGateway) + } +} diff --git a/lambda/rapidcore/standalone/router.go b/lambda/rapidcore/standalone/router.go index 5a4ae7c..f1712ea 100644 --- a/lambda/rapidcore/standalone/router.go +++ b/lambda/rapidcore/standalone/router.go @@ -7,18 +7,35 @@ import ( "context" "net/http" + "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" "go.amzn.com/lambda/rapidcore/telemetry" "github.com/go-chi/chi" ) -func NewHTTPRouter(sandbox rapidcore.Sandbox, eventLog *telemetry.EventLog, shutdownFunc context.CancelFunc) *chi.Mux { - ipcSrv := sandbox.InteropServer() +type InteropServer interface { + Init(i *interop.Init, invokeTimeoutMs int64) error + AwaitInitialized() error + FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error + Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) + Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) + AwaitRelease() (*statejson.InternalStateDescription, error) + Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription + InternalState() (*statejson.InternalStateDescription, error) + CurrentToken() *interop.Token + Restore(restore *interop.Restore) error +} + +func NewHTTPRouter(ipcSrv InteropServer, lambdaInvokeAPI rapidcore.LambdaInvokeAPI, eventLog *telemetry.EventLog, shutdownFunc context.CancelFunc, bs interop.Bootstrap) *chi.Mux { r := chi.NewRouter() r.Use(standaloneAccessLogDecorator) - r.Post("/2015-03-31/functions/*/invocations", func(w http.ResponseWriter, r *http.Request) { Execute(w, r, sandbox) }) - r.Post("/test/init", func(w http.ResponseWriter, r *http.Request) { InitHandler(w, r, sandbox) }) + + r.Post("/2015-03-31/functions/*/invocations", func(w http.ResponseWriter, r *http.Request) { Execute(w, r, lambdaInvokeAPI) }) + r.Get("/test/ping", func(w http.ResponseWriter, r *http.Request) { PingHandler(w, r) }) + r.Post("/test/init", func(w http.ResponseWriter, r *http.Request) { InitHandler(w, r, ipcSrv, bs) }) + r.Post("/test/waitUntilInitialized", func(w http.ResponseWriter, r *http.Request) { WaitUntilInitializedHandler(w, r, ipcSrv) }) r.Post("/test/reserve", func(w http.ResponseWriter, r *http.Request) { ReserveHandler(w, r, ipcSrv) }) r.Post("/test/invoke", func(w http.ResponseWriter, r *http.Request) { InvokeHandler(w, r, ipcSrv) }) r.Post("/test/waitUntilRelease", func(w http.ResponseWriter, r *http.Request) { WaitUntilReleaseHandler(w, r, ipcSrv) }) @@ -27,6 +44,6 @@ func NewHTTPRouter(sandbox rapidcore.Sandbox, eventLog *telemetry.EventLog, shut r.Post("/test/directInvoke/{reservationtoken}", func(w http.ResponseWriter, r *http.Request) { DirectInvokeHandler(w, r, ipcSrv) }) r.Get("/test/internalState", func(w http.ResponseWriter, r *http.Request) { InternalStateHandler(w, r, ipcSrv) }) r.Get("/test/eventLog", func(w http.ResponseWriter, r *http.Request) { EventLogHandler(w, r, eventLog) }) - + r.Post("/test/restore", func(w http.ResponseWriter, r *http.Request) { RestoreHandler(w, r, ipcSrv) }) return r } diff --git a/lambda/rapidcore/standalone/shutdownHandler.go b/lambda/rapidcore/standalone/shutdownHandler.go index ee91277..8085541 100644 --- a/lambda/rapidcore/standalone/shutdownHandler.go +++ b/lambda/rapidcore/standalone/shutdownHandler.go @@ -9,14 +9,13 @@ import ( "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapidcore" ) type shutdownAPIRequest struct { TimeoutMs int64 `json:"timeoutMs"` } -func ShutdownHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer, shutdownFunc context.CancelFunc) { +func ShutdownHandler(w http.ResponseWriter, r *http.Request, s InteropServer, shutdownFunc context.CancelFunc) { shutdown := shutdownAPIRequest{} if lerr := readBodyAndUnmarshalJSON(r, &shutdown); lerr != nil { lerr.Send(w, r) @@ -24,8 +23,7 @@ func ShutdownHandler(w http.ResponseWriter, r *http.Request, s rapidcore.Interop } internalState := s.Shutdown(&interop.Shutdown{ - DeadlineNs: metering.Monotime() + int64(shutdown.TimeoutMs*1000*1000), - CorrelationID: "shutdownCorrelationID", + DeadlineNs: metering.Monotime() + int64(shutdown.TimeoutMs*1000*1000), }) w.Write(internalState.AsJSON()) diff --git a/lambda/rapidcore/standalone/util.go b/lambda/rapidcore/standalone/util.go index 21ee08f..7ba7420 100644 --- a/lambda/rapidcore/standalone/util.go +++ b/lambda/rapidcore/standalone/util.go @@ -6,7 +6,7 @@ package standalone import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" log "github.com/sirupsen/logrus" @@ -58,7 +58,7 @@ func (w *ResponseWriterProxy) IsError() bool { } func readBodyAndUnmarshalJSON(r *http.Request, dst interface{}) *ErrorReply { - bodyBytes, err := ioutil.ReadAll(r.Body) + bodyBytes, err := io.ReadAll(r.Body) if err != nil { return newErrorReply(ClientInvalidRequest, fmt.Sprintf("Failed to read full body: %s", err)) } diff --git a/lambda/rapidcore/standalone/waitUntilInitializedHandler.go b/lambda/rapidcore/standalone/waitUntilInitializedHandler.go new file mode 100644 index 0000000..95d64ac --- /dev/null +++ b/lambda/rapidcore/standalone/waitUntilInitializedHandler.go @@ -0,0 +1,23 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package standalone + +import ( + "net/http" + + "go.amzn.com/lambda/rapidcore" +) + +func WaitUntilInitializedHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { + err := s.AwaitInitialized() + if err != nil { + switch err { + case rapidcore.ErrInitDoneFailed: + w.WriteHeader(DoneFailedHTTPCode) + case rapidcore.ErrInitResetReceived: + w.WriteHeader(DoneFailedHTTPCode) + } + } + w.WriteHeader(http.StatusOK) +} diff --git a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go index 9aab644..0a756dd 100644 --- a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go +++ b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go @@ -9,7 +9,7 @@ import ( "go.amzn.com/lambda/rapidcore" ) -func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { internalState, err := s.AwaitRelease() if err != nil { switch err { @@ -20,7 +20,7 @@ func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s rapidcore // TODO use http.StatusOK w.WriteHeader(http.StatusGatewayTimeout) return - case rapidcore.ErrTerminated: + case rapidcore.ErrInitDoneFailed: w.WriteHeader(DoneFailedHTTPCode) w.Write(internalState.AsJSON()) return diff --git a/lambda/rapidcore/telemetry/eventLog.go b/lambda/rapidcore/telemetry/eventLog.go index c66672c..2f809fa 100644 --- a/lambda/rapidcore/telemetry/eventLog.go +++ b/lambda/rapidcore/telemetry/eventLog.go @@ -9,6 +9,8 @@ import ( "time" ) +// TODO: Refactor to represent event structs below as a form of Events API entity + type XrayEvent struct { Msg string `json:"msg"` TraceID string `json:"traceID"` @@ -32,20 +34,13 @@ type FunctionLogEvent struct{} type ExtensionLogEvent struct{} type EventLog struct { + Events []SandboxEvent `json:"events,omitempty"` // populated by the StandaloneEventLog object Xray []XrayEvent `json:"xray,omitempty"` PlatformLog []PlatformLogEvent `json:"platformLogs,omitempty"` Logs []string `json:"rawLogs,omitempty"` mutex sync.Mutex } -func (p *EventLog) LogXrayEvent(msg string, traceID string, segmentName string, segmentID string) { - p.Xray = append(p.Xray, XrayEvent{Msg: msg, TraceID: traceID, SegmentName: segmentName, SegmentID: segmentID, Timestamp: time.Now().UnixNano() / int64(time.Millisecond)}) -} - -func (p *EventLog) LogExtensionInitEvent(agentName string, state string, subscriptions string, errorType string) { - p.PlatformLog = append(p.PlatformLog, PlatformLogEvent{agentName, state, errorType, strings.Split(subscriptions, ",")}) -} - func parseLogString(s string) []string { elems := strings.Split(s, "\t")[1:] for i, e := range elems { @@ -62,19 +57,7 @@ func (p *EventLog) dispatchLogEvent(logStr string) { if strings.HasPrefix(logStr, "XRAY") { // format: 'XRAY\tMessage: %s\tTraceID: %s\tSegmentName: %s\tSegmentID: %s' msg, traceID, segmentName, segmentID := elems[0], elems[1], elems[2], elems[3] - p.LogXrayEvent(msg, traceID, segmentName, segmentID) - } - - if strings.HasPrefix(logStr, "EXTENSION") && strings.Contains(logStr, "Error Type") { - // format: 'EXTENSION\tName: %s\tState: %s\tEvents: [%s]\tError Type: %s' - agentName, state, subscriptions, errorType := elems[0], elems[1], elems[2], elems[3] - p.LogExtensionInitEvent(agentName, state, subscriptions, errorType) - } - - if strings.HasPrefix(logStr, "EXTENSION") && !strings.Contains(logStr, "Error Type") { - // format: 'EXTENSION\tName: %s\tState: %s\tEvents: [%s]' - agentName, state, subscriptions, errorType := elems[0], elems[1], elems[2], "" - p.LogExtensionInitEvent(agentName, state, subscriptions, errorType) + p.Xray = append(p.Xray, XrayEvent{Msg: msg, TraceID: traceID, SegmentName: segmentName, SegmentID: segmentID, Timestamp: time.Now().UnixNano() / int64(time.Millisecond)}) } } diff --git a/lambda/rapidcore/telemetry/events_api.go b/lambda/rapidcore/telemetry/events_api.go new file mode 100644 index 0000000..7a882fd --- /dev/null +++ b/lambda/rapidcore/telemetry/events_api.go @@ -0,0 +1,97 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "sort" + "time" + + "go.amzn.com/lambda/telemetry" +) + +// EventType indicates the type of SandboxEvent. See full list: +type EventType = string + +const ( + PlatformInitRuntimeDone = EventType("platform.initRuntimeDone") + PlatformRestoreRuntimeDone = EventType("platform.restoreRuntimeDone") + PlatformRuntimeDone = EventType("platform.runtimeDone") + PlatformExtension = EventType("platform.extension") +) + +/* + SandboxEvent represents a generic sandbox event. For example: + {'time': '2021-03-16T13:10:42.358Z', + 'type': 'platform.extension', + 'record': { "name": "foo bar", "state": "Ready", "events": ["INVOKE", "SHUTDOWN"]}} +*/ +type SandboxEvent struct { + Time string `json:"time"` + Type EventType `json:"type"` + Record map[string]interface{} `json:"record"` +} + +type StandaloneEventLog struct { + requestID string + eventLog *EventLog +} + +func (s *StandaloneEventLog) SetCurrentRequestID(requestID string) { + s.requestID = requestID +} + +func (s *StandaloneEventLog) SendInitRuntimeDone(data *telemetry.InitRuntimeDoneData) error { + record := map[string]interface{}{"initializationType": data.InitSource, "status": data.Status} + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformInitRuntimeDone, record}) + return nil +} + +func (s *StandaloneEventLog) SendRestoreRuntimeDone(status string) error { + record := map[string]interface{}{"status": status} + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformRestoreRuntimeDone, record}) + return nil +} + +func (s *StandaloneEventLog) SendRuntimeDone(data telemetry.InvokeRuntimeDoneData) error { + // e.g. 'record': {'requestId': '1506eb3053d148f3bb7ec0fabe6f8d91','status': 'success', 'metrics': {...}, 'tracing': {...}} + record := map[string]interface{}{ + "requestId": s.requestID, + "status": data.Status, + "metrics": data.Metrics, + "internalMetrics": data.InternalMetrics, + "spans": data.Spans, + } + + if data.Tracing != nil { + record["tracing"] = map[string]string{ + "spanId": data.Tracing.SpanID, + "type": string(data.Tracing.Type), + "value": data.Tracing.Value, + } + } + + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformRuntimeDone, record}) + return nil +} + +func (s *StandaloneEventLog) SendExtensionInit(agentName, state, errorType string, subscriptions []string) error { + // e.g. 'record': { "name": "", "state": "", errorType: "", events: [""] } + sort.Strings(subscriptions) + record := map[string]interface{}{"name": agentName, "state": state, "events": subscriptions} + if len(errorType) > 0 { + record["errorType"] = errorType + } + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformExtension, record}) + return nil +} + +func (s *StandaloneEventLog) SendImageErrorLog(logline string) { + // Called on bootstrap exec errors for OCI error modes, e.g. InvalidEntrypoint etc. +} + +func NewStandaloneEventLog(eventLog *EventLog) *StandaloneEventLog { + return &StandaloneEventLog{ + eventLog: eventLog, + } +} diff --git a/lambda/runtimecmd/runtime_command.go b/lambda/runtimecmd/runtime_command.go deleted file mode 100644 index adf7886..0000000 --- a/lambda/runtimecmd/runtime_command.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package runtimecmd - -import ( - "context" - "fmt" - "io" - "os" - "os/exec" - "path" - "syscall" -) - -// CustomRuntimeCmd wraps exec.Cmd -type CustomRuntimeCmd struct { - *exec.Cmd -} - -// NewCustomRuntimeCmd returns a new CustomRuntimeCmd -func NewCustomRuntimeCmd(ctx context.Context, bootstrapCmd []string, dir string, env []string, stdoutWriter io.Writer, stderrWriter io.Writer, extraFiles []*os.File) *CustomRuntimeCmd { - cmd := exec.CommandContext(ctx, bootstrapCmd[0], bootstrapCmd[1:]...) - cmd.Dir = dir - - cmd.Stdout = stdoutWriter - cmd.Stderr = stderrWriter - - cmd.Env = env - - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - - if len(extraFiles) > 0 { - cmd.ExtraFiles = extraFiles - } - - return &CustomRuntimeCmd{cmd} -} - -// Name returnes runtime executable name -func (cmd *CustomRuntimeCmd) Name() string { - return path.Base(cmd.Path) -} - -// Pid returns the pid of a started runtime process -func (cmd *CustomRuntimeCmd) Pid() int { - return cmd.Process.Pid -} - -// Wait waits for the started customer runtime process to exit -func (cmd *CustomRuntimeCmd) Wait() error { - if err := cmd.Cmd.Wait(); err != nil { - return fmt.Errorf("Runtime exited with error: %v", err) - } - - return fmt.Errorf("Runtime exited without providing a reason") -} diff --git a/lambda/runtimecmd/runtime_command_test.go b/lambda/runtimecmd/runtime_command_test.go deleted file mode 100644 index f99599d..0000000 --- a/lambda/runtimecmd/runtime_command_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package runtimecmd - -import ( - "context" - "errors" - "io/ioutil" - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestRuntimeCommandSetsEnvironmentVariables(t *testing.T) { - envVars := []string{"foo=1", "bar=2", "baz=3"} - - currentDir, err := os.Getwd() - assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) - - execCmdArgs := []string{"foobar"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) - - assert.ElementsMatch(t, envVars, runtimeCmd.Env) - assert.Equal(t, execCmdArgs, runtimeCmd.Args) -} - -func TestRuntimeCommandSetsCurrentWorkingDir(t *testing.T) { - envVars := []string{} - - currentDir, err := os.Getwd() - assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) - - execCmdArgs := []string{"foobar"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) - - assert.Equal(t, currentDir, runtimeCmd.Dir) -} - -func TestRuntimeCommandSetsMultipleArgs(t *testing.T) { - envVars := []string{} - - currentDir, err := os.Getwd() - assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) - - execCmdArgs := []string{"foobar", "--baz", "22"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) - - assert.Equal(t, execCmdArgs, runtimeCmd.Args) -} diff --git a/lambda/supervisor/local_supervisor.go b/lambda/supervisor/local_supervisor.go new file mode 100644 index 0000000..1174089 --- /dev/null +++ b/lambda/supervisor/local_supervisor.go @@ -0,0 +1,302 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package supervisor + +import ( + "errors" + "fmt" + "os/exec" + "sync" + "syscall" + "time" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/supervisor/model" +) + +// typecheck interface compliance +var _ model.SupervisorClient = (*LocalSupervisor)(nil) + +type process struct { + // pid of the running process + pid int + // channel that can be use to block + // while waiting on process termination. + termination chan struct{} +} + +type LocalSupervisor struct { + events chan model.Event + processMapLock sync.Mutex + processMap map[string]process +} + +func NewLocalSupervisor() model.Supervisor { + return model.Supervisor{ + SupervisorClient: &LocalSupervisor{ + events: make(chan model.Event), + processMap: make(map[string]process), + }, + OperatorConfig: model.DomainConfig{ + RootPath: "/", + }, + RuntimeConfig: model.DomainConfig{ + RootPath: "/", + }, + } +} + +func (*LocalSupervisor) Start(req *model.StartRequest) error { + return nil +} +func (*LocalSupervisor) Configure(req *model.ConfigureRequest) error { + return nil +} +func (s *LocalSupervisor) Exec(req *model.ExecRequest) error { + if req.Domain != "runtime" { + log.Debug("Exec is a no op if domain != runtime") + return nil + } + command := exec.Command(req.Path, req.Args...) + + if req.Env != nil { + envStrings := make([]string, 0, len(*req.Env)) + for key, value := range *req.Env { + envStrings = append(envStrings, key+"="+value) + } + command.Env = envStrings + } + + if req.Cwd != nil && *req.Cwd != "" { + command.Dir = *req.Cwd + } + + if req.ExtraFiles != nil { + command.ExtraFiles = *req.ExtraFiles + } + + command.Stdout = req.StdoutWriter + command.Stderr = req.StderrWriter + + command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + err := command.Start() + + if err != nil { + return err + // TODO Use supevisor specific error + } + + pid := command.Process.Pid + termination := make(chan struct{}) + s.processMapLock.Lock() + s.processMap[req.Name] = process{ + pid: pid, + termination: termination, + } + s.processMapLock.Unlock() + + go func() { + err = command.Wait() + // close the termination channel to unblock whoever's blocked on + // it (used to implement kill's blocking behaviour) + close(termination) + + var cell int32 + var exitStatus *int32 + var signo *int32 + var exitErr *exec.ExitError + + if err == nil { + exitStatus = &cell + } else if errors.As(err, &exitErr) { + if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { + if code := status.ExitStatus(); code >= 0 { + cell = int32(code) + exitStatus = &cell + } else { + cell = int32(status.Signal()) + signo = &cell + } + } + } + + if signo == nil && exitStatus == nil { + log.Error("Cannot convert process exit status to unix WaitStatus. This is unexpected. Assuming ExitStatus 1") + cell = 1 + exitStatus = &cell + } + s.events <- model.Event{ + Time: uint64(time.Now().UnixMilli()), + Event: model.EventData{ + Domain: &req.Domain, + Name: &req.Name, + Signo: signo, + ExitStatus: exitStatus, + }, + } + }() + + return nil +} + +func kill(p process, name string, timeout *time.Duration) error { + // kill should report success if the process terminated by the time + //supervisor receives the request. + select { + // ifthis case is selected, the channel is closed, + // which means the process is terminated + case <-p.termination: + log.Debugf("Process %s already terminated.", name) + return nil + default: + log.Infof("Sending SIGKILL to %s(%d).", name, p.pid) + } + + if timeout != nil && *timeout <= 0 { + return fmt.Errorf("Timed out while trying to SIGKILL %s", name) + } + + pgid, err := syscall.Getpgid(p.pid) + + if err == nil { + // Negative pid sends signal to all in process group + syscall.Kill(-pgid, syscall.SIGKILL) + } else { + syscall.Kill(p.pid, syscall.SIGKILL) + } + + // the nil channel blocks forever + var timer <-chan time.Time + if timeout != nil { + timer = time.After(*timeout) + } + + // block until the (main) process exits + // or the timeout fires + select { + case <-p.termination: + return nil + case <-timer: + return fmt.Errorf("Timed out while trying to SIGKILL %s", name) + } +} + +func (s *LocalSupervisor) Kill(req *model.KillRequest) error { + if req.Domain != "runtime" { + log.Debug("Kill is a no op if domain != runtime") + return nil + } + s.processMapLock.Lock() + process, ok := s.processMap[req.Name] + s.processMapLock.Unlock() + if !ok { + msg := "Unknown process" + return &model.SupervisorError{ + Kind: model.NoSuchEntity, + Message: &msg, + } + } + timeout := convertTimeout(req.Timeout) + + return kill(process, req.Name, timeout) +} + +func (s *LocalSupervisor) Terminate(req *model.TerminateRequest) error { + if req.Domain != "runtime" { + log.Debug("Terminate is no op if domain != runtime") + return nil + } + s.processMapLock.Lock() + process, ok := s.processMap[req.Name] + pid := process.pid + s.processMapLock.Unlock() + if !ok { + msg := "Unknown process" + err := &model.SupervisorError{ + Kind: model.NoSuchEntity, + Message: &msg, + } + log.WithError(err).Errorf("Process %s not found in local supervisor map", req.Name) + return err + } + + pgid, err := syscall.Getpgid(pid) + + if err == nil { + // Negative pid sends signal to all in process group + // best effort, ignore errors + _ = syscall.Kill(-pgid, syscall.SIGTERM) + } else { + _ = syscall.Kill(pid, syscall.SIGTERM) + } + + return nil +} + +func (s *LocalSupervisor) Stop(req *model.StopRequest) error { + if req.Domain != "runtime" { + log.Debug("Shutdown is no op if domain != runtime") + return nil + } + timeout := convertTimeout(req.Timeout) + + // shut down kills all the processes in the map + s.processMapLock.Lock() + defer s.processMapLock.Unlock() + + nprocs := len(s.processMap) + + successes := make(chan struct{}) + errors := make(chan error) + for name, proc := range s.processMap { + go func(n string, p process) { + log.Debugf("Killing %s", n) + err := kill(p, n, timeout) + if err != nil { + errors <- err + } else { + successes <- struct{}{} + } + + }(name, proc) + } + + var err error + for i := 0; i < nprocs; i++ { + select { + case <-successes: + case e := <-errors: + if err == nil { + err = fmt.Errorf("Shutdown failed: %s", e.Error()) + } + } + + } + + s.processMap = make(map[string]process) + return err +} +func (*LocalSupervisor) Freeze(req *model.FreezeRequest) error { + return nil +} +func (*LocalSupervisor) Thaw(req *model.ThawRequest) error { + return nil +} +func (s *LocalSupervisor) Ping() error { + return nil +} + +func (s *LocalSupervisor) Events() (<-chan model.Event, error) { + return s.events, nil +} + +func convertTimeout(millis *uint64) *time.Duration { + var timeout *time.Duration + if millis != nil { + t := time.Duration(*millis) * time.Millisecond + timeout = &t + } + return timeout +} diff --git a/lambda/supervisor/local_supervisor_test.go b/lambda/supervisor/local_supervisor_test.go new file mode 100644 index 0000000..8b3336b --- /dev/null +++ b/lambda/supervisor/local_supervisor_test.go @@ -0,0 +1,215 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package supervisor + +import ( + "errors" + "fmt" + "syscall" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/supervisor/model" +) + +func TestRuntimeDomainExec(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + + assert.Nil(t, err) +} + +func TestInvalidRuntimeDomainExec(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/none", + }) + + require.Error(t, err) +} + +func TestEvents(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + sync := make(chan struct{}) + go func() { + evt, ok := <-client.events + require.True(t, ok) + termination := evt.Event.ProcessTerminated() + require.NotNil(t, termination) + assert.Equal(t, "runtime", *termination.Domain) + assert.Equal(t, "agent", *termination.Name) + sync <- struct{}{} + }() + + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + <-sync +} + +func TestTerminate(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = supv.Terminate(&model.TerminateRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err) + // wait for process exit notification + ev := <-client.events + require.NotNil(t, ev.Event.ProcessTerminated()) + term := *ev.Event.ProcessTerminated() + require.Nil(t, term.Exited()) + require.NotNil(t, term.Signaled()) + require.EqualValues(t, syscall.SIGTERM, *term.Signo) +} + +// Termiante should not fail if the message is not delivered +func TestTerminateExited(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + // wait a short bit for bash to exit + time.Sleep(100 * time.Millisecond) + err = supv.Terminate(&model.TerminateRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err) +} + +func TestKill(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + err = supv.Kill(&model.KillRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err) + timer := time.NewTimer(50 * time.Millisecond) + select { + case _, ok := <-client.events: + assert.True(t, ok) + case <-timer.C: + require.Fail(t, "Process should have exited by the time kill returns") + } +} + +func TestKillExited(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + //wait for natural exit event + <-client.events + err = supv.Kill(&model.KillRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err, "Kill should succeed for exited processes") +} + +func TestKillUnknown(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Kill(&model.KillRequest{ + Domain: "runtime", + Name: "unknown", + }) + require.Error(t, err) + var supvError *model.SupervisorError + assert.True(t, errors.As(err, &supvError)) + assert.Equal(t, supvError.Kind, model.NoSuchEntity) +} + +func TestShutdown(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + log.Debug("hello") + // start a bunch of processes, some short running, some longer running + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent-0", + Path: "/bin/bash", + Args: []string{"-c", "sleep 1s"}, + }) + require.NoError(t, err) + + err = supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent-1", + Path: "/bin/bash", + }) + require.NoError(t, err) + + err = supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent-2", + Path: "/bin/bash", + Args: []string{"-c", "sleep 2s"}, + }) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = supv.Stop(&model.StopRequest{ + Domain: "runtime", + }) + require.NoError(t, err) + // Shutdown is expected to block untill all processes have exited + expected := map[string]struct{}{ + "agent-0": {}, + "agent-1": {}, + "agent-2": {}, + } + done := false + timer := time.NewTimer(200 * time.Millisecond) + for !done { + select { + case ev := <-client.events: + data := ev.Event.ProcessTerminated() + assert.NotNil(t, data) + _, ok := expected[*data.Name] + assert.True(t, ok) + delete(expected, *data.Name) + case <-timer.C: + fmt.Print(expected) + assert.Equal(t, 0, len(expected), "All process should terminate at shutdown") + done = true + } + } +} diff --git a/lambda/supervisor/model/model.go b/lambda/supervisor/model/model.go new file mode 100644 index 0000000..384726d --- /dev/null +++ b/lambda/supervisor/model/model.go @@ -0,0 +1,269 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "fmt" + "io" + "os" + "syscall" +) + +type Supervisor struct { + SupervisorClient + OperatorConfig DomainConfig + RuntimeConfig DomainConfig +} + +type DomainConfig struct { + // path to the root of the domain within the root mnt namespace + RootPath string +} + +type SupervisorClient interface { + Start(req *StartRequest) error + Configure(req *ConfigureRequest) error + Exec(req *ExecRequest) error + Terminate(req *TerminateRequest) error + Kill(req *KillRequest) error + Stop(req *StopRequest) error + Freeze(req *FreezeRequest) error + Thaw(req *ThawRequest) error + Ping() error + Events() (<-chan Event, error) +} + +type StartRequest struct { + Domain string `json:"domain"` + // name of the cgroup profile to start the domain in + CgroupProfile *string `json:"cgroup_profile,omitempty"` +} + +// Mount in lockhard::mnt is a Rust enum, an algebraic type, where each case has different set of fields. +// This models only the Mount::Drive case, the only one we need for now. +type DriveMount struct { + Source string `json:"source,omitempty"` + Destination string `json:"destination,omitempty"` + FsType string `json:"fs_type,omitempty"` + Options []string `json:"options,omitempty"` + Chowner []uint32 `json:"chowner,omitempty"` // array of two integers representing a tuple + Chmode uint32 `json:"chmode,omitempty"` + // Lockhard also expects a "type" field here, which in our case is constant, so we provide it upon serialization below +} + +// Adds the "type": "drive" to json +func (m *DriveMount) MarshalJSON() ([]byte, error) { + type driveMountAlias DriveMount + + return json.Marshal(&struct { + Type string `json:"type,omitempty"` + *driveMountAlias + }{ + Type: "drive", + driveMountAlias: (*driveMountAlias)(m), + }) +} + +type Capabilities struct { + Ambient []string `json:"ambient,omitempty"` + Bounding []string `json:"bounding,omitempty"` + Effective []string `json:"effective,omitempty"` + Inheritable []string `json:"inheritable,omitempty"` + Permitted []string `json:"permitted,omitempty"` +} + +type CgroupProfile struct { + Name string `json:"name"` + CPUPct *float64 `json:"cpu_pct,omitempty"` + MemMaxBytes *uint64 `json:"mem_max,omitempty"` +} + +type ExecUser struct { + UID *uint32 `json:"uid"` + GID *uint32 `json:"gid"` +} + +type ConfigureRequest struct { + // domain to configure + Domain string `json:"domain"` + Mounts []DriveMount `json:"mounts,omitempty"` + Capabilities *Capabilities `json:"capabilities,omitempty"` + SeccompFilters []string `json:"seccomp_filters,omitempty"` + // list of cgroup profiles available for the domain + // cgroup profiles are set on boot or thaw requests + CgroupProfiles []CgroupProfile `json:"cgroup_profiles,omitempty"` + // uid and gid of the user the spawned process runs as (w.r.t. the domain user namespace). + // If nil, Supervisor will use the ExecUser specified in the domain configuration file + ExecUser *ExecUser `json:"exec_user,omitempty"` + // additional hooks to execute on domain start + AdditionalStartHooks []Hook `json:"additional_start_hooks,omitempty"` +} + +type Event struct { + Time uint64 `json:"timestamp_millis"` + Event EventData `json:"event"` +} + +// EventData is a union type tagged by the "EventType" +// and "Cause" strings. +// you can use ProcessTermination() or EventLoss() to access +// the correct type of Event. +type EventData struct { + EvType string `json:"type"` + Domain *string `json:"domain"` + Name *string `json:"name"` + Cause *string `json:"cause"` + Signo *int32 `json:"signo"` + ExitStatus *int32 `json:"exit_status"` + Size *uint64 `json:"size"` +} + +// returns nil if the event is not a EventLoss event +// otherwise returns how many events were lost due to +// backpressure (slow reader) +func (d EventData) EventLoss() *uint64 { + return d.Size +} + +// Returns a ProcessTermination struct that describe the process +// which terminated. Use Signaled() or Exited() to check whether +// the process terminated because of a signal or exited on its own +func (d EventData) ProcessTerminated() *ProcessTermination { + if d.Signo != nil || d.ExitStatus != nil { + return &ProcessTermination{ + Domain: d.Domain, + Name: d.Name, + Signo: d.Signo, + ExitStatus: d.ExitStatus, + } + } + return nil +} + +// Event signalling that a process exited +type ProcessTermination struct { + Domain *string + Name *string + Signo *int32 + ExitStatus *int32 +} + +// If not nil, the process was terminated by an unhandled signal. +// The returned value is the number of the signal that terminated the process +func (t ProcessTermination) Signaled() *int32 { + return t.Signo +} + +// It not nil, the process exited (as opposed to killed by a signal). +// The returned value is the exit_status returned by the process +func (t ProcessTermination) Exited() *int32 { + return t.ExitStatus +} + +func (t ProcessTermination) Success() bool { + return t.ExitStatus != nil && *t.ExitStatus == 0 +} + +// Transform the process termination status in a string that +// is equal to what would be returned by golang exec.ExitError.Error() +// We used to rely on this format to report errors to customer (sigh) +// so we keep this for backwards compatibility +func (t ProcessTermination) String() string { + if t.ExitStatus != nil { + return fmt.Sprintf("exit status %d", *t.ExitStatus) + } + sig := syscall.Signal(*t.Signo) + return fmt.Sprintf("signal: %s", sig.String()) +} + +type Hook struct { + // Unique name identifying the hook + Name string `json:"name"` + // Path in the parent domain mount namespace that locates + // the executable to run as the hook + Path string `json:"path"` + // Args for the hook + Args []string `json:"args,omitempty"` + // Map of ENV variables to set when running the hook + Env *map[string]string `json:"envs,omitempty"` + // Maximum time for the hook to run. The hook will be considered failed + // if it takes more than this value (default 10_000) + TimeoutMillis *uint64 `json:"timeout_millis,omitempty"` +} + +type ExecRequest struct { + // Identifier that Supervisor will assign to the spawned process. + // The tuple (Domain,Name) must be unique. It is the caller's responsibility + // to generate the unique name + Name string `json:"name"` + Domain string `json:"domain"` + // Path pointing to the exectuable file within the domain's root filesystem + Path string `json:"path"` + Args []string `json:"args,omitempty"` + // If nil, root of the domain + Cwd *string `json:"cwd,omitempty"` + Env *map[string]string `json:"env,omitempty"` + // If not nil, points to the socket that Supervisor + // uses to get the processes stdout and stderr. + LogsSock *string `json:"logs_sock,omitempty"` + StdoutWriter io.Writer `json:"-"` + StderrWriter io.Writer `json:"-"` + ExtraFiles *[]*os.File `json:"-"` +} + +type ErrorKind string + +const ( + // operation on an unkown entity (e.g., domain process) + NoSuchEntity ErrorKind = "no_such_entity" + // operation not allowed in the current state (e.g., tried to exec a proces in a domain which is not booted) + InvalidState ErrorKind = "invalid_state" + // Serialization or derserialization issue in the communication + Serde ErrorKind = "serde" + // Unhandled Supervisor server error + Failure ErrorKind = "failure" +) + +type SupervisorError struct { + Kind ErrorKind `json:"error_kind"` + Message *string `json:"message"` +} + +func (e *SupervisorError) Error() string { + return string(e.Kind) +} + +// Send SIGETERM asynchrnously to a process +type TerminateRequest struct { + Name string `json:"name"` + Domain string `json:"domain"` +} + +// Force terminate a process (SIGKILL) +// Block until process is exited or timeout +// If timeout is 0 or nil, block forever +type KillRequest struct { + Name string `json:"name"` + Domain string `json:"domain"` + Timeout *uint64 `json:",omitempty"` +} + +// Stop the domain. Supervisor will first try to +// cleanly terminate the domain's init process. If unsuccessful, +// within Timeout seconds, it will send SIGKILL. +type StopRequest struct { + Domain string `json:"domain"` + Timeout *uint64 `json:",omitempty"` +} + +type FreezeRequest struct { + Domain string `json:"domain"` +} + +type ThawRequest struct { + Domain string `json:"domain"` + // if not nil, changes the cgroup profile of the domain upon thawing. + CgroupProfile *string `json:"cgroup_profile,omitempty"` +} diff --git a/lambda/telemetry/events_api.go b/lambda/telemetry/events_api.go index 132977e..e7c5c36 100644 --- a/lambda/telemetry/events_api.go +++ b/lambda/telemetry/events_api.go @@ -3,12 +3,136 @@ package telemetry +import ( + "time" + + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" + "go.amzn.com/lambda/rapi/model" +) + +type RuntimeDoneInvokeMetrics struct { + ProducedBytes int64 + DurationMs float64 +} + +func GetRuntimeDoneInvokeMetrics(invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics, runtimeDoneTime int64) *RuntimeDoneInvokeMetrics { + if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && invokeReceivedTime != 0 { + return &RuntimeDoneInvokeMetrics{ + ProducedBytes: invokeResponseMetrics.ProducedBytes, + // time taken from sending the invoke to the sandbox until the runtime calls GET /next + DurationMs: float64((runtimeDoneTime - invokeReceivedTime) / int64(time.Millisecond)), + } + } + + // when we get a reset before runtime called /response + if invokeReceivedTime != 0 { + return &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64((runtimeDoneTime - invokeReceivedTime) / int64(time.Millisecond)), + } + } + + // We didn't have time to register the invokeReceiveTime, which means we crash/reset very early, + // too early for the runtime to actual run. In such case, the runtimeDone event shouldn't be sent + // Not returning Nil even in this improbable case guarantees that we will always have some metrics to send to FluxPump + return &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(0), + } +} + +type InitRuntimeDoneData struct { + InitSource string + Status string +} + +type InvokeRuntimeDoneData struct { + Status string + Metrics *RuntimeDoneInvokeMetrics + InternalMetrics *interop.InvokeResponseMetrics + Tracing *TracingCtx + Spans []Span +} + +type Span struct { + Name string + Start string + DurationMs float64 +} + +func GetRuntimeDoneSpans(invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) []Span { + if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && invokeReceivedTime != 0 { + // time span from when the invoke is received in the sandbox to the moment the runtime calls PUT /response + responseLatencyMsSpan := Span{ + Name: "responseLatency", + Start: getEpochTimeInISO8601FormatFromMonotime(invokeReceivedTime), + DurationMs: float64((invokeResponseMetrics.StartReadingResponseMonoTimeMs - invokeReceivedTime) / int64(time.Millisecond)), + } + + // time span from when the runtime called PUT /response to the moment the body of the response is fully sent + responseDurationMsSpan := Span{ + Name: "responseDuration", + Start: getEpochTimeInISO8601FormatFromMonotime(invokeResponseMetrics.StartReadingResponseMonoTimeMs), + DurationMs: float64((invokeResponseMetrics.FinishReadingResponseMonoTimeMs - invokeResponseMetrics.StartReadingResponseMonoTimeMs) / int64(time.Millisecond)), + } + return []Span{responseLatencyMsSpan, responseDurationMsSpan} + } + + return []Span{} +} + +func getEpochTimeInISO8601FormatFromMonotime(monotime int64) string { + return time.Unix(0, metering.MonoToEpoch(monotime)).Format("2006-01-02T15:04:05.000Z") +} + +type TracingCtx struct { + SpanID string + Type model.TracingType + Value string +} + +func BuildTracingCtx(tracingType model.TracingType, traceID string, lambdaSegmentID string) *TracingCtx { + // it takes current tracing context and change its parent value with the provided lambda segment id + root, currentParent, sample := ParseTraceID(traceID) + if root == "" || sample != model.XRaySampled { + return nil + } + + return &TracingCtx{ + SpanID: currentParent, + Type: tracingType, + Value: BuildFullTraceID(root, lambdaSegmentID, sample), + } +} + +const ( + RuntimeDoneSuccess = "success" + RuntimeDoneFailure = "failure" +) + type EventsAPI interface { SetCurrentRequestID(requestID string) - SendRuntimeDone(status string) error + SendInitRuntimeDone(data *InitRuntimeDoneData) error + SendRestoreRuntimeDone(status string) error + SendRuntimeDone(data InvokeRuntimeDoneData) error + SendExtensionInit(agentName, state, errorType string, subscriptions []string) error + SendImageErrorLog(logline string) } type NoOpEventsAPI struct{} func (s *NoOpEventsAPI) SetCurrentRequestID(requestID string) {} -func (s *NoOpEventsAPI) SendRuntimeDone(status string) error { return nil } +func (s *NoOpEventsAPI) SendInitRuntimeDone(data *InitRuntimeDoneData) error { + return nil +} +func (s *NoOpEventsAPI) SendRestoreRuntimeDone(status string) error { + return nil +} +func (s *NoOpEventsAPI) SendRuntimeDone(data InvokeRuntimeDoneData) error { + return nil +} +func (s *NoOpEventsAPI) SendExtensionInit(agentName, state, errorType string, subscriptions []string) error { + return nil +} +func (s *NoOpEventsAPI) SendImageErrorLog(logline string) {} diff --git a/lambda/telemetry/events_api_test.go b/lambda/telemetry/events_api_test.go new file mode 100644 index 0000000..b943be9 --- /dev/null +++ b/lambda/telemetry/events_api_test.go @@ -0,0 +1,139 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" +) + +func TestGetRuntimeDoneInvokeMetrics(t *testing.T) { + now := metering.Monotime() + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + ProducedBytes: int64(100), + RuntimeCalledResponse: true, + } + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(10), + } + + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, runtimeDoneTime)) +} + +func TestGetRuntimeDoneInvokeMetricsWhenRuntimeCalledError(t *testing.T) { + now := metering.Monotime() + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + ProducedBytes: int64(100), + RuntimeCalledResponse: false, + } + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(10), + } + + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, runtimeDoneTime)) +} + +func TestGetRuntimeDoneInvokeMetricsWhenInvokeReceivedTimeIsZero(t *testing.T) { + now := int64(0) // January 1st, 1970 at 00:00:00 UTC + invokeReceivedTime := now + + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(0), + } + actual := GetRuntimeDoneInvokeMetrics(invokeReceivedTime, nil, runtimeDoneTime) + assert.Equal(t, expected, actual) +} + +func TestGetRuntimeDoneInvokeMetricsWhenInvokeResponseMetricsIsNil(t *testing.T) { + now := metering.Monotime() + invokeReceivedTime := now + + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(10), + } + + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, nil, runtimeDoneTime)) +} + +func TestGetRuntimeDoneSpans(t *testing.T) { + now := metering.Monotime() + startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) + finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, + RuntimeCalledResponse: true, + } + + expectedResponseLatencyMsStartTime := getEpochTimeInISO8601FormatFromMonotime(now) + expectedResponseDurationMsStartTime := getEpochTimeInISO8601FormatFromMonotime(startReadingResponseMonoTimeMs) + expected := []Span{ + Span{ + Name: "responseLatency", + Start: expectedResponseLatencyMsStartTime, + DurationMs: 5, + }, + Span{ + Name: "responseDuration", + Start: expectedResponseDurationMsStartTime, + DurationMs: 2, + }, + } + + assert.Equal(t, expected, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) +} + +func TestGetRuntimeDoneSpansWhenRuntimeCalledError(t *testing.T) { + now := metering.Monotime() + startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) + finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, + RuntimeCalledResponse: false, + } + + assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) +} + +func TestGetRuntimeDoneSpansWhenInvokeResponseMetricsNil(t *testing.T) { + invokeReceivedTime := metering.Monotime() + + assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, nil)) +} + +func TestGetRuntimeDoneSpansWhenInvokeReceivedTimeIsZero(t *testing.T) { + now := int64(0) // January 1st, 1970 at 00:00:00 UTC + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + StartReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(5)), + FinishReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(7)), + } + + assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) +} diff --git a/lambda/telemetry/logs_egress_api.go b/lambda/telemetry/logs_egress_api.go index ac9a754..7e84fe2 100644 --- a/lambda/telemetry/logs_egress_api.go +++ b/lambda/telemetry/logs_egress_api.go @@ -8,7 +8,12 @@ import ( "os" ) -type LogsEgressAPI interface { +// StdLogsEgressAPI is the interface that wraps the basic methods required to setup +// logs channels for Runtime's stdout/stderr and Extension's stdout/stderr. +// +// Implementation should return a Writer implementor for stdout and another for +// stderr on success and an error on failure. +type StdLogsEgressAPI interface { GetExtensionSockets() (io.Writer, io.Writer, error) GetRuntimeSockets() (io.Writer, io.Writer, error) } diff --git a/lambda/telemetry/logs_subscription_api.go b/lambda/telemetry/logs_subscription_api.go index 3ea7a20..6ee9490 100644 --- a/lambda/telemetry/logs_subscription_api.go +++ b/lambda/telemetry/logs_subscription_api.go @@ -10,28 +10,37 @@ import ( "go.amzn.com/lambda/interop" ) -// LogsSubscriptionAPI represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible -type LogsSubscriptionAPI interface { +// SubscriptionAPI represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible +type SubscriptionAPI interface { Subscribe(agentName string, body io.Reader, headers map[string][]string) (resp []byte, status int, respHeaders map[string][]string, err error) RecordCounterMetric(metricName string, count int) - FlushMetrics() interop.LogsAPIMetrics + FlushMetrics() interop.TelemetrySubscriptionMetrics Clear() TurnOff() + GetEndpointURL() string + GetServiceClosedErrorMessage() string + GetServiceClosedErrorType() string } -type NoOpLogsSubscriptionAPI struct{} +type NoOpSubscriptionAPI struct{} // Subscribe writes response to a shared memory -func (m *NoOpLogsSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { +func (m *NoOpSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { return []byte(`{}`), http.StatusOK, map[string][]string{}, nil } -func (m *NoOpLogsSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} +func (m *NoOpSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} -func (m *NoOpLogsSubscriptionAPI) FlushMetrics() interop.LogsAPIMetrics { - return interop.LogsAPIMetrics(map[string]int{}) +func (m *NoOpSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { + return interop.TelemetrySubscriptionMetrics(map[string]int{}) } -func (m *NoOpLogsSubscriptionAPI) Clear() {} +func (m *NoOpSubscriptionAPI) Clear() {} -func (m *NoOpLogsSubscriptionAPI) TurnOff() {} +func (m *NoOpSubscriptionAPI) TurnOff() {} + +func (m *NoOpSubscriptionAPI) GetEndpointURL() string { return "" } + +func (m *NoOpSubscriptionAPI) GetServiceClosedErrorMessage() string { return "" } + +func (m *NoOpSubscriptionAPI) GetServiceClosedErrorType() string { return "" } diff --git a/lambda/telemetry/tracer.go b/lambda/telemetry/tracer.go index 1ac8325..affca60 100644 --- a/lambda/telemetry/tracer.go +++ b/lambda/telemetry/tracer.go @@ -11,6 +11,7 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi/model" ) type traceContextKey int @@ -129,3 +130,24 @@ func ParseTraceID(fullTraceID string) (rootID, parentID, sample string) { } return } + +// BuildFullTraceID takes individual components of X-Ray trace header +// and puts them together into a formatted trace header. +// If root is empty, returns an empty string. +func BuildFullTraceID(root, parent, sample string) string { + if root == "" { + return "" + } + + parts := make([]string, 0, 3) + parts = append(parts, "Root="+root) + if parent != "" { + parts = append(parts, "Parent="+parent) + } + if sample == "" { + sample = model.XRayNonSampled + } + parts = append(parts, "Sampled="+sample) + + return strings.Join(parts, ";") +} diff --git a/lambda/telemetry/tracer_test.go b/lambda/telemetry/tracer_test.go index 9ac1260..c31653f 100644 --- a/lambda/telemetry/tracer_test.go +++ b/lambda/telemetry/tracer_test.go @@ -5,6 +5,8 @@ package telemetry import ( "testing" + + "go.amzn.com/lambda/rapi/model" ) var parserTests = []struct { @@ -35,3 +37,47 @@ func TestParseTraceID(t *testing.T) { }) } } + +func TestBuildFullTraceID(t *testing.T) { + specs := map[string]struct { + root string + parent string + sample string + expectedTraceID string + }{ + "all non-empty components, sampled": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + parent: "c88d77b0aef840e9", + sample: model.XRaySampled, + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", + }, + "all non-empty components, non-sampled": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + parent: "c88d77b0aef840e9", + sample: model.XRayNonSampled, + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=0", + }, + "root is non-empty, parent and sample are empty": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Sampled=0", + }, + "root is empty": { + parent: "c88d77b0aef840e9", + expectedTraceID: "", + }, + "sample is empty": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + parent: "c88d77b0aef840e9", + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=0", + }, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + actual := BuildFullTraceID(spec.root, spec.parent, spec.sample) + if actual != spec.expectedTraceID { + t.Errorf("got %q, wanted %q", actual, spec.expectedTraceID) + } + }) + } +} diff --git a/lambda/testdata/agents/bash_stderr.sh b/lambda/testdata/agents/bash_stderr.sh deleted file mode 100755 index 65c0ff1..0000000 --- a/lambda/testdata/agents/bash_stderr.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash - -printf "stderr line 1\n" >&2 -printf "stderr line 2\n" >&2 -printf "stderr line 3\n" >&2 diff --git a/lambda/testdata/agents/bash_stdout.sh b/lambda/testdata/agents/bash_stdout.sh deleted file mode 100755 index d0cb893..0000000 --- a/lambda/testdata/agents/bash_stdout.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash - -printf "stdout line 1\n" -printf "stdout line 2\n" -printf "stdout line 3\n" diff --git a/lambda/testdata/agents/bash_stdout_and_stderr.sh b/lambda/testdata/agents/bash_stdout_and_stderr.sh deleted file mode 100755 index cf87e60..0000000 --- a/lambda/testdata/agents/bash_stdout_and_stderr.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env bash - -printf "stdout line 1\n" -printf "stderr line 1\n" >&2 -printf "stdout line 2\n" -printf "stderr line 2\n" >&2 -printf "stdout line 3\n" -printf "stderr line 3\n" >&2 diff --git a/lambda/testdata/flowtesting.go b/lambda/testdata/flowtesting.go index ee163bb..c028d7c 100644 --- a/lambda/testdata/flowtesting.go +++ b/lambda/testdata/flowtesting.go @@ -8,28 +8,31 @@ import ( "io" "io/ioutil" "net/http" + "time" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" - "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/rendering" "go.amzn.com/lambda/telemetry" "go.amzn.com/lambda/testdata/mockthread" ) +const ( + contentTypeHeader = "Content-Type" + functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" +) + type MockInteropServer struct { - Response []byte - ErrorResponse *interop.ErrorResponse - ResponseContentType string - ActiveInvokeID string + Response []byte + ErrorResponse *interop.ErrorResponse + ResponseContentType string + FunctionResponseMode string + ActiveInvokeID string } -// StartAcceptingDirectInvokes -func (i *MockInteropServer) StartAcceptingDirectInvokes() error { return nil } - // SendResponse writes response to a shared memory. -func (i *MockInteropServer) SendResponse(invokeID string, contentType string, reader io.Reader) error { +func (i *MockInteropServer) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { bytes, err := ioutil.ReadAll(reader) if err != nil { return err @@ -41,7 +44,8 @@ func (i *MockInteropServer) SendResponse(invokeID string, contentType string, re } } i.Response = bytes - i.ResponseContentType = contentType + i.ResponseContentType = headers[contentTypeHeader] + i.FunctionResponseMode = headers[functionResponseModeHeader] return nil } @@ -49,68 +53,35 @@ func (i *MockInteropServer) SendResponse(invokeID string, contentType string, re func (i *MockInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error { i.ErrorResponse = response i.ResponseContentType = response.ContentType + i.FunctionResponseMode = response.FunctionResponseMode return nil } -func (i *MockInteropServer) GetCurrentInvokeID() string { - return i.ActiveInvokeID +// SendInitErrorResponse writes error response during init to a shared memory and sends GIRD FAULT. +func (i *MockInteropServer) SendInitErrorResponse(invokeID string, response *interop.ErrorResponse) error { + i.ErrorResponse = response + i.ResponseContentType = response.ContentType + return nil } -func (i *MockInteropServer) CommitResponse() error { return nil } - -// SendRunning sends GIRD RUNNING. -func (i *MockInteropServer) SendRunning(*interop.Running) error { return nil } - -// SendDone sends GIRD DONE. -func (i *MockInteropServer) SendDone(*interop.Done) error { return nil } - -// SendDoneFail sends GIRD DONEFAIL. -func (i *MockInteropServer) SendDoneFail(*interop.DoneFail) error { return nil } - -// StartChan returns Start emitter -func (i *MockInteropServer) StartChan() <-chan *interop.Start { return nil } - -// InvokeChan returns Invoke emitter -func (i *MockInteropServer) InvokeChan() <-chan *interop.Invoke { return nil } - -// ResetChan returns Reset emitter -func (i *MockInteropServer) ResetChan() <-chan *interop.Reset { return nil } - -// ShutdownChan returns Shutdown emitter -func (i *MockInteropServer) ShutdownChan() <-chan *interop.Shutdown { return nil } - -// TransportErrorChan emits errors if there was parsing/connection issue -func (i *MockInteropServer) TransportErrorChan() <-chan error { return nil } - -func (i *MockInteropServer) Clear() {} - -func (i *MockInteropServer) IsResponseSent() bool { - return !(i.Response == nil && i.ErrorResponse == nil) +func (i *MockInteropServer) GetCurrentInvokeID() string { + return i.ActiveInvokeID } func (i *MockInteropServer) SendRuntimeReady() error { return nil } -func (i *MockInteropServer) SetInternalStateGetter(isd interop.InternalStateGetter) {} - -func (m *MockInteropServer) Init(i *interop.Start, invokeTimeoutMs int64) {} - -func (m *MockInteropServer) Invoke(w http.ResponseWriter, i *interop.Invoke) error { return nil } - -func (m *MockInteropServer) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { - return nil -} - // FlowTest provides configuration for tests that involve synchronization flows. type FlowTest struct { - AppCtx appctx.ApplicationContext - InitFlow core.InitFlowSynchronization - InvokeFlow core.InvokeFlowSynchronization - RegistrationService core.RegistrationService - RenderingService *rendering.EventRenderingService - Runtime *core.Runtime - InteropServer *MockInteropServer - LogsSubscriptionAPI *telemetry.NoOpLogsSubscriptionAPI - CredentialsService core.CredentialsService + AppCtx appctx.ApplicationContext + InitFlow core.InitFlowSynchronization + InvokeFlow core.InvokeFlowSynchronization + RegistrationService core.RegistrationService + RenderingService *rendering.EventRenderingService + Runtime *core.Runtime + InteropServer *MockInteropServer + TelemetrySubscription *telemetry.NoOpSubscriptionAPI + CredentialsService core.CredentialsService + EventsAPI telemetry.EventsAPI } // ConfigureForInit initialize synchronization gates and states for init. @@ -125,13 +96,13 @@ func (s *FlowTest) ConfigureForInvoke(ctx context.Context, invoke *interop.Invok s.RenderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invoke, telemetry.GetCustomerTracingHeader)) } -func (s *FlowTest) ConfigureForInitCaching(token, awsKey, awsSecret, awsSession string) { - s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession) +func (s *FlowTest) ConfigureForRestore() { + s.RenderingService.SetRenderer(rendering.NewRestoreRenderer()) } -func (s *FlowTest) ConfigureForBlockedInitCaching(token, awsKey, awsSecret, awsSession string) { - s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession) - s.CredentialsService.BlockService() +func (s *FlowTest) ConfigureForInitCaching(token, awsKey, awsSecret, awsSession string) { + credentialsExpiration := time.Now().Add(30 * time.Minute) + s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession, credentialsExpiration) } // NewFlowTest returns new FlowTest configuration. @@ -145,16 +116,18 @@ func NewFlowTest() *FlowTest { runtime := core.NewRuntime(initFlow, invokeFlow) runtime.ManagedThread = &mockthread.MockManagedThread{} interopServer := &MockInteropServer{} + eventsAPI := telemetry.NoOpEventsAPI{} appctx.StoreInteropServer(appCtx, interopServer) return &FlowTest{ - AppCtx: appCtx, - InitFlow: initFlow, - InvokeFlow: invokeFlow, - RegistrationService: registrationService, - RenderingService: renderingService, - LogsSubscriptionAPI: &telemetry.NoOpLogsSubscriptionAPI{}, - Runtime: runtime, - InteropServer: interopServer, - CredentialsService: credentialsService, + AppCtx: appCtx, + InitFlow: initFlow, + InvokeFlow: invokeFlow, + RegistrationService: registrationService, + RenderingService: renderingService, + TelemetrySubscription: &telemetry.NoOpSubscriptionAPI{}, + Runtime: runtime, + InteropServer: interopServer, + CredentialsService: credentialsService, + EventsAPI: &eventsAPI, } } diff --git a/test/integration/local_lambda/test_end_to_end.py b/test/integration/local_lambda/test_end_to_end.py index a85abce..c5c3e63 100644 --- a/test/integration/local_lambda/test_end_to_end.py +++ b/test/integration/local_lambda/test_end_to_end.py @@ -53,6 +53,7 @@ def tearDownClass(cls): "remaining_time_in_default_deadline", "pre-runtime-api", "assert-overwritten", + "port_override" ] for image in images_to_delete: @@ -264,6 +265,23 @@ def test_function_name_is_overriden(self, arch, port): ) self.assertEqual(b'"My lambda ran succesfully"', r.content) + @parameterized.expand([("x86_64", "8011"), ("arm64", "9011"), ("", "9061")]) + def test_port_override(self, arch, port): + image, rie, image_name = self.tagged_name("port_override", arch) + + # Use port 8081 inside the container instead of 8080 + cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8081 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler --runtime-interface-emulator-address 0.0.0.0:8081" + + Popen(cmd.split(" ")).communicate() + + # sleep 1s to give enough time for the endpoint to be up to curl + time.sleep(SLEEP_TIME) + + r = requests.post( + f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} + ) + self.assertEqual(b'"My lambda ran succesfully"', r.content) + if __name__ == "__main__": main()