From c1cf0c518004962baa79bd82d5c4c2776e15073f Mon Sep 17 00:00:00 2001 From: Renato Valenzuela <37676028+valerena@users.noreply.github.com> Date: Thu, 27 Apr 2023 10:46:51 -0700 Subject: [PATCH] feat: Pull upstream changes 2023/04 (#87) --- cmd/aws-lambda-rie/handlers.go | 31 +- cmd/aws-lambda-rie/http.go | 6 +- cmd/aws-lambda-rie/main.go | 51 +- lambda/agents/agent.go | 75 +- lambda/agents/agent_test.go | 221 ++--- lambda/agents/log_line_splitter.go | 40 - lambda/appctx/appctx.go | 14 + lambda/appctx/appctxutil.go | 26 +- lambda/appctx/appctxutil_test.go | 25 + .../core/bandwidthlimiter/bandwidthlimiter.go | 61 ++ .../bandwidthlimiter/bandwidthlimiter_test.go | 106 +++ lambda/core/bandwidthlimiter/throttler.go | 154 ++++ .../core/bandwidthlimiter/throttler_test.go | 215 +++++ lambda/core/bandwidthlimiter/util.go | 46 ++ lambda/core/bandwidthlimiter/util_test.go | 45 + lambda/core/credentials.go | 42 +- lambda/core/credentials_test.go | 54 +- lambda/core/directinvoke/directinvoke.go | 288 ++++++- lambda/core/directinvoke/directinvoke_test.go | 358 ++++++++ lambda/core/directinvoke/util.go | 84 ++ lambda/core/doc.go | 38 +- lambda/core/externalagent.go | 1 - lambda/core/flow.go | 17 + lambda/core/registrations.go | 21 + lambda/core/registrations_test.go | 2 +- lambda/core/runtime_state_names.go | 12 +- lambda/core/states.go | 78 +- lambda/core/states_test.go | 134 ++- lambda/core/watchdog.go | 102 --- lambda/core/watchdog_test.go | 50 -- lambda/fatalerror/fatalerror.go | 5 +- lambda/interop/bootstrap.go | 18 + lambda/interop/cancellable_request.go | 27 + lambda/interop/environment_variables.go | 14 + lambda/interop/model.go | 331 +++++--- lambda/interop/model_test.go | 27 + lambda/interop/sandbox_model.go | 171 +++- lambda/logging/doc.go | 23 +- lambda/logging/internal_log_test.go | 26 +- lambda/logging/platform_log.go | 65 -- lambda/logging/platform_log_test.go | 42 - lambda/logging/taillog.go | 52 -- lambda/logging/taillog_test.go | 29 - lambda/rapi/handler/agentiniterror_test.go | 10 +- lambda/rapi/handler/agentnext_test.go | 16 +- lambda/rapi/handler/agentregister.go | 11 +- lambda/rapi/handler/agentregister_test.go | 9 +- lambda/rapi/handler/constants.go | 1 - lambda/rapi/handler/credentials_test.go | 37 +- lambda/rapi/handler/initerror.go | 42 +- lambda/rapi/handler/initerror_test.go | 4 +- lambda/rapi/handler/invocationerror.go | 25 +- lambda/rapi/handler/invocationerror_test.go | 6 +- lambda/rapi/handler/invocationnext_test.go | 2 +- lambda/rapi/handler/invocationresponse.go | 36 +- .../rapi/handler/invocationresponse_test.go | 95 ++- lambda/rapi/handler/restorenext.go | 40 + lambda/rapi/handler/restorenext_test.go | 87 ++ lambda/rapi/handler/runtimelogs.go | 36 +- lambda/rapi/handler/runtimelogs_stub.go | 45 +- lambda/rapi/handler/runtimelogs_stub_test.go | 16 +- lambda/rapi/handler/runtimelogs_test.go | 130 ++- lambda/rapi/middleware/middleware_test.go | 6 +- lambda/rapi/model/tracing.go | 11 +- lambda/rapi/rendering/doc.go | 2 - lambda/rapi/rendering/render_json.go | 33 + lambda/rapi/rendering/rendering.go | 84 +- lambda/rapi/router.go | 39 +- lambda/rapi/router_test.go | 47 +- lambda/rapi/security_test.go | 6 +- lambda/rapi/server.go | 16 +- lambda/rapi/server_test.go | 3 +- lambda/rapid/bootstrap.go | 18 - lambda/rapid/exit.go | 147 ++-- lambda/rapid/graceful_shutdown.go | 198 ----- lambda/rapid/sandbox.go | 165 ++-- lambda/rapid/shutdown.go | 366 +++++++++ lambda/rapid/start.go | 776 +++++++++++------- lambda/rapid/start_test.go | 34 +- lambda/rapidcore/bootstrap.go | 79 +- lambda/rapidcore/bootstrap_test.go | 115 ++- lambda/rapidcore/env/environment.go | 18 +- lambda/rapidcore/env/environment_test.go | 41 +- lambda/rapidcore/errors.go | 6 +- lambda/rapidcore/sandbox.go | 259 ------ lambda/rapidcore/sandbox_api.go | 147 ++++ lambda/rapidcore/sandbox_builder.go | 217 +++++ lambda/rapidcore/sandbox_emulator_api.go | 52 ++ lambda/rapidcore/server.go | 621 +++++++++----- lambda/rapidcore/server_test.go | 370 ++++++--- .../standalone/directInvokeHandler.go | 16 +- lambda/rapidcore/standalone/executeHandler.go | 17 +- lambda/rapidcore/standalone/initHandler.go | 70 +- .../standalone/internalStateHandler.go | 4 +- lambda/rapidcore/standalone/invokeHandler.go | 34 +- lambda/rapidcore/standalone/pingHandler.go | 12 + lambda/rapidcore/standalone/reserveHandler.go | 12 +- lambda/rapidcore/standalone/resetHandler.go | 4 +- lambda/rapidcore/standalone/restoreHandler.go | 41 + lambda/rapidcore/standalone/router.go | 27 +- .../rapidcore/standalone/shutdownHandler.go | 6 +- lambda/rapidcore/standalone/util.go | 4 +- .../standalone/waitUntilInitializedHandler.go | 23 + .../standalone/waitUntilReleaseHandler.go | 4 +- lambda/rapidcore/telemetry/eventLog.go | 25 +- lambda/rapidcore/telemetry/events_api.go | 97 +++ lambda/runtimecmd/runtime_command.go | 57 -- lambda/runtimecmd/runtime_command_test.go | 51 -- lambda/supervisor/local_supervisor.go | 302 +++++++ lambda/supervisor/local_supervisor_test.go | 215 +++++ lambda/supervisor/model/model.go | 269 ++++++ lambda/telemetry/events_api.go | 128 ++- lambda/telemetry/events_api_test.go | 139 ++++ lambda/telemetry/logs_egress_api.go | 7 +- lambda/telemetry/logs_subscription_api.go | 29 +- lambda/telemetry/tracer.go | 22 + lambda/telemetry/tracer_test.go | 46 ++ lambda/testdata/agents/bash_stderr.sh | 5 - lambda/testdata/agents/bash_stdout.sh | 5 - .../testdata/agents/bash_stdout_and_stderr.sh | 8 - lambda/testdata/flowtesting.go | 123 ++- .../local_lambda/test_end_to_end.py | 18 + 122 files changed, 6845 insertions(+), 2726 deletions(-) delete mode 100644 lambda/agents/log_line_splitter.go create mode 100644 lambda/core/bandwidthlimiter/bandwidthlimiter.go create mode 100644 lambda/core/bandwidthlimiter/bandwidthlimiter_test.go create mode 100644 lambda/core/bandwidthlimiter/throttler.go create mode 100644 lambda/core/bandwidthlimiter/throttler_test.go create mode 100644 lambda/core/bandwidthlimiter/util.go create mode 100644 lambda/core/bandwidthlimiter/util_test.go create mode 100644 lambda/core/directinvoke/directinvoke_test.go create mode 100644 lambda/core/directinvoke/util.go delete mode 100644 lambda/core/watchdog.go delete mode 100644 lambda/core/watchdog_test.go create mode 100644 lambda/interop/bootstrap.go create mode 100644 lambda/interop/cancellable_request.go create mode 100644 lambda/interop/environment_variables.go create mode 100644 lambda/interop/model_test.go delete mode 100644 lambda/logging/platform_log.go delete mode 100644 lambda/logging/platform_log_test.go delete mode 100644 lambda/logging/taillog.go delete mode 100644 lambda/logging/taillog_test.go create mode 100644 lambda/rapi/handler/restorenext.go create mode 100644 lambda/rapi/handler/restorenext_test.go create mode 100644 lambda/rapi/rendering/render_json.go delete mode 100644 lambda/rapid/bootstrap.go delete mode 100644 lambda/rapid/graceful_shutdown.go create mode 100644 lambda/rapid/shutdown.go delete mode 100644 lambda/rapidcore/sandbox.go create mode 100644 lambda/rapidcore/sandbox_api.go create mode 100644 lambda/rapidcore/sandbox_builder.go create mode 100644 lambda/rapidcore/sandbox_emulator_api.go create mode 100644 lambda/rapidcore/standalone/pingHandler.go create mode 100644 lambda/rapidcore/standalone/restoreHandler.go create mode 100644 lambda/rapidcore/standalone/waitUntilInitializedHandler.go create mode 100644 lambda/rapidcore/telemetry/events_api.go delete mode 100644 lambda/runtimecmd/runtime_command.go delete mode 100644 lambda/runtimecmd/runtime_command_test.go create mode 100644 lambda/supervisor/local_supervisor.go create mode 100644 lambda/supervisor/local_supervisor_test.go create mode 100644 lambda/supervisor/model/model.go create mode 100644 lambda/telemetry/events_api_test.go delete mode 100755 lambda/testdata/agents/bash_stderr.sh delete mode 100755 lambda/testdata/agents/bash_stdout.sh delete mode 100755 lambda/testdata/agents/bash_stdout_and_stderr.sh diff --git a/cmd/aws-lambda-rie/handlers.go b/cmd/aws-lambda-rie/handlers.go index 39097fc..42032cf 100644 --- a/cmd/aws-lambda-rie/handlers.go +++ b/cmd/aws-lambda-rie/handlers.go @@ -14,8 +14,10 @@ import ( "strings" "time" + "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" + "go.amzn.com/lambda/rapidcore/env" "github.com/google/uuid" @@ -27,6 +29,19 @@ type Sandbox interface { Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error } +type InteropServer interface { + Init(i *interop.Init, invokeTimeoutMs int64) error + AwaitInitialized() error + FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error + Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) + Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) + AwaitRelease() (*statejson.InternalStateDescription, error) + Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription + InternalState() (*statejson.InternalStateDescription, error) + CurrentToken() *interop.Token + Restore(restore *interop.Restore) error +} + var initDone bool func GetenvWithDefault(key string, defaultValue string) string { @@ -57,7 +72,7 @@ func printEndReports(invokeId string, initDuration string, memorySize string, in invokeId, invokeDuration, math.Ceil(invokeDuration), memorySize, memorySize) } -func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { +func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs interop.Bootstrap) { log.Debugf("invoke: -> %s %s %v", r.Method, r.URL, r.Header) bodyBytes, err := ioutil.ReadAll(r.Body) if err != nil { @@ -80,7 +95,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { if !initDone { - initStart, initEnd := InitHandler(sandbox, functionVersion, timeout) + initStart, initEnd := InitHandler(sandbox, functionVersion, timeout, bs) // Calculate InitDuration initTimeMS := math.Min(float64(initEnd.Sub(initStart).Nanoseconds()), @@ -99,7 +114,6 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { TraceID: r.Header.Get("X-Amzn-Trace-Id"), LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), Payload: bytes.NewReader(bodyBytes), - CorrelationID: "invokeCorrelationID", } fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion) @@ -166,7 +180,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { w.Write(invokeResp.Body) } -func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.Time, time.Time) { +func InitHandler(sandbox Sandbox, functionVersion string, timeout int64, bs interop.Bootstrap) (time.Time, time.Time) { additionalFunctionEnvironmentVariables := map[string]string{} // Add default Env Vars if they were not defined. This is a required otherwise 1p Python2.7, Python3.6, and @@ -189,15 +203,20 @@ func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.T // pass to rapid sandbox.Init(&interop.Init{ Handler: GetenvWithDefault("AWS_LAMBDA_FUNCTION_HANDLER", os.Getenv("_HANDLER")), - CorrelationID: "initCorrelationID", AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), AwsSession: os.Getenv("AWS_SESSION_TOKEN"), XRayDaemonAddress: "0.0.0.0:0", // TODO FunctionName: GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function"), FunctionVersion: functionVersion, - + RuntimeInfo: interop.RuntimeInfo{ + ImageJSON: "{}", + Arn: "", + Version: ""}, CustomerEnvironmentVariables: additionalFunctionEnvironmentVariables, + SandboxType: interop.SandboxClassic, + Bootstrap: bs, + EnvironmentVariables: env.NewEnvironment(), }, timeout*1000) initEnd := time.Now() return initStart, initEnd diff --git a/cmd/aws-lambda-rie/http.go b/cmd/aws-lambda-rie/http.go index be4002d..88bd39b 100644 --- a/cmd/aws-lambda-rie/http.go +++ b/cmd/aws-lambda-rie/http.go @@ -7,16 +7,18 @@ import ( "net/http" log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore" ) -func startHTTPServer(ipport string, sandbox Sandbox) { +func startHTTPServer(ipport string, sandbox *rapidcore.SandboxBuilder, bs interop.Bootstrap) { srv := &http.Server{ Addr: ipport, } // Pass a channel http.HandleFunc("/2015-03-31/functions/function/invocations", func(w http.ResponseWriter, r *http.Request) { - InvokeHandler(w, r, sandbox) + InvokeHandler(w, r, sandbox.LambdaInvokeAPI(), bs) }) // go routine (main thread waits) diff --git a/cmd/aws-lambda-rie/main.go b/cmd/aws-lambda-rie/main.go index 3a87e46..65879c0 100644 --- a/cmd/aws-lambda-rie/main.go +++ b/cmd/aws-lambda-rie/main.go @@ -6,6 +6,7 @@ package main import ( "context" "fmt" + "net" "os" "runtime/debug" @@ -21,8 +22,11 @@ const ( ) type options struct { - LogLevel string `long:"log-level" default:"info" description:"log level"` + LogLevel string `long:"log-level" description:"The level of AWS Lambda Runtime Interface Emulator logs to display. Can also be set by the environment variable 'LOG_LEVEL'. Defaults to the value 'info'."` InitCachingEnabled bool `long:"enable-init-caching" description:"Enable support for Init Caching"` + // Do not have a default value so we do not need to keep it in sync with the default value in lambda/rapidcore/sandbox_builder.go + RuntimeAPIAddress string `long:"runtime-api-address" description:"The address of the AWS Lambda Runtime API to communicate with the Lambda execution environment."` + RuntimeInterfaceEmulatorAddress string `long:"runtime-interface-emulator-address" default:"0.0.0.0:8080" description:"The address for the AWS Lambda Runtime Interface Emulator to accept HTTP request upon."` } func main() { @@ -30,11 +34,37 @@ func main() { debug.SetGCPercent(33) opts, args := getCLIArgs() - rapidcore.SetLogLevel(opts.LogLevel) + + logLevel := "info" + + // If you specify an option by using a parameter on the CLI command line, it overrides any value from either the corresponding environment variable. + // + // https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html + if opts.LogLevel != "" { + logLevel = opts.LogLevel + } else if envLogLevel, envLogLevelSet := os.LookupEnv("LOG_LEVEL"); envLogLevelSet { + logLevel = envLogLevel + } + + rapidcore.SetLogLevel(logLevel) + + if opts.RuntimeAPIAddress != "" { + _, _, err := net.SplitHostPort(opts.RuntimeAPIAddress) + + if err != nil { + log.WithError(err).Fatalf("The command line value for \"--runtime-api-address\" is not a valid network address %q.", opts.RuntimeAPIAddress) + } + } + + _, _, err := net.SplitHostPort(opts.RuntimeInterfaceEmulatorAddress) + + if err != nil { + log.WithError(err).Fatalf("The command line value for \"--runtime-interface-emulator-address\" is not a valid network address %q.", opts.RuntimeInterfaceEmulatorAddress) + } bootstrap, handler := getBootstrap(args, opts) sandbox := rapidcore. - NewSandboxBuilder(bootstrap). + NewSandboxBuilder(). AddShutdownFunc(context.CancelFunc(func() { os.Exit(0) })). SetExtensionsFlag(true). SetInitCachingFlag(opts.InitCachingEnabled) @@ -43,10 +73,17 @@ func main() { sandbox.SetHandler(handler) } - go sandbox.Create() + if opts.RuntimeAPIAddress != "" { + sandbox.SetRuntimeAPIAddress(opts.RuntimeAPIAddress) + } + + sandboxContext, internalStateFn := sandbox.Create() + // Since we have not specified a custom interop server for standalone, we can + // directly reference the default interop server, which is a concrete type + sandbox.DefaultInteropServer().SetSandboxContext(sandboxContext) + sandbox.DefaultInteropServer().SetInternalStateGetter(internalStateFn) - testAPIipport := "0.0.0.0:8080" - startHTTPServer(testAPIipport, sandbox) + startHTTPServer(opts.RuntimeInterfaceEmulatorAddress, sandbox, bootstrap) } func getCLIArgs() (options, []string) { @@ -112,5 +149,5 @@ func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) { log.Panic("insufficient arguments: bootstrap not provided") } - return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir), handler + return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir, ""), handler } diff --git a/lambda/agents/agent.go b/lambda/agents/agent.go index 16625c2..b1f8563 100644 --- a/lambda/agents/agent.go +++ b/lambda/agents/agent.go @@ -4,77 +4,38 @@ package agents import ( - "fmt" - "io" - "io/ioutil" - "os/exec" + "os" "path" - "syscall" + "path/filepath" log "github.com/sirupsen/logrus" ) -// AgentProcess is the common interface exposed by both internal and external agent processes -type AgentProcess interface { - Name() string -} - -// ExternalAgentProcess represents an external agent process -type ExternalAgentProcess struct { - cmd *exec.Cmd -} - -// NewExternalAgentProcess returns a new external agent process -func NewExternalAgentProcess(path string, env []string, stdoutWriter io.Writer, stderrWriter io.Writer) ExternalAgentProcess { - command := exec.Command(path) - command.Env = env - - command.Stdout = NewNewlineSplitWriter(stdoutWriter) - command.Stderr = NewNewlineSplitWriter(stderrWriter) - command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - - return ExternalAgentProcess{ - cmd: command, - } -} - -// Name returns the name of the agent -// For external agents is the executable name -func (a *ExternalAgentProcess) Name() string { - return path.Base(a.cmd.Path) -} - -func (a *ExternalAgentProcess) Pid() int { - return a.cmd.Process.Pid -} - -// Start starts an external agent process -func (a *ExternalAgentProcess) Start() error { - return a.cmd.Start() -} - -// Wait waits for the external agent process to exit -func (a *ExternalAgentProcess) Wait() error { - return a.cmd.Wait() -} - -// String is used to print values passed as an operand to any format that accepts a string or to an unformatted printer such as Print. -func (a *ExternalAgentProcess) String() string { - return fmt.Sprintf("%s (%s)", a.Name(), a.cmd.Path) -} - // ListExternalAgentPaths return a list of external agents found in a given directory -func ListExternalAgentPaths(root string) []string { +func ListExternalAgentPaths(dir string, root string) []string { var agentPaths []string - files, err := ioutil.ReadDir(root) + if !isCanonical(dir) || !isCanonical(root) { + log.Warningf("Agents base paths are not absolute and in canonical form: %s, %s", dir, root) + return agentPaths + } + fullDir := path.Join(root, dir) + files, err := os.ReadDir(fullDir) if err != nil { log.WithError(err).Warning("Cannot list external agents") return agentPaths } for _, file := range files { if !file.IsDir() { - agentPaths = append(agentPaths, path.Join(root, file.Name())) + // The returned path is absolute wrt to `root`. This allows + // to exec the agents in their own mount namespace + p := path.Join("/", dir, file.Name()) + agentPaths = append(agentPaths, p) } } return agentPaths } + +func isCanonical(path string) bool { + absPath, err := filepath.Abs(path) + return err == nil && absPath == path +} diff --git a/lambda/agents/agent_test.go b/lambda/agents/agent_test.go index d314a76..e6732ff 100644 --- a/lambda/agents/agent_test.go +++ b/lambda/agents/agent_test.go @@ -4,13 +4,12 @@ package agents import ( - "bytes" - "io/ioutil" "os" "path" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // - Test utilities @@ -50,14 +49,8 @@ func mkLink(name, target string) fileInfo { } } -// populate a temporary directory with a list of files and directories -// returns the name of the temporary root directory -func createFileTree(fs []fileInfo) (string, error) { - - root, err := ioutil.TempDir(os.TempDir(), "tmp-") - if err != nil { - return "", err - } +// populate a directory with a list of files and directories +func createFileTree(root string, fs []fileInfo) error { for _, info := range fs { filename := info.name @@ -65,67 +58,40 @@ func createFileTree(fs []fileInfo) (string, error) { name := path.Base(filename) err := os.MkdirAll(dir, 0775) if err != nil && !os.IsExist(err) { - return "", err + return err } if os.ModeDir == info.mode&os.ModeDir { err := os.Mkdir(path.Join(dir, name), info.mode&os.ModePerm) if err != nil { - return "", err + return err } } else if os.ModeSymlink == info.mode&os.ModeSymlink { target := path.Join(root, info.target) _, err = os.Stat(target) if err != nil { - return "", err + return err } err := os.Symlink(target, path.Join(dir, name)) if err != nil { - return "", err + return err } } else { file, err := os.OpenFile(path.Join(dir, name), os.O_RDWR|os.O_CREATE, info.mode&os.ModePerm) if err != nil { - return "", err + return err } file.Truncate(info.size) file.Close() } } - return root, nil -} - -// executes a given closure inside a temporary directory populated with the given FS tree -func within(fs []fileInfo, closure func()) error { - - var root string - var cwd string - var err error - - if root, err = createFileTree(fs); err != nil { - return err - } - - defer os.RemoveAll(root) - - if cwd, err = os.Getwd(); err != nil { - return err - } - - if err = os.Chdir(root); err != nil { - return err - } - - defer os.Chdir(cwd) - - closure() return nil } // - Actual tests // If the agents folder is empty it is not an error -func TestRootEmpty(t *testing.T) { +func TestBaseEmpty(t *testing.T) { assert := assert.New(t) @@ -133,34 +99,51 @@ func TestRootEmpty(t *testing.T) { mkDir("/opt/extensions", 0777), } - assert.NoError(within(fs, func() { - agents := ListExternalAgentPaths("opt/extensions") - assert.Equal(0, len(agents)) - })) + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) + + createFileTree(tmpDir, fs) + defer os.RemoveAll(tmpDir) + + agents := ListExternalAgentPaths(path.Join(tmpDir, "/opt/extensions"), "/") + assert.Equal(0, len(agents)) } // Test that non-existant /opt/extensions is treated as if no agents were found -func TestRootNotExist(t *testing.T) { +func TestBaseNotExist(t *testing.T) { assert := assert.New(t) - agents := ListExternalAgentPaths("/path/which/does/not/exist") + agents := ListExternalAgentPaths("/path/which/does/not/exist", "/") + assert.Equal(0, len(agents)) +} + +// Test that non-existant root dir is teaded as if no agents were found +func TestChrootNotExist(t *testing.T) { + + assert := assert.New(t) + + agents := ListExternalAgentPaths("/bin", "/does/not/exist") assert.Equal(0, len(agents)) } // Test that non-directory /opt/extensions is treated as if no agents were found -func TestRootNotDir(t *testing.T) { +func TestBaseNotDir(t *testing.T) { assert := assert.New(t) fs := []fileInfo{ mkFile("/opt/extensions", 1, 0777), } + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) + + createFileTree(tmpDir, fs) + defer os.RemoveAll(tmpDir) - assert.NoError(within(fs, func() { - agents := ListExternalAgentPaths("opt/extensions") - assert.Equal(0, len(agents)) - })) + path := path.Join(tmpDir, "/opt/extensions") + agents := ListExternalAgentPaths(path, "/") + assert.Equal(0, len(agents)) } // Test our ability to find agent bootstraps in the FS and return them sorted. @@ -188,99 +171,63 @@ func TestFindAgentMixed(t *testing.T) { fs := append([]fileInfo{}, listed...) fs = append(fs, unlisted...) - assert.NoError(within(fs, func() { - agentPaths := ListExternalAgentPaths("opt/extensions") - assert.Equal(len(listed), len(agentPaths)) - last := "" - for index := range listed { - if len(last) > 0 { - assert.GreaterOrEqual(agentPaths[index], last) - } - last = agentPaths[index] - } - })) -} - -// Test our ability to start agents -func TestAgentStart(t *testing.T) { - assert := assert.New(t) - agent := NewExternalAgentProcess("../testdata/agents/bash_true.sh", []string{}, &mockWriter{}, &mockWriter{}) - assert.Nil(agent.Start()) - assert.Nil(agent.Wait()) -} + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) -// Test that execution of invalid agents is correctly reported -func TestInvalidAgentStart(t *testing.T) { - assert := assert.New(t) - agent := NewExternalAgentProcess("/bin/none", []string{}, &mockWriter{}, &mockWriter{}) - assert.True(os.IsNotExist(agent.Start())) -} + createFileTree(tmpDir, fs) + defer os.RemoveAll(tmpDir) -func TestAgentStdoutWriter(t *testing.T) { - // Given - assert := assert.New(t) - - stdout := &mockWriter{} - stderr := &mockWriter{} - expectedStdout := "stdout line 1\nstdout line 2\nstdout line 3\n" - expectedStderr := "" - - agent := NewExternalAgentProcess("../testdata/agents/bash_stdout.sh", []string{}, stdout, stderr) - - // When - assert.NoError(agent.Start()) - assert.NoError(agent.Wait()) - - // Then - assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) - assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) + path := path.Join(tmpDir, "/opt/extensions") + agentPaths := ListExternalAgentPaths(path, "/") + assert.Equal(len(listed), len(agentPaths)) + last := "" + for index := range listed { + if len(last) > 0 { + assert.GreaterOrEqual(agentPaths[index], last) + } + last = agentPaths[index] + } } -func TestAgentStderrWriter(t *testing.T) { - // Given - assert := assert.New(t) - - stdout := &mockWriter{} - stderr := &mockWriter{} - expectedStdout := "" - expectedStderr := "stderr line 1\nstderr line 2\nstderr line 3\n" - - agent := NewExternalAgentProcess("../testdata/agents/bash_stderr.sh", []string{}, stdout, stderr) - - // When - assert.NoError(agent.Start()) - assert.NoError(agent.Wait()) - - // Then - assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) - assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) -} +// Test our ability to find agent bootstraps in the FS and return them sorted, +// when using a different mount namespace root for the extensiosn domain. +// Even if not all files are valid as executable agents, +// ListExternalAgentPaths() is expected to return all of them. +func TestFindAgentMixedInChroot(t *testing.T) { -func TestAgentStdoutAndStderrSeperateWriters(t *testing.T) { - // Given assert := assert.New(t) - stdout := &mockWriter{} - stderr := &mockWriter{} - expectedStdout := "stdout line 1\nstdout line 2\nstdout line 3\n" - expectedStderr := "stderr line 1\nstderr line 2\nstderr line 3\n" + listed := []fileInfo{ + mkFile("/opt/extensions/ok2", 1, 0777), // this is ok + mkFile("/opt/extensions/ok1", 1, 0777), // this is ok as well + mkFile("/opt/extensions/not_exec", 1, 0666), // this is not executable + mkFile("/opt/extensions/not_read", 1, 0333), // this is not readable + mkFile("/opt/extensions/empty_file", 0, 0777), // this is empty + mkLink("/opt/extensions/link", "/opt/extensions/ok1"), // symlink must be ignored + } - agent := NewExternalAgentProcess("../testdata/agents/bash_stdout_and_stderr.sh", []string{}, stdout, stderr) + unlisted := []fileInfo{ + mkDir("/opt/extensions/empty_dir", 0777), // this is an empty directory + mkDir("/opt/extensions/nonempty_dir", 0777), // subdirs should not be listed + mkFile("/opt/extensions/nonempty_dir/notok", 1, 0777), // files in subdirs should not be listed + } - // When - assert.NoError(agent.Start()) - assert.NoError(agent.Wait()) + fs := append([]fileInfo{}, listed...) + fs = append(fs, unlisted...) - // Then - assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) - assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) -} + rootDir, err := os.MkdirTemp("", "rootfs") + require.NoError(t, err) -type mockWriter struct { - bytesReceived [][]byte -} + createFileTree(rootDir, fs) + defer os.RemoveAll(rootDir) -func (m *mockWriter) Write(bytes []byte) (int, error) { - m.bytesReceived = append(m.bytesReceived, bytes) - return len(bytes), nil + agentPaths := ListExternalAgentPaths("/opt/extensions", rootDir) + assert.Equal(len(listed), len(agentPaths)) + last := "" + for index := range listed { + if len(last) > 0 { + assert.GreaterOrEqual(agentPaths[index], last) + } + last = agentPaths[index] + } } diff --git a/lambda/agents/log_line_splitter.go b/lambda/agents/log_line_splitter.go deleted file mode 100644 index ac2c134..0000000 --- a/lambda/agents/log_line_splitter.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package agents - -import ( - "bytes" - "io" -) - -// NewlineSplitWriter wraps an io.Writer and calls the underlying writer for each newline separated line -type NewlineSplitWriter struct { - writer io.Writer -} - -// NewNewlineSplitWriter returns an instance of NewlineSplitWriter -func NewNewlineSplitWriter(w io.Writer) *NewlineSplitWriter { - return &NewlineSplitWriter{ - writer: w, - } -} - -// Write splits the byte buffer by newline and calls the underlying writer for each line -func (nsw *NewlineSplitWriter) Write(buf []byte) (int, error) { - newBuf := make([]byte, len(buf)) - copy(newBuf, buf) - lines := bytes.SplitAfter(newBuf, []byte("\n")) - var bytesWritten int - for _, line := range lines { - if len(line) > 0 { - n, err := nsw.writer.Write(line) - bytesWritten += n - if err != nil { - return bytesWritten, err - } - } - } - - return bytesWritten, nil -} diff --git a/lambda/appctx/appctx.go b/lambda/appctx/appctx.go index 44776ab..6c81653 100644 --- a/lambda/appctx/appctx.go +++ b/lambda/appctx/appctx.go @@ -10,6 +10,8 @@ import ( // A Key type is used as a key for storing values in the application context. type Key int +type InitType int + const ( // AppCtxInvokeErrorResponseKey is used for storing deferred invoke error response. // Only used by xray. TODO refactor xray interface so it doesn't use appctx @@ -23,6 +25,18 @@ const ( // AppCtxFirstFatalErrorKey is used to store first unrecoverable error message encountered to propagate it to slicer with DONE(errortype) or DONEFAIL(errortype) AppCtxFirstFatalErrorKey + + // AppCtxInitType is used to store the init type (init caching or plain INIT) + AppCtxInitType + + // AppCtxSandbox type is used to store the sandbox type (SandboxClassic or SandboxPreWarmed) + AppCtxSandboxType +) + +// Possible values for AppCtxInitType key +const ( + Init InitType = iota + InitCaching ) // ApplicationContext is an application scope context. diff --git a/lambda/appctx/appctxutil.go b/lambda/appctx/appctxutil.go index a3e652f..a30677f 100644 --- a/lambda/appctx/appctxutil.go +++ b/lambda/appctx/appctxutil.go @@ -5,11 +5,12 @@ package appctx import ( "context" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" "net/http" "strings" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + log "github.com/sirupsen/logrus" ) @@ -164,3 +165,24 @@ func LoadFirstFatalError(appCtx ApplicationContext) (errorType fatalerror.ErrorT } return v.(fatalerror.ErrorType), true } + +func StoreInitType(appCtx ApplicationContext, initCachingEnabled bool) { + if initCachingEnabled { + appCtx.Store(AppCtxInitType, InitCaching) + } else { + appCtx.Store(AppCtxInitType, Init) + } +} + +// Default Init Type is Init unless it's explicitly stored in ApplicationContext +func LoadInitType(appCtx ApplicationContext) InitType { + return appCtx.GetOrDefault(AppCtxInitType, Init).(InitType) +} + +func StoreSandboxType(appCtx ApplicationContext, sandboxType interop.SandboxType) { + appCtx.Store(AppCtxSandboxType, sandboxType) +} + +func LoadSandboxType(appCtx ApplicationContext) interop.SandboxType { + return appCtx.GetOrDefault(AppCtxSandboxType, interop.SandboxClassic).(interop.SandboxType) +} diff --git a/lambda/appctx/appctxutil_test.go b/lambda/appctx/appctxutil_test.go index a8a4761..b6df9aa 100644 --- a/lambda/appctx/appctxutil_test.go +++ b/lambda/appctx/appctxutil_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.amzn.com/lambda/fatalerror" + + "go.amzn.com/lambda/interop" ) func runTestRequestWithUserAgent(t *testing.T, userAgent string, expectedRuntimeRelease string) { @@ -200,3 +202,26 @@ func TestFirstFatalError(t *testing.T) { require.True(t, found) require.Equal(t, fatalerror.AgentCrash, v) } + +func TestStoreLoadInitType(t *testing.T) { + appCtx := NewApplicationContext() + + initType := LoadInitType(appCtx) + assert.Equal(t, Init, initType) + + StoreInitType(appCtx, true) + initType = LoadInitType(appCtx) + assert.Equal(t, InitCaching, initType) +} + +func TestStoreLoadSandboxType(t *testing.T) { + appCtx := NewApplicationContext() + + sandboxType := LoadSandboxType(appCtx) + assert.Equal(t, interop.SandboxClassic, sandboxType) + + StoreSandboxType(appCtx, interop.SandboxPreWarmed) + + sandboxType = LoadSandboxType(appCtx) + assert.Equal(t, interop.SandboxPreWarmed, sandboxType) +} diff --git a/lambda/core/bandwidthlimiter/bandwidthlimiter.go b/lambda/core/bandwidthlimiter/bandwidthlimiter.go new file mode 100644 index 0000000..05c600a --- /dev/null +++ b/lambda/core/bandwidthlimiter/bandwidthlimiter.go @@ -0,0 +1,61 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "io" + + "go.amzn.com/lambda/interop" +) + +func BandwidthLimitingCopy(dst *BandwidthLimitingWriter, src io.Reader) (written int64, err error) { + written, err = io.Copy(dst, src) + _ = dst.Close() + return +} + +func NewBandwidthLimitingWriter(w io.Writer, bucket *Bucket) (*BandwidthLimitingWriter, error) { + throttler, err := NewThrottler(bucket) + if err != nil { + return nil, err + } + return &BandwidthLimitingWriter{w: w, th: throttler}, nil +} + +type BandwidthLimitingWriter struct { + w io.Writer + th *Throttler +} + +func (w *BandwidthLimitingWriter) ChunkedWrite(p []byte) (n int, err error) { + i := NewChunkIterator(p, int(w.th.b.capacity)) + for { + buf := i.Next() + if buf == nil { + return + } + written, writeErr := w.th.bandwidthLimitingWrite(w.w, buf) + n += written + if writeErr != nil { + return n, writeErr + } + } +} + +func (w *BandwidthLimitingWriter) Write(p []byte) (n int, err error) { + w.th.start() + if int64(len(p)) > w.th.b.capacity { + return w.ChunkedWrite(p) + } + return w.th.bandwidthLimitingWrite(w.w, p) +} + +func (w *BandwidthLimitingWriter) Close() (err error) { + w.th.stop() + return +} + +func (w *BandwidthLimitingWriter) GetMetrics() (metrics *interop.InvokeResponseMetrics) { + return w.th.metrics +} diff --git a/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go b/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go new file mode 100644 index 0000000..7ede24b --- /dev/null +++ b/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go @@ -0,0 +1,106 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "bytes" + "io" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBandwidthLimitingCopy(t *testing.T) { + var size10mb int64 = 10 * 1024 * 1024 + + inputBuffer := []byte(strings.Repeat("a", int(size10mb))) + reader := bytes.NewReader(inputBuffer) + + bucket, err := NewBucket(size10mb/2, size10mb/4, size10mb/2, time.Millisecond/2) + assert.NoError(t, err) + + internalWriter := bytes.NewBuffer(make([]byte, 0, size10mb)) + writer, err := NewBandwidthLimitingWriter(internalWriter, bucket) + assert.NoError(t, err) + + n, err := BandwidthLimitingCopy(writer, reader) + assert.Equal(t, size10mb, n) + assert.Equal(t, nil, err) + assert.Equal(t, inputBuffer, internalWriter.Bytes()) +} + +type ErrorBufferWriter struct { + w ByteBufferWriter + failAfter int +} + +func (w *ErrorBufferWriter) Write(p []byte) (n int, err error) { + if w.failAfter >= 1 { + w.failAfter-- + } + n, err = w.w.Write(p) + if w.failAfter == 0 { + return n, io.ErrUnexpectedEOF + } + return n, err +} + +func (w *ErrorBufferWriter) Bytes() []byte { + return w.w.Bytes() +} + +func TestNewBandwidthLimitingWriter(t *testing.T) { + type testCase struct { + refillNumber int64 + internalWriter ByteBufferWriter + inputBuffer []byte + expectedN int + expectedError error + } + testCases := []testCase{ + { + refillNumber: 2, + internalWriter: bytes.NewBuffer(make([]byte, 0, 36)), // buffer size greater than bucket size + inputBuffer: []byte(strings.Repeat("a", 36)), + expectedN: 36, + expectedError: nil, + }, + { + refillNumber: 2, + internalWriter: bytes.NewBuffer(make([]byte, 0, 12)), // buffer size lesser than bucket size + inputBuffer: []byte(strings.Repeat("a", 12)), + expectedN: 12, + expectedError: nil, + }, + { + // buffer size greater than bucket size and error after two Write() invocations + refillNumber: 2, + internalWriter: &ErrorBufferWriter{w: bytes.NewBuffer(make([]byte, 0, 36)), failAfter: 2}, + inputBuffer: []byte(strings.Repeat("a", 36)), + expectedN: 32, + expectedError: io.ErrUnexpectedEOF, + }, + } + + for _, test := range testCases { + bucket, err := NewBucket(16, 8, test.refillNumber, 100*time.Millisecond) + assert.NoError(t, err) + + writer, err := NewBandwidthLimitingWriter(test.internalWriter, bucket) + assert.NoError(t, err) + assert.False(t, writer.th.running) + + n, err := writer.Write(test.inputBuffer) + assert.True(t, writer.th.running) + assert.Equal(t, test.expectedN, n) + assert.Equal(t, test.expectedError, err) + assert.Equal(t, test.inputBuffer[:n], test.internalWriter.Bytes()) + + err = writer.Close() + assert.Nil(t, err) + assert.False(t, writer.th.running) + } +} diff --git a/lambda/core/bandwidthlimiter/throttler.go b/lambda/core/bandwidthlimiter/throttler.go new file mode 100644 index 0000000..b3b57dd --- /dev/null +++ b/lambda/core/bandwidthlimiter/throttler.go @@ -0,0 +1,154 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "errors" + "fmt" + "io" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" +) + +var ErrBufferSizeTooLarge = errors.New("buffer size cannot be greater than bucket size") + +func NewBucket(capacity int64, initialTokenCount int64, refillNumber int64, refillInterval time.Duration) (*Bucket, error) { + if capacity <= 0 || initialTokenCount < 0 || refillNumber <= 0 || refillInterval <= 0 || + capacity < initialTokenCount { + errorMsg := fmt.Sprintf("invalid bucket parameters (capacity: %d, initialTokenCount: %d, refillNumber: %d,"+ + "refillInterval: %d)", capacity, initialTokenCount, refillInterval, refillInterval) + log.Error(errorMsg) + return nil, errors.New(errorMsg) + } + return &Bucket{ + capacity: capacity, + tokenCount: initialTokenCount, + refillNumber: refillNumber, + refillInterval: refillInterval, + mutex: sync.Mutex{}, + }, nil +} + +type Bucket struct { + capacity int64 + tokenCount int64 + refillNumber int64 + refillInterval time.Duration + mutex sync.Mutex +} + +func (b *Bucket) produceTokens() { + b.mutex.Lock() + defer b.mutex.Unlock() + if b.tokenCount < b.capacity { + b.tokenCount = min64(b.tokenCount+b.refillNumber, b.capacity) + } +} + +func (b *Bucket) consumeTokens(n int64) bool { + b.mutex.Lock() + defer b.mutex.Unlock() + if n <= b.tokenCount { + b.tokenCount -= n + return true + } + return false +} + +func (b *Bucket) getTokenCount() int64 { + b.mutex.Lock() + defer b.mutex.Unlock() + return b.tokenCount +} + +func NewThrottler(bucket *Bucket) (*Throttler, error) { + if bucket == nil { + errorMsg := "cannot create a throttler with nil bucket" + log.Error(errorMsg) + return nil, errors.New(errorMsg) + } + return &Throttler{ + b: bucket, + running: false, + produced: make(chan int64), + done: make(chan struct{}), + // FIXME: + // The runtime tells whether the function response mode is streaming or not. + // Ideally, we would want to use that value here. Since I'm just rebasing, I will leave + // as-is, but we should use that instead of relying on our memory to set this here + // because we "know" it's a streaming code path. + metrics: &interop.InvokeResponseMetrics{FunctionResponseMode: interop.FunctionResponseModeStreaming}, + }, nil +} + +type Throttler struct { + b *Bucket + running bool + produced chan int64 + done chan struct{} + metrics *interop.InvokeResponseMetrics +} + +func (th *Throttler) start() { + if th.running { + return + } + th.running = true + th.metrics.StartReadingResponseMonoTimeMs = metering.Monotime() + go func() { + ticker := time.NewTicker(th.b.refillInterval) + for { + select { + case <-ticker.C: + th.b.produceTokens() + select { + case th.produced <- metering.Monotime(): + default: + } + case <-th.done: + ticker.Stop() + return + } + } + }() +} + +func (th *Throttler) stop() { + if !th.running { + return + } + th.running = false + th.metrics.FinishReadingResponseMonoTimeMs = metering.Monotime() + durationMs := (th.metrics.FinishReadingResponseMonoTimeMs - th.metrics.StartReadingResponseMonoTimeMs) / int64(time.Millisecond) + if durationMs > 0 { + th.metrics.OutboundThroughputBps = (th.metrics.ProducedBytes / durationMs) * int64(time.Second/time.Millisecond) + } else { + th.metrics.OutboundThroughputBps = -1 + } + th.done <- struct{}{} +} + +func (th *Throttler) bandwidthLimitingWrite(w io.Writer, p []byte) (written int, err error) { + n := int64(len(p)) + if n > th.b.capacity { + return 0, ErrBufferSizeTooLarge + } + for { + if th.b.consumeTokens(n) { + written, err = w.Write(p) + th.metrics.ProducedBytes += int64(written) + return + } + waitStart := metering.Monotime() + elapsed := <-th.produced - waitStart + if elapsed > 0 { + th.metrics.TimeShapedNs += elapsed + } + } +} diff --git a/lambda/core/bandwidthlimiter/throttler_test.go b/lambda/core/bandwidthlimiter/throttler_test.go new file mode 100644 index 0000000..a88a14d --- /dev/null +++ b/lambda/core/bandwidthlimiter/throttler_test.go @@ -0,0 +1,215 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewBucket(t *testing.T) { + type testCase struct { + capacity int64 + initialTokenCount int64 + refillNumber int64 + refillInterval time.Duration + bucketCreated bool + } + testCases := []testCase{ + {capacity: 8, initialTokenCount: 6, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: true}, + {capacity: 8, initialTokenCount: 6, refillNumber: 2, refillInterval: -100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: 6, refillNumber: -5, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: -2, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: -2, initialTokenCount: 6, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: 10, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + } + + for _, test := range testCases { + bucket, err := NewBucket(test.capacity, test.initialTokenCount, test.refillNumber, test.refillInterval) + if test.bucketCreated { + assert.NoError(t, err) + assert.NotNil(t, bucket) + } else { + assert.Error(t, err) + assert.Nil(t, bucket) + } + } +} + +func TestBucket_produceTokens_consumeTokens(t *testing.T) { + var consumed bool + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + assert.Equal(t, int64(8), bucket.getTokenCount()) + + consumed = bucket.consumeTokens(5) + assert.Equal(t, int64(3), bucket.getTokenCount()) + assert.True(t, consumed) + + bucket.produceTokens() + assert.Equal(t, int64(9), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(15), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(16), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(16), bucket.getTokenCount()) + + consumed = bucket.consumeTokens(18) + assert.Equal(t, int64(16), bucket.getTokenCount()) + assert.False(t, consumed) + + consumed = bucket.consumeTokens(16) + assert.Equal(t, int64(0), bucket.getTokenCount()) + assert.True(t, consumed) +} + +func TestNewThrottler(t *testing.T) { + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + assert.NotNil(t, throttler) + + throttler, err = NewThrottler(nil) + assert.Error(t, err) + assert.Nil(t, throttler) +} + +func TestNewThrottler_start_stop(t *testing.T) { + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + + assert.False(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + + <-time.Tick(2 * throttler.b.refillInterval) + assert.LessOrEqual(t, int64(14), throttler.b.getTokenCount()) + assert.True(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + <-time.Tick(2 * throttler.b.refillInterval) + assert.Equal(t, int64(16), throttler.b.getTokenCount()) + assert.True(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) +} + +type ByteBufferWriter interface { + Write(p []byte) (n int, err error) + Bytes() []byte +} + +type FixedSizeBufferWriter struct { + buf []byte +} + +func (w *FixedSizeBufferWriter) Write(p []byte) (n int, err error) { + n = copy(w.buf, p) + return +} + +func (w *FixedSizeBufferWriter) Bytes() []byte { + return w.buf +} + +func TestNewThrottler_bandwidthLimitingWrite(t *testing.T) { + var size10mb int64 = 10 * 1024 * 1024 + + type testCase struct { + capacity int64 + initialTokenCount int64 + writer ByteBufferWriter + inputBuffer []byte + expectedN int + expectedError error + } + testCases := []testCase{ + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 14)), + inputBuffer: []byte(strings.Repeat("a", 12)), + expectedN: 12, + expectedError: nil, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 12)), + inputBuffer: []byte(strings.Repeat("a", 14)), + expectedN: 14, + expectedError: nil, + }, + { + capacity: size10mb, + initialTokenCount: size10mb, + writer: bytes.NewBuffer(make([]byte, 0, size10mb)), + inputBuffer: []byte(strings.Repeat("a", int(size10mb))), + expectedN: int(size10mb), + expectedError: nil, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 18)), + inputBuffer: []byte(strings.Repeat("a", 18)), + expectedN: 0, + expectedError: ErrBufferSizeTooLarge, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: &FixedSizeBufferWriter{buf: make([]byte, 12)}, + inputBuffer: []byte(strings.Repeat("a", 14)), + expectedN: 12, + expectedError: nil, + }, + } + + for _, test := range testCases { + bucket, err := NewBucket(test.capacity, test.initialTokenCount, 2, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + + writer := test.writer + throttler.start() + n, err := throttler.bandwidthLimitingWrite(writer, test.inputBuffer) + assert.Equal(t, test.expectedN, n) + assert.Equal(t, test.expectedError, err) + + if test.expectedError == nil { + assert.Equal(t, test.inputBuffer[:n], test.writer.Bytes()) + } else { + assert.Equal(t, []byte{}, test.writer.Bytes()) + } + throttler.stop() + } +} diff --git a/lambda/core/bandwidthlimiter/util.go b/lambda/core/bandwidthlimiter/util.go new file mode 100644 index 0000000..7078d5d --- /dev/null +++ b/lambda/core/bandwidthlimiter/util.go @@ -0,0 +1,46 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func min64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +func NewChunkIterator(buf []byte, chunkSize int) *ChunkIterator { + if buf == nil { + return nil + } + return &ChunkIterator{ + buf: buf, + chunkSize: chunkSize, + offset: 0, + } +} + +type ChunkIterator struct { + buf []byte + chunkSize int + offset int +} + +func (i *ChunkIterator) Next() []byte { + begin := i.offset + end := min(i.offset+i.chunkSize, len(i.buf)) + i.offset = end + + if begin == end { + return nil + } + return i.buf[begin:end] +} diff --git a/lambda/core/bandwidthlimiter/util_test.go b/lambda/core/bandwidthlimiter/util_test.go new file mode 100644 index 0000000..ed93c77 --- /dev/null +++ b/lambda/core/bandwidthlimiter/util_test.go @@ -0,0 +1,45 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewChunkIterator(t *testing.T) { + buf := []byte("abcdefghijk") + + type testCase struct { + buf []byte + chunkSize int + expectedResult [][]byte + } + testCases := []testCase{ + {buf: nil, chunkSize: 0, expectedResult: [][]byte{}}, + {buf: nil, chunkSize: 1, expectedResult: [][]byte{}}, + {buf: buf, chunkSize: 0, expectedResult: [][]byte{}}, + {buf: buf, chunkSize: 1, expectedResult: [][]byte{ + []byte("a"), []byte("b"), []byte("c"), []byte("d"), []byte("e"), []byte("f"), []byte("g"), []byte("h"), + []byte("i"), []byte("j"), []byte("k"), + }}, + {buf: buf, chunkSize: 4, expectedResult: [][]byte{[]byte("abcd"), []byte("efgh"), []byte("ijk")}}, + {buf: buf, chunkSize: 5, expectedResult: [][]byte{[]byte("abcde"), []byte("fghij"), []byte("k")}}, + {buf: buf, chunkSize: 11, expectedResult: [][]byte{[]byte("abcdefghijk")}}, + {buf: buf, chunkSize: 12, expectedResult: [][]byte{[]byte("abcdefghijk")}}, + } + + for _, test := range testCases { + iterator := NewChunkIterator(test.buf, test.chunkSize) + if test.buf == nil { + assert.Nil(t, iterator) + } else { + for _, expectedChunk := range test.expectedResult { + assert.Equal(t, expectedChunk, iterator.Next()) + } + assert.Nil(t, iterator.Next()) + } + } +} diff --git a/lambda/core/credentials.go b/lambda/core/credentials.go index 7b1bf14..ad152d0 100644 --- a/lambda/core/credentials.go +++ b/lambda/core/credentials.go @@ -7,8 +7,6 @@ import ( "fmt" "sync" "time" - - log "github.com/sirupsen/logrus" ) const ( @@ -26,11 +24,9 @@ type Credentials struct { } type CredentialsService interface { - SetCredentials(token, awsKey, awsSecret, awsSession string) + SetCredentials(token, awsKey, awsSecret, awsSession string, expiration time.Time) GetCredentials(token string) (*Credentials, error) - UpdateCredentials(awsKey, awsSecret, awsSession string) error - BlockService() - UnblockService() + UpdateCredentials(awsKey, awsSecret, awsSession string, expiration time.Time) error } type credentialsServiceImpl struct { @@ -51,7 +47,7 @@ func NewCredentialsService() CredentialsService { return credentialsService } -func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSession string) { +func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSession string, expiration time.Time) { c.contentMutex.Lock() defer c.contentMutex.Unlock() @@ -59,7 +55,7 @@ func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSes AwsKey: awsKey, AwsSecret: awsSecret, AwsSession: awsSession, - Expiration: time.Now().Add(16 * time.Minute), + Expiration: expiration, } } @@ -77,33 +73,7 @@ func (c *credentialsServiceImpl) GetCredentials(token string) (*Credentials, err return nil, ErrCredentialsNotFound } -func (c *credentialsServiceImpl) BlockService() { - if c.currentState == BLOCKED { - return - } - log.Info("blocking the credentials service") - c.serviceMutex.Lock() - - c.contentMutex.Lock() - defer c.contentMutex.Unlock() - - c.currentState = BLOCKED -} - -func (c *credentialsServiceImpl) UnblockService() { - if c.currentState == UNBLOCKED { - return - } - log.Info("unblocking the credentials service") - - c.contentMutex.Lock() - defer c.contentMutex.Unlock() - - c.currentState = UNBLOCKED - c.serviceMutex.Unlock() -} - -func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession string) error { +func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession string, expiration time.Time) error { mapSize := len(c.credentials) if mapSize != 1 { return fmt.Errorf("there are %d set of credentials", mapSize) @@ -114,6 +84,6 @@ func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession token = key } - c.SetCredentials(token, awsKey, awsSecret, awsSession) + c.SetCredentials(token, awsKey, awsSecret, awsSession, expiration) return nil } diff --git a/lambda/core/credentials_test.go b/lambda/core/credentials_test.go index ab0b247..625ab8e 100644 --- a/lambda/core/credentials_test.go +++ b/lambda/core/credentials_test.go @@ -19,7 +19,8 @@ const ( func TestGetSetCredentialsHappy(t *testing.T) { credentialsService := NewCredentialsService() - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) + credentialsExpiration := time.Now().Add(15 * time.Minute) + credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession, credentialsExpiration) credentials, err := credentialsService.GetCredentials(Token) @@ -40,8 +41,12 @@ func TestGetCredentialsFail(t *testing.T) { func TestUpdateCredentialsHappy(t *testing.T) { credentialsService := NewCredentialsService() - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) - err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1") + credentialsExpiration := time.Now().Add(15 * time.Minute) + credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession, credentialsExpiration) + + restoreCredentialsExpiration := time.Now().Add(10 * time.Hour) + + err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1", restoreCredentialsExpiration) assert.NoError(t, err) credentials, err := credentialsService.GetCredentials(Token) @@ -50,49 +55,16 @@ func TestUpdateCredentialsHappy(t *testing.T) { assert.Equal(t, "sampleKey1", credentials.AwsKey) assert.Equal(t, "sampleSecret1", credentials.AwsSecret) assert.Equal(t, "sampleSession1", credentials.AwsSession) -} - -func TestUpdateCredentialsFail(t *testing.T) { - credentialsService := NewCredentialsService() - err := credentialsService.UpdateCredentials("unknownKey", "unknownSecret", "unknownSession") - - assert.Error(t, err) -} + nineHoursLater := time.Now().Add(9 * time.Hour) -func TestUpdateCredentialsOfBlockedService(t *testing.T) { - credentialsService := NewCredentialsService() - credentialsService.BlockService() - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) - err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1") - assert.NoError(t, err) + assert.True(t, nineHoursLater.Before(credentials.Expiration)) } -func TestConsecutiveBlockService(t *testing.T) { +func TestUpdateCredentialsFail(t *testing.T) { credentialsService := NewCredentialsService() - timeout := time.After(1 * time.Second) - done := make(chan bool) - - go func() { - for i := 0; i < 10; i++ { - credentialsService.BlockService() - } - done <- true - }() - - select { - case <-timeout: - t.Fatal("BlockService should not block the calling thread.") - case <-done: - } -} - -// unlocking a mutex twice causes panic -// the assertion here is basically not having panic -func TestConsecutiveUnblockService(t *testing.T) { - credentialsService := NewCredentialsService() + err := credentialsService.UpdateCredentials("unknownKey", "unknownSecret", "unknownSession", time.Now()) - credentialsService.UnblockService() - credentialsService.UnblockService() + assert.Error(t, err) } diff --git a/lambda/core/directinvoke/directinvoke.go b/lambda/core/directinvoke/directinvoke.go index 1699121..8ef59ae 100644 --- a/lambda/core/directinvoke/directinvoke.go +++ b/lambda/core/directinvoke/directinvoke.go @@ -4,28 +4,38 @@ package directinvoke import ( + "context" "fmt" "io" "net/http" + "strconv" "github.com/go-chi/chi" + "go.amzn.com/lambda/core/bandwidthlimiter" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" + + log "github.com/sirupsen/logrus" ) const ( - InvokeIDHeader = "Invoke-Id" - InvokedFunctionArnHeader = "Invoked-Function-Arn" - VersionIDHeader = "Invoked-Function-Version" - ReservationTokenHeader = "Reservation-Token" - CustomerHeadersHeader = "Customer-Headers" - ContentTypeHeader = "Content-Type" + InvokeIDHeader = "Invoke-Id" + InvokedFunctionArnHeader = "Invoked-Function-Arn" + VersionIDHeader = "Invoked-Function-Version" + ReservationTokenHeader = "Reservation-Token" + CustomerHeadersHeader = "Customer-Headers" + ContentTypeHeader = "Content-Type" + MaxPayloadSizeHeader = "MaxPayloadSize" + ResponseBandwidthRateHeader = "ResponseBandwidthRate" + ResponseBandwidthBurstSizeHeader = "ResponseBandwidthBurstSize" + FunctionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" ErrorTypeHeader = "Error-Type" - EndOfResponseTrailer = "End-Of-Response" - - SandboxErrorType = "Error.Sandbox" + EndOfResponseTrailer = "End-Of-Response" + FunctionErrorTypeTrailer = "Lambda-Runtime-Function-Error-Type" + FunctionErrorBodyTrailer = "Lambda-Runtime-Function-Error-Body" ) const ( @@ -34,7 +44,14 @@ const ( EndOfResponseOversized = "Oversized" ) +var ResetReasonMap = map[string]fatalerror.ErrorType{ + "failure": fatalerror.SandboxFailure, + "timeout": fatalerror.SandboxTimeout, +} + var MaxDirectResponseSize int64 = interop.MaxPayloadSize // this is intentionally not a constant so we can configure it via CLI +var ResponseBandwidthRate int64 = interop.ResponseBandwidthRate +var ResponseBandwidthBurstSize int64 = interop.ResponseBandwidthBurstSize func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) { w.Header().Set(ErrorTypeHeader, errorType) @@ -42,6 +59,12 @@ func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } +func renderInternalServerError(w http.ResponseWriter, errorType string) { + w.Header().Set(ErrorTypeHeader, errorType) + w.WriteHeader(http.StatusInternalServerError) + w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) +} + // ReceiveDirectInvoke parses invoke and verifies it against Token message. Uses deadline provided by Token // Renders BadRequest in case of error func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.Token) (*interop.Invoke, error) { @@ -54,6 +77,47 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T } now := metering.Monotime() + + MaxDirectResponseSize = interop.MaxPayloadSize + if maxPayloadSize := r.Header.Get(MaxPayloadSizeHeader); maxPayloadSize != "" { + if n, err := strconv.ParseInt(maxPayloadSize, 10, 64); err == nil && n >= -1 { + MaxDirectResponseSize = n + } else { + log.Error("MaxPayloadSize header is not a valid number") + renderBadRequest(w, r, interop.ErrInvalidMaxPayloadSize.Error()) + return nil, interop.ErrInvalidMaxPayloadSize + } + } + + if MaxDirectResponseSize == -1 { + w.Header().Add("Trailer", FunctionErrorTypeTrailer) + w.Header().Add("Trailer", FunctionErrorBodyTrailer) + + ResponseBandwidthRate = interop.ResponseBandwidthRate + if responseBandwidthRate := r.Header.Get(ResponseBandwidthRateHeader); responseBandwidthRate != "" { + if n, err := strconv.ParseInt(responseBandwidthRate, 10, 64); err == nil && + interop.MinResponseBandwidthRate <= n && n <= interop.MaxResponseBandwidthRate { + ResponseBandwidthRate = n + } else { + log.Error("ResponseBandwidthRate header is not a valid number or is out of the allowed range") + renderBadRequest(w, r, interop.ErrInvalidResponseBandwidthRate.Error()) + return nil, interop.ErrInvalidResponseBandwidthRate + } + } + + ResponseBandwidthBurstSize = interop.ResponseBandwidthBurstSize + if responseBandwidthBurstSize := r.Header.Get(ResponseBandwidthBurstSizeHeader); responseBandwidthBurstSize != "" { + if n, err := strconv.ParseInt(responseBandwidthBurstSize, 10, 64); err == nil && + interop.MinResponseBandwidthBurstSize <= n && n <= interop.MaxResponseBandwidthBurstSize { + ResponseBandwidthBurstSize = n + } else { + log.Error("ResponseBandwidthBurstSize header is not a valid number or is out of the allowed range") + renderBadRequest(w, r, interop.ErrInvalidResponseBandwidthBurstSize.Error()) + return nil, interop.ErrInvalidResponseBandwidthBurstSize + } + } + } + inv := &interop.Invoke{ ID: r.Header.Get(InvokeIDHeader), ReservationToken: chi.URLParam(r, "reservationtoken"), @@ -66,7 +130,6 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T LambdaSegmentID: token.LambdaSegmentID, ClientContext: custHeaders.ClientContext, Payload: r.Body, - CorrelationID: "invokeCorrelationID", DeadlineNs: fmt.Sprintf("%d", now+token.FunctionTimeout.Nanoseconds()), NeedDebugLogs: token.NeedDebugLogs, InvokeReceivedTime: now, @@ -99,24 +162,215 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T return inv, nil } -func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, w http.ResponseWriter) error { - for k, v := range additionalHeaders { - w.Header().Add(k, v) +type CopyDoneResult struct { + Metrics *interop.InvokeResponseMetrics + Error error +} + +func getErrorTypeFromResetReason(resetReason string) fatalerror.ErrorType { + errorTypeTrailer, ok := ResetReasonMap[resetReason] + if !ok { + errorTypeTrailer = fatalerror.Unknown + } + return errorTypeTrailer +} + +func isErrorResponse(additionalHeaders map[string]string) (isErrorResponse bool) { + _, isErrorResponse = additionalHeaders[ErrorTypeHeader] + return +} + +func isStreamingInvoke() bool { + return MaxDirectResponseSize == -1 +} + +func asyncPayloadCopy(w http.ResponseWriter, payload io.Reader) (copyDone chan CopyDoneResult, cancel context.CancelFunc, err error) { + copyDone = make(chan CopyDoneResult) + streamedResponseWriter, cancel, err := NewStreamedResponseWriter(w) + if err != nil { + return nil, nil, &interop.ErrInternalPlatformError{} + } + go func() { // copy payload in a separate go routine + _, copyError := bandwidthlimiter.BandwidthLimitingCopy(streamedResponseWriter, payload) + if copyError != nil { + w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) + } else { + w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) + } + copyDoneResult := CopyDoneResult{ + Metrics: streamedResponseWriter.GetMetrics(), + Error: copyError, + } + copyDone <- copyDoneResult + cancel() // free resources + }() + return +} + +func sendStreamingInvokeResponse(payload io.Reader, trailers http.Header, w http.ResponseWriter, + interruptedResponseChan chan *interop.Reset, sendResponseChan chan *interop.InvokeResponseMetrics, + request *interop.CancellableRequest, runtimeCalledResponse bool) (err error) { + /* In case of /response, we copy the payload and, once copied, we attach: + * 1) 'Lambda-Runtime-Function-Error-Type' + * 2) 'Lambda-Runtime-Function-Error-Body' + * trailers. */ + copyDone, cancel, err := asyncPayloadCopy(w, payload) + if err != nil { + renderInternalServerError(w, err.Error()) + return err + } + + var errorTypeTrailer string + var errorBodyTrailer string + var copyDoneResult CopyDoneResult + select { + case copyDoneResult = <-copyDone: // copy finished + errorTypeTrailer = trailers.Get(FunctionErrorTypeTrailer) + errorBodyTrailer = trailers.Get(FunctionErrorBodyTrailer) + if copyDoneResult.Error != nil && errorTypeTrailer == "" { // truncated payload, error type not known + errorTypeTrailer = string(fatalerror.TruncatedResponse) + } + case reset := <-interruptedResponseChan: // reset initiated + cancel() + if request != nil { + // In case of reset: + // * to interrupt copying when runtime called /response (a potential stuck on Body.Read() operation), + // we close the underlying connection using .Close() method on the request object + // * for /error case, the whole body is already read in /error handler, so we don't need additional handling + // when sending streaming invoke error response + connErr := request.Cancel() + if connErr != nil { + log.Warnf("Failed to close underlying connection: %s", connErr) + } + } else { + log.Warn("Cannot close underlying connection. Request object is nil") + } + copyDoneResult = <-copyDone + reset.InvokeResponseMetrics = copyDoneResult.Metrics + interruptedResponseChan <- nil + errorTypeTrailer = string(getErrorTypeFromResetReason(reset.Reason)) + } + w.Header().Set(FunctionErrorTypeTrailer, errorTypeTrailer) + w.Header().Set(FunctionErrorBodyTrailer, errorBodyTrailer) + + copyDoneResult.Metrics.RuntimeCalledResponse = runtimeCalledResponse + sendResponseChan <- copyDoneResult.Metrics + + if copyDoneResult.Error != nil { + log.Errorf("Error while streaming response payload: %s", copyDoneResult.Error) + err = &interop.ErrTruncatedResponse{} + } + return +} + +func sendStreamingInvokeErrorResponse(payload io.Reader, w http.ResponseWriter, + interruptedResponseChan chan *interop.Reset, sendResponseChan chan *interop.InvokeResponseMetrics, runtimeCalledResponse bool) (err error) { + + copyDone, cancel, err := asyncPayloadCopy(w, payload) + if err != nil { + renderInternalServerError(w, err.Error()) + return err + } + + var copyDoneResult CopyDoneResult + select { + case copyDoneResult = <-copyDone: // copy finished + case reset := <-interruptedResponseChan: // reset initiated + cancel() + copyDoneResult = <-copyDone + reset.InvokeResponseMetrics = copyDoneResult.Metrics + interruptedResponseChan <- nil + } + + copyDoneResult.Metrics.RuntimeCalledResponse = runtimeCalledResponse + sendResponseChan <- copyDoneResult.Metrics + + if copyDoneResult.Error != nil { + log.Errorf("Error while streaming error response payload: %s", copyDoneResult.Error) + err = &interop.ErrTruncatedResponse{} + } + return +} + +// parseFunctionResponseMode fetches the mode from the header +// If the header is absent, it returns buffered mode. +func parseFunctionResponseMode(w http.ResponseWriter) (interop.FunctionResponseMode, error) { + headerValue := w.Header().Get(FunctionResponseModeHeader) + // the header is optional, so it's ok to be absent + if headerValue == "" { + return interop.FunctionResponseModeBuffered, nil + } + + return interop.ConvertToFunctionResponseMode(headerValue) +} + +func sendPayloadLimitedResponse(payload io.Reader, trailers http.Header, w http.ResponseWriter, sendResponseChan chan *interop.InvokeResponseMetrics, runtimeCalledResponse bool) (err error) { + functionResponseMode, err := parseFunctionResponseMode(w) + if err != nil { + return err + } + + // non-streaming invoke request but runtime is streaming: predefine Trailer headers + if functionResponseMode == interop.FunctionResponseModeStreaming { + w.Header().Add("Trailer", FunctionErrorTypeTrailer) + w.Header().Add("Trailer", FunctionErrorBodyTrailer) + } + + startReadingResponseMonoTimeMs := metering.Monotime() + written, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // +1 because we do allow 10MB but not 10MB + 1 byte + + // non-streaming invoke request but runtime is streaming: set response trailers + if functionResponseMode == interop.FunctionResponseModeStreaming { + w.Header().Set(FunctionErrorTypeTrailer, trailers.Get(FunctionErrorTypeTrailer)) + w.Header().Set(FunctionErrorBodyTrailer, trailers.Get(FunctionErrorBodyTrailer)) } - n, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // +1 because we do allow 10MB but not 10MB + 1 byte if err != nil { w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) - } else if n == MaxDirectResponseSize+1 { + err = &interop.ErrTruncatedResponse{} + } else if MaxDirectResponseSize != -1 && written == MaxDirectResponseSize+1 { w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) err = &interop.ErrorResponseTooLargeDI{ ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ - ResponseSize: int(n), + ResponseSize: int(written), MaxResponseSize: int(MaxDirectResponseSize), }, } } else { w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } - return err + + sendResponseChan <- &interop.InvokeResponseMetrics{ + ProducedBytes: int64(written), + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: metering.Monotime(), + TimeShapedNs: int64(-1), + OutboundThroughputBps: int64(-1), + // FIXME: + // We should use InvokeResponseMode here, because only when it's streaming we're interested + // on it. If the invoke is buffered, we don't generate streaming metrics, even if the + // function response mode is streaming. + FunctionResponseMode: interop.FunctionResponseModeBuffered, + RuntimeCalledResponse: runtimeCalledResponse, + } + return +} + +func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, trailers http.Header, + w http.ResponseWriter, interruptedResponseChan chan *interop.Reset, + sendResponseChan chan *interop.InvokeResponseMetrics, request *interop.CancellableRequest, runtimeCalledResponse bool) error { + + for k, v := range additionalHeaders { + w.Header().Add(k, v) + } + + if isStreamingInvoke() { // unlimited payload; response streaming mode + if isErrorResponse(additionalHeaders) { // send streamed error response when runtime called /error + return sendStreamingInvokeErrorResponse(payload, w, interruptedResponseChan, sendResponseChan, runtimeCalledResponse) + } + // send streamed response when runtime called /response + return sendStreamingInvokeResponse(payload, trailers, w, interruptedResponseChan, sendResponseChan, request, runtimeCalledResponse) + } + + return sendPayloadLimitedResponse(payload, trailers, w, sendResponseChan, runtimeCalledResponse) } diff --git a/lambda/core/directinvoke/directinvoke_test.go b/lambda/core/directinvoke/directinvoke_test.go new file mode 100644 index 0000000..4e26161 --- /dev/null +++ b/lambda/core/directinvoke/directinvoke_test.go @@ -0,0 +1,358 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "bytes" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/interop" +) + +func NewResponseWriterWithoutFlushMethod() *ResponseWriterWithoutFlushMethod { + return &ResponseWriterWithoutFlushMethod{} +} + +type ResponseWriterWithoutFlushMethod struct{} + +func (*ResponseWriterWithoutFlushMethod) Header() http.Header { return http.Header{} } +func (*ResponseWriterWithoutFlushMethod) Write([]byte) (n int, err error) { return } +func (*ResponseWriterWithoutFlushMethod) WriteHeader(_ int) {} + +func NewSimpleResponseWriter() *SimpleResponseWriter { + return &SimpleResponseWriter{ + buffer: bytes.NewBuffer(nil), + trailers: make(http.Header), + } +} + +type SimpleResponseWriter struct { + buffer *bytes.Buffer + trailers http.Header +} + +func (w *SimpleResponseWriter) Header() http.Header { return w.trailers } +func (w *SimpleResponseWriter) Write(p []byte) (n int, err error) { return w.buffer.Write(p) } +func (*SimpleResponseWriter) WriteHeader(_ int) {} +func (*SimpleResponseWriter) Flush() {} + +func NewInterruptableResponseWriter(interruptAfter int) (*InterruptableResponseWriter, chan struct{}) { + interruptedTestWriterChan := make(chan struct{}) + return &InterruptableResponseWriter{ + buffer: bytes.NewBuffer(nil), + trailers: make(http.Header), + interruptAfter: interruptAfter, + interruptedTestWriterChan: interruptedTestWriterChan, + }, interruptedTestWriterChan +} + +type InterruptableResponseWriter struct { + buffer *bytes.Buffer + trailers http.Header + interruptAfter int // expect Writer to be interrupted after 'interruptAfter' number of writes + interruptedTestWriterChan chan struct{} +} + +func (w *InterruptableResponseWriter) Header() http.Header { return w.trailers } +func (w *InterruptableResponseWriter) Write(p []byte) (n int, err error) { + if w.interruptAfter >= 1 { + w.interruptAfter-- + } else if w.interruptAfter == 0 { + w.interruptedTestWriterChan <- struct{}{} // ready to be interrupted + <-w.interruptedTestWriterChan // wait until interrupted + } + n, err = w.buffer.Write(p) + return +} +func (*InterruptableResponseWriter) WriteHeader(_ int) {} +func (*InterruptableResponseWriter) Flush() {} + +// This is a simple reader implementing io.Reader interface. It's based on strings.Reader, but it doesn't have extra +// methods that allow faster copying such as .WriteTo() method. +func NewReader(s string) *Reader { return &Reader{s, 0, -1} } + +type Reader struct { + s string + i int64 // current reading index + prevRune int // index of previous rune; or < 0 +} + +func (r *Reader) Read(b []byte) (n int, err error) { + if r.i >= int64(len(r.s)) { + return 0, io.EOF + } + r.prevRune = -1 + n = copy(b, r.s[r.i:]) + r.i += int64(n) + return +} + +func TestSendDirectInvokeWithIncompatibleResponseWriter(t *testing.T) { + MaxDirectResponseSize = -1 + err := SendDirectInvokeResponse(nil, nil, nil, NewResponseWriterWithoutFlushMethod(), nil, nil, nil, false) + require.Error(t, err) + require.Equal(t, "ErrInternalPlatformError", err.Error()) +} + +func TestAsyncPayloadCopySuccess(t *testing.T) { + payloadString := strings.Repeat("a", 10*1024*1024) + writer := NewSimpleResponseWriter() + + expectedPayloadString := payloadString + + copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + <-copyDone + require.Equal(t, expectedPayloadString, writer.buffer.String()) +} + +// We use an interruptable response writer which informs on a channel that it's ready to be interrupted after +// 'interruptAfter' number of writes, then it waits for interruption completion to resume the current write operation. +// For this test, after initiating copying, we wait for one chunk of 32 KiB to be copied. Then, we use cancel() to +// interrupt copying. At this point, only ongoing .Write() operations can be performed. We inform the writer about +// interruption completion, and the writer resumes the current .Write() operation, which gives us another 32 KiB chunk +// that is copied. After that, copying returns, and we receive a signal on <-copyDone channel. +func TestAsyncPayloadCopySuccessAfterCancel(t *testing.T) { + payloadString := strings.Repeat("a", 10*1024*1024) // 10 MiB + writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) + + expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB (2 chunks) + + copyDone, cancel, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks + cancel() // interrupt copying + interruptedTestWriterChan <- struct{}{} // inform test writer about interruption + + <-copyDone + require.Equal(t, expectedPayloadString, writer.buffer.String()) +} + +func TestAsyncPayloadCopyWithIncompatibleResponseWriter(t *testing.T) { + copyDone, cancel, err := asyncPayloadCopy(&ResponseWriterWithoutFlushMethod{}, nil) + require.Nil(t, copyDone) + require.Nil(t, cancel) + require.Error(t, err) + require.Equal(t, "ErrInternalPlatformError", err.Error()) +} + +func TestSendStreamingInvokeResponseSuccess(t *testing.T) { + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + trailers := http.Header{} + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := payloadString + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendPayloadLimitedResponseWithinThresholdWithStreamingFunction(t *testing.T) { + payloadSize := 1 + payloadString := strings.Repeat("a", payloadSize) + payload := NewReader(payloadString) + trailers := http.Header{} + writer := NewSimpleResponseWriter() + writer.Header().Set("Lambda-Runtime-Function-Response-Mode", "streaming") + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + MaxDirectResponseSize = int64(payloadSize + 1) + + go func() { + err := sendPayloadLimitedResponse(payload, trailers, writer, sendResponseChan, true) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + metrics := <-sendResponseChan + require.Equal(t, interop.FunctionResponseModeBuffered, metrics.FunctionResponseMode) + require.Equal(t, len(payloadString), len(writer.buffer.String())) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished + + // Reset to its default value, just in case other tests use them + MaxDirectResponseSize = interop.MaxPayloadSize +} + +func TestSendPayloadLimitedResponseAboveThresholdWithStreamingFunction(t *testing.T) { + payloadSize := 2 + payloadString := strings.Repeat("a", payloadSize) + payload := NewReader(payloadString) + trailers := http.Header{} + writer := NewSimpleResponseWriter() + writer.Header().Set("Lambda-Runtime-Function-Response-Mode", "streaming") + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + MaxDirectResponseSize = int64(payloadSize - 1) + expectedError := &interop.ErrorResponseTooLargeDI{ + ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ + MaxResponseSize: int(MaxDirectResponseSize), + ResponseSize: payloadSize, + }, + } + + go func() { + err := sendPayloadLimitedResponse(payload, trailers, writer, sendResponseChan, true) + require.Equal(t, expectedError, err) + testFinished <- struct{}{} + }() + + metrics := <-sendResponseChan + require.Equal(t, interop.FunctionResponseModeBuffered, metrics.FunctionResponseMode) + require.Equal(t, len(payloadString), len(writer.buffer.String())) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Oversized", writer.Header().Get("End-Of-Response")) + <-testFinished + + // Reset to its default value, just in case other tests use them + MaxDirectResponseSize = interop.MaxPayloadSize +} + +func TestSendStreamingInvokeResponseSuccessWithTrailers(t *testing.T) { + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + trailers := http.Header{ + "Lambda-Runtime-Function-Error-Type": []string{"ErrorType"}, + "Lambda-Runtime-Function-Error-Body": []string{"ErrorBody"}, + } + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := payloadString + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "ErrorType", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "ErrorBody", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendStreamingInvokeResponseReset(t *testing.T) { // Reset initiated after writing two chunks of 32 KiB + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + trailers := http.Header{} + writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, true) + require.Error(t, err) + require.Equal(t, "ErrTruncatedResponse", err.Error()) + testFinished <- struct{}{} + }() + + reset := &interop.Reset{Reason: "timeout"} + require.Nil(t, reset.InvokeResponseMetrics) + + <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks + interruptedResponseChan <- reset // send reset + time.Sleep(10 * time.Millisecond) // wait for cancel() being called (first instruction after getting reset) + interruptedTestWriterChan <- struct{}{} // inform test writer about interruption + <-interruptedResponseChan // wait for copy done after interruption + require.NotNil(t, reset.InvokeResponseMetrics) + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "Sandbox.Timeout", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Truncated", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendStreamingInvokeErrorResponseSuccess(t *testing.T) { + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := payloadString + + go func() { + err := sendStreamingInvokeErrorResponse(payload, writer, interruptedResponseChan, sendResponseChan, false) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendStreamingInvokeErrorResponseReset(t *testing.T) { // Reset initiated after writing two chunks of 32 KiB + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB + + go func() { + err := sendStreamingInvokeErrorResponse(payload, writer, interruptedResponseChan, sendResponseChan, true) + require.Error(t, err) + require.Equal(t, "ErrTruncatedResponse", err.Error()) + testFinished <- struct{}{} + }() + + reset := &interop.Reset{Reason: "timeout"} + require.Nil(t, reset.InvokeResponseMetrics) + + <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks + interruptedResponseChan <- reset // send reset + time.Sleep(10 * time.Millisecond) // wait for cancel() being called (first instruction after getting reset) + interruptedTestWriterChan <- struct{}{} // inform test writer about interruption + <-interruptedResponseChan // wait for copy done after interruption + require.NotNil(t, reset.InvokeResponseMetrics) + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Truncated", writer.Header().Get("End-Of-Response")) + <-testFinished +} diff --git a/lambda/core/directinvoke/util.go b/lambda/core/directinvoke/util.go new file mode 100644 index 0000000..511d656 --- /dev/null +++ b/lambda/core/directinvoke/util.go @@ -0,0 +1,84 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "context" + "errors" + "go.amzn.com/lambda/core/bandwidthlimiter" + "io" + "net/http" + "time" + + log "github.com/sirupsen/logrus" +) + +const DefaultRefillIntervalMs = 125 // default refill interval in milliseconds + +func NewStreamedResponseWriter(w http.ResponseWriter) (*bandwidthlimiter.BandwidthLimitingWriter, context.CancelFunc, error) { + flushingWriter, err := NewFlushingWriter(w) // after writing a chunk we have to flush it to avoid additional buffering by ResponseWriter + if err != nil { + return nil, nil, err + } + cancellableWriter, cancel := NewCancellableWriter(flushingWriter) // cancelling prevents next calls to Write() from happening + + refillNumber := ResponseBandwidthRate * DefaultRefillIntervalMs / 1000 // refillNumber is calculated based on 'ResponseBandwidthRate' and bucket refill interval + refillInterval := DefaultRefillIntervalMs * time.Millisecond + + // Initial bucket for token bucket algorithm allows for a burst of up to 6 MiB, and an average transmission rate of 2 MiB/s + bucket, err := bandwidthlimiter.NewBucket(ResponseBandwidthBurstSize, ResponseBandwidthBurstSize, refillNumber, refillInterval) + if err != nil { + cancel() // free resources + return nil, nil, err + } + + bandwidthLimitingWriter, err := bandwidthlimiter.NewBandwidthLimitingWriter(cancellableWriter, bucket) + if err != nil { + cancel() // free resources + return nil, nil, err + } + + return bandwidthLimitingWriter, cancel, nil +} + +func NewFlushingWriter(w io.Writer) (*FlushingWriter, error) { + flusher, ok := w.(http.Flusher) + if !ok { + errorMsg := "expected http.ResponseWriter to be an http.Flusher" + log.Error(errorMsg) + return nil, errors.New(errorMsg) + } + return &FlushingWriter{ + w: w, + flusher: flusher, + }, nil +} + +type FlushingWriter struct { + w io.Writer + flusher http.Flusher +} + +func (w *FlushingWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + w.flusher.Flush() + return +} + +func NewCancellableWriter(w io.Writer) (*CancellableWriter, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return &CancellableWriter{w: w, ctx: ctx}, cancel +} + +type CancellableWriter struct { + w io.Writer + ctx context.Context +} + +func (w *CancellableWriter) Write(p []byte) (int, error) { + if err := w.ctx.Err(); err != nil { + return 0, err + } + return w.w.Write(p) +} diff --git a/lambda/core/doc.go b/lambda/core/doc.go index 23c1539..4a7157f 100644 --- a/lambda/core/doc.go +++ b/lambda/core/doc.go @@ -2,26 +2,23 @@ // SPDX-License-Identifier: Apache-2.0 /* - Package core provides state objects and synchronization primitives for managing data flow in the system. - -States +# States Runtime and Agent implement state object design pattern. Runtime state interface: -type RuntimeState interface { - InitError() error - Ready() error - InvocationResponse() error - InvocationErrorResponse() error -} + type RuntimeState interface { + InitError() error + Ready() error + InvocationResponse() error + InvocationErrorResponse() error + } - -Gates +# Gates Gates provide synchornization primitives for managing data flow in the system. @@ -31,8 +28,9 @@ set of operations being performed in other threads completes. To better understand gates, consider two examples below: Example 1: main thread is awaiting registered threads to walk through the gate, - and after the last registered thread walked through the gate, gate - condition will be satisfied and main thread will proceed: + + and after the last registered thread walked through the gate, gate + condition will be satisfied and main thread will proceed: [main] // register threads with the gate and start threads ... [main] g.AwaitGateCondition() @@ -42,27 +40,25 @@ Example 1: main thread is awaiting registered threads to walk through the gate, [thread] // not blocked Example 2: main thread is awaiting registered threads to arrive at the gate, - and after the last registered thread arrives at the gate, gate - condition will be satisfied and main thread, along with registered - threads will proceed: + + and after the last registered thread arrives at the gate, gate + condition will be satisfied and main thread, along with registered + threads will proceed: [main] // register threads with the gate and start threads ... [main] g.AwaitGateCondition() [main] // blocked until gate condition is satisfied - -Flow +# Flow Flow wraps a set of specific gates required to implement specific data flow in the system. Example flows would be INIT, INVOKE and RESET. - -Registrations +# Registrations Registration service manages registrations, it maintains the mapping between registered parties are events they are registered. Parties not registered in the system will not be issued events. - */ package core diff --git a/lambda/core/externalagent.go b/lambda/core/externalagent.go index 792f356..cd367d2 100644 --- a/lambda/core/externalagent.go +++ b/lambda/core/externalagent.go @@ -22,7 +22,6 @@ type ExternalAgent struct { currentState ExternalAgentState stateLastModified time.Time - Pid int StartedState ExternalAgentState RegisteredState ExternalAgentState diff --git a/lambda/core/flow.go b/lambda/core/flow.go index 3c22b84..b2cb538 100644 --- a/lambda/core/flow.go +++ b/lambda/core/flow.go @@ -19,6 +19,9 @@ type InitFlowSynchronization interface { CancelWithError(error) + RuntimeRestoreReady() error + AwaitRuntimeRestoreReady() error + Clear() } @@ -26,6 +29,7 @@ type initFlowSynchronizationImpl struct { externalAgentsRegisteredGate Gate runtimeReadyGate Gate agentReadyGate Gate + runtimeRestoreReadyGate Gate } // SetExternalAgentsRegisterCount notifies init flow that N /extension/register calls should be done in future by external agents @@ -43,6 +47,11 @@ func (s *initFlowSynchronizationImpl) AwaitRuntimeReady() error { return s.runtimeReadyGate.AwaitGateCondition() } +// AwaitRuntimeRestoreReady awaits runtime restore ready state (/restore/next is called by runtime) +func (s *initFlowSynchronizationImpl) AwaitRuntimeRestoreReady() error { + return s.runtimeRestoreReadyGate.AwaitGateCondition() +} + // AwaitExternalAgentsRegistered awaits for all subscribed agents to report registered func (s *initFlowSynchronizationImpl) AwaitExternalAgentsRegistered() error { return s.externalAgentsRegisteredGate.AwaitGateCondition() @@ -58,6 +67,11 @@ func (s *initFlowSynchronizationImpl) RuntimeReady() error { return s.runtimeReadyGate.WalkThrough() } +// Ready called by runtime when restore is completed (i.e. /next is called after /restore/next) +func (s *initFlowSynchronizationImpl) RuntimeRestoreReady() error { + return s.runtimeRestoreReadyGate.WalkThrough() +} + // Ready called by agent when initialized func (s *initFlowSynchronizationImpl) AgentReady() error { return s.agentReadyGate.WalkThrough() @@ -73,6 +87,7 @@ func (s *initFlowSynchronizationImpl) CancelWithError(err error) { s.externalAgentsRegisteredGate.CancelWithError(err) s.runtimeReadyGate.CancelWithError(err) s.agentReadyGate.CancelWithError(err) + s.runtimeRestoreReadyGate.CancelWithError(err) } // Clear gates state @@ -80,6 +95,7 @@ func (s *initFlowSynchronizationImpl) Clear() { s.externalAgentsRegisteredGate.Clear() s.runtimeReadyGate.Clear() s.agentReadyGate.Clear() + s.runtimeRestoreReadyGate.Clear() } // NewInitFlowSynchronization returns new InitFlowSynchronization instance. @@ -88,6 +104,7 @@ func NewInitFlowSynchronization() InitFlowSynchronization { runtimeReadyGate: NewGate(1), externalAgentsRegisteredGate: NewGate(0), agentReadyGate: NewGate(maxAgentsLimit), + runtimeRestoreReadyGate: NewGate(1), } return initFlow } diff --git a/lambda/core/registrations.go b/lambda/core/registrations.go index dca9d90..f68612c 100644 --- a/lambda/core/registrations.go +++ b/lambda/core/registrations.go @@ -10,8 +10,11 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/interop" "github.com/google/uuid" + + log "github.com/sirupsen/logrus" ) type registrationServiceState int @@ -70,6 +73,7 @@ type FunctionMetadata struct { FunctionName string FunctionVersion string Handler string + RuntimeInfo interop.RuntimeInfo } // RegistrationService keeps track of registered parties, including external agents, threads, and runtime. @@ -94,6 +98,7 @@ type RegistrationService interface { CountAgents() int Clear() AgentsInfo() []AgentInfo + CancelFlows(err error) } type registrationServiceImpl struct { @@ -105,6 +110,7 @@ type registrationServiceImpl struct { initFlow InitFlowSynchronization invokeFlow InvokeFlowSynchronization functionMetadata FunctionMetadata + cancelOnce sync.Once } func (s *registrationServiceImpl) Clear() { @@ -115,6 +121,7 @@ func (s *registrationServiceImpl) Clear() { s.internalAgents.Clear() s.externalAgents.Clear() s.state = registrationServiceOn + s.cancelOnce = sync.Once{} } func (s *registrationServiceImpl) InitFlow() InitFlowSynchronization { @@ -373,6 +380,19 @@ func (s *registrationServiceImpl) TurnOff() { s.state = registrationServiceOff } +// CancelFlows cancels init and invoke flows with error. +func (s *registrationServiceImpl) CancelFlows(err error) { + s.mutex.Lock() + defer s.mutex.Unlock() + // The following block protects us from overwriting the error + // which was first used to cancel flows. + s.cancelOnce.Do(func() { + log.Debugf("Canceling flows: %s", err) + s.initFlow.CancelWithError(err) + s.invokeFlow.CancelWithError(err) + }) +} + // NewRegistrationService returns new RegistrationService instance. func NewRegistrationService(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchronization) RegistrationService { return ®istrationServiceImpl{ @@ -382,5 +402,6 @@ func NewRegistrationService(initFlow InitFlowSynchronization, invokeFlow InvokeF externalAgents: NewExternalAgentsMap(), initFlow: initFlow, invokeFlow: invokeFlow, + cancelOnce: sync.Once{}, } } diff --git a/lambda/core/registrations_test.go b/lambda/core/registrations_test.go index d8857a6..5956ac3 100644 --- a/lambda/core/registrations_test.go +++ b/lambda/core/registrations_test.go @@ -63,7 +63,7 @@ func TestRegistrationServiceHappyPathDuringInit(t *testing.T) { assert.NoError(t, runtime.Ready()) }() - assert.NoError(t, initFlow.AwaitRuntimeReady()) + assert.NoError(t, initFlow.AwaitRuntimeRestoreReady()) registrationService.TurnOff() // Agents Ready diff --git a/lambda/core/runtime_state_names.go b/lambda/core/runtime_state_names.go index b20b9f8..b04ba5d 100644 --- a/lambda/core/runtime_state_names.go +++ b/lambda/core/runtime_state_names.go @@ -5,10 +5,14 @@ package core // String values of possibles runtime states const ( - RuntimeStartedStateName = "Started" - RuntimeInitErrorStateName = "InitError" - RuntimeReadyStateName = "Ready" - RuntimeRunningStateName = "Running" + RuntimeStartedStateName = "Started" + RuntimeInitErrorStateName = "InitError" + RuntimeReadyStateName = "Ready" + RuntimeRunningStateName = "Running" + // RuntimeStartedState -> RuntimeRestoreReadyState + RuntimeRestoreReadyStateName = "RestoreReady" + // RuntimeRestoreReadyState -> RuntimeRestoringState + RuntimeRestoringStateName = "Restoring" RuntimeInvocationResponseStateName = "InvocationResponse" RuntimeInvocationErrorResponseStateName = "InvocationErrorResponse" RuntimeResponseSentStateName = "RuntimeResponseSentState" diff --git a/lambda/core/states.go b/lambda/core/states.go index bc7359d..a5e2010 100644 --- a/lambda/core/states.go +++ b/lambda/core/states.go @@ -72,6 +72,7 @@ var ErrConcurrentStateModification = errors.New("Concurrent state modification") type RuntimeState interface { InitError() error Ready() error + RestoreReady() error InvocationResponse() error InvocationErrorResponse() error ResponseSent() error @@ -82,6 +83,7 @@ type disallowEveryTransitionByDefault struct{} func (s *disallowEveryTransitionByDefault) InitError() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) Ready() error { return ErrNotAllowed } +func (s *disallowEveryTransitionByDefault) RestoreReady() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) InvocationResponse() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) InvocationErrorResponse() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) ResponseSent() error { return ErrNotAllowed } @@ -92,13 +94,14 @@ type Runtime struct { currentState RuntimeState stateLastModified time.Time - Pid int responseTime time.Time RuntimeStartedState RuntimeState RuntimeInitErrorState RuntimeState RuntimeReadyState RuntimeState RuntimeRunningState RuntimeState + RuntimeRestoreReadyState RuntimeState + RuntimeRestoringState RuntimeState RuntimeInvocationResponseState RuntimeState RuntimeInvocationErrorResponseState RuntimeState RuntimeResponseSentState RuntimeState @@ -135,6 +138,12 @@ func (s *Runtime) Ready() error { return s.currentState.Ready() } +func (s *Runtime) RestoreReady() error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.RestoreReady() +} + // InvocationResponse delegates to state implementation. func (s *Runtime) InvocationResponse() error { s.ManagedThread.Lock() @@ -196,6 +205,8 @@ func NewRuntime(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchroni runtime.RuntimeInvocationResponseState = &RuntimeInvocationResponseState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeInvocationErrorResponseState = &RuntimeInvocationErrorResponseState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeResponseSentState = &RuntimeResponseSentState{runtime: runtime, invokeFlow: invokeFlow} + runtime.RuntimeRestoreReadyState = &RuntimeRestoreReadyState{} + runtime.RuntimeRestoringState = &RuntimeRestoringState{runtime: runtime, initFlow: initFlow} runtime.setStateUnsafe(runtime.RuntimeStartedState) return runtime @@ -211,7 +222,14 @@ type RuntimeStartedState struct { // Ready call when runtime init done. func (s *RuntimeStartedState) Ready() error { s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) - err := s.initFlow.RuntimeReady() + // runtime called /next without calling /restore/next + // that means it's not interested in restore phase + err := s.initFlow.RuntimeRestoreReady() + if err != nil { + return err + } + + err = s.initFlow.RuntimeReady() if err != nil { return err } @@ -225,6 +243,22 @@ func (s *RuntimeStartedState) Ready() error { return nil } +func (s *RuntimeStartedState) RestoreReady() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeRestoreReadyState) + err := s.initFlow.RuntimeRestoreReady() + if err != nil { + return err + } + + s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeRestoreReadyState && s.runtime.currentState != s.runtime.RuntimeRestoringState { + return ErrConcurrentStateModification + } + + s.runtime.setStateUnsafe(s.runtime.RuntimeRestoringState) + return nil +} + // InitError move runtime to init error state. func (s *RuntimeStartedState) InitError() error { s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) @@ -236,6 +270,38 @@ func (s *RuntimeStartedState) Name() string { return RuntimeStartedStateName } +type RuntimeRestoringState struct { + disallowEveryTransitionByDefault + runtime *Runtime + initFlow InitFlowSynchronization +} + +// Runtime is healthy after restore and called /next +func (s *RuntimeRestoringState) Ready() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) + err := s.initFlow.RuntimeReady() + if err != nil { + return err + } + s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { + return ErrConcurrentStateModification + } + + s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) + return nil +} + +// Runtime has thrown an exception when executing restore hooks and called /init/error +func (s *RuntimeRestoringState) InitError() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) + return nil +} + +func (s *RuntimeRestoringState) Name() string { + return RuntimeRestoringStateName +} + // RuntimeInitErrorState runtime started state. type RuntimeInitErrorState struct { disallowEveryTransitionByDefault @@ -297,6 +363,14 @@ func (s *RuntimeRunningState) Name() string { return RuntimeRunningStateName } +type RuntimeRestoreReadyState struct { + disallowEveryTransitionByDefault +} + +func (s *RuntimeRestoreReadyState) Name() string { + return RuntimeRestoreReadyStateName +} + // RuntimeInvocationResponseState runtime response is available. // Start state for runtime response submission. type RuntimeInvocationResponseState struct { diff --git a/lambda/core/states_test.go b/lambda/core/states_test.go index 4b01838..37f38e2 100644 --- a/lambda/core/states_test.go +++ b/lambda/core/states_test.go @@ -39,10 +39,7 @@ func TestRuntimeInitErrorAfterReady(t *testing.T) { } func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // Started assert.Equal(t, runtime.RuntimeStartedState, runtime.GetState()) // Started -> InitError @@ -53,6 +50,10 @@ func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { runtime.SetState(runtime.RuntimeStartedState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Started -> RestoreReady + runtime.SetState(runtime.RuntimeStartedState) + assert.NoError(t, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) // Started -> ResponseSent runtime.SetState(runtime.RuntimeStartedState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -68,10 +69,7 @@ func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { } func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // InitError -> InitError runtime.SetState(runtime.RuntimeInitErrorState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -80,6 +78,10 @@ func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { runtime.SetState(runtime.RuntimeInitErrorState) assert.Equal(t, ErrNotAllowed, runtime.Ready()) assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) + // InitError -> RestoreReady + runtime.SetState(runtime.RuntimeInitErrorState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) // InitError -> ResponseSent runtime.SetState(runtime.RuntimeInitErrorState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -95,10 +97,7 @@ func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { } func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // Ready -> InitError runtime.SetState(runtime.RuntimeReadyState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -107,6 +106,10 @@ func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { runtime.SetState(runtime.RuntimeReadyState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Ready -> RestoreReady + runtime.SetState(runtime.RuntimeReadyState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) // Ready -> ResponseSent runtime.SetState(runtime.RuntimeReadyState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -122,10 +125,7 @@ func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { } func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // Running -> InitError runtime.SetState(runtime.RuntimeRunningState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -134,6 +134,10 @@ func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { runtime.SetState(runtime.RuntimeRunningState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Running -> RestoreReady + runtime.SetState(runtime.RuntimeRunningState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) // Running -> ResponseSent runtime.SetState(runtime.RuntimeRunningState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -149,10 +153,7 @@ func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { } func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // InvocationResponse -> InitError runtime.SetState(runtime.RuntimeInvocationResponseState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -161,6 +162,10 @@ func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { runtime.SetState(runtime.RuntimeInvocationResponseState) assert.Equal(t, ErrNotAllowed, runtime.Ready()) assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) + // InvocationResponse -> RestoreReady + runtime.SetState(runtime.RuntimeInvocationResponseState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) // InvocationResponse -> ResponseSent runtime.SetState(runtime.RuntimeInvocationResponseState) assert.NoError(t, runtime.ResponseSent()) @@ -177,10 +182,7 @@ func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { } func TestRuntimeStateTransitionsFromInvocationErrorResponseState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // InvocationErrorResponse -> InitError runtime.SetState(runtime.RuntimeInvocationErrorResponseState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -189,6 +191,10 @@ func TestRuntimeStateTransitionsFromInvocationErrorResponseState(t *testing.T) { runtime.SetState(runtime.RuntimeInvocationErrorResponseState) assert.Equal(t, ErrNotAllowed, runtime.Ready()) assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) + // InvocationErrorResponse -> RestoreReady + runtime.SetState(runtime.RuntimeInvocationErrorResponseState) + assert.Equal(t, ErrNotAllowed, runtime.Ready()) + assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) // InvocationErrorResponse -> ResponseSent runtime.SetState(runtime.RuntimeInvocationErrorResponseState) assert.NoError(t, runtime.ResponseSent()) @@ -204,10 +210,7 @@ func TestRuntimeStateTransitionsFromInvocationErrorResponseState(t *testing.T) { } func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // ResponseSent -> InitError runtime.SetState(runtime.RuntimeResponseSentState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -216,6 +219,10 @@ func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { runtime.SetState(runtime.RuntimeResponseSentState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // ResponseSent -> RestoreReady + runtime.SetState(runtime.RuntimeResponseSentState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) // ResponseSent -> ResponseSent runtime.SetState(runtime.RuntimeResponseSentState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -230,6 +237,71 @@ func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) } +func TestRuntimeStateTransitionsFromRestoreReadyState(t *testing.T) { + runtime := newRuntime() + // RestoreReady -> InitError + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> Ready + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> RestoreReady() + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> ResponseSent + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> InvocationResponse + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> InvocationErrorResponse + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) +} + +func TestRuntimeStateTransitionsFromRestoringState(t *testing.T) { + runtime := newRuntime() + // RestoreRunning -> InitError + runtime.SetState(runtime.RuntimeRestoringState) + assert.NoError(t, runtime.InitError()) + assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) + // RestoreRunning -> Ready + runtime.SetState(runtime.RuntimeRestoringState) + assert.NoError(t, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // RestoreRunning -> RestoreReady + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> ResponseSent + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> InvocationResponse + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> InvocationErrorResponse + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) +} + +func newRuntime() *Runtime { + initFlow := &mockInitFlowSynchronization{} + invokeFlow := &mockInvokeFlowSynchronization{} + runtime := NewRuntime(initFlow, invokeFlow) + runtime.ManagedThread = &mockthread.MockManagedThread{} + + return runtime +} + type mockInitFlowSynchronization struct { mock.Mock ReadyCond *sync.Cond @@ -272,6 +344,12 @@ func (s *mockInitFlowSynchronization) CancelWithError(err error) { s.Called(err) } func (s *mockInitFlowSynchronization) Clear() {} +func (s *mockInitFlowSynchronization) RuntimeRestoreReady() error { + return nil +} +func (s *mockInitFlowSynchronization) AwaitRuntimeRestoreReady() error { + return nil +} type mockInvokeFlowSynchronization struct{ mock.Mock } diff --git a/lambda/core/watchdog.go b/lambda/core/watchdog.go deleted file mode 100644 index bf57d01..0000000 --- a/lambda/core/watchdog.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "fmt" - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/fatalerror" - "sync" -) - -type WaitableProcess interface { - // Wait blocks until process exits and returns error in case of non-zero exit code - Wait() error - // Pid returnes process ID - Pid() int - // Name returnes process executable name (for logging) - Name() string -} - -// Watchdog watches started goroutines. -type Watchdog struct { - cancelOnce sync.Once - initFlow InitFlowSynchronization - invokeFlow InvokeFlowSynchronization - exitPidChan chan<- int - appCtx appctx.ApplicationContext - mutedMutex sync.Mutex - muted bool -} - -func (w *Watchdog) Mute() { - w.mutedMutex.Lock() - defer w.mutedMutex.Unlock() - w.muted = true -} - -func (w *Watchdog) Unmute() { - w.mutedMutex.Lock() - defer w.mutedMutex.Unlock() - w.muted = false -} - -func (w *Watchdog) Muted() bool { - w.mutedMutex.Lock() - defer w.mutedMutex.Unlock() - return w.muted -} - -// GoWait waits for process to complete in separate goroutine and handles the process termination -// Returns PID of the process -func (w *Watchdog) GoWait(p WaitableProcess, errorType fatalerror.ErrorType) int { - pid := p.Pid() - name := p.Name() - appCtx := w.appCtx - go func() { - err := p.Wait() - - if !w.Muted() { - appctx.StoreFirstFatalError(appCtx, errorType) - - if err == nil { - err = fmt.Errorf("exit code 0") - } - log.Warnf("Process %d(%s) exited: %s", pid, name, err) - } - - w.CancelFlows(err) - w.exitPidChan <- pid - }() - - return pid -} - -// CancelFlows cancels init and invoke flows with error. -func (w *Watchdog) CancelFlows(err error) { - // The following block protects us from overwriting the error - // which was first used to cancel flows. - w.cancelOnce.Do(func() { - log.Debugf("Canceling flows: %s", err) - w.initFlow.CancelWithError(err) - w.invokeFlow.CancelWithError(err) - }) -} - -// Clear watchdog state -func (w *Watchdog) Clear() { - w.cancelOnce = sync.Once{} -} - -// NewWatchdog returns new instance of a Watchdog. -func NewWatchdog(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchronization, exitPidChan chan<- int, appCtx appctx.ApplicationContext) *Watchdog { - return &Watchdog{ - initFlow: initFlow, - invokeFlow: invokeFlow, - exitPidChan: exitPidChan, - appCtx: appCtx, - mutedMutex: sync.Mutex{}, - } -} diff --git a/lambda/core/watchdog_test.go b/lambda/core/watchdog_test.go deleted file mode 100644 index 84f8342..0000000 --- a/lambda/core/watchdog_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "errors" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/fatalerror" - "testing" -) - -var errTest = errors.New("ErrTest") - -type MockProcess struct { -} - -func (s *MockProcess) Wait() error { return errTest } -func (s *MockProcess) Pid() int { return 0 } -func (s *MockProcess) Name() string { return "" } - -func TestWatchdogCallback(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - initFlow.On("CancelWithError", mock.Anything) - invokeFlow.On("CancelWithError", mock.Anything) - - pidChan := make(chan int) - appCtx := appctx.NewApplicationContext() - w := NewWatchdog(initFlow, invokeFlow, pidChan, appCtx) - - w.GoWait(&MockProcess{}, fatalerror.AgentExitError) - w.GoWait(&MockProcess{}, fatalerror.AgentExitError) - - <-pidChan - initFlow.AssertCalled(t, "CancelWithError", errTest) - initFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - invokeFlow.AssertCalled(t, "CancelWithError", errTest) - invokeFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - - <-pidChan - initFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - invokeFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - - err, found := appctx.LoadFirstFatalError(appCtx) - require.True(t, found) - require.Equal(t, err, fatalerror.AgentExitError) -} diff --git a/lambda/fatalerror/fatalerror.go b/lambda/fatalerror/fatalerror.go index 7292baf..bb8a86a 100644 --- a/lambda/fatalerror/fatalerror.go +++ b/lambda/fatalerror/fatalerror.go @@ -3,7 +3,7 @@ package fatalerror -// This package defines constant error types returned to slicer with DONE(failure) +// This package defines constant error types returned to slicer with DONE(failure), and also sandbox errors // Separate package for namespacing // ErrorType is returned to slicer inside DONE @@ -18,5 +18,8 @@ const ( InvalidEntrypoint ErrorType = "Runtime.InvalidEntrypoint" InvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir" InvalidTaskConfig ErrorType = "Runtime.InvalidTaskConfig" + TruncatedResponse ErrorType = "Runtime.TruncatedResponse" + SandboxFailure ErrorType = "Sandbox.Failure" + SandboxTimeout ErrorType = "Sandbox.Timeout" Unknown ErrorType = "Unknown" ) diff --git a/lambda/interop/bootstrap.go b/lambda/interop/bootstrap.go new file mode 100644 index 0000000..4a9b6af --- /dev/null +++ b/lambda/interop/bootstrap.go @@ -0,0 +1,18 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "os" + + "go.amzn.com/lambda/fatalerror" +) + +type Bootstrap interface { + Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable + Env(e EnvironmentVariables) map[string]string // returns the environment variables to be passed to the bootstrapped process + Cwd() (string, error) // returns the working directory of the bootstrap process + ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime + CachedFatalError(err error) (fatalerror.ErrorType, string, bool) +} diff --git a/lambda/interop/cancellable_request.go b/lambda/interop/cancellable_request.go new file mode 100644 index 0000000..7e8fca5 --- /dev/null +++ b/lambda/interop/cancellable_request.go @@ -0,0 +1,27 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "net" + "net/http" +) + +type key int + +const ( + HTTPConnKey key = iota +) + +func GetConn(r *http.Request) net.Conn { + return r.Context().Value(HTTPConnKey).(net.Conn) +} + +type CancellableRequest struct { + Request *http.Request +} + +func (c *CancellableRequest) Cancel() error { + return GetConn(c.Request).Close() +} diff --git a/lambda/interop/environment_variables.go b/lambda/interop/environment_variables.go new file mode 100644 index 0000000..46bdf8b --- /dev/null +++ b/lambda/interop/environment_variables.go @@ -0,0 +1,14 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +type EnvironmentVariables interface { + AgentExecEnv() map[string]string + RuntimeExecEnv() map[string]string + SetHandler(handler string) + StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) + StoreEnvironmentVariablesFromInit(customerEnv map[string]string, + handler, awsKey, awsSecret, awsSession, funcName, funcVer string) + StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) +} diff --git a/lambda/interop/model.go b/lambda/interop/model.go index 5cdf63f..cc9c7d0 100644 --- a/lambda/interop/model.go +++ b/lambda/interop/model.go @@ -8,17 +8,99 @@ import ( "fmt" "io" "net/http" + "strings" "time" "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/supervisor/model" + + log "github.com/sirupsen/logrus" ) // MaxPayloadSize max event body size declared as LAMBDA_EVENT_BODY_SIZE -const MaxPayloadSize = 6*1024*1024 + 100 // 6 MiB + 100 bytes +const ( + MaxPayloadSize = 6*1024*1024 + 100 // 6 MiB + 100 bytes + + ResponseBandwidthRate = 2 * 1024 * 1024 // default average rate of 2 MiB/s + ResponseBandwidthBurstSize = 6 * 1024 * 1024 // default burst size of 6 MiB + + MinResponseBandwidthRate = 32 * 1024 // 32 KiB/s + MaxResponseBandwidthRate = 64 * 1024 * 1024 // 64 MiB/s + + MinResponseBandwidthBurstSize = 32 * 1024 // 32 KiB + MaxResponseBandwidthBurstSize = 64 * 1024 * 1024 // 64 MiB +) const functionResponseSizeTooLargeType = "Function.ResponseSizeTooLarge" +// ResponseMode are top-level constants used in combination with the various types of +// modes we have for responses, such as invoke's response mode and function's response mode. +// In the future we might have invoke's request mode or similar, so these help set the ground +// for consistency. +type ResponseMode string + +const ResponseModeBuffered = "Buffered" +const ResponseModeStreaming = "Streaming" + +type InvokeResponseMode string + +const InvokeResponseModeBuffered InvokeResponseMode = ResponseModeBuffered +const InvokeResponseModeStreaming InvokeResponseMode = ResponseModeStreaming + +var AllInvokeResponseModes = []string{ + string(InvokeResponseModeBuffered), string(InvokeResponseModeStreaming), +} + +// ConvertToInvokeResponseMode converts the given string to a InvokeResponseMode +// It is case insensitive and if there is no match, an error is thrown. +func ConvertToInvokeResponseMode(value string) (InvokeResponseMode, error) { + // buffered + if strings.EqualFold(value, string(InvokeResponseModeBuffered)) { + return InvokeResponseModeBuffered, nil + } + + // streaming + if strings.EqualFold(value, string(InvokeResponseModeStreaming)) { + return InvokeResponseModeStreaming, nil + } + + // unknown + allowedValues := strings.Join(AllInvokeResponseModes, ", ") + log.Errorf("Unlable to map %s to %s.", value, allowedValues) + return "", ErrInvalidInvokeResponseMode +} + +// FunctionResponseMode is passed by Runtime to tell whether the response should be +// streamed or not. +type FunctionResponseMode string + +const FunctionResponseModeBuffered FunctionResponseMode = ResponseModeBuffered +const FunctionResponseModeStreaming FunctionResponseMode = ResponseModeStreaming + +var AllFunctionResponseModes = []string{ + string(FunctionResponseModeBuffered), string(FunctionResponseModeStreaming), +} + +// ConvertToFunctionResponseMode converts the given string to a FunctionResponseMode +// It is case insensitive and if there is no match, an error is thrown. +func ConvertToFunctionResponseMode(value string) (FunctionResponseMode, error) { + // buffered + if strings.EqualFold(value, string(FunctionResponseModeBuffered)) { + return FunctionResponseModeBuffered, nil + } + + // streaming + if strings.EqualFold(value, string(FunctionResponseModeStreaming)) { + return FunctionResponseModeStreaming, nil + } + + // unknown + allowedValues := strings.Join(AllFunctionResponseModes, ", ") + log.Errorf("Unlable to map %s to %s.", value, allowedValues) + return "", ErrInvalidFunctionResponseMode +} + // Message is a generic interop message. type Message interface{} @@ -37,11 +119,10 @@ type Invoke struct { ContentType string Payload io.Reader NeedDebugLogs bool - CorrelationID string // internal use only ReservationToken string VersionID string InvokeReceivedTime int64 - ResyncState Resync + InvokeResponseMetrics *InvokeResponseMetrics } type Token struct { @@ -54,21 +135,13 @@ type Token struct { LambdaSegmentID string InvokeMetadata string NeedDebugLogs bool - ResyncState Resync -} - -type Resync struct { - IsResyncReceived bool - AwsKey string - AwsSecret string - AwsSession string - ReceivedTime time.Time } type ErrorResponse struct { // Payload sent via shared memory. - Payload []byte `json:"Payload,omitempty"` - ContentType string `json:"-"` + Payload []byte `json:"Payload,omitempty"` + ContentType string `json:"-"` + FunctionResponseMode string `json:"-"` // When error response body (Payload) is not provided, e.g. // not retrievable, error type and error message will be @@ -92,48 +165,80 @@ type SandboxType string const SandboxPreWarmed SandboxType = "PreWarmed" const SandboxClassic SandboxType = "Classic" -// Start message received from the slicer, part of the protocol. -type Start struct { - InvokeID string - Handler string - AwsKey string - AwsSecret string - AwsSession string - SuppressInit bool - XRayDaemonAddress string // only in standalone - FunctionName string // only in standalone - FunctionVersion string // only in standalone - CorrelationID string // internal use only - // TODO: define new Init type that has the Start fields as well as env vars below. - // In standalone mode, these env vars come from test/init but from environment otherwise. - CustomerEnvironmentVariables map[string]string - SandboxType SandboxType +// RuntimeInfo contains metadata about the runtime used by the Sandbox +type RuntimeInfo struct { + ImageJSON string // image config, e.g {\"layers\":[]} + Arn string // runtime ARN, e.g. arn:awstest:lambda:us-west-2::runtime:python3.8::alpha + Version string // human-readable runtime arn equivalent, e.g. python3.8.v999 } -// Running message is sent to the slicer, part of the protocol. -type Running struct { - WaitStartTimeNs int64 - WaitEndTimeNs int64 - PreLoadTimeNs int64 - PostLoadTimeNs int64 - ExtensionsEnabled bool +// Captures configuration of the operator and runtime domain +// that are only known after INIT is received +type DynamicDomainConfig struct { + // extra hooks to execute at domain start. Currently used for filesystem and network hooks. + // It can be empty. + AdditionalStartHooks []model.Hook + Mounts []model.DriveMount + //TODO: other dynamic configurations for the domain go here } // Reset message is sent to rapid to initiate reset sequence type Reset struct { - Reason string - DeadlineNs int64 - CorrelationID string // internal use only + Reason string + DeadlineNs int64 + InvokeResponseMetrics *InvokeResponseMetrics + TraceID string + LambdaSegmentID string +} + +// Restore message is sent to rapid to restore runtime to make it ready for consecutive invokes +type Restore struct { + AwsKey string + AwsSecret string + AwsSession string + CredentialsExpiry time.Time +} + +type Resync struct { } // Shutdown message is sent to rapid to initiate graceful shutdown type Shutdown struct { - DeadlineNs int64 - CorrelationID string // internal use only + DeadlineNs int64 +} + +// Metrics for response status of LogsAPI/TelemetryAPI `/subscribe` calls +type TelemetrySubscriptionMetrics map[string]int + +func MergeSubscriptionMetrics(logsAPIMetrics TelemetrySubscriptionMetrics, telemetryAPIMetrics TelemetrySubscriptionMetrics) TelemetrySubscriptionMetrics { + metrics := make(map[string]int) + for metric, value := range logsAPIMetrics { + metrics[metric] = value + } + + for metric, value := range telemetryAPIMetrics { + metrics[metric] += value + } + return metrics +} + +// InvokeResponseMetrics are produced while sending streaming invoke response to WP +type InvokeResponseMetrics struct { + StartReadingResponseMonoTimeMs int64 + FinishReadingResponseMonoTimeMs int64 + TimeShapedNs int64 + ProducedBytes int64 + OutboundThroughputBps int64 // in bytes per second + FunctionResponseMode FunctionResponseMode + RuntimeCalledResponse bool } -// Metrics for response status of LogsAPI `/subscribe` calls -type LogsAPIMetrics map[string]int +func IsResponseStreamingMetrics(metrics *InvokeResponseMetrics) bool { + if metrics == nil { + return false + } + return metrics.FunctionResponseMode == FunctionResponseModeStreaming +} type DoneMetadata struct { NumActiveExtensions int @@ -141,25 +246,26 @@ type DoneMetadata struct { ExtensionNames string RuntimeRelease string // Metrics for response status of LogsAPI `/subscribe` calls - LogsAPIMetrics LogsAPIMetrics - InvokeRequestReadTimeNs int64 - InvokeRequestSizeBytes int64 - InvokeCompletionTimeNs int64 - InvokeReceivedTime int64 - RuntimeReadyTime int64 + LogsAPIMetrics TelemetrySubscriptionMetrics + InvokeRequestReadTimeNs int64 + InvokeRequestSizeBytes int64 + InvokeCompletionTimeNs int64 + InvokeReceivedTime int64 + RuntimeReadyTime int64 + RuntimeTimeThrottledMs int64 + RuntimeProducedBytes int64 + RuntimeOutboundThroughputBps int64 } type Done struct { - WaitForExit bool - ErrorType fatalerror.ErrorType - CorrelationID string // internal use only - Meta DoneMetadata + WaitForExit bool + ErrorType fatalerror.ErrorType + Meta DoneMetadata } type DoneFail struct { - ErrorType fatalerror.ErrorType - CorrelationID string // internal use only - Meta DoneMetadata + ErrorType fatalerror.ErrorType + Meta DoneMetadata } // ErrInvalidInvokeID is returned when invokeID provided in Invoke2 does not match one provided in Token @@ -171,6 +277,22 @@ var ErrInvalidReservationToken = fmt.Errorf("ErrInvalidReservationToken") // ErrInvalidFunctionVersion is returned when functionVersion provided in Invoke2 does not match one provided in Token var ErrInvalidFunctionVersion = fmt.Errorf("ErrInvalidFunctionVersion") +// ErrInvalidFunctionResponseMode is returned when the value sent by runtime during Invoke2 +// is not a constant of type interop.FunctionResponseMode +var ErrInvalidFunctionResponseMode = fmt.Errorf("ErrInvalidFunctionResponseMode") + +// ErrInvalidInvokeResponseMode is returned when optional InvokeResponseMode header provided in Invoke2 is not a constant of type interop.InvokeResponseMode +var ErrInvalidInvokeResponseMode = fmt.Errorf("ErrInvalidInvokeResponseMode") + +// ErrInvalidMaxPayloadSize is returned when optional MaxPayloadSize header provided in Invoke2 is invalid +var ErrInvalidMaxPayloadSize = fmt.Errorf("ErrInvalidMaxPayloadSize") + +// ErrInvalidResponseBandwidthRate is returned when optional ResponseBandwidthRate header provided in Invoke2 is invalid +var ErrInvalidResponseBandwidthRate = fmt.Errorf("ErrInvalidResponseBandwidthRate") + +// ErrInvalidResponseBandwidthBurstSize is returned when optional ResponseBandwidthBurstSize header provided in Invoke2 is invalid +var ErrInvalidResponseBandwidthBurstSize = fmt.Errorf("ErrInvalidResponseBandwidthBurstSize") + // ErrMalformedCustomerHeaders is returned when customer headers format is invalid var ErrMalformedCustomerHeaders = fmt.Errorf("ErrMalformedCustomerHeaders") @@ -180,6 +302,20 @@ var ErrResponseSent = fmt.Errorf("ErrResponseSent") // ErrReservationExpired is returned when invoke arrived after InvackDeadline var ErrReservationExpired = fmt.Errorf("ErrReservationExpired") +// ErrInternalPlatformError is returned when internal platform error occurred +type ErrInternalPlatformError struct{} + +func (s *ErrInternalPlatformError) Error() string { + return "ErrInternalPlatformError" +} + +// ErrTruncatedResponse is returned when response is truncated +type ErrTruncatedResponse struct{} + +func (s *ErrTruncatedResponse) Error() string { + return "ErrTruncatedResponse" +} + // ErrorResponseTooLarge is returned when response Payload exceeds shared memory buffer size type ErrorResponseTooLarge struct { MaxResponseSize int @@ -211,17 +347,21 @@ func (s *ErrorResponseTooLarge) AsInteropError() *ErrorResponse { return &resp } -// Server implements Slicer communication protocol. +// Server used for sending messages and sharing data between the Runtime API handlers and the +// internal platform facing servers. For example, +// +// responseCtx.SendResponse(...) +// +// will send the response payload and metadata provided by the runtime to the platform, through the internal +// protocol used by the specific implementation +// TODO: rename this to InvokeResponseContext, used to send responses from handlers to platform-facing server type Server interface { - // StartAcceptingDirectInvokes starts accepting on direct invoke socket (if one is available) - StartAcceptingDirectInvokes() error - - // SendErrorResponse sends response. + // SendResponse sends response. // Errors returned: // ErrInvalidInvokeID - validation error indicating that provided invokeID doesn't match current invokeID // ErrResponseSent - validation error indicating that response with given invokeID was already sent // Non-nil error - non-nil error indicating transport failure - SendResponse(invokeID string, contentType string, response io.Reader) error + SendResponse(invokeID string, headers map[string]string, response io.Reader, trailers http.Header, request *CancellableRequest) error // SendErrorResponse sends error response. // Errors returned: @@ -229,61 +369,36 @@ type Server interface { // ErrResponseSent - validation error indicating that response with given invokeID was already sent // Non-nil error - non-nil error indicating transport failure SendErrorResponse(invokeID string, response *ErrorResponse) error + SendInitErrorResponse(invokeID string, response *ErrorResponse) error // GetCurrentInvokeID returns current invokeID. // NOTE, in case of INIT, when invokeID is not known in advance (e.g. provisioned concurrency), // returned invokeID will contain empty value. GetCurrentInvokeID() string - // CommitMessage confirms that the message written through SendResponse and SendErrorResponse is complete. - CommitResponse() error - - // SendRunning sends GIRD RUNNING. - // Returns error on transport failure. - SendRunning(*Running) error - - // SendRuntimeReady sends GIRD RTREADY + // SendRuntimeReady sends a message indicating the runtime has called /invocation/next. + // The checkpoint allows us to compute the overhead due to Extensions by substracting it + // from the time when all extensions have called /next. + // TODO: this method is a lifecycle event used only for metrics, and doesn't belong here SendRuntimeReady() error +} - // SendDone sends GIRD DONE. - // Returns error on transport failure. - SendDone(*Done) error - - // SendDone sends GIRD DONEFAIL. - // Returns error on transport failure. - SendDoneFail(*DoneFail) error - - // StartChan returns Start emitter - StartChan() <-chan *Start - - // InvokeChan returns Invoke emitter - InvokeChan() <-chan *Invoke - - // ResetChan returns Reset emitter - ResetChan() <-chan *Reset - - // ShutdownChan returns Shutdown emitter - ShutdownChan() <-chan *Shutdown - - // TransportErrorChan emits errors if there was parsing/connection issue - TransportErrorChan() <-chan error - - // Clear is called on rapid reset. It should leave server prepared for new invocations - Clear() - - // IsResponseSent exposes is response sent flag - IsResponseSent() bool - - // The following are used by standalone rapid only - // TODO refactor to decouple the interfaces +type InternalStateGetter func() statejson.InternalStateDescription - SetInternalStateGetter(cb InternalStateGetter) +const OnDemandInitTelemetrySource string = "on-demand" +const ProvisionedConcurrencyInitTelemetrySource string = "provisioned-concurrency" +const InitCachingInitTelemetrySource string = "snap-start" - Init(i *Start, invokeTimeoutMs int64) +func InferTelemetryInitSource(initCachingEnabled bool, sandboxType SandboxType) string { + initSource := OnDemandInitTelemetrySource - Invoke(responseWriter http.ResponseWriter, invoke *Invoke) error + // ToDo: Unify this selection of SandboxType by using the START message + // after having a roadmap on the combination of INIT modes + if initCachingEnabled { + initSource = InitCachingInitTelemetrySource + } else if sandboxType == SandboxPreWarmed { + initSource = ProvisionedConcurrencyInitTelemetrySource + } - Shutdown(shutdown *Shutdown) *statejson.InternalStateDescription + return initSource } - -type InternalStateGetter func() statejson.InternalStateDescription diff --git a/lambda/interop/model_test.go b/lambda/interop/model_test.go new file mode 100644 index 0000000..9ad4d17 --- /dev/null +++ b/lambda/interop/model_test.go @@ -0,0 +1,27 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMergeSubscriptionMetrics(t *testing.T) { + logsAPIMetrics := map[string]int{ + "server_error": 1, + "client_error": 2, + } + + telemetryAPIMetrics := map[string]int{ + "server_error": 1, + "success": 5, + } + + metrics := MergeSubscriptionMetrics(logsAPIMetrics, telemetryAPIMetrics) + assert.Equal(t, 5, metrics["success"]) + assert.Equal(t, 2, metrics["server_error"]) + assert.Equal(t, 2, metrics["client_error"]) +} diff --git a/lambda/interop/sandbox_model.go b/lambda/interop/sandbox_model.go index dddfcf2..b5d15b0 100644 --- a/lambda/interop/sandbox_model.go +++ b/lambda/interop/sandbox_model.go @@ -3,20 +3,183 @@ package interop -// Init represents an init message and is currently only used in standalone +import ( + "time" + + "go.amzn.com/lambda/fatalerror" +) + +// Init represents an init message +// In Rapid Shim, this is a START GirD message +// In Rapid Daemon, this is an INIT GirP message type Init struct { InvokeID string Handler string AwsKey string AwsSecret string AwsSession string + CredentialsExpiry time.Time SuppressInit bool + InvokeTimeoutMs int64 // timeout duration of whole invoke + InitTimeoutMs int64 // timeout duration for init only XRayDaemonAddress string // only in standalone FunctionName string // only in standalone FunctionVersion string // only in standalone - CorrelationID string // internal use only - // TODO: define new Init type that has the Start fields as well as env vars below. // In standalone mode, these env vars come from test/init but from environment otherwise. CustomerEnvironmentVariables map[string]string - SandboxType + SandboxType SandboxType + // there is no dynamic config at the moment for the runtime domain + OperatorDomainExtraConfig DynamicDomainConfig + RuntimeInfo RuntimeInfo + Bootstrap Bootstrap + EnvironmentVariables EnvironmentVariables // contains env vars for agents and runtime procs +} + +// InitStarted contains metadata about the initialized sandbox +// In Rapid Shim, this translates to a RUNNING GirD message to Slicer +// In Rapid Daemon, this is followed by a SANDBOX GirP message to MM +type InitStarted struct { + WaitStartTimeNs int64 + WaitEndTimeNs int64 + PreLoadTimeNs int64 + PostLoadTimeNs int64 + ExtensionsEnabled bool + Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent +} + +// InitSuccess indicates that runtime/extensions initialization completed successfully +// In Rapid Shim, this translates to a DONE GirD message to Slicer +// In Rapid Daemon, this is followed by a DONEDONE GirP message to MM +type InitSuccess struct { + NumActiveExtensions int // indicates number of active extensions + ExtensionNames string // file names of extensions in /opt/extensions + RuntimeRelease string + LogsAPIMetrics TelemetrySubscriptionMetrics // used if telemetry API enabled + Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent +} + +// InitFailure indicates that runtime/extensions initialization failed due to process exit or /error calls +// In Rapid Shim, this translates to either a DONE or a DONEFAIL GirD message to Slicer (depending on extensions mode) +// However, even on failure, the next invoke is expected to work with a suppressed init - i.e. we init again as aprt of the invoke +type InitFailure struct { + ResetReceived bool // indicates if failure happened due to a reset received + RequestReset bool // Indicates whether reset should be requested on init failure + ErrorType fatalerror.ErrorType + ErrorMessage error + NumActiveExtensions int + RuntimeRelease string // value of the User Agent HTTP header provided by runtime + LogsAPIMetrics TelemetrySubscriptionMetrics + Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent +} + +// ResponseMetrics groups metrics related to the response stream +type ResponseMetrics struct { + RuntimeTimeThrottledMs int64 + RuntimeProducedBytes int64 + RuntimeOutboundThroughputBps int64 +} + +// InvokeMetrics groups metrics related to the invoke phase +type InvokeMetrics struct { + InvokeRequestReadTimeNs int64 + InvokeRequestSizeBytes int64 + RuntimeReadyTime int64 +} + +// InvokeSuccess is the success response to invoke phase end +type InvokeSuccess struct { + RuntimeRelease string // value of the User Agent HTTP header provided by runtime + NumActiveExtensions int + ExtensionNames string + InvokeCompletionTimeNs int64 + InvokeReceivedTime int64 + LogsAPIMetrics TelemetrySubscriptionMetrics + ResponseMetrics ResponseMetrics + InvokeMetrics InvokeMetrics +} + +// InvokeFailure is the failure response to invoke phase end +type InvokeFailure struct { + ResetReceived bool // indicates if failure happened due to a reset received + RequestReset bool // indicates if reset must be requested after the failure + ErrorType fatalerror.ErrorType + ErrorMessage error + RuntimeRelease string // value of the User Agent HTTP header provided by runtime + NumActiveExtensions int + InvokeReceivedTime int64 + LogsAPIMetrics TelemetrySubscriptionMetrics + ResponseMetrics ResponseMetrics + InvokeMetrics InvokeMetrics + ExtensionNames string + DefaultErrorResponse *ErrorResponse // error resp constructed by platform during fn errors +} + +// ResetSuccess is the success response to reset request +type ResetSuccess struct { + ExtensionsResetMs int64 + ErrorType fatalerror.ErrorType + ResponseMetrics ResponseMetrics +} + +// ResetFailure is the failure response to reset request +type ResetFailure struct { + ExtensionsResetMs int64 + ErrorType fatalerror.ErrorType + ResponseMetrics ResponseMetrics +} + +// ShutdownSuccess is the response to a shutdown request +type ShutdownSuccess struct { + ErrorType fatalerror.ErrorType +} + +// SandboxInfoFromInit captures data from init request that +// is required during invoke (e.g. for suppressed init) +type SandboxInfoFromInit struct { + EnvironmentVariables EnvironmentVariables // contains agent env vars (creds, customer, platform) + SandboxType SandboxType // indicating Pre-Warmed, On-Demand etc + RuntimeBootstrap Bootstrap // contains the runtime bootstrap binary path, Cwd, Args, Env, Cmd +} + +// RapidContext expose methods for functionality of the Rapid Core library +type RapidContext interface { + HandleInit(i *Init, started chan<- InitStarted, success chan<- InitSuccess, failure chan<- InitFailure) + HandleInvoke(i *Invoke, sbMetadata SandboxInfoFromInit) (InvokeSuccess, *InvokeFailure) + HandleReset(reset *Reset, invokeReceivedTime int64, InvokeResponseMetrics *InvokeResponseMetrics) (ResetSuccess, *ResetFailure) + HandleShutdown(shutdown *Shutdown) ShutdownSuccess + HandleRestore(restore *Restore) error + Clear() +} + +// SandboxContext represents the sandbox lifecycle context +type SandboxContext interface { + Init(i *Init, timeoutMs int64) (InitStarted, InitContext) + Reset(reset *Reset) (ResetSuccess, *ResetFailure) + Shutdown(shutdown *Shutdown) ShutdownSuccess + Restore(restore *Restore) error + + // TODO: refactor this + // invokeReceivedTime and InvokeResponseMetrics are needed to compute the runtimeDone metrics + // in case of a Reset during an invoke (reset.reason=failure or reset.reason=timeout). + // Ideally: + // - the InvokeContext will have a Reset method to deal with Reset during an invoke and will hold invokeReceivedTime and InvokeResponseMetrics + // - the SandboxContext will have its own Reset/Spindown method + SetInvokeReceivedTime(invokeReceivedTime int64) + SetInvokeResponseMetrics(metrics *InvokeResponseMetrics) +} + +// InitContext represents the lifecycle of a sandbox initialization +type InitContext interface { + Wait() (InitSuccess, *InitFailure) + Reserve() InvokeContext +} + +// InvokeContext represents the lifecycle of a sandbox reservation +type InvokeContext interface { + SendRequest(i *Invoke) + Wait() (InvokeSuccess, *InvokeFailure) +} + +// Restored message is sent to Slicer to inform Runtime Restore Hook execution was successful +type Restored struct { } diff --git a/lambda/logging/doc.go b/lambda/logging/doc.go index 92637a1..a1f7e95 100644 --- a/lambda/logging/doc.go +++ b/lambda/logging/doc.go @@ -2,24 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 /* - RAPID emits or proxies the following sources of logging: -1. Internal logs: RAPID's own application logs into stderr for operational use, visible only internally -2. Function stream-based logs: Runtime's stdout and stderr, read as newline separated lines -3. Function message-based logs: Stock runtimes communicate using a custom TLV protocol over a Unix pipe -4. Extension stream-based logs: Extension's stdout and stderr, read as newline separated lines -5. Platform logs: Logs that RAPID generates, but is visible in customer's logs. - - -It has the following log sinks, which may further be egressed to other sinks (e.g. CloudWatch) by external telemetry agents: - -1. Internal Log File (stderr): stderr is redirected to a file specified by Sandbox Factory via env-vars, and accessible via StreamQuery -2. Stdout: stream-based function logs are output to RAPID's stdout process, and read by a telemetry agent -3. Telemetry API MSG-verb events: function messages-based logs are written using GirP protocol into the console socket specified by Sandbox Factory env-vars -4. Telemetry API LOGX-verb events: extension stream-based logs are written using GirP protocol into the console socket specified by Sandbox Factory env-vars -5. Telemetry API LOGP-verb events: platform logs are written using GirP protocol into the console socket specified by Sandbox Factory env-vars -6. Tail logs: a truncated version of function stream-based and message-based logs are written along with the invocation response to the frontend when 'debug logging' is enabled - + 1. Internal logs: RAPID's own application logs into stderr for operational use, visible only internally + 2. Function stream-based logs: Runtime's stdout and stderr, read as newline separated lines + 3. Function message-based logs: Stock runtimes communicate using a custom TLV protocol over a Unix pipe + 4. Extension stream-based logs: Extension's stdout and stderr, read as newline separated lines + 5. Platform logs: Logs that RAPID generates, but is visible either in customer's logs or via Logs API + (e.g. EXTENSION, RUNTIME, RUNTIMEDONE, IMAGE) */ package logging diff --git a/lambda/logging/internal_log_test.go b/lambda/logging/internal_log_test.go index b94ac88..3ec537f 100644 --- a/lambda/logging/internal_log_test.go +++ b/lambda/logging/internal_log_test.go @@ -8,7 +8,7 @@ import ( "fmt" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "io/ioutil" + "io" "log" "testing" ) @@ -67,14 +67,14 @@ func TestInternalFormatter(t *testing.T) { } func BenchmarkLogPrint(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { log.Print(1, "two", true) } } func BenchmarkLogrusPrint(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { logrus.Print(1, "two", true) } @@ -83,21 +83,21 @@ func BenchmarkLogrusPrint(b *testing.B) { func BenchmarkLogrusPrintInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) for n := 0; n < b.N; n++ { l.Print(1, "two", true) } } func BenchmarkLogPrintf(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { log.Printf("field:%v,field:%v,field:%v", 1, "two", true) } } func BenchmarkLogrusPrintf(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { logrus.Printf("field:%v,field:%v,field:%v", 1, "two", true) } @@ -106,14 +106,14 @@ func BenchmarkLogrusPrintf(b *testing.B) { func BenchmarkLogrusPrintfInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) for n := 0; n < b.N; n++ { l.Printf("field:%v,field:%v,field:%v", 1, "two", true) } } func BenchmarkLogrusDebugLogLevelDisabled(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) logrus.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { logrus.Debug(1, "two", true) @@ -122,7 +122,7 @@ func BenchmarkLogrusDebugLogLevelDisabled(b *testing.B) { func BenchmarkLogrusDebugLogLevelDisabledInternalFormatter(b *testing.B) { var l = logrus.New() - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) l.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { l.Debug(1, "two", true) @@ -130,7 +130,7 @@ func BenchmarkLogrusDebugLogLevelDisabledInternalFormatter(b *testing.B) { } func BenchmarkLogrusDebugLogLevelEnabled(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) logrus.SetLevel(logrus.DebugLevel) for n := 0; n < b.N; n++ { logrus.Debug(1, "two", true) @@ -140,7 +140,7 @@ func BenchmarkLogrusDebugLogLevelEnabled(b *testing.B) { func BenchmarkLogrusDebugLogLevelEnabledInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) l.SetLevel(logrus.DebugLevel) for n := 0; n < b.N; n++ { l.Debug(1, "two", true) @@ -148,7 +148,7 @@ func BenchmarkLogrusDebugLogLevelEnabledInternalFormatter(b *testing.B) { } func BenchmarkLogrusDebugWithFieldLogLevelDisabled(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) logrus.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { logrus.WithField("field", "value").Debug(1, "two", true) @@ -158,7 +158,7 @@ func BenchmarkLogrusDebugWithFieldLogLevelDisabled(b *testing.B) { func BenchmarkLogrusDebugWithFieldLogLevelDisabledInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) l.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { l.WithField("field", "value").Debug(1, "two", true) diff --git a/lambda/logging/platform_log.go b/lambda/logging/platform_log.go deleted file mode 100644 index 5154f93..0000000 --- a/lambda/logging/platform_log.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "fmt" - "io" - "log" - "strings" -) - -// TODO PlatformLogger interface has this LogExtensionInitEvent() method so it's easier to assert against it in standalone tests; -// TODO However, this makes interface harder to maintain (you are supposed to add new method to PlatformLogger for each event type) -// TODO We need to remove those methods and make PlatformLogger just a log.Logger interface - -// PlatformLogger is a logger that logs platform lines to customers' logs -type PlatformLogger interface { - Printf(fmt string, args ...interface{}) - LogExtensionInitEvent(agentName, state, errorType string, subscriptions []string) -} - -// FormattedPlatformLogger formats and logs platform lines to customers' logs via Telemetry API -type FormattedPlatformLogger struct { - logger *log.Logger -} - -// NewPlatformLogger is a logger for logging Platform log lines into customers' logs -func NewPlatformLogger(output, tailLogWriter io.Writer) *FormattedPlatformLogger { - prefix, flags := "", 0 - return &FormattedPlatformLogger{ - logger: log.New(io.MultiWriter(output, tailLogWriter), prefix, flags), - } -} - -// LogExtensionInitEvent formats and logs a line containing agent info -func (l *FormattedPlatformLogger) LogExtensionInitEvent(agentName, state, errorType string, subscriptions []string) { - format := "EXTENSION\tName: %s\tState: %s\tEvents: [%s]" - line := fmt.Sprintf(format, agentName, state, strings.Join(subscriptions, ",")) - if len(errorType) > 0 { - line += fmt.Sprintf("\tError Type: %s", errorType) - } - l.logger.Println(line) -} - -func (l *FormattedPlatformLogger) Printf(fmt string, args ...interface{}) { - fmt += "\n" // we append newline to the logline because that's how they are separated on recepient - l.logger.Printf(fmt, args...) -} - -func SupernovaInvalidTaskConfigRepr(err error) func(error) string { - return func(unused error) string { - return fmt.Sprintf("IMAGE\tInvalid task config: %s", err) - } -} - -func SupernovaLaunchErrorRepr(entrypoint []string, cmd []string, workingDir string) func(error) string { - return func(err error) string { - return fmt.Sprintf("IMAGE\tLaunch error: %s\tEntrypoint: [%s]\tCmd: [%s]\tWorkingDir: [%s]", - err, - strings.Join(entrypoint, ","), - strings.Join(cmd, ","), - workingDir) - } -} diff --git a/lambda/logging/platform_log_test.go b/lambda/logging/platform_log_test.go deleted file mode 100644 index 8b01778..0000000 --- a/lambda/logging/platform_log_test.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestPlatformLogExtensionLine(t *testing.T) { - var buf bytes.Buffer - var tailLogBuf bytes.Buffer - logger := NewPlatformLogger(&buf, &tailLogBuf) - - logger.LogExtensionInitEvent("agentName", "Registered", "", []string{"INVOKE", "SHUTDOWN"}) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\n", buf.String()) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\n", tailLogBuf.String()) -} - -func TestPlatformLogExtensionLineWithError(t *testing.T) { - var buf bytes.Buffer - var tailLogBuf bytes.Buffer - logger := NewPlatformLogger(&buf, &tailLogBuf) - - errorType := "Extension.FooBar" - logger.LogExtensionInitEvent("agentName", "Registered", errorType, []string{"INVOKE", "SHUTDOWN"}) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\tError Type: "+errorType+"\n", buf.String()) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\tError Type: "+errorType+"\n", tailLogBuf.String()) -} - -func TestPlatformLogPrintf(t *testing.T) { - var buf bytes.Buffer - var tailLogBuf bytes.Buffer - logger := NewPlatformLogger(&buf, &tailLogBuf) - - logger.Printf("bebe %s %d", "as", 12) - require.Equal(t, "bebe as 12\n", buf.String()) - require.Equal(t, "bebe as 12\n", tailLogBuf.String()) -} diff --git a/lambda/logging/taillog.go b/lambda/logging/taillog.go deleted file mode 100644 index 9fe5352..0000000 --- a/lambda/logging/taillog.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "io" - "sync" -) - -// TailLogWriter writes tail/debug log to provided io.Writer -type TailLogWriter struct { - out io.Writer - enabled bool - mutex sync.Mutex -} - -// Enable enables log writer. -func (lw *TailLogWriter) Enable() { - lw.mutex.Lock() - defer lw.mutex.Unlock() - - lw.enabled = true -} - -// Disable disables log writer. -func (lw *TailLogWriter) Disable() { - lw.mutex.Lock() - defer lw.mutex.Unlock() - - lw.enabled = false -} - -// Writer wraps the basic io.Write method -func (lw *TailLogWriter) Write(p []byte) (n int, err error) { - lw.mutex.Lock() - defer lw.mutex.Unlock() - - if lw.enabled { - return lw.out.Write(p) - } - // Else returns a successful write so that MultiWriter won't stop - return len(p), nil -} - -// NewTailLogWriter returns a new invoke tail log writer, default output is discarded until output is configured. -func NewTailLogWriter(w io.Writer) *TailLogWriter { - return &TailLogWriter{ - out: w, - enabled: false, - } -} diff --git a/lambda/logging/taillog_test.go b/lambda/logging/taillog_test.go deleted file mode 100644 index 2bc444c..0000000 --- a/lambda/logging/taillog_test.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDisableDebugLog(t *testing.T) { - buf := new(bytes.Buffer) - tailLogWriter := NewTailLogWriter(buf) - tailLogWriter.Disable() - - tailLogWriter.Write([]byte("hello_world")) - assert.Len(t, buf.String(), 0) -} - -func TestEnableDebugLog(t *testing.T) { - buf := new(bytes.Buffer) - tailLogWriter := NewTailLogWriter(buf) - tailLogWriter.Enable() - - tailLogWriter.Write([]byte("hello_world")) - assert.Equal(t, "hello_world", buf.String()) -} diff --git a/lambda/rapi/handler/agentiniterror_test.go b/lambda/rapi/handler/agentiniterror_test.go index 571031e..50b9143 100644 --- a/lambda/rapi/handler/agentiniterror_test.go +++ b/lambda/rapi/handler/agentiniterror_test.go @@ -6,7 +6,7 @@ package handler import ( "context" "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -60,7 +60,7 @@ func TestAgentInitErrorMissingErrorHeader(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, errAgentMissingHeader, errorResponse.ErrorType) } @@ -77,7 +77,7 @@ func TestAgentInitErrorUnknownAgent(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, errAgentIdentifierUnknown, errorResponse.ErrorType) } @@ -97,7 +97,7 @@ func TestAgentInitErrorAgentInvalidState(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, errAgentInvalidState, errorResponse.ErrorType) } @@ -118,7 +118,7 @@ func TestAgentInitErrorRequestAccepted(t *testing.T) { assert.Equal(t, http.StatusAccepted, responseRecorder.Code) var response model.StatusResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, "OK", response.Status) diff --git a/lambda/rapi/handler/agentnext_test.go b/lambda/rapi/handler/agentnext_test.go index ef14e49..003c4b6 100644 --- a/lambda/rapi/handler/agentnext_test.go +++ b/lambda/rapi/handler/agentnext_test.go @@ -7,7 +7,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -52,7 +52,7 @@ func TestRenderAgentInvokeUnknownAgent(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, errAgentIdentifierUnknown, errorResponse.ErrorType) @@ -75,7 +75,7 @@ func TestRenderAgentInvokeInvalidAgentState(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, errAgentInvalidState, errorResponse.ErrorType) @@ -118,7 +118,7 @@ func TestRenderAgentInvokeNextHappy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentInvokeEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -167,7 +167,7 @@ func TestRenderAgentInternalInvokeNextHappy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentInvokeEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -212,7 +212,7 @@ func TestRenderAgentInternalShutdownEvent(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentShutdownEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -254,7 +254,7 @@ func TestRenderAgentExternalShutdownEvent(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentShutdownEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -297,7 +297,7 @@ func TestRenderAgentInvokeNextHappyEmptyTraceID(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentInvokeEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Nil(t, response.Tracing) diff --git a/lambda/rapi/handler/agentregister.go b/lambda/rapi/handler/agentregister.go index 776ac28..8882965 100644 --- a/lambda/rapi/handler/agentregister.go +++ b/lambda/rapi/handler/agentregister.go @@ -6,10 +6,9 @@ package handler import ( "encoding/json" "errors" - "io/ioutil" + "io" "net/http" - "github.com/go-chi/render" log "github.com/sirupsen/logrus" "go.amzn.com/lambda/core" "go.amzn.com/lambda/rapi/model" @@ -26,7 +25,7 @@ type RegisterRequest struct { } func parseRegister(request *http.Request) (*RegisterRequest, error) { - body, err := ioutil.ReadAll(request.Body) + body, err := io.ReadAll(request.Body) if err != nil { return nil, err } @@ -70,7 +69,6 @@ func (h *agentRegisterHandler) ServeHTTP(writer http.ResponseWriter, request *ht } func (h *agentRegisterHandler) renderResponse(agentID string, writer http.ResponseWriter, request *http.Request) { - render.Status(request, http.StatusOK) writer.Header().Set(LambdaAgentIdentifier, agentID) metadata := h.registrationService.GetFunctionMetadata() @@ -81,7 +79,10 @@ func (h *agentRegisterHandler) renderResponse(agentID string, writer http.Respon Handler: metadata.Handler, } - render.JSON(writer, request, resp) + if err := rendering.RenderJSON(http.StatusOK, writer, request, resp); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(writer, err.Error(), http.StatusInternalServerError) + } } func (h *agentRegisterHandler) registerExternalAgent(agent *core.ExternalAgent, registerRequest *RegisterRequest, writer http.ResponseWriter, request *http.Request) { diff --git a/lambda/rapi/handler/agentregister_test.go b/lambda/rapi/handler/agentregister_test.go index 185f249..35456ee 100644 --- a/lambda/rapi/handler/agentregister_test.go +++ b/lambda/rapi/handler/agentregister_test.go @@ -7,7 +7,6 @@ import ( "bytes" "encoding/json" "io" - "io/ioutil" "net/http" "net/http/httptest" "testing" @@ -41,7 +40,7 @@ func TestRenderAgentRegisterInvalidAgentName(t *testing.T) { require.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) require.Equal(t, http.StatusForbidden, responseRecorder.Code) require.Equal(t, errAgentNameInvalid, errorResponse.ErrorType) @@ -63,7 +62,7 @@ func TestRenderAgentRegisterRegistrationClosed(t *testing.T) { require.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) require.Equal(t, http.StatusForbidden, responseRecorder.Code) require.Equal(t, errAgentRegistrationClosed, errorResponse.ErrorType) @@ -88,7 +87,7 @@ func TestRenderAgentRegisterInvalidAgentState(t *testing.T) { require.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) require.Equal(t, http.StatusForbidden, responseRecorder.Code) require.Equal(t, errAgentInvalidState, errorResponse.ErrorType) @@ -311,7 +310,7 @@ func TestRenderAgentResponse(t *testing.T) { require.Equal(t, http.StatusOK, responseRecorder.Code) registerResponse := ExtensionRegisterResponseWithConfig{} - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, ®isterResponse) assert.Equal(t, tt.expectedRegistrationResponse.FunctionName, registerResponse.FunctionName) assert.Equal(t, tt.expectedRegistrationResponse.FunctionVersion, registerResponse.FunctionVersion) diff --git a/lambda/rapi/handler/constants.go b/lambda/rapi/handler/constants.go index 01553f3..5912d71 100644 --- a/lambda/rapi/handler/constants.go +++ b/lambda/rapi/handler/constants.go @@ -20,7 +20,6 @@ const ( errAgentMissingHeader string = "Extension.MissingHeader" errTooManyExtensions string = "Extension.TooManyExtensions" errInvalidEventType string = "Extension.InvalidEventType" - errLogsSubscriptionClosed string = "Logs.SubscriptionClosed" errInvalidRequestFormat string = "InvalidRequestFormat" StateTransitionFailedForExtensionMessageFormat string = "State transition from %s to %s failed for extension %s. Error: %s" diff --git a/lambda/rapi/handler/credentials_test.go b/lambda/rapi/handler/credentials_test.go index fa4a2bd..d5a1090 100644 --- a/lambda/rapi/handler/credentials_test.go +++ b/lambda/rapi/handler/credentials_test.go @@ -21,13 +21,11 @@ const InitCachingAwsKey = "sampleAwsKey" const InitCachingAwsSecret = "sampleAwsSecret" const InitCachingAwsSessionToken = "sampleAwsSessionToken" -func getRequestContext(isServiceBlocked bool) (http.Handler, *http.Request, *httptest.ResponseRecorder) { +func getRequestContext() (http.Handler, *http.Request, *httptest.ResponseRecorder) { flowTest := testdata.NewFlowTest() - if isServiceBlocked { - flowTest.ConfigureForBlockedInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) - } else { - flowTest.ConfigureForInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) - } + + flowTest.ConfigureForInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) + handler := NewCredentialsHandler(flowTest.CredentialsService) responseRecorder := httptest.NewRecorder() appCtx := flowTest.AppCtx @@ -38,14 +36,14 @@ func getRequestContext(isServiceBlocked bool) (http.Handler, *http.Request, *htt } func TestEmptyAuthorizationHeader(t *testing.T) { - handler, request, responseRecorder := getRequestContext(false) + handler, request, responseRecorder := getRequestContext() handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusNotFound, responseRecorder.Code) } func TestArbitraryAuthorizationHeader(t *testing.T) { - handler, request, responseRecorder := getRequestContext(false) + handler, request, responseRecorder := getRequestContext() request.Header.Set("Authorization", "randomAuthToken") handler.ServeHTTP(responseRecorder, request) @@ -53,7 +51,7 @@ func TestArbitraryAuthorizationHeader(t *testing.T) { } func TestSuccessfulGet(t *testing.T) { - handler, request, responseRecorder := getRequestContext(false) + handler, request, responseRecorder := getRequestContext() request.Header.Set("Authorization", InitCachingToken) handler.ServeHTTP(responseRecorder, request) @@ -67,25 +65,6 @@ func TestSuccessfulGet(t *testing.T) { expirationTime, err := time.Parse(time.RFC3339, responseMap["Expiration"]) assert.NoError(t, err) durationUntilExpiration := time.Until(expirationTime) - assert.True(t, durationUntilExpiration.Minutes() <= 16 && durationUntilExpiration.Minutes() > 15 && durationUntilExpiration.Hours() < 1) + assert.True(t, durationUntilExpiration.Minutes() <= 30 && durationUntilExpiration.Minutes() > 29 && durationUntilExpiration.Hours() < 1) log.Println(responseRecorder.Body.String()) } - -func TestBlockedGet(t *testing.T) { - handler, request, responseRecorder := getRequestContext(true) - request.Header.Set("Authorization", InitCachingToken) - - timeout := time.After(1 * time.Second) - done := make(chan bool) - - go func() { - handler.ServeHTTP(responseRecorder, request) - done <- true - }() - - select { - case <-done: - t.Fatal("Endpoint should be blocked!") - case <-timeout: - } -} diff --git a/lambda/rapi/handler/initerror.go b/lambda/rapi/handler/initerror.go index 4015a11..d28e2d4 100644 --- a/lambda/rapi/handler/initerror.go +++ b/lambda/rapi/handler/initerror.go @@ -5,11 +5,12 @@ package handler import ( "encoding/json" - "io/ioutil" + "io" "net/http" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" "go.amzn.com/lambda/core" "go.amzn.com/lambda/rapi/rendering" @@ -19,6 +20,7 @@ import ( type initErrorHandler struct { registrationService core.RegistrationService + eventsAPI telemetry.EventsAPI } func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -30,6 +32,10 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R } runtime := h.registrationService.GetRuntime() + + // the previousStateName is needed to define if the init/error is called for INIT or RESTORE + previousStateName := runtime.GetState().Name() + if err := runtime.InitError(); err != nil { log.Warn(err) rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, @@ -39,18 +45,24 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R errorType := request.Header.Get("Lambda-Runtime-Function-Error-Type") - errorBody, err := ioutil.ReadAll(request.Body) + errorBody, err := io.ReadAll(request.Body) if err != nil { log.WithError(err).Warn("Failed to read error body") } + if previousStateName == core.RuntimeRestoringStateName { + h.sendRestoreRuntimeDoneLogEvent() + } else { + h.sendInitRuntimeDoneLogEvent(appCtx) + } + response := &interop.ErrorResponse{ ErrorType: errorType, Payload: errorBody, ContentType: determineJSONContentType(errorBody), } - if err := server.SendErrorResponse(server.GetCurrentInvokeID(), response); err != nil { + if err := server.SendInitErrorResponse(server.GetCurrentInvokeID(), response); err != nil { rendering.RenderInteropError(writer, request, err) return } @@ -62,9 +74,10 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R // NewInitErrorHandler returns a new instance of http handler // for serving /runtime/init/error. -func NewInitErrorHandler(registrationService core.RegistrationService) http.Handler { +func NewInitErrorHandler(registrationService core.RegistrationService, eventsAPI telemetry.EventsAPI) http.Handler { return &initErrorHandler{ registrationService: registrationService, + eventsAPI: eventsAPI, } } @@ -74,3 +87,24 @@ func determineJSONContentType(body []byte) string { } return "application/octet-stream" } + +func (h *initErrorHandler) sendInitRuntimeDoneLogEvent(appCtx appctx.ApplicationContext) { + // ToDo: Convert this to an enum for the whole package to increase readability. + initCachingEnabled := appctx.LoadInitType(appCtx) == appctx.InitCaching + + initSource := interop.InferTelemetryInitSource(initCachingEnabled, appctx.LoadSandboxType(appCtx)) + runtimeDoneData := &telemetry.InitRuntimeDoneData{ + InitSource: initSource, + Status: telemetry.RuntimeDoneFailure, + } + + if err := h.eventsAPI.SendInitRuntimeDone(runtimeDoneData); err != nil { + log.Errorf("Failed to send INITRD: %s", err) + } +} + +func (h *initErrorHandler) sendRestoreRuntimeDoneLogEvent() { + if err := h.eventsAPI.SendRestoreRuntimeDone(telemetry.RuntimeDoneFailure); err != nil { + log.Errorf("Failed to send RESTRD: %s", err) + } +} diff --git a/lambda/rapi/handler/initerror_test.go b/lambda/rapi/handler/initerror_test.go index c2d3d89..c9a5a83 100644 --- a/lambda/rapi/handler/initerror_test.go +++ b/lambda/rapi/handler/initerror_test.go @@ -27,7 +27,7 @@ func runTestInitErrorHandler(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - handler := NewInitErrorHandler(flowTest.RegistrationService) + handler := NewInitErrorHandler(flowTest.RegistrationService, flowTest.EventsAPI) responseRecorder := httptest.NewRecorder() appCtx := flowTest.AppCtx @@ -49,7 +49,7 @@ func runTestInitErrorHandler(t *testing.T) { require.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", responseRecorder.Code, http.StatusAccepted) require.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) - require.Equal(t, "application/json; charset=utf-8", responseRecorder.Header().Get("Content-Type")) + require.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) // Validate init error persisted in the application context. errorResponse := flowTest.InteropServer.ErrorResponse diff --git a/lambda/rapi/handler/invocationerror.go b/lambda/rapi/handler/invocationerror.go index d60b5d6..170c0cb 100644 --- a/lambda/rapi/handler/invocationerror.go +++ b/lambda/rapi/handler/invocationerror.go @@ -6,7 +6,7 @@ package handler import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "go.amzn.com/lambda/interop" @@ -25,6 +25,11 @@ const errorWithCauseContentType = "application/vnd.aws.lambda.error.cause+json" const xrayErrorCauseHeaderName = "Lambda-Runtime-Function-XRay-Error-Cause" const invalidErrorBodyMessage = "Invalid error body" +const ( + contentTypeHeader = "Content-Type" + functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" +) + type invocationErrorHandler struct { registrationService core.RegistrationService } @@ -52,7 +57,7 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * var contentType string var err error - switch request.Header.Get("Content-Type") { + switch request.Header.Get(contentTypeHeader) { case errorWithCauseContentType: errorBody, errorCause, err = h.getErrorBodyForErrorCauseContentType(request) contentType = "application/json" @@ -62,18 +67,20 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * default: errorBody, err = h.getErrorBody(request) errorCause = h.getValidatedErrorCause(request.Header) - contentType = request.Header.Get("Content-Type") + contentType = request.Header.Get(contentTypeHeader) } + functionResponseMode := request.Header.Get(functionResponseModeHeader) if err != nil { log.WithError(err).Warn("Failed to parse error body") } response := &interop.ErrorResponse{ - ErrorType: errorType, - Payload: errorBody, - ErrorCause: errorCause, - ContentType: contentType, + ErrorType: errorType, + Payload: errorBody, + ErrorCause: errorCause, + ContentType: contentType, + FunctionResponseMode: functionResponseMode, } if err := server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), response); err != nil { @@ -95,7 +102,7 @@ func (h *invocationErrorHandler) getErrorType(headers http.Header) string { } func (h *invocationErrorHandler) getErrorBody(request *http.Request) ([]byte, error) { - errorBody, err := ioutil.ReadAll(request.Body) + errorBody, err := io.ReadAll(request.Body) if err != nil { return nil, fmt.Errorf("error reading request body: %s", err) } @@ -120,7 +127,7 @@ func (h *invocationErrorHandler) getValidatedErrorCause(headers http.Header) jso } func (h *invocationErrorHandler) getErrorBodyForErrorCauseContentType(request *http.Request) ([]byte, json.RawMessage, error) { - errorBody, err := ioutil.ReadAll(request.Body) + errorBody, err := io.ReadAll(request.Body) if err != nil { return nil, nil, fmt.Errorf("error reading request body: %s", err) } diff --git a/lambda/rapi/handler/invocationerror_test.go b/lambda/rapi/handler/invocationerror_test.go index 6defa14..2f177fe 100644 --- a/lambda/rapi/handler/invocationerror_test.go +++ b/lambda/rapi/handler/invocationerror_test.go @@ -77,7 +77,7 @@ func runTestInvocationErrorHandler(t *testing.T) { assert.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", responseRecorder.Code, http.StatusAccepted) assert.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) - assert.Equal(t, "application/json; charset=utf-8", responseRecorder.Header().Get("Content-Type")) + assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) errorResponse := flowTest.InteropServer.ErrorResponse assert.NotNil(t, errorResponse) @@ -268,7 +268,8 @@ func TestInvocationResponsePayloadIsDefaultErrorMessageWhenRequestParsingFailsFo invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} request := httptest.NewRequest("POST", "/", bytes.NewReader(invalidRequestBody)) request = addInvocationID(request, invoke.ID) - request.Header.Set("Content-Type", errorWithCauseContentType) + request.Header.Set(contentTypeHeader, errorWithCauseContentType) + request.Header.Set(functionResponseModeHeader, "function-response-mode") // Corresponding invoke must be placed into appCtx. flowTest.ConfigureForInvoke(context.Background(), invoke) @@ -280,6 +281,7 @@ func TestInvocationResponsePayloadIsDefaultErrorMessageWhenRequestParsingFailsFo assert.NotNil(t, errorResponse) assert.Nil(t, flowTest.InteropServer.Response) assert.Equal(t, "application/octet-stream", flowTest.InteropServer.ResponseContentType) + assert.Equal(t, "function-response-mode", flowTest.InteropServer.FunctionResponseMode) invokeResponsePayload := errorResponse.Payload diff --git a/lambda/rapi/handler/invocationnext_test.go b/lambda/rapi/handler/invocationnext_test.go index beebd97..5bddb86 100644 --- a/lambda/rapi/handler/invocationnext_test.go +++ b/lambda/rapi/handler/invocationnext_test.go @@ -92,7 +92,7 @@ func TestRenderInvoke(t *testing.T) { assert.Equal(t, invokePayload, responseRecorder.Body.String()) } -//Cgo calls removed due to crashes while spawning threads under memory pressure. +// Cgo calls removed due to crashes while spawning threads under memory pressure. func TestRenderInvokeDoesNotCallCgo(t *testing.T) { cgoCallsBefore := runtime.NumCgoCall() TestRenderInvoke(t) diff --git a/lambda/rapi/handler/invocationresponse.go b/lambda/rapi/handler/invocationresponse.go index 7c15342..7e47d2e 100644 --- a/lambda/rapi/handler/invocationresponse.go +++ b/lambda/rapi/handler/invocationresponse.go @@ -15,7 +15,10 @@ import ( log "github.com/sirupsen/logrus" ) -const contentTypeOverrideHeaderName = "Content-Type" +const ( + StreamingFunctionResponseMode = "streaming" + ErrInvalidResponseModeHeader = "Runtime.InvalidResponseModeHeader" +) type invocationResponseHandler struct { registrationService core.RegistrationService @@ -39,9 +42,23 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques invokeID := chi.URLParam(request, "awsrequestid") - responseContentType := request.Header.Get(contentTypeOverrideHeaderName) + headers := map[string]string{contentTypeHeader: request.Header.Get(contentTypeHeader)} + if functionResponseMode := request.Header.Get(functionResponseModeHeader); functionResponseMode != "" { + switch functionResponseMode { + case StreamingFunctionResponseMode: + headers[functionResponseModeHeader] = functionResponseMode + default: + errorResponse := &interop.ErrorResponse{ + ErrorType: ErrInvalidResponseModeHeader, + ContentType: request.Header.Get(contentTypeHeader), + } + _ = server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), errorResponse) + rendering.RenderInvalidFunctionResponseMode(writer, request) + return + } + } - if err := server.SendResponse(invokeID, responseContentType, request.Body); err != nil { + if err := server.SendResponse(invokeID, headers, request.Body, request.Trailer, &interop.CancellableRequest{Request: request}); err != nil { switch err := err.(type) { case *interop.ErrorResponseTooLarge: if server.SendErrorResponse(invokeID, err.AsInteropError()) != nil { @@ -66,6 +83,19 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques rendering.RenderRequestEntityTooLarge(writer, request) return + + case *interop.ErrTruncatedResponse: + if err := runtime.ResponseSent(); err != nil { + log.Panic(err) + } + + rendering.RenderTruncatedHTTPRequestError(writer, request) + return + + case *interop.ErrInternalPlatformError: + rendering.RenderInternalServerError(writer, request) + return + default: rendering.RenderInteropError(writer, request, err) return diff --git a/lambda/rapi/handler/invocationresponse_test.go b/lambda/rapi/handler/invocationresponse_test.go index e40a5bf..7c0b220 100644 --- a/lambda/rapi/handler/invocationresponse_test.go +++ b/lambda/rapi/handler/invocationresponse_test.go @@ -8,7 +8,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -55,7 +55,7 @@ func TestResponseTooLarge(t *testing.T) { responseRecorder.Code, http.StatusRequestEntityTooLarge) expectedAPIResponse := fmt.Sprintf("{\"errorMessage\":\"Exceeded maximum allowed payload size (%d bytes).\",\"errorType\":\"RequestEntityTooLarge\"}\n", interop.MaxPayloadSize) - body, err := ioutil.ReadAll(responseRecorder.Body) + body, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) test.AssertJsonsEqual(t, []byte(expectedAPIResponse), body) @@ -98,7 +98,8 @@ func TestResponseAccepted(t *testing.T) { request := httptest.NewRequest("", "/", bytes.NewReader(responseBody)) request = addInvocationID(request, invoke.ID) - request.Header.Set(contentTypeOverrideHeaderName, "application/json") + request.Header.Set(contentTypeHeader, "application/json") + request.Header.Set(functionResponseModeHeader, "streaming") handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) // Assertions @@ -106,7 +107,7 @@ func TestResponseAccepted(t *testing.T) { responseRecorder.Code, http.StatusAccepted) expectedAPIResponse := "{\"status\":\"OK\"}\n" - body, err := ioutil.ReadAll(responseRecorder.Body) + body, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) test.AssertJsonsEqual(t, []byte(expectedAPIResponse), body) @@ -114,6 +115,92 @@ func TestResponseAccepted(t *testing.T) { assert.NotNil(t, response) assert.Nil(t, flowTest.InteropServer.ErrorResponse) assert.Equal(t, "application/json", flowTest.InteropServer.ResponseContentType) + assert.Equal(t, "streaming", flowTest.InteropServer.FunctionResponseMode) assert.Equal(t, responseBody, response, "Persisted response data in app context must match the submitted.") } + +func TestResponseWithDifferentFunctionResponseModes(t *testing.T) { + type testCase struct { + providedFunctionResponseMode string + expectedFunctionResponseMode string + expectedAPIResponse string + expectedStatusCode int + expectedErrorResponse bool + } + testCases := []testCase{ + { + providedFunctionResponseMode: "", + expectedFunctionResponseMode: "", + expectedAPIResponse: "{\"status\":\"OK\"}\n", + expectedStatusCode: http.StatusAccepted, + expectedErrorResponse: false, + }, + { + providedFunctionResponseMode: "streaming", + expectedFunctionResponseMode: "streaming", + expectedAPIResponse: "{\"status\":\"OK\"}\n", + expectedStatusCode: http.StatusAccepted, + expectedErrorResponse: false, + }, + { + providedFunctionResponseMode: "invalid-mode", + expectedFunctionResponseMode: "", + expectedAPIResponse: "{\"errorMessage\":\"Invalid function response mode\", \"errorType\":\"InvalidFunctionResponseMode\"}\n", + expectedStatusCode: http.StatusBadRequest, + expectedErrorResponse: true, + }, + } + + for _, testCase := range testCases { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + flowTest.Runtime.Ready() + handler := NewInvocationResponseHandler(flowTest.RegistrationService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + // Invoke that we are sending response for must be placed into appCtx. + invoke := &interop.Invoke{ + ID: "InvocationID1", + InvokedFunctionArn: "arn::dummy1", + CognitoIdentityID: "CognitoidentityID1", + CognitoIdentityPoolID: "CognitoidentityPollID1", + DeadlineNs: "deadlinens1", + ClientContext: "clientcontext1", + ContentType: "application/json", + Payload: strings.NewReader(`{"message": "hello"}`), + } + + flowTest.ConfigureForInvoke(context.Background(), invoke) + + // Invocation response submitted by runtime. + var responseBody = []byte("{'foo': 'bar'}") + + request := httptest.NewRequest("", "/", bytes.NewReader(responseBody)) + request = addInvocationID(request, invoke.ID) + request.Header.Set(functionResponseModeHeader, testCase.providedFunctionResponseMode) + handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) + + // Assertions + assert.Equal(t, testCase.expectedStatusCode, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", + responseRecorder.Code, testCase.expectedStatusCode) + + body, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + test.AssertJsonsEqual(t, []byte(testCase.expectedAPIResponse), body) + + if testCase.expectedErrorResponse { + assert.NotNil(t, flowTest.InteropServer.ErrorResponse) + assert.Nil(t, flowTest.InteropServer.Response) + assert.Equal(t, "Runtime.InvalidResponseModeHeader", flowTest.InteropServer.ErrorResponse.ErrorType) + } else { + assert.NotNil(t, flowTest.InteropServer.Response) + assert.Nil(t, flowTest.InteropServer.ErrorResponse) + assert.Equal(t, responseBody, flowTest.InteropServer.Response, + "Persisted response data in app context must match the submitted.") + } + + assert.Equal(t, testCase.expectedFunctionResponseMode, flowTest.InteropServer.FunctionResponseMode) + } +} diff --git a/lambda/rapi/handler/restorenext.go b/lambda/rapi/handler/restorenext.go new file mode 100644 index 0000000..ecff059 --- /dev/null +++ b/lambda/rapi/handler/restorenext.go @@ -0,0 +1,40 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "net/http" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/rapi/rendering" +) + +type restoreNextHandler struct { + registrationService core.RegistrationService + renderingService *rendering.EventRenderingService +} + +func (h *restoreNextHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + runtime := h.registrationService.GetRuntime() + err := runtime.RestoreReady() + if err != nil { + log.Warn(err) + rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, runtime.GetState().Name(), core.RuntimeReadyStateName, err) + return + } + err = h.renderingService.RenderRuntimeEvent(writer, request) + if err != nil { + log.Error(err) + rendering.RenderInternalServerError(writer, request) + return + } +} + +func NewRestoreNextHandler(registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { + return &restoreNextHandler{ + registrationService: registrationService, + renderingService: renderingService, + } +} diff --git a/lambda/rapi/handler/restorenext_test.go b/lambda/rapi/handler/restorenext_test.go new file mode 100644 index 0000000..7018d98 --- /dev/null +++ b/lambda/rapi/handler/restorenext_test.go @@ -0,0 +1,87 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" +) + +func TestRenderRestoreNext(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + handler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + flowTest.ConfigureForRestore() + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusOK, responseRecorder.Code) +} + +func TestBrokenRenderer(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + handler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + flowTest.ConfigureForRestore() + flowTest.RenderingService.SetRenderer(&mockBrokenRenderer{}) + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) + + assert.JSONEq(t, `{"errorMessage":"Internal Server Error","errorType":"InternalServerError"}`, responseRecorder.Body.String()) +} + +func TestRenderRestoreAfterInvoke(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + deadlineNs := 12345 + invokePayload := "Payload" + invoke := &interop.Invoke{ + TraceID: "Root=RootID;Parent=LambdaFrontend;Sampled=1", + ID: "ID", + InvokedFunctionArn: "InvokedFunctionArn", + CognitoIdentityID: "CognitoIdentityId1", + CognitoIdentityPoolID: "CognitoIdentityPoolId1", + ClientContext: "ClientContext", + DeadlineNs: strconv.Itoa(deadlineNs), + ContentType: "image/png", + Payload: strings.NewReader(invokePayload), + } + + ctx := telemetry.NewTraceContext(context.Background(), "RootID", "InvocationSubegmentID") + flowTest.ConfigureForInvoke(ctx, invoke) + + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + restoreHandler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + restoreRequest := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + responseRecorder = httptest.NewRecorder() + restoreHandler.ServeHTTP(responseRecorder, restoreRequest) + + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) +} diff --git a/lambda/rapi/handler/runtimelogs.go b/lambda/rapi/handler/runtimelogs.go index 9b4e406..99941b0 100644 --- a/lambda/rapi/handler/runtimelogs.go +++ b/lambda/rapi/handler/runtimelogs.go @@ -7,7 +7,7 @@ import ( "bytes" "errors" "fmt" - "io/ioutil" + "io" "net/http" "go.amzn.com/lambda/core" @@ -20,8 +20,8 @@ import ( ) type runtimeLogsHandler struct { - registrationService core.RegistrationService - logsSubscriptionAPI telemetry.LogsSubscriptionAPI + registrationService core.RegistrationService + telemetrySubscription telemetry.SubscriptionAPI } func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -31,10 +31,10 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http switch err := err.(type) { case *ErrAgentIdentifierUnknown: rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension "+err.agentID.String()) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } return } @@ -45,21 +45,21 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http if err != nil { log.Error(err) rendering.RenderInternalServerError(writer, request) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) return } - respBody, status, headers, err := h.logsSubscriptionAPI.Subscribe(agentName, bytes.NewReader(body), request.Header) + respBody, status, headers, err := h.telemetrySubscription.Subscribe(agentName, bytes.NewReader(body), request.Header) if err != nil { log.Errorf("Telemetry API error: %s", err) switch err { case logsapi.ErrTelemetryServiceOff: rendering.RenderForbiddenWithTypeMsg(writer, request, - errLogsSubscriptionClosed, "Logs API subscription is closed already") - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.GetServiceClosedErrorType(), h.telemetrySubscription.GetServiceClosedErrorMessage()) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } return } @@ -67,11 +67,11 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http rendering.RenderRuntimeLogsResponse(writer, respBody, status, headers) switch status / 100 { case 2: // 2xx - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeSuccess, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeSuccess, 1) case 4: // 4xx - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) case 5: // 5xx - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } } @@ -114,7 +114,7 @@ func (h *runtimeLogsHandler) getAgentName(agentID uuid.UUID) (string, bool) { } func (h *runtimeLogsHandler) getBody(writer http.ResponseWriter, request *http.Request) ([]byte, error) { - body, err := ioutil.ReadAll(request.Body) + body, err := io.ReadAll(request.Body) if err != nil { return nil, fmt.Errorf("Failed to read error body: %s", err) } @@ -122,11 +122,11 @@ func (h *runtimeLogsHandler) getBody(writer http.ResponseWriter, request *http.R return body, nil } -// NewRuntimeLogsHandler returns a new instance of http handler +// NewRuntimeTelemetrySubscriptionHandler returns a new instance of http handler // for serving /runtime/logs -func NewRuntimeLogsHandler(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.LogsSubscriptionAPI) http.Handler { +func NewRuntimeTelemetrySubscriptionHandler(registrationService core.RegistrationService, telemetrySubscription telemetry.SubscriptionAPI) http.Handler { return &runtimeLogsHandler{ - registrationService: registrationService, - logsSubscriptionAPI: logsSubscriptionAPI, + registrationService: registrationService, + telemetrySubscription: telemetrySubscription, } } diff --git a/lambda/rapi/handler/runtimelogs_stub.go b/lambda/rapi/handler/runtimelogs_stub.go index 0ce472e..f540e9b 100644 --- a/lambda/rapi/handler/runtimelogs_stub.go +++ b/lambda/rapi/handler/runtimelogs_stub.go @@ -6,27 +6,48 @@ package handler import ( "net/http" + log "github.com/sirupsen/logrus" "go.amzn.com/lambda/rapi/model" - - "github.com/go-chi/render" + "go.amzn.com/lambda/rapi/rendering" ) const ( - telemetryAPIDisabledErrorType = "Logs.NotSupported" + logsAPIDisabledErrorType = "Logs.NotSupported" + telemetryAPIDisabledErrorType = "Telemetry.NotSupported" ) -type runtimeLogsStubHandler struct{} +type runtimeLogsStubAPIHandler struct{} -func (h *runtimeLogsStubHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - render.Status(request, http.StatusAccepted) - render.JSON(writer, request, &model.ErrorResponse{ - ErrorType: telemetryAPIDisabledErrorType, +func (h *runtimeLogsStubAPIHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if err := rendering.RenderJSON(http.StatusAccepted, writer, request, &model.ErrorResponse{ + ErrorType: logsAPIDisabledErrorType, ErrorMessage: "Logs API is not supported", - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(writer, err.Error(), http.StatusInternalServerError) + } +} + +// NewRuntimeLogsAPIStubHandler returns a new instance of http handler +// for serving /runtime/logs when a telemetry service implementation is absent +func NewRuntimeLogsAPIStubHandler() http.Handler { + return &runtimeLogsStubAPIHandler{} +} + +type runtimeTelemetryAPIStubHandler struct{} + +func (h *runtimeTelemetryAPIStubHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if err := rendering.RenderJSON(http.StatusAccepted, writer, request, &model.ErrorResponse{ + ErrorType: telemetryAPIDisabledErrorType, + ErrorMessage: "Telemetry API is not supported", + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(writer, err.Error(), http.StatusInternalServerError) + } } -// NewRuntimeLogsStubHandler returns a new instance of http handler +// NewRuntimeTelemetryAPIStubHandler returns a new instance of http handler // for serving /runtime/logs when a telemetry service implementation is absent -func NewRuntimeLogsStubHandler() http.Handler { - return &runtimeLogsStubHandler{} +func NewRuntimeTelemetryAPIStubHandler() http.Handler { + return &runtimeTelemetryAPIStubHandler{} } diff --git a/lambda/rapi/handler/runtimelogs_stub_test.go b/lambda/rapi/handler/runtimelogs_stub_test.go index 4826d12..5b27983 100644 --- a/lambda/rapi/handler/runtimelogs_stub_test.go +++ b/lambda/rapi/handler/runtimelogs_stub_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" ) -func TestSuccessfulRuntimeLogsStub202Response(t *testing.T) { - handler := NewRuntimeLogsStubHandler() +func TestSuccessfulRuntimeLogsAPIStub202Response(t *testing.T) { + handler := NewRuntimeLogsAPIStubHandler() requestBody := []byte(`foobar`) request := httptest.NewRequest("PUT", "/logs", bytes.NewBuffer(requestBody)) responseRecorder := httptest.NewRecorder() @@ -23,3 +23,15 @@ func TestSuccessfulRuntimeLogsStub202Response(t *testing.T) { assert.Equal(t, http.StatusAccepted, responseRecorder.Code) assert.JSONEq(t, `{"errorMessage":"Logs API is not supported","errorType":"Logs.NotSupported"}`, responseRecorder.Body.String()) } + +func TestSuccessfulRuntimeTelemetryAPIStub202Response(t *testing.T) { + handler := NewRuntimeTelemetryAPIStubHandler() + requestBody := []byte(`foobar`) + request := httptest.NewRequest("PUT", "/telemetry", bytes.NewBuffer(requestBody)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.JSONEq(t, `{"errorMessage":"Telemetry API is not supported","errorType":"Telemetry.NotSupported"}`, responseRecorder.Body.String()) +} diff --git a/lambda/rapi/handler/runtimelogs_test.go b/lambda/rapi/handler/runtimelogs_test.go index b7db6df..892d61e 100644 --- a/lambda/rapi/handler/runtimelogs_test.go +++ b/lambda/rapi/handler/runtimelogs_test.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "testing" @@ -23,30 +22,45 @@ import ( "go.amzn.com/lambda/rapidcore/telemetry/logsapi" ) -type mockLogsSubscriptionAPI struct{ mock.Mock } +type mockSubscriptionAPI struct{ mock.Mock } -func (s *mockLogsSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { +func (s *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { args := s.Called(agentName, body, headers) return args.Get(0).([]byte), args.Int(1), args.Get(2).(map[string][]string), args.Error(3) } -func (s *mockLogsSubscriptionAPI) RecordCounterMetric(metricName string, count int) { +func (s *mockSubscriptionAPI) RecordCounterMetric(metricName string, count int) { s.Called(metricName, count) } -func (s *mockLogsSubscriptionAPI) FlushMetrics() interop.LogsAPIMetrics { +func (s *mockSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { args := s.Called() - return args.Get(0).(interop.LogsAPIMetrics) + return args.Get(0).(interop.TelemetrySubscriptionMetrics) } -func (s *mockLogsSubscriptionAPI) Clear() { +func (s *mockSubscriptionAPI) Clear() { s.Called() } -func (s *mockLogsSubscriptionAPI) TurnOff() { +func (s *mockSubscriptionAPI) TurnOff() { s.Called() } +func (s *mockSubscriptionAPI) GetEndpointURL() string { + args := s.Called() + return args.Get(0).(string) +} + +func (s *mockSubscriptionAPI) GetServiceClosedErrorMessage() string { + args := s.Called() + return args.Get(0).(string) +} + +func (s *mockSubscriptionAPI) GetServiceClosedErrorType() string { + args := s.Called() + return args.Get(0).(string) +} + func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} respBody, respStatus, respHeaders := []byte(`barbaz`), http.StatusNotFound, map[string][]string{"K": []string{"V1", "V2"}} @@ -60,11 +74,11 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return(respBody, respStatus, respHeaders, nil) - logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -77,10 +91,10 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - logsSubscriptionAPI.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) assert.Equal(t, respStatus, responseRecorder.Code) @@ -98,10 +112,10 @@ func TestErrorUnregisteredAgentID(t *testing.T) { core.NewInvokeFlowSynchronization(), ) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -114,16 +128,16 @@ func TestErrorUnregisteredAgentID(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) expectedErrorBody := fmt.Sprintf(`{"errorMessage":"Unknown extension %s","errorType":"Extension.UnknownExtensionIdentifier"}`+"\n", invalidAgentID.String()) - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json; charset=utf-8"}}) + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) } func TestErrorTelemetryAPICallFailure(t *testing.T) { @@ -139,11 +153,11 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - logsSubscriptionAPI.On("RecordCounterMetric", serverErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", serverErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -156,16 +170,16 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) expectedErrorBody := `{"errorMessage":"Internal Server Error","errorType":"InternalServerError"}` + "\n" - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json; charset=utf-8"}}) + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", serverErrMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", serverErrMetric, 1) } func TestRenderLogsSubscriptionClosed(t *testing.T) { @@ -181,11 +195,13 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Logs API subscription is closed already") + telemetrySubscription.On("GetServiceClosedErrorType").Return("Logs.SubscriptionClosed") - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -198,14 +214,58 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) expectedErrorBody := `{"errorMessage":"Logs API subscription is closed already","errorType":"Logs.SubscriptionClosed"}` + "\n" - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json; charset=utf-8"}}) + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) + + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + assert.Equal(t, expectedErrorBody, string(recordedBody)) + assert.Equal(t, expectedHeaders, responseRecorder.Header()) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) +} + +func TestRenderTelemetrySubscriptionClosed(t *testing.T) { + agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} + apiError := logsapi.ErrTelemetryServiceOff + clientErrMetric := logsapi.SubscribeClientErr + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization(), + core.NewInvokeFlowSynchronization(), + ) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Telemetry API subscription is closed already") + telemetrySubscription.On("GetServiceClosedErrorType").Return("Telemetry.SubscriptionClosed") + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + expectedErrorBody := `{"errorMessage":"Telemetry API subscription is closed already","errorType":"Telemetry.SubscriptionClosed"}` + "\n" + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) } diff --git a/lambda/rapi/middleware/middleware_test.go b/lambda/rapi/middleware/middleware_test.go index 7b37de9..a0b9134 100644 --- a/lambda/rapi/middleware/middleware_test.go +++ b/lambda/rapi/middleware/middleware_test.go @@ -7,7 +7,7 @@ import ( "bytes" "context" "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -58,7 +58,7 @@ func TestAgentUniqueIdentifierHeaderValidatorForbidden(t *testing.T) { responseRecorder := httptest.NewRecorder() router.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, handler.ErrAgentIdentifierMissing, errorResponse.ErrorType) @@ -66,7 +66,7 @@ func TestAgentUniqueIdentifierHeaderValidatorForbidden(t *testing.T) { request.Header.Set(handler.LambdaAgentIdentifier, "invalid-unique-identifier") router.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - respBody, _ = ioutil.ReadAll(responseRecorder.Body) + respBody, _ = io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, handler.ErrAgentIdentifierInvalid, errorResponse.ErrorType) } diff --git a/lambda/rapi/model/tracing.go b/lambda/rapi/model/tracing.go index af90e8f..83f97e8 100644 --- a/lambda/rapi/model/tracing.go +++ b/lambda/rapi/model/tracing.go @@ -3,14 +3,21 @@ package model +type TracingType string + const ( // XRayTracingType represents an X-Ray Tracing object type - XRayTracingType = "X-Amzn-Trace-Id" + XRayTracingType TracingType = "X-Amzn-Trace-Id" +) + +const ( + XRaySampled = "1" + XRayNonSampled = "0" ) // Tracing object returned as part of agent Invoke event type Tracing struct { - Type string `json:"type"` + Type TracingType `json:"type"` XRayTracing } diff --git a/lambda/rapi/rendering/doc.go b/lambda/rapi/rendering/doc.go index 4573638..bc359a1 100644 --- a/lambda/rapi/rendering/doc.go +++ b/lambda/rapi/rendering/doc.go @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 /* - Package rendering provides stateful event rendering service. State of the rendering service should be set from the main event dispatch thread @@ -17,6 +16,5 @@ Example of INVOKE event: [main] // release threads registered for INVOKE event [thread] // receives INVOKE event - */ package rendering diff --git a/lambda/rapi/rendering/render_json.go b/lambda/rapi/rendering/render_json.go new file mode 100644 index 0000000..8cea816 --- /dev/null +++ b/lambda/rapi/rendering/render_json.go @@ -0,0 +1,33 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rendering + +import ( + "bytes" + "encoding/json" + log "github.com/sirupsen/logrus" + "net/http" +) + +// RenderJSON: +// - marshals 'v' to JSON, automatically escaping HTML +// - sets the Content-Type as application/json +// - sets the HTTP response status code +// - returns an error if it occurred before writing to response +func RenderJSON(status int, w http.ResponseWriter, r *http.Request, v interface{}) error { + buf := &bytes.Buffer{} + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(true) + if err := enc.Encode(v); err != nil { + return err + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if _, err := w.Write(buf.Bytes()); err != nil { + log.WithError(err).Warn("Error while writing response body") + } + + return nil +} diff --git a/lambda/rapi/rendering/rendering.go b/lambda/rapi/rendering/rendering.go index c75d010..0edfb68 100644 --- a/lambda/rapi/rendering/rendering.go +++ b/lambda/rapi/rendering/rendering.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "strconv" "sync" @@ -19,7 +18,6 @@ import ( "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi/model" - "github.com/go-chi/render" "github.com/google/uuid" log "github.com/sirupsen/logrus" ) @@ -33,6 +31,8 @@ const ( ErrorTypeInvalidRequestID = "InvalidRequestID" // ErrorTypeRequestEntityTooLarge error type for payload too large ErrorTypeRequestEntityTooLarge = "RequestEntityTooLarge" + // ErrorTypeTruncatedHTTPRequest error type for truncated HTTP request + ErrorTypeTruncatedHTTPRequest = "TruncatedHTTPRequest" ) // ErrRenderingServiceStateNotSet returned when state not set @@ -100,6 +100,9 @@ type InvokeRenderer struct { metrics InvokeRendererMetrics } +type RestoreRenderer struct { +} + // NewAgentInvokeEvent forms a new AgentInvokeEvent from INVOKE request func NewAgentInvokeEvent(req *interop.Invoke) (*model.AgentInvokeEvent, error) { deadlineMono, err := strconv.ParseInt(req.DeadlineNs, 10, 64) @@ -145,7 +148,7 @@ func (s *InvokeRenderer) bufferInvokeRequest() error { if nil == s.requestBuffer { reader := io.LimitReader(s.invoke.Payload, interop.MaxPayloadSize) start := time.Now() - s.requestBuffer, err = ioutil.ReadAll(reader) + s.requestBuffer, err = io.ReadAll(reader) s.metrics = InvokeRendererMetrics{ ReadTime: time.Since(start), SizeBytes: len(s.requestBuffer), @@ -193,6 +196,15 @@ func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request return nil } +func (s *RestoreRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { + writer.WriteHeader(http.StatusOK) + return nil +} + +func (s *RestoreRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { + return nil +} + // NewInvokeRenderer returns new invoke event renderer func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, traceParser func(context.Context, *interop.Invoke) string) *InvokeRenderer { return &InvokeRenderer{ @@ -204,6 +216,10 @@ func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, traceParser } } +func NewRestoreRenderer() *RestoreRenderer { + return &RestoreRenderer{} +} + func (s *InvokeRenderer) GetMetrics() InvokeRendererMetrics { s.requestMutex.Lock() defer s.requestMutex.Unlock() @@ -283,46 +299,78 @@ func renderAgentInvokeHeaders(writer http.ResponseWriter, eventID uuid.UUID) { // RenderForbiddenWithTypeMsg method for rendering error response func RenderForbiddenWithTypeMsg(w http.ResponseWriter, r *http.Request, errorType string, format string, args ...interface{}) { - render.Status(r, http.StatusForbidden) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusForbidden, w, r, &model.ErrorResponse{ ErrorType: errorType, ErrorMessage: fmt.Sprintf(format, args...), - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderInternalServerError method for rendering error response func RenderInternalServerError(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusInternalServerError) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusInternalServerError, w, r, &model.ErrorResponse{ ErrorMessage: "Internal Server Error", ErrorType: ErrorTypeInternalServerError, - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderRequestEntityTooLarge method for rendering error response func RenderRequestEntityTooLarge(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusRequestEntityTooLarge) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusRequestEntityTooLarge, w, r, &model.ErrorResponse{ ErrorMessage: fmt.Sprintf("Exceeded maximum allowed payload size (%d bytes).", interop.MaxPayloadSize), ErrorType: ErrorTypeRequestEntityTooLarge, - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderTruncatedHTTPRequestError method for rendering error response +func RenderTruncatedHTTPRequestError(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "HTTP request detected as truncated", + ErrorType: ErrorTypeTruncatedHTTPRequest, + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderInvalidRequestID renders invalid request ID error response func RenderInvalidRequestID(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusBadRequest) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ ErrorMessage: "Invalid request ID", ErrorType: "InvalidRequestID", - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInvalidFunctionResponseMode renders invalid function response mode response +func RenderInvalidFunctionResponseMode(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "Invalid function response mode", + ErrorType: "InvalidFunctionResponseMode", + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderAccepted method for rendering accepted status response func RenderAccepted(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusAccepted) - render.JSON(w, r, &model.StatusResponse{ + if err := RenderJSON(http.StatusAccepted, w, r, &model.StatusResponse{ Status: "OK", - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderInteropError is a convenience method for interpreting interop errors diff --git a/lambda/rapi/router.go b/lambda/rapi/router.go index 1d2766a..5c2a56d 100644 --- a/lambda/rapi/router.go +++ b/lambda/rapi/router.go @@ -19,7 +19,7 @@ import ( // NewRouter returns a new instance of chi router implementing // Runtime API specification. -func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { +func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, eventsAPI telemetry.EventsAPI) http.Handler { router := chi.NewRouter() router.Use(middleware.AppCtxMiddleware(appCtx)) @@ -46,7 +46,11 @@ func NewRouter(appCtx appctx.ApplicationContext, registrationService core.Regist handler.NewInvocationErrorHandler(registrationService)).ServeHTTP) router.Post("/runtime/init/error", - handler.NewInitErrorHandler(registrationService).ServeHTTP) + handler.NewInitErrorHandler(registrationService, eventsAPI).ServeHTTP) + + if appctx.LoadInitType(appCtx) == appctx.InitCaching { + router.Get("/runtime/restore/next", handler.NewRestoreNextHandler(registrationService, renderingService).ServeHTTP) + } return router } @@ -80,14 +84,14 @@ func ExtensionsRouter(appCtx appctx.ApplicationContext, registrationService core // LogsAPIRouter returns a new instance of chi router implementing // Logs API specification. -func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.LogsSubscriptionAPI) http.Handler { +func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.SubscriptionAPI) http.Handler { router := chi.NewRouter() router.Use(middleware.AccessLogMiddleware()) router.Use(middleware.AllowIfExtensionsEnabled) router.Put("/logs", middleware.AgentUniqueIdentifierHeaderValidator( - handler.NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI)).ServeHTTP) + handler.NewRuntimeTelemetrySubscriptionHandler(registrationService, logsSubscriptionAPI)).ServeHTTP) return router } @@ -98,7 +102,32 @@ func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptio func LogsAPIStubRouter() http.Handler { router := chi.NewRouter() - router.Put("/logs", handler.NewRuntimeLogsStubHandler().ServeHTTP) + router.Put("/logs", handler.NewRuntimeLogsAPIStubHandler().ServeHTTP) + + return router +} + +// TelemetryRouter returns a new instance of chi router implementing +// Telemetry API specification. +func TelemetryAPIRouter(registrationService core.RegistrationService, telemetrySubscriptionAPI telemetry.SubscriptionAPI) http.Handler { + router := chi.NewRouter() + router.Use(middleware.AccessLogMiddleware()) + router.Use(middleware.AllowIfExtensionsEnabled) + + router.Put("/telemetry", + middleware.AgentUniqueIdentifierHeaderValidator( + handler.NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscriptionAPI)).ServeHTTP) + + return router +} + +// TelemetryStubRouter returns a new instance of chi router implementing +// a stub of Telemetry API that always returns a non-committal response to +// prevent customer code from crashing when Telemetry API is disabled locally +func TelemetryAPIStubRouter() http.Handler { + router := chi.NewRouter() + + router.Put("/telemetry", handler.NewRuntimeTelemetryAPIStubHandler().ServeHTTP) return router } diff --git a/lambda/rapi/router_test.go b/lambda/rapi/router_test.go index f1cbde8..73cbde1 100644 --- a/lambda/rapi/router_test.go +++ b/lambda/rapi/router_test.go @@ -60,7 +60,7 @@ func assertResponseErrorType(t *testing.T, expectedErrorType string, response *h // rendered as JSON, regardless of the value provided // in "Accept" header. // -// When using render.Render(...), chi rendering library +// When using render.Render(...), rendering function // would attempt to render response using content type // specified in the "Accept" header. // @@ -69,7 +69,7 @@ func assertResponseErrorType(t *testing.T, expectedErrorType string, response *h func TestAcceptXML(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := httptest.NewRecorder() request := httptest.NewRequest("POST", "/runtime/invocation/x-y-z/error", bytes.NewReader([]byte(""))) // Tell server that client side accepts "application/xml". @@ -90,7 +90,7 @@ func TestAcceptXML(t *testing.T) { func Test404PageNotFound(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/unsupported", bytes.NewReader([]byte("")))) assert.Equal(t, http.StatusNotFound, responseRecorder.Code) assert.Equal(t, "404 page not found\n", responseRecorder.Body.String()) @@ -99,7 +99,7 @@ func Test404PageNotFound(t *testing.T) { func Test405MethodNotAllowed(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("DELETE", "/runtime/invocation/ABC/error", bytes.NewReader([]byte("")))) assert.Equal(t, http.StatusMethodNotAllowed, responseRecorder.Code) } @@ -107,7 +107,7 @@ func Test405MethodNotAllowed(t *testing.T) { func TestInitErrorAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/init/error", bytes.NewReader([]byte("{}")))) assert.Equal(t, http.StatusAccepted, responseRecorder.Code) } @@ -115,7 +115,7 @@ func TestInitErrorAccepted(t *testing.T) { func TestInitErrorForbidden(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -126,7 +126,7 @@ func TestInitErrorForbidden(t *testing.T) { func TestInvokeResponseAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -137,7 +137,7 @@ func TestInvokeResponseAccepted(t *testing.T) { func TestInvokeErrorResponseAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -148,7 +148,7 @@ func TestInvokeErrorResponseAccepted(t *testing.T) { func TestInvokeNextTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -159,7 +159,7 @@ func TestInvokeNextTwice(t *testing.T) { func TestInvokeResponseInvalidRequestID(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -171,7 +171,7 @@ func TestInvokeResponseInvalidRequestID(t *testing.T) { func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -183,7 +183,7 @@ func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { func TestInvokeResponseTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -197,7 +197,7 @@ func TestInvokeResponseTwice(t *testing.T) { func TestInvokeErrorResponseTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -211,7 +211,7 @@ func TestInvokeErrorResponseTwice(t *testing.T) { func TestInvokeResponseAfterErrorResponse(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -225,7 +225,7 @@ func TestInvokeResponseAfterErrorResponse(t *testing.T) { func TestInvokeErrorResponseAfterResponse(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -239,7 +239,7 @@ func TestInvokeErrorResponseAfterResponse(t *testing.T) { func TestMoreThanOneInvoke(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) var responseRecorder *httptest.ResponseRecorder for _, id := range []string{"A", "B", "C"} { flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) @@ -250,12 +250,25 @@ func TestMoreThanOneInvoke(t *testing.T) { } } +func TestInitCachingAPIDisabledForPlainInit(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + var responseRecorder *httptest.ResponseRecorder + + responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/restore/next", nil)) + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) + + responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/credentials", nil)) + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) +} + func benchmarkInvoke(b *testing.B, payload []byte) { b.StopTimer() b.ReportAllocs() flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) for i := 0; i < b.N; i++ { id := uuid.New().String() flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) diff --git a/lambda/rapi/security_test.go b/lambda/rapi/security_test.go index 3f869d5..5312b43 100644 --- a/lambda/rapi/security_test.go +++ b/lambda/rapi/security_test.go @@ -20,7 +20,7 @@ func TestInvokeValidId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) @@ -53,7 +53,7 @@ func TestSecurityInvokeResponseBadRequestId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) @@ -100,7 +100,7 @@ func TestSecurityInvokeErrorBadRequestId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) diff --git a/lambda/rapi/server.go b/lambda/rapi/server.go index e2c6ad4..dd027f4 100644 --- a/lambda/rapi/server.go +++ b/lambda/rapi/server.go @@ -13,6 +13,7 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/rendering" "go.amzn.com/lambda/telemetry" @@ -23,6 +24,7 @@ const version20180601 = "/2018-06-01" const version20200101 = "/2020-01-01" const version20200815 = "/2020-08-15" const version20210423 = "/2021-04-23" +const version20220701 = "/2022-07-01" // Server is a Runtime API server type Server struct { @@ -33,6 +35,10 @@ type Server struct { exit chan error } +func SaveConnInContext(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, interop.HTTPConnKey, c) +} + // NewServer creates a new Runtime API Server // // Unlike net/http server's ListenAndServe, we separate Listen() @@ -44,28 +50,30 @@ func NewServer(host string, port int, appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, telemetryAPIEnabled bool, - logsSubscriptionAPI telemetry.LogsSubscriptionAPI, initCachingEnabled bool, credentialsService core.CredentialsService) *Server { + logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI, credentialsService core.CredentialsService, eventsAPI telemetry.EventsAPI) *Server { exitErrors := make(chan error, 1) router := chi.NewRouter() - router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService)) + router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService, eventsAPI)) router.Mount(version20200101, ExtensionsRouter(appCtx, registrationService, renderingService)) if telemetryAPIEnabled { router.Mount(version20200815, LogsAPIRouter(registrationService, logsSubscriptionAPI)) + router.Mount(version20220701, TelemetryAPIRouter(registrationService, telemetrySubscriptionAPI)) } else { router.Mount(version20200815, LogsAPIStubRouter()) + router.Mount(version20220701, TelemetryAPIStubRouter()) } - if initCachingEnabled { + if appctx.LoadInitType(appCtx) == appctx.InitCaching { router.Mount(version20210423, CredentialsAPIRouter(credentialsService)) } return &Server{ host: host, port: port, - server: &http.Server{Handler: router}, + server: &http.Server{Handler: router, ConnContext: SaveConnInContext}, listener: nil, exit: exitErrors, } diff --git a/lambda/rapi/server_test.go b/lambda/rapi/server_test.go index ce6e4e6..cf31fab 100644 --- a/lambda/rapi/server_test.go +++ b/lambda/rapi/server_test.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "testing" "time" @@ -52,7 +51,7 @@ func TestServerReturnsSuccessfulResponse(t *testing.T) { if err != nil { assert.FailNowf(t, "Failed to get response", err.Error()) } - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { assert.FailNowf(t, "Failed to read response body", err.Error()) } diff --git a/lambda/rapid/bootstrap.go b/lambda/rapid/bootstrap.go deleted file mode 100644 index e82ec6c..0000000 --- a/lambda/rapid/bootstrap.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "os" - - "go.amzn.com/lambda/fatalerror" -) - -type Bootstrap interface { - Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable - Env(e EnvironmentVariables) []string // returns the environment variables to be passed to the bootstrapped process - Cwd() (string, error) // returns the working directory of the bootstrap process - ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime - CachedFatalError(err error) (fatalerror.ErrorType, string, bool) -} diff --git a/lambda/rapid/exit.go b/lambda/rapid/exit.go index af5cb72..e45f3a4 100644 --- a/lambda/rapid/exit.go +++ b/lambda/rapid/exit.go @@ -5,8 +5,9 @@ package rapid import ( "fmt" - "os" + "time" + "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" @@ -15,102 +16,116 @@ import ( log "github.com/sirupsen/logrus" ) -func checkInteropError(format string, err error) { - if err == interop.ErrInvalidInvokeID || err == interop.ErrResponseSent { - log.Warnf(format, err) - } else { - log.Panicf(format, err) +func handleInvokeError(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { + invokeFailure := newInvokeFailureMsg(execCtx, invokeRequest, invokeMx, err) + resp := model.ErrorResponse{ + ErrorType: string(invokeFailure.ErrorType), + ErrorMessage: fmt.Sprintf("Error: %v", invokeFailure.ErrorMessage), } -} -func trySendDefaultErrorResponse(interopServer interop.Server, invokeID string, errorType fatalerror.ErrorType, err error) { - resp := model.ErrorResponse{ - ErrorType: string(errorType), - ErrorMessage: fmt.Sprintf("Error: %v", err), + if invokeRequest.ID != "" { + resp.ErrorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeRequest.ID, invokeFailure.ErrorMessage) } - if invokeID != "" { - resp.ErrorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeID, err) + // This is the default error response that gets sent back as the function response in failure cases + invokeFailure.DefaultErrorResponse = resp.AsInteropError() + + // Invoke with extensions disabled maintains behaviour parity with pre-extensions rapid + if !extensions.AreEnabled() { + invokeFailure.RequestReset = false + return invokeFailure } - if err := interopServer.SendErrorResponse(invokeID, resp.AsInteropError()); err != nil { - checkInteropError("Failed to send default error response: %s", err) + if err == errResetReceived { + // errResetReceived is returned when execution flow was interrupted by the Reset message, + // hence this error deserves special handling and we yield to main receive loop to handle it + invokeFailure.ResetReceived = true + return invokeFailure } + + invokeFailure.RequestReset = true + return invokeFailure } -func reportErrorAndExit(doneFailMsg *interop.DoneFail, invokeID string, interopServer interop.Server, err error) { - // This function maintains compatibility of exit sequence behaviour - // with Sandbox Factory in the absence of extensions - - // NOTE this check will prevent us from sending FAULT message in case - // response (positive or negative) has already been sent. This is done - // to maintain legacy behavior of RAPID. - // ALSO NOTE, this works in case of positive response because this will - // be followed by RAPID exit. - if !interopServer.IsResponseSent() { - trySendDefaultErrorResponse(interopServer, invokeID, doneFailMsg.ErrorType, err) +func newInvokeFailureMsg(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { + errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + errorType = fatalerror.Unknown } - if err := interopServer.CommitResponse(); err != nil { - checkInteropError("Failed to commit error response: %s", err) + invokeFailure := &interop.InvokeFailure{ + ErrorType: errorType, + ErrorMessage: err, + RequestReset: true, + ResetReceived: false, + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + InvokeReceivedTime: invokeRequest.InvokeReceivedTime, } - // old behavior: no DoneFails - doneMsg := &interop.Done{ - WaitForExit: true, - CorrelationID: doneFailMsg.CorrelationID, // required for standalone mode - Meta: doneFailMsg.Meta, + if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeFailure.ResponseMetrics.RuntimeTimeThrottledMs = invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) + invokeFailure.ResponseMetrics.RuntimeProducedBytes = invokeRequest.InvokeResponseMetrics.ProducedBytes + invokeFailure.ResponseMetrics.RuntimeOutboundThroughputBps = invokeRequest.InvokeResponseMetrics.OutboundThroughputBps } - if err := interopServer.SendDone(doneMsg); err != nil { - checkInteropError("Failed to send DONE during exit: %s", err) + if invokeMx != nil { + invokeFailure.InvokeMetrics.InvokeRequestReadTimeNs = invokeMx.rendererMetrics.ReadTime.Nanoseconds() + invokeFailure.InvokeMetrics.InvokeRequestSizeBytes = int64(invokeMx.rendererMetrics.SizeBytes) + invokeFailure.InvokeMetrics.RuntimeReadyTime = int64(invokeMx.runtimeReadyTime) + invokeFailure.ExtensionNames = execCtx.GetExtensionNames() } - os.Exit(1) -} - -func reportErrorAndRequestReset(doneFailMsg *interop.DoneFail, invokeID string, interopServer interop.Server, err error) { - if err == errResetReceived { - // errResetReceived is returned when execution flow was interrupted by the Reset message, - // hence this error deserves special handling and we yield to main receive loop to handle it - return + if execCtx.telemetryAPIEnabled { + invokeFailure.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } - trySendDefaultErrorResponse(interopServer, invokeID, doneFailMsg.ErrorType, err) - - if err := interopServer.CommitResponse(); err != nil { - checkInteropError("Failed to commit error response: %s", err) - } + return invokeFailure +} - if err := interopServer.SendDoneFail(doneFailMsg); err != nil { - checkInteropError("Failed to send DONEFAIL: %s", err) +func generateInitFailureMsg(execCtx *rapidContext, err error) interop.InitFailure { + errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + errorType = fatalerror.Unknown } -} -func handleInitError(doneFailMsg *interop.DoneFail, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error) { - if execCtx.standaloneMode { - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) - return + initFailureMsg := interop.InitFailure{ + RequestReset: true, + ErrorType: errorType, + ErrorMessage: err, + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + Ack: make(chan struct{}), } - if !execCtx.HasActiveExtensions() { - // we don't expect Slicer to send RESET during INIT, that's why we Exit here - reportErrorAndExit(doneFailMsg, invokeID, interopServer, err) + if execCtx.telemetryAPIEnabled { + initFailureMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) + return initFailureMsg } -func handleInvokeError(doneFailMsg *interop.DoneFail, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error) { - if execCtx.standaloneMode { - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) +func handleInitError(execCtx *rapidContext, invokeID string, err error, initFailureResponse chan<- interop.InitFailure) { + log.WithError(err).WithField("InvokeID", invokeID).Error("Init failed") + initFailureMsg := generateInitFailureMsg(execCtx, err) + + if err == errResetReceived { + // errResetReceived is returned when execution flow was interrupted by the Reset message, + // hence this error deserves special handling and we yield to main receive loop to handle it + initFailureMsg.ResetReceived = true + initFailureResponse <- initFailureMsg + <-initFailureMsg.Ack return } - // Invoke with extensions disabled maintains behaviour parity with pre-extensions rapid - if !extensions.AreEnabled() { - reportErrorAndExit(doneFailMsg, invokeID, interopServer, err) + if !execCtx.HasActiveExtensions() && !execCtx.standaloneMode { + // different behaviour when no extensions are present, + // for compatibility with previous implementations + initFailureMsg.RequestReset = false + } else { + initFailureMsg.RequestReset = true } - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) + initFailureResponse <- initFailureMsg + <-initFailureMsg.Ack } diff --git a/lambda/rapid/graceful_shutdown.go b/lambda/rapid/graceful_shutdown.go deleted file mode 100644 index 5ad1326..0000000 --- a/lambda/rapid/graceful_shutdown.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "syscall" - "time" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/rapi/rendering" - - log "github.com/sirupsen/logrus" -) - -func sigkillProcessGroup(pid int, sigkilledPids map[int]bool) map[int]bool { - pgid, err := syscall.Getpgid(pid) - if err == nil { - syscall.Kill(-pgid, 9) // Negative pid sends signal to all in process group - } else { - syscall.Kill(pid, 9) - } - sigkilledPids[pid] = true - - return sigkilledPids -} - -func awaitSigkilledProcessesToExit(exitPidChan chan int, processesExited, sigkilledPidsToAwait map[int]bool) { - for pid := range processesExited { - delete(sigkilledPidsToAwait, pid) - } - - for len(sigkilledPidsToAwait) != 0 { - pid := <-exitPidChan - _, found := sigkilledPidsToAwait[pid] - if !found { - log.Warnf("Unexpected process %d exited while waiting for sigkilled processes to exit", pid) - } else { - delete(sigkilledPidsToAwait, pid) - } - } -} - -func gracefulShutdown(execCtx *rapidContext, watchdog *core.Watchdog, profiler *metering.ExtensionsResetDurationProfiler, deadlineNs int64, killAgents bool, reason string) { - watchdog.Mute() - defer watchdog.Unmute() - - if execCtx.registrationService.CountAgents() == 0 { - // We do not spend any compute time on runtime graceful shutdown if there are no agents - if runtime := execCtx.registrationService.GetRuntime(); runtime != nil && runtime.Pid != 0 { - sigkilledPids := sigkillProcessGroup(runtime.Pid, map[int]bool{}) - if execCtx.standaloneMode { - processesExited := map[int]bool{} - awaitSigkilledProcessesToExit(execCtx.exitPidChan, processesExited, sigkilledPids) - } - } - return - } - - mono := metering.Monotime() - - availableNs := deadlineNs - mono - - if availableNs < 0 { - log.Warnf("Deadline is in the past: %v, %v, %v", mono, deadlineNs, availableNs) - availableNs = 0 - } - - profiler.AvailableNs = availableNs - - start := time.Now() - profiler.Start() - - runtimeDeadline := start.Add(time.Duration(float64(availableNs) * runtimeDeadlineShare)) - agentsDeadline := start.Add(time.Duration(availableNs)) - - sigkilledPids := make(map[int]bool) // Track process ids that were sent sigkill - processesExited := make(map[int]bool) // Don't send sigkill to processes that exit out of order - - processesExited, sigkilledPids = shutdownRuntime(execCtx, start, runtimeDeadline, processesExited, sigkilledPids) - processesExited, sigkilledPids = shutdownAgents(execCtx, start, profiler, agentsDeadline, killAgents, reason, processesExited, sigkilledPids) - if execCtx.standaloneMode { - awaitSigkilledProcessesToExit(execCtx.exitPidChan, processesExited, sigkilledPids) - } - - profiler.Stop() -} - -func shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time, processesExited, sigkilledPids map[int]bool) (map[int]bool, map[int]bool) { - // If runtime is started: - // 1. SIGTERM and wait until timeout - // 2. SIGKILL on timeout - - log.Debug("shutdown runtime") - runtime := execCtx.registrationService.GetRuntime() - if runtime == nil || runtime.Pid == 0 { - log.Warn("Runtime not started") - return processesExited, sigkilledPids - } - - syscall.Kill(runtime.Pid, syscall.SIGTERM) - - runtimeTimeout := deadline.Sub(start) - runtimeTimer := time.NewTimer(runtimeTimeout) - - for { - select { - case pid := <-execCtx.exitPidChan: - processesExited[pid] = true - if pid == runtime.Pid { - log.Info("runtime exited") - return processesExited, sigkilledPids - } - - log.Warnf("Process %d exited unexpectedly", pid) - case <-runtimeTimer.C: - log.Warnf("Timeout: no SIGCHLD from Runtime after %d ms; dispatching SIGKILL to runtime process group", int64(runtimeTimeout/time.Millisecond)) - sigkilledPids = sigkillProcessGroup(runtime.Pid, sigkilledPids) - return processesExited, sigkilledPids - } - } -} - -func shutdownAgents(execCtx *rapidContext, start time.Time, profiler *metering.ExtensionsResetDurationProfiler, deadline time.Time, killAgents bool, reason string, processesExited, sigkilledPids map[int]bool) (map[int]bool, map[int]bool) { - // For each external agent, if agent is launched: - // 1. Send Shutdown event if subscribed for it, else send SIGKILL to process group - // 2. Wait for all Shutdown-subscribed agents to exit with timeout - // 3. Send SIGKILL to process group for Shutdown-subscribed agents on timeout - - log.Debug("shutdown agents") - execCtx.renderingService.SetRenderer( - &rendering.ShutdownRenderer{ - AgentEvent: model.AgentShutdownEvent{ - AgentEvent: &model.AgentEvent{ - EventType: "SHUTDOWN", - DeadlineMs: deadline.UnixNano() / (1000 * 1000), - }, - ShutdownReason: reason, - }, - }) - - pidsToShutdown := make(map[int]*core.ExternalAgent) - for _, a := range execCtx.registrationService.GetExternalAgents() { - if a.Pid == 0 { - log.Warnf("Agent %s failed not launched; skipping shutdown", a) - continue - } - if a.IsSubscribed(core.ShutdownEvent) { - pidsToShutdown[a.Pid] = a - a.Release() - } else { - if !processesExited[a.Pid] { - sigkilledPids = sigkillProcessGroup(a.Pid, sigkilledPids) - } - } - } - profiler.NumAgentsRegisteredForShutdown = len(pidsToShutdown) - - var timerChan <-chan time.Time // default timerChan - if killAgents { - timerChan = time.NewTimer(deadline.Sub(start)).C // timerChan with deadline - } - - timeoutExceeded := false - for !timeoutExceeded && len(pidsToShutdown) != 0 { - select { - case pid := <-execCtx.exitPidChan: - processesExited[pid] = true - a, found := pidsToShutdown[pid] - if !found { - log.Warnf("Process %d exited unexpectedly", pid) - } else { - if err := a.Exited(); err != nil { - log.Warnf("%s failed to transition to EXITED: %s (current state: %s)", a.String(), err, a.GetState().Name()) - } - delete(pidsToShutdown, pid) - } - case <-timerChan: - timeoutExceeded = true - } - } - - if len(pidsToShutdown) != 0 { - for pid, agent := range pidsToShutdown { - if err := agent.ShutdownFailed(); err != nil { - log.Warnf("%s failed to transition to ShutdownFailed: %s (current state: %s)", agent, err, agent.GetState().Name()) - } - log.Warnf("Killing agent %s which failed to shutdown", agent) - if !processesExited[pid] { - sigkilledPids = sigkillProcessGroup(pid, sigkilledPids) - } - } - } - - return processesExited, sigkilledPids -} diff --git a/lambda/rapid/sandbox.go b/lambda/rapid/sandbox.go index a5614b0..9259514 100644 --- a/lambda/rapid/sandbox.go +++ b/lambda/rapid/sandbox.go @@ -7,49 +7,43 @@ import ( "context" "fmt" "io" + "sync" + "time" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/logging" "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi" "go.amzn.com/lambda/rapi/rendering" + supvmodel "go.amzn.com/lambda/supervisor/model" "go.amzn.com/lambda/telemetry" -) -type EnvironmentVariables interface { - AgentExecEnv() []string - RuntimeExecEnv() []string - SetHandler(handler string) - StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) - StoreEnvironmentVariablesFromInit(customerEnv map[string]string, - handler, awsKey, awsSecret, awsSession, funcName, funcVer string) - StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) -} + log "github.com/sirupsen/logrus" +) type Sandbox struct { - EnableTelemetryAPI bool - StandaloneMode bool - Bootstrap Bootstrap - InteropServer interop.Server - Tracer telemetry.Tracer - LogsSubscriptionAPI telemetry.LogsSubscriptionAPI - LogsEgressAPI telemetry.LogsEgressAPI - Environment EnvironmentVariables - DebugTailLogger *logging.TailLogWriter - PlatformLogger logging.PlatformLogger - RuntimeStdoutWriter io.Writer - RuntimeStderrWriter io.Writer - PreLoadTimeNs int64 - Handler string - SignalCtx context.Context - EventsAPI telemetry.EventsAPI - InitCachingEnabled bool + EnableTelemetryAPI bool + StandaloneMode bool + InteropServer interop.Server + Tracer telemetry.Tracer + LogsSubscriptionAPI telemetry.SubscriptionAPI + TelemetrySubscriptionAPI telemetry.SubscriptionAPI + LogsEgressAPI telemetry.StdLogsEgressAPI + RuntimeStdoutWriter io.Writer + RuntimeStderrWriter io.Writer + PreLoadTimeNs int64 + Handler string + SignalCtx context.Context + EventsAPI telemetry.EventsAPI + InitCachingEnabled bool + Supervisor supvmodel.Supervisor + RuntimeAPIHost string + RuntimeAPIPort int } // Start is a public version of start() that exports only configurable parameters -func Start(s *Sandbox) { +func Start(s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, string) { appCtx := appctx.NewApplicationContext() initFlow := core.NewInitFlowSynchronization() invokeFlow := core.NewInvokeFlowSynchronization() @@ -57,19 +51,18 @@ func Start(s *Sandbox) { renderingService := rendering.NewRenderingService() credentialsService := core.NewCredentialsService() - if s.StandaloneMode { - s.InteropServer.SetInternalStateGetter(registrationService.GetInternalStateDescriptor(appCtx)) - } - server := rapi.NewServer(RuntimeAPIHost, RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.InitCachingEnabled, credentialsService) - - postLoadTimeNs := metering.Monotime() + appctx.StoreInitType(appCtx, s.InitCachingEnabled) + server := rapi.NewServer(s.RuntimeAPIHost, s.RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.TelemetrySubscriptionAPI, credentialsService, s.EventsAPI) runtimeAPIAddr := fmt.Sprintf("%s:%d", server.Host(), server.Port()) - s.Environment.StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddr) + postLoadTimeNs := metering.Monotime() + + // TODO: pass this directly down to HTTP servers and handlers, instead of using + // global state to share the interop server implementation appctx.StoreInteropServer(appCtx, s.InteropServer) - start(s.SignalCtx, &rapidContext{ + execCtx := &rapidContext{ server: server, appCtx: appCtx, postLoadTimeNs: postLoadTimeNs, @@ -78,24 +71,86 @@ func Start(s *Sandbox) { invokeFlow: invokeFlow, registrationService: registrationService, renderingService: renderingService, - exitPidChan: make(chan int), - resetChan: make(chan *interop.Reset), credentialsService: credentialsService, - telemetryAPIEnabled: s.EnableTelemetryAPI, - logsSubscriptionAPI: s.LogsSubscriptionAPI, - logsEgressAPI: s.LogsEgressAPI, - bootstrap: s.Bootstrap, - interopServer: s.InteropServer, - xray: s.Tracer, - environment: s.Environment, - standaloneMode: s.StandaloneMode, - debugTailLogger: s.DebugTailLogger, - platformLogger: s.PlatformLogger, - runtimeStdoutWriter: s.RuntimeStdoutWriter, - runtimeStderrWriter: s.RuntimeStderrWriter, - preLoadTimeNs: s.PreLoadTimeNs, - eventsAPI: s.EventsAPI, - initCachingEnabled: s.InitCachingEnabled, - }) + telemetryAPIEnabled: s.EnableTelemetryAPI, + logsSubscriptionAPI: s.LogsSubscriptionAPI, + telemetrySubscriptionAPI: s.TelemetrySubscriptionAPI, + logsEgressAPI: s.LogsEgressAPI, + interopServer: s.InteropServer, + xray: s.Tracer, + standaloneMode: s.StandaloneMode, + preLoadTimeNs: s.PreLoadTimeNs, + eventsAPI: s.EventsAPI, + initCachingEnabled: s.InitCachingEnabled, + signalCtx: s.SignalCtx, + supervisor: s.Supervisor, + executionMutex: sync.Mutex{}, + shutdownContext: newShutdownContext(), + } + + // We call /ping on Supervisor before starting Rapid, since Rapid + // depends on Supervisor setting up networking dependencies + var startupErr error + for retries := 1; retries <= 5; retries++ { + if startupErr = s.Supervisor.Ping(); startupErr == nil { + break + } + // Retry timeout: 5s, same order-of-mag as test client PING retries + // TODO: revisit retry timeout, identify appropriate value for prod. + time.Sleep(1000 * time.Millisecond) + } + + if startupErr != nil { + log.Panicf("Application ping to Supervisor failed, terminating Rapid Startup: %s", startupErr) + } + + go start(s.SignalCtx, execCtx) + + return execCtx, registrationService.GetInternalStateDescriptor(appCtx), runtimeAPIAddr +} + +func (r *rapidContext) HandleInit(init *interop.Init, initStartedResponseChan chan<- interop.InitStarted, initSuccessResponseChan chan<- interop.InitSuccess, initFailureResponseChan chan<- interop.InitFailure) { + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + handleInit(r, init, initStartedResponseChan, initSuccessResponseChan, initFailureResponseChan) +} + +func (r *rapidContext) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + // Clear the context used by the last invok + r.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) + return handleInvoke(r, invoke, sbInfoFromInit) +} + +func (r *rapidContext) HandleReset(reset *interop.Reset, invokeReceivedTime int64, InvokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { + // In the event of a Reset during init/invoke, CancelFlows cancels execution + // flows and return with the errResetReceived err - this error is special-cased + // and not handled by the init/invoke (unexpected) error handling functions + r.registrationService.CancelFlows(errResetReceived) + + // Wait until invoke error handling has returned before continuing execution + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + + // Clear the context used by the last invoke, i.e. error message etc. + r.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) + return handleReset(r, reset, invokeReceivedTime, InvokeResponseMetrics) +} + +func (r *rapidContext) HandleShutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { + // Wait until invoke error handling has returned before continuing execution + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + // Shutdown doesn't cancel flows, so it can block forever + return handleShutdown(r, shutdown, standaloneShutdownReason) +} + +func (r *rapidContext) HandleRestore(restore *interop.Restore) error { + return handleRestore(r, restore) +} + +func (r *rapidContext) Clear() { + reinitialize(r) } diff --git a/lambda/rapid/shutdown.go b/lambda/rapid/shutdown.go new file mode 100644 index 0000000..fe23a9f --- /dev/null +++ b/lambda/rapid/shutdown.go @@ -0,0 +1,366 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package rapid implements synchronous even dispatch loop. +package rapid + +import ( + "fmt" + "sync" + "time" + + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/metering" + "go.amzn.com/lambda/rapi/model" + "go.amzn.com/lambda/rapi/rendering" + supvmodel "go.amzn.com/lambda/supervisor/model" + + log "github.com/sirupsen/logrus" +) + +const ( + // supervisor shutdown and kill operations block until the exit status of the + // interested process has been collected, or until the specified timeotuw + // expires (in which case the operation fails). + // Note that this timeout is mainly relevant when any of the domain + // processes are in uninterruptible sleep state (notable examples: syscall + // to read/write a newtorked driver) + // + // We set a non nil value for these timeouts so that RAPID doesn't block + // forever in one of the cases above. + supervisorBlockingMaxMillis = 9000 + runtimeDeadlineShare = 0.3 +) + +type shutdownContext struct { + // Adding a mutex around shuttingDown because there may be concurrent reads/writes. + // Because the code in shutdown() and the seperate go routine created in setupEventsWatcher() + // could be concurrently accessing the field shuttingDown. + shuttingDownMutex sync.Mutex + shuttingDown bool + agentsAwaitingExit map[string]*core.ExternalAgent + // Adding a mutex around runtimeDomainExited because there may be concurrent reads/writes. + // The first reason this can be caused is by different go routines reading/writing different keys. + // The second reason this can be caused is between the code shutting down the runtime/extensions and + // handleProcessExit in a separate go routine, reading and writing to the same key. Caused by + // unexpected exits. + runtimeDomainExitedMutex sync.Mutex + // used to synchronize on processes exits. We create the channel when a + // process is started and we close it upon exit notification from + // supervisor. Closing the channel is basically a persistent broadcast of process exit. + // We never write anything to the channels + runtimeDomainExited map[string]chan struct{} +} + +func newShutdownContext() *shutdownContext { + return &shutdownContext{ + shuttingDownMutex: sync.Mutex{}, + shuttingDown: false, + agentsAwaitingExit: make(map[string]*core.ExternalAgent), + runtimeDomainExited: make(map[string]chan struct{}), + runtimeDomainExitedMutex: sync.Mutex{}, + } +} + +func (s *shutdownContext) isShuttingDown() bool { + s.shuttingDownMutex.Lock() + defer s.shuttingDownMutex.Unlock() + return s.shuttingDown +} + +func (s *shutdownContext) setShuttingDown(value bool) { + s.shuttingDownMutex.Lock() + defer s.shuttingDownMutex.Unlock() + s.shuttingDown = value +} + +func (s *shutdownContext) handleProcessExit(termination supvmodel.ProcessTermination) { + + name := *termination.Name + agent, found := s.agentsAwaitingExit[name] + + // If it is an agent registered to receive a shutdown event. + if found { + log.Debugf("Handling termination for %s", name) + exitStatus := termination.Exited() + if exitStatus != nil && *exitStatus == 0 { + // If the agent exited by itself after receiving the shutdown event. + stateErr := agent.Exited() + if stateErr != nil { + log.Warnf("%s failed to transition to EXITED: %s (current state: %s)", agent.String(), stateErr, agent.GetState().Name()) + } + } else { + // If the agent did not exit by itself, had to be SIGKILLed (only in standalone mode). + stateErr := agent.ShutdownFailed() + if stateErr != nil { + log.Warnf("%s failed to transition to ShutdownFailed: %s (current state: %s)", agent, stateErr, agent.GetState().Name()) + } + } + } + + exitedChannel, found := s.getExitedChannel(name) + + if !found { + log.Panicf("Unable to find an exitedChannel for '%s', it should have been created just after it was execed.", name) + } + // we close the channel so that whoever is blocked on it + // or will try to block on it in the future unblocks immediately + close(exitedChannel) +} + +func (s *shutdownContext) getExitedChannel(name string) (chan struct{}, bool) { + s.runtimeDomainExitedMutex.Lock() + defer s.runtimeDomainExitedMutex.Unlock() + exitedChannel, found := s.runtimeDomainExited[name] + return exitedChannel, found +} + +func (s *shutdownContext) createExitedChannel(name string) { + s.runtimeDomainExitedMutex.Lock() + defer s.runtimeDomainExitedMutex.Unlock() + + _, found := s.runtimeDomainExited[name] + + if found { + log.Panicf("Tried to create an exited channel for '%s' but one already exists.", name) + } + s.runtimeDomainExited[name] = make(chan struct{}) +} + +// Blocks until all the processes in the runtime domain generation have exited. +// This helps us have a nice sync point on Shutdown where we know for sure that +// all the processes have exited and the state has been cleared. +// +// It is OK not to hold the lock because we know that this is called only during +// shutdown and nobody will start a new process during shutdown +func (s *shutdownContext) clearExitedChannel() { + s.runtimeDomainExitedMutex.Lock() + mapLen := len(s.runtimeDomainExited) + channels := make([]chan struct{}, 0, mapLen) + for _, v := range s.runtimeDomainExited { + channels = append(channels, v) + } + s.runtimeDomainExitedMutex.Unlock() + + for _, v := range channels { + <-v + } + + s.runtimeDomainExitedMutex.Lock() + s.runtimeDomainExited = make(map[string]chan struct{}, mapLen) + s.runtimeDomainExitedMutex.Unlock() +} + +func (s *shutdownContext) shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time) { + // If runtime is started: + // 1. SIGTERM and wait until timeout + // 2. SIGKILL on timeout + log.Debug("Shutting down the runtime.") + name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) + exitedChannel, found := s.getExitedChannel(name) + + if found { + + err := execCtx.supervisor.Terminate(&supvmodel.TerminateRequest{ + Domain: RuntimeDomain, + Name: name, + }) + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Termination signal to runtime") + } + + runtimeTimeout := deadline.Sub(start) + log.Tracef("The runtime timeout is %v.", runtimeTimeout) + runtimeTimer := time.NewTimer(runtimeTimeout) + select { + case <-runtimeTimer.C: + log.Warnf("Timeout: The runtime did not exit after %d ms; Killing it.", int64(runtimeTimeout/time.Millisecond)) + supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) + err = execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Kill signal to runtime") + } + case <-exitedChannel: + } + } else { + log.Warn("The runtime was not started.") + } + log.Debug("Shutdown the runtime.") +} + +func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, deadline time.Time, reason string) { + // For each external agent, if agent is launched: + // 1. Send Shutdown event if subscribed for it, else send SIGKILL to process group + // 2. Wait for all Shutdown-subscribed agents to exit with timeout + // 3. Send SIGKILL to process group for Shutdown-subscribed agents on timeout + + log.Debug("Shutting down the agents.") + execCtx.renderingService.SetRenderer( + &rendering.ShutdownRenderer{ + AgentEvent: model.AgentShutdownEvent{ + AgentEvent: &model.AgentEvent{ + EventType: "SHUTDOWN", + DeadlineMs: deadline.UnixNano() / (1000 * 1000), + }, + ShutdownReason: reason, + }, + }) + + var wg sync.WaitGroup + + // clear agentsAwaitingExit from last shutdownAgents + s.agentsAwaitingExit = make(map[string]*core.ExternalAgent) + + for _, a := range execCtx.registrationService.GetExternalAgents() { + name := fmt.Sprintf("extension-%s-%d", a.Name, execCtx.runtimeDomainGeneration) + exitedChannel, found := s.getExitedChannel(name) + supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) + + if !found { + log.Warnf("Agent %s failed to launch, therefore skipping shutting it down.", a) + continue + } + + wg.Add(1) + + if a.IsSubscribed(core.ShutdownEvent) { + log.Debugf("Agent %s is registered for the shutdown event.", a) + s.agentsAwaitingExit[name] = a + + go func(name string, agent *core.ExternalAgent) { + defer wg.Done() + + agent.Release() + + agentTimeout := deadline.Sub(start) + var agentTimeoutChan <-chan time.Time + if execCtx.standaloneMode { + agentTimeoutChan = time.NewTimer(agentTimeout).C + } + + select { + case <-agentTimeoutChan: + log.Warnf("Timeout: the agent %s did not exit after %d ms; Killing it.", name, int64(agentTimeout/time.Millisecond)) + err := execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Kill signal to runtime") + } + case <-exitedChannel: + } + }(name, a) + } else { + log.Debugf("Agent %s is not registered for the shutdown event, so just killing it.", a) + + go func(name string) { + defer wg.Done() + + execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + }(name) + } + } + + // Wait on the agents subscribed to the shutdown event to voluntary shutting down after receiving the shutdown event or be sigkilled. + // In addition to waiting on the agents not subscribed to the shutdown event being sigkilled. + wg.Wait() + log.Debug("Shutdown the agents.") +} + +func (s *shutdownContext) shutdown(execCtx *rapidContext, deadlineNs int64, reason string) (int64, bool, error) { + var err error + s.setShuttingDown(true) + defer s.setShuttingDown(false) + + // Fatal errors such as Runtime exit and Extension.Crash + // are ignored by the events watcher when shutting down + execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) + + runtimeDomainProfiler := &metering.ExtensionsResetDurationProfiler{} + supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) + + // We do not spend any compute time on runtime graceful shutdown if there are no agents + if execCtx.registrationService.CountAgents() == 0 { + name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) + + _, found := s.getExitedChannel(name) + + if found { + log.Debug("SIGKILLing the runtime as no agents are registered.") + err = execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Kill signal to runtime") + } + } else { + log.Debugf("Could not find runtime process %s in processes map. Already exited/never started", name) + } + } else { + mono := metering.Monotime() + availableNs := deadlineNs - mono + + if availableNs < 0 { + log.Warnf("Deadline is in the past: %v, %v, %v", mono, deadlineNs, availableNs) + availableNs = 0 + } + + start := time.Now() + + runtimeDeadline := start.Add(time.Duration(float64(availableNs) * runtimeDeadlineShare)) + agentsDeadline := start.Add(time.Duration(availableNs)) + + runtimeDomainProfiler.AvailableNs = availableNs + runtimeDomainProfiler.Start() + + s.shutdownRuntime(execCtx, start, runtimeDeadline) + s.shutdownAgents(execCtx, start, agentsDeadline, reason) + + runtimeDomainProfiler.NumAgentsRegisteredForShutdown = len(s.agentsAwaitingExit) + } + log.Info("Stopping runtime domain") + err = execCtx.supervisor.Stop(&supvmodel.StopRequest{ + Domain: RuntimeDomain, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + log.WithError(err).Error("Failed shutting runtime domain down") + } else { + log.Info("Waiting for runtime domain processes termination") + s.clearExitedChannel() + log.Info("Stopping operator domain") + err = execCtx.supervisor.Stop(&supvmodel.StopRequest{ + Domain: OperatorDomain, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + log.WithError(err).Error("Failed shutting operator domain down") + } + } + + runtimeDomainProfiler.Stop() + extensionsRestMs, timeout := runtimeDomainProfiler.CalculateExtensionsResetMs() + return extensionsRestMs, timeout, err +} diff --git a/lambda/rapid/start.go b/lambda/rapid/start.go index 087ef13..76337af 100644 --- a/lambda/rapid/start.go +++ b/lambda/rapid/start.go @@ -7,9 +7,11 @@ package rapid import ( "context" "errors" - "io" + "fmt" "os" + "path" "strings" + "sync" "time" "go.amzn.com/lambda/agents" @@ -18,11 +20,11 @@ import ( "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/logging" "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi" + "go.amzn.com/lambda/rapi/model" "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/runtimecmd" + supvmodel "go.amzn.com/lambda/supervisor/model" "go.amzn.com/lambda/telemetry" "github.com/google/uuid" @@ -31,11 +33,11 @@ import ( ) const ( - RuntimeAPIHost = "127.0.0.1" - RuntimeAPIPort = 9001 + RuntimeDomain = "runtime" + OperatorDomain = "operator" defaultAgentLocation = "/opt/extensions" - runtimeDeadlineShare = 0.3 disableExtensionsFile = "/opt/disable-extensions-jwigqn8j" + runtimeProcessName = "runtime" ) const ( @@ -47,37 +49,38 @@ const ( var errResetReceived = errors.New("errResetReceived") type rapidContext struct { - bootstrap Bootstrap - interopServer interop.Server - server *rapi.Server - appCtx appctx.ApplicationContext - preLoadTimeNs int64 - postLoadTimeNs int64 - startRequest *interop.Start - initDone bool - initFlow core.InitFlowSynchronization - invokeFlow core.InvokeFlowSynchronization - registrationService core.RegistrationService - renderingService *rendering.EventRenderingService - telemetryAPIEnabled bool - logsSubscriptionAPI telemetry.LogsSubscriptionAPI - logsEgressAPI telemetry.LogsEgressAPI - xray telemetry.Tracer - exitPidChan chan int - resetChan chan *interop.Reset - environment EnvironmentVariables - standaloneMode bool - debugTailLogger *logging.TailLogWriter - platformLogger logging.PlatformLogger - runtimeStdoutWriter io.Writer - runtimeStderrWriter io.Writer - eventsAPI telemetry.EventsAPI - initCachingEnabled bool - credentialsService core.CredentialsService + interopServer interop.Server + server *rapi.Server + appCtx appctx.ApplicationContext + preLoadTimeNs int64 + postLoadTimeNs int64 + initDone bool + supervisor supvmodel.Supervisor + runtimeDomainGeneration uint32 + initFlow core.InitFlowSynchronization + invokeFlow core.InvokeFlowSynchronization + registrationService core.RegistrationService + renderingService *rendering.EventRenderingService + telemetryAPIEnabled bool + logsSubscriptionAPI telemetry.SubscriptionAPI + telemetrySubscriptionAPI telemetry.SubscriptionAPI + logsEgressAPI telemetry.StdLogsEgressAPI + xray telemetry.Tracer + standaloneMode bool + eventsAPI telemetry.EventsAPI + initCachingEnabled bool + credentialsService core.CredentialsService + signalCtx context.Context + executionMutex sync.Mutex + shutdownContext *shutdownContext } +// Validate interface compliance +var _ interop.RapidContext = (*rapidContext)(nil) + type invokeMetrics struct { - rendererMetrics rendering.InvokeRendererMetrics + rendererMetrics rendering.InvokeRendererMetrics + runtimeReadyTime int64 } @@ -102,7 +105,7 @@ func (c *rapidContext) GetExtensionNames() string { func logAgentsInitStatus(execCtx *rapidContext) { for _, agent := range execCtx.registrationService.AgentsInfo() { - execCtx.platformLogger.LogExtensionInitEvent(agent.Name, agent.State, agent.ErrorType, agent.Subscriptions) + execCtx.eventsAPI.SendExtensionInit(agent.Name, agent.State, agent.ErrorType, agent.Subscriptions) } } @@ -113,8 +116,7 @@ func agentLaunchError(agent *core.ExternalAgent, appCtx appctx.ApplicationContex appctx.StoreFirstFatalError(appCtx, fatalerror.AgentLaunchError) } -func doInitExtensions(execCtx *rapidContext, watchdog *core.Watchdog) error { - agentPaths := agents.ListExternalAgentPaths(defaultAgentLocation) +func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, env interop.EnvironmentVariables) error { initFlow := execCtx.registrationService.InitFlow() // we don't bring it into the loop below because we don't want unnecessary broadcasts on agent gate @@ -123,38 +125,42 @@ func doInitExtensions(execCtx *rapidContext, watchdog *core.Watchdog) error { } for _, agentPath := range agentPaths { - env := execCtx.environment.AgentExecEnv() - - agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() + // Using path.Base(agentPath) not agentName because the agent name is contact, as standalone can get the internal state. + agent, err := execCtx.registrationService.CreateExternalAgent(path.Base(agentPath)) if err != nil { return err } - // Compose debug log writer with all log sinks. Debug log writer w - // will not write logs when disabled by invoke parameter - agentStdoutWriter = io.MultiWriter(execCtx.debugTailLogger, agentStdoutWriter) - agentStderrWriter = io.MultiWriter(execCtx.debugTailLogger, agentStderrWriter) + if execCtx.registrationService.CountAgents() > core.MaxAgentsAllowed { + agentLaunchError(agent, execCtx.appCtx, core.ErrTooManyExtensions) + return core.ErrTooManyExtensions + } - agentProc := agents.NewExternalAgentProcess(agentPath, env, agentStdoutWriter, agentStderrWriter) + env := env.AgentExecEnv() - agent, err := execCtx.registrationService.CreateExternalAgent(agentProc.Name()) + agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() if err != nil { return err } + agentName := fmt.Sprintf("extension-%s-%d", path.Base(agentPath), execCtx.runtimeDomainGeneration) - if execCtx.registrationService.CountAgents() > core.MaxAgentsAllowed { - agentLaunchError(agent, execCtx.appCtx, core.ErrTooManyExtensions) - return core.ErrTooManyExtensions - } + err = execCtx.supervisor.Exec(&supvmodel.ExecRequest{ + Domain: domain, + Name: agentName, + Path: agentPath, + Env: &env, + StdoutWriter: agentStdoutWriter, + StderrWriter: agentStderrWriter, + }) - if err := agentProc.Start(); err != nil { + if err != nil { agentLaunchError(agent, execCtx.appCtx, err) return err } - agent.Pid = watchdog.GoWait(&agentProc, fatalerror.AgentCrash) + execCtx.shutdownContext.createExitedChannel(agentName) } if err := initFlow.AwaitExternalAgentsRegistered(); err != nil { @@ -164,20 +170,154 @@ func doInitExtensions(execCtx *rapidContext, watchdog *core.Watchdog) error { return nil } -func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) error { +func doRuntimeBootstrap(execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit) ([]string, map[string]string, string, []*os.File, error) { + env := sbInfoFromInit.EnvironmentVariables + runtimeBootstrap := sbInfoFromInit.RuntimeBootstrap + bootstrapCmd, err := runtimeBootstrap.Cmd() + if err != nil { + if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) + execCtx.eventsAPI.SendImageErrorLog(formattedLog) + } else { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) + } + return []string{}, map[string]string{}, "", []*os.File{}, err + } + + bootstrapEnv := runtimeBootstrap.Env(env) + bootstrapCwd, err := runtimeBootstrap.Cwd() + if err != nil { + if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) + execCtx.eventsAPI.SendImageErrorLog(formattedLog) + } else { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidWorkingDir) + } + return []string{}, map[string]string{}, "", []*os.File{}, err + } + + bootstrapExtraFiles := runtimeBootstrap.ExtraFiles() + + return bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, nil +} + +func (c *rapidContext) setupEventsWatcher(events <-chan supvmodel.Event) { + go func() { + for event := range events { + var err error = nil + log.Debugf("The events handler received the event %+v.", event) + if loss := event.Event.EventLoss(); loss != nil { + log.Panicf("Lost %d events from supervisor", *loss) + } + termination := event.Event.ProcessTerminated() + + // If we are not shutting down then we care if an unexpected exit happens. + if !c.shutdownContext.isShuttingDown() { + runtimeProcessName := fmt.Sprintf("%s-%d", runtimeProcessName, c.runtimeDomainGeneration) + + // If event from the runtime. + if *termination.Name == runtimeProcessName { + if termination.Success() { + err = fmt.Errorf("Runtime exited without providing a reason") + } else { + err = fmt.Errorf("Runtime exited with error: %s", termination.String()) + } + appctx.StoreFirstFatalError(c.appCtx, fatalerror.RuntimeExit) + } else { + if termination.Success() { + err = fmt.Errorf("exit code 0") + } else { + err = fmt.Errorf(termination.String()) + } + + appctx.StoreFirstFatalError(c.appCtx, fatalerror.AgentCrash) + } + + log.Warnf("Process %s exited: %+v", *termination.Name, termination) + } + + // At the moment we only get termination events. + // When their are other event types then we would need to be selective, + // about what we send to handleShutdownEvent(). + c.shutdownContext.handleProcessExit(*termination) + c.registrationService.CancelFlows(err) + } + }() +} + +func doOperatorDomainInit(ctx context.Context, execCtx *rapidContext, operatorDomainExtraConfig interop.DynamicDomainConfig) error { + events, err := execCtx.supervisor.Events() + if err != nil { + log.WithError(err).Panic("Could not get events stream from supervsior") + } + execCtx.setupEventsWatcher(events) + + log.Info("Configuring and starting Operator Domain") + conf := operatorDomainExtraConfig + err = execCtx.supervisor.Configure(&supvmodel.ConfigureRequest{ + Domain: OperatorDomain, + AdditionalStartHooks: conf.AdditionalStartHooks, + Mounts: conf.Mounts, + }) + + if err != nil { + log.WithError(err).Error("Failed to configure operator domain") + return err + } + + err = execCtx.supervisor.Start(&supvmodel.StartRequest{ + Domain: OperatorDomain, + }) + + if err != nil { + log.WithError(err).Error("Failed to start operator domain") + return err + } + + // we configure the runtime domain only once and not at + // every init phase (e.g., suppressed or reset). + err = execCtx.supervisor.Configure(&supvmodel.ConfigureRequest{ + Domain: RuntimeDomain, + }) + + if err != nil { + log.WithError(err).Error("Failed to configure operator domain") + return err + } + + return nil + +} + +func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit) error { execCtx.xray.RecordInitStartTime() defer execCtx.xray.RecordInitEndTime() - if extensions.AreEnabled() { - defer func() { + defer func() { + if extensions.AreEnabled() { logAgentsInitStatus(execCtx) - }() + } + }() + + log.Info("Starting runtime domain") + err := execCtx.supervisor.Start(&supvmodel.StartRequest{ + Domain: RuntimeDomain, + }) + if err != nil { + log.WithError(err).Panic("Failed configuring runtime domain") + } + execCtx.runtimeDomainGeneration++ - if err := doInitExtensions(execCtx, watchdog); err != nil { + if extensions.AreEnabled() { + runtimeExtensions := agents.ListExternalAgentPaths(defaultAgentLocation, + execCtx.supervisor.RuntimeConfig.RootPath) + if err := doInitExtensions(RuntimeDomain, runtimeExtensions, execCtx, sbInfoFromInit.EnvironmentVariables); err != nil { return err } } + appctx.StoreSandboxType(execCtx.appCtx, sbInfoFromInit.SandboxType) + initFlow := execCtx.registrationService.InitFlow() // Runtime state machine @@ -188,56 +328,66 @@ func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) // runtime is implicitly subscribed for certain lifecycle events. log.Debug("Preregister runtime") registrationService := execCtx.registrationService - if err := registrationService.PreregisterRuntime(runtime); err != nil { + err = registrationService.PreregisterRuntime(runtime) + + if err != nil { return err } - bootstrap := execCtx.bootstrap - bootstrapCmd, err := bootstrap.Cmd() + bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, err := doRuntimeBootstrap(execCtx, sbInfoFromInit) + if err != nil { - if fatalError, formattedLog, hasError := bootstrap.CachedFatalError(err); hasError { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.platformLogger.Printf("%s", formattedLog) - } else { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) - } return err } - bootstrapEnv := bootstrap.Env(execCtx.environment) - bootstrapCwd, err := bootstrap.Cwd() + runtimeStdoutWriter, runtimeStderrWriter, err := execCtx.logsEgressAPI.GetRuntimeSockets() + if err != nil { - if fatalError, formattedLog, hasError := bootstrap.CachedFatalError(err); hasError { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.platformLogger.Printf("%s", formattedLog) - } else { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidWorkingDir) - } return err } - bootstrapExtraFiles := bootstrap.ExtraFiles() - runtimeCmd := runtimecmd.NewCustomRuntimeCmd(ctx, bootstrapCmd, bootstrapCwd, bootstrapEnv, execCtx.runtimeStdoutWriter, execCtx.runtimeStderrWriter, bootstrapExtraFiles) - log.Debug("Start runtime") - err = runtimeCmd.Start() + checkCredentials(execCtx, bootstrapEnv) + name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) + err = execCtx.supervisor.Exec(&supvmodel.ExecRequest{ + Domain: RuntimeDomain, + Name: name, + Cwd: &bootstrapCwd, + Path: bootstrapCmd[0], + Args: bootstrapCmd[1:], + Env: &bootstrapEnv, + StdoutWriter: runtimeStdoutWriter, + StderrWriter: runtimeStderrWriter, + ExtraFiles: &bootstrapExtraFiles, + }) + + runtimeDoneStatus := telemetry.RuntimeDoneSuccess + + defer func() { + sendInitRuntimeDoneLogEvent(execCtx, sbInfoFromInit.SandboxType, runtimeDoneStatus) + }() + if err != nil { - if fatalError, formattedLog, hasError := bootstrap.CachedFatalError(err); hasError { + if fatalError, formattedLog, hasError := sbInfoFromInit.RuntimeBootstrap.CachedFatalError(err); hasError { appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.platformLogger.Printf("%s", formattedLog) + execCtx.eventsAPI.SendImageErrorLog(formattedLog) } else { appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) } + runtimeDoneStatus = telemetry.RuntimeDoneFailure return err } - registrationService.GetRuntime().Pid = watchdog.GoWait(runtimeCmd, fatalerror.RuntimeExit) + execCtx.shutdownContext.createExitedChannel(name) - if err := initFlow.AwaitRuntimeReady(); err != nil { + if err := initFlow.AwaitRuntimeRestoreReady(); err != nil { + runtimeDoneStatus = telemetry.RuntimeDoneFailure return err } + runtimeDoneStatus = telemetry.RuntimeDoneSuccess + // Registration phase finished for agents - no more agents can be registered with the system registrationService.TurnOff() if extensions.AreEnabled() { @@ -253,22 +403,17 @@ func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) // Logs API subscription phase finished for agents - no more agents can be subscribed to the Logs API if execCtx.telemetryAPIEnabled { execCtx.logsSubscriptionAPI.TurnOff() + execCtx.telemetrySubscriptionAPI.TurnOff() } execCtx.initDone = true + return nil } -func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke, mx *invokeMetrics) error { +func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop.Invoke, mx *invokeMetrics, sbInfoFromInit interop.SandboxInfoFromInit) error { execCtx.eventsAPI.SetCurrentRequestID(invokeRequest.ID) appCtx := execCtx.appCtx - appctx.StoreErrorResponse(appCtx, nil) - - if invokeRequest.NeedDebugLogs { - execCtx.debugTailLogger.Enable() - } else { - execCtx.debugTailLogger.Disable() - } xray := execCtx.xray xray.Configure(invokeRequest) @@ -277,11 +422,11 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo if !execCtx.initDone { // do inline init if err := xray.CaptureInitSubsegment(ctx, func(ctx context.Context) error { - return doInit(ctx, execCtx, watchdog) + return doRuntimeDomainInit(ctx, execCtx, sbInfoFromInit) }); err != nil { return err } - } else if execCtx.startRequest.SandboxType != interop.SandboxPreWarmed { + } else if sbInfoFromInit.SandboxType != interop.SandboxPreWarmed { xray.SendInitSubsegmentWithRecordedTimesOnce(ctx) } @@ -317,16 +462,20 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo if extensions.AreEnabled() { log.Debug("Release agents conditions") for _, agent := range extAgents { + //TODO handle Supervisors listening channel agent.Release() } for _, agent := range intAgents { + //TODO handle Supervisors listening channel agent.Release() } } log.Debug("Release runtime condition") + //TODO handle Supervisors listening channel runtime.Release() log.Debug("Await runtime response") + //TODO handle Supervisors listening channel return invokeFlow.AwaitRuntimeResponse() })); err != nil { return err @@ -335,12 +484,21 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo // Runtime overhead if err := xray.CaptureOverheadSubsegment(ctx, func(ctx context.Context) error { log.Debug("Await runtime ready") + //TODO handle Supervisors listening channel return invokeFlow.AwaitRuntimeReady() }); err != nil { return err } mx.runtimeReadyTime = metering.Monotime() - if err := execCtx.eventsAPI.SendRuntimeDone("success"); err != nil { + + runtimeDoneEventData := telemetry.InvokeRuntimeDoneData{ + Status: telemetry.RuntimeDoneSuccess, + Metrics: telemetry.GetRuntimeDoneInvokeMetrics(invokeRequest.InvokeReceivedTime, invokeRequest.InvokeResponseMetrics, mx.runtimeReadyTime), + InternalMetrics: invokeRequest.InvokeResponseMetrics, + Tracing: telemetry.BuildTracingCtx(model.XRayTracingType, invokeRequest.TraceID, invokeRequest.LambdaSegmentID), + Spans: telemetry.GetRuntimeDoneSpans(invokeRequest.InvokeReceivedTime, invokeRequest.InvokeResponseMetrics), + } + if err := execCtx.eventsAPI.SendRuntimeDone(runtimeDoneEventData); err != nil { log.Errorf("Failed to send RUNDONE: %s", err) } @@ -348,6 +506,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo if execCtx.HasActiveExtensions() { execCtx.interopServer.SendRuntimeReady() log.Debug("Await agents ready") + //TODO handle Supervisors listening channel if err := invokeFlow.AwaitAgentsReady(); err != nil { log.Warnf("AwaitAgentsReady() = %s", err) return err @@ -364,177 +523,148 @@ func extensionsDisabledByLayer() bool { return err == nil } -// acceptStartRequest is a second initialization phase, performed after receiving START +// acceptInitRequest is a second initialization phase, performed after receiving START // initialized entities: _HANDLER, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN -func (c *rapidContext) acceptStartRequest(startRequest *interop.Start) { - c.startRequest = startRequest - c.environment.StoreEnvironmentVariablesFromInit( - startRequest.CustomerEnvironmentVariables, - startRequest.Handler, - startRequest.AwsKey, - startRequest.AwsSecret, - startRequest.AwsSession, - startRequest.FunctionName, - startRequest.FunctionVersion) +func (c *rapidContext) acceptInitRequest(initRequest *interop.Init) *interop.Init { + initRequest.EnvironmentVariables.StoreEnvironmentVariablesFromInit( + initRequest.CustomerEnvironmentVariables, + initRequest.Handler, + initRequest.AwsKey, + initRequest.AwsSecret, + initRequest.AwsSession, + initRequest.FunctionName, + initRequest.FunctionVersion) c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - FunctionName: startRequest.FunctionName, - FunctionVersion: startRequest.FunctionVersion, - Handler: startRequest.Handler, + FunctionName: initRequest.FunctionName, + FunctionVersion: initRequest.FunctionVersion, + Handler: initRequest.Handler, + RuntimeInfo: initRequest.RuntimeInfo, }) if extensionsDisabledByLayer() { extensions.Disable() } + + return initRequest } -func (c *rapidContext) acceptStartRequestForInitCaching(startRequest *interop.Start) error { +func (c *rapidContext) acceptInitRequestForInitCaching(initRequest *interop.Init) (*interop.Init, error) { log.Info("Configure environment for Init Caching.") - c.startRequest = startRequest randomUUID, err := uuid.NewRandom() if err != nil { - return err + return initRequest, err } initCachingToken := randomUUID.String() - c.environment.StoreEnvironmentVariablesFromInitForInitCaching( - RuntimeAPIHost, - RuntimeAPIPort, - startRequest.CustomerEnvironmentVariables, - startRequest.Handler, - startRequest.FunctionName, - startRequest.FunctionVersion, + initRequest.EnvironmentVariables.StoreEnvironmentVariablesFromInitForInitCaching( + c.server.Host(), + c.server.Port(), + initRequest.CustomerEnvironmentVariables, + initRequest.Handler, + initRequest.FunctionName, + initRequest.FunctionVersion, initCachingToken) c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - FunctionName: startRequest.FunctionName, - FunctionVersion: startRequest.FunctionVersion, - Handler: startRequest.Handler, + FunctionName: initRequest.FunctionName, + FunctionVersion: initRequest.FunctionVersion, + Handler: initRequest.Handler, }) - c.credentialsService.SetCredentials(initCachingToken, startRequest.AwsKey, startRequest.AwsSecret, startRequest.AwsSession) + c.credentialsService.SetCredentials(initCachingToken, initRequest.AwsKey, initRequest.AwsSecret, initRequest.AwsSession, initRequest.CredentialsExpiry) if extensionsDisabledByLayer() { extensions.Disable() } - return nil + return initRequest, nil } -func handleStart(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, startRequest *interop.Start) { +func handleInit(execCtx *rapidContext, initRequest *interop.Init, + initStartedResponse chan<- interop.InitStarted, + initSuccessResponse chan<- interop.InitSuccess, + initFailureResponse chan<- interop.InitFailure) { + ctx := execCtx.signalCtx + if execCtx.initCachingEnabled { - if err := execCtx.acceptStartRequestForInitCaching(startRequest); err != nil { - handleStartError(execCtx, startRequest.InvokeID, startRequest.CorrelationID, err) + var err error + if initRequest, err = execCtx.acceptInitRequestForInitCaching(initRequest); err != nil { + // TODO: call handleInitError only after sending the RUNNING, since + // Slicer will fail receiving DONEFAIL here as it is expecting RUNNING + handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) return } - - execCtx.credentialsService.UnblockService() - defer execCtx.credentialsService.BlockService() } else { - execCtx.acceptStartRequest(startRequest) + initRequest = execCtx.acceptInitRequest(initRequest) } - interopServer, appCtx := execCtx.interopServer, execCtx.appCtx - - if err := interopServer.SendRunning(&interop.Running{ + initStartedMsg := interop.InitStarted{ PreLoadTimeNs: execCtx.preLoadTimeNs, PostLoadTimeNs: execCtx.postLoadTimeNs, WaitStartTimeNs: execCtx.postLoadTimeNs, WaitEndTimeNs: metering.Monotime(), ExtensionsEnabled: extensions.AreEnabled(), - }); err != nil { - log.Panic(err) + Ack: make(chan struct{}), } - if !startRequest.SuppressInit { - if err := doInit(ctx, execCtx, watchdog); err != nil { - handleStartError(execCtx, startRequest.InvokeID, startRequest.CorrelationID, err) - return - } - } + initStartedResponse <- initStartedMsg + <-initStartedMsg.Ack - doneMsg := &interop.Done{ - CorrelationID: startRequest.CorrelationID, - Meta: interop.DoneMetadata{ - RuntimeRelease: appctx.GetRuntimeRelease(appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - ExtensionNames: execCtx.GetExtensionNames(), - }, - } - if execCtx.telemetryAPIEnabled { - doneMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() - } - if err := interopServer.SendDone(doneMsg); err != nil { - log.Panic(err) - } - - if err := interopServer.StartAcceptingDirectInvokes(); err != nil { - log.Panic(err) - } -} - -func handleStartError(execCtx *rapidContext, invokeID string, correlationID string, err error) { - log.WithError(err).WithField("InvokeID", invokeID).Error("Init failed") - doneFailMsg := generateDoneFail(execCtx, correlationID, nil, 0) - handleInitError(doneFailMsg, execCtx, invokeID, execCtx.interopServer, err) -} - -func generateDoneFail(execCtx *rapidContext, correlationID string, invokeMx *invokeMetrics, invokeReceivedTime int64) *interop.DoneFail { - errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) - if !found { - errorType = fatalerror.Unknown + // Operator domain init happens only once, it's never suppressed, + // and it's terminal in case of failures + if err := doOperatorDomainInit(ctx, execCtx, initRequest.OperatorDomainExtraConfig); err != nil { + // TODO: I believe we need to handle this specially, because we want + // to consider any failure here as terminal + handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) + return } - doneFailMsg := &interop.DoneFail{ - ErrorType: errorType, - CorrelationID: correlationID, // required for standalone mode - Meta: interop.DoneMetadata{ - RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - InvokeReceivedTime: invokeReceivedTime, - }, + if !initRequest.SuppressInit { + // doRuntimeDomainInit() is used in both init/invoke, so the signature requires sbInfo arg + sbInfo := interop.SandboxInfoFromInit{ + EnvironmentVariables: initRequest.EnvironmentVariables, + SandboxType: initRequest.SandboxType, + RuntimeBootstrap: initRequest.Bootstrap, + } + if err := doRuntimeDomainInit(ctx, execCtx, sbInfo); err != nil { + handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) + return + } } - if invokeMx != nil { - doneFailMsg.Meta.InvokeRequestReadTimeNs = invokeMx.rendererMetrics.ReadTime.Nanoseconds() - doneFailMsg.Meta.InvokeRequestSizeBytes = int64(invokeMx.rendererMetrics.SizeBytes) - doneFailMsg.Meta.RuntimeReadyTime = int64(invokeMx.runtimeReadyTime) - doneFailMsg.Meta.ExtensionNames = execCtx.GetExtensionNames() + initSuccessMsg := interop.InitSuccess{ + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + ExtensionNames: execCtx.GetExtensionNames(), + Ack: make(chan struct{}), } if execCtx.telemetryAPIEnabled { - doneFailMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() + initSuccessMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } - return doneFailMsg + initSuccessResponse <- initSuccessMsg + <-initSuccessMsg.Ack } -func handleInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke) { - interopServer, appCtx := execCtx.interopServer, execCtx.appCtx - +func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { + ctx := execCtx.signalCtx invokeMx := invokeMetrics{} - if invokeRequest.ResyncState.IsResyncReceived { - err := execCtx.credentialsService.UpdateCredentials(invokeRequest.ResyncState.AwsKey, invokeRequest.ResyncState.AwsSecret, invokeRequest.ResyncState.AwsSession) - execCtx.credentialsService.UnblockService() - - if err != nil { - log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Resync for Invoke failed") - doneFailMsg := generateDoneFail(execCtx, invokeRequest.CorrelationID, &invokeMx, invokeRequest.InvokeReceivedTime) - handleInvokeError(doneFailMsg, execCtx, invokeRequest.ID, interopServer, err) - } - } - - if err := doInvoke(ctx, execCtx, watchdog, invokeRequest, &invokeMx); err != nil { + if err := doInvoke(ctx, execCtx, invokeRequest, &invokeMx, sbInfoFromInit); err != nil { log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Invoke failed") - doneFailMsg := generateDoneFail(execCtx, invokeRequest.CorrelationID, &invokeMx, invokeRequest.InvokeReceivedTime) - handleInvokeError(doneFailMsg, execCtx, invokeRequest.ID, interopServer, err) - return - } + invokeFailure := handleInvokeError(execCtx, invokeRequest, &invokeMx, err) - if err := execCtx.interopServer.CommitResponse(); err != nil { - log.Panic(err) + if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeFailure.ResponseMetrics = interop.ResponseMetrics{ + RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), + RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, + RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, + } + } + return interop.InvokeSuccess{}, invokeFailure } var invokeCompletionTimeNs int64 @@ -542,30 +672,35 @@ func handleInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Wat invokeCompletionTimeNs = time.Now().UnixNano() - responseTimeNs } - doneMsg := &interop.Done{ - CorrelationID: invokeRequest.CorrelationID, - Meta: interop.DoneMetadata{ - RuntimeRelease: appctx.GetRuntimeRelease(appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - ExtensionNames: execCtx.GetExtensionNames(), + invokeSuccessMsg := interop.InvokeSuccess{ + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + ExtensionNames: execCtx.GetExtensionNames(), + InvokeMetrics: interop.InvokeMetrics{ InvokeRequestReadTimeNs: invokeMx.rendererMetrics.ReadTime.Nanoseconds(), InvokeRequestSizeBytes: int64(invokeMx.rendererMetrics.SizeBytes), - InvokeCompletionTimeNs: invokeCompletionTimeNs, - InvokeReceivedTime: invokeRequest.InvokeReceivedTime, RuntimeReadyTime: invokeMx.runtimeReadyTime, }, + InvokeCompletionTimeNs: invokeCompletionTimeNs, + InvokeReceivedTime: invokeRequest.InvokeReceivedTime, } - if execCtx.telemetryAPIEnabled { - doneMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() + + if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeSuccessMsg.ResponseMetrics = interop.ResponseMetrics{ + RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), + RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, + RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, + } } - if err := interopServer.SendDone(doneMsg); err != nil { - log.Panic(err) + if execCtx.telemetryAPIEnabled { + invokeSuccessMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } + + return invokeSuccessMsg, nil } -func reinitialize(execCtx *rapidContext, watchdog *core.Watchdog) { - execCtx.interopServer.Clear() +func reinitialize(execCtx *rapidContext) { execCtx.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) execCtx.appCtx.Delete(appctx.AppCtxRuntimeReleaseKey) execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) @@ -576,90 +711,125 @@ func reinitialize(execCtx *rapidContext, watchdog *core.Watchdog) { execCtx.invokeFlow.Clear() if execCtx.telemetryAPIEnabled { execCtx.logsSubscriptionAPI.Clear() + execCtx.telemetrySubscriptionAPI.Clear() } - watchdog.Clear() -} - -func blockForever() { - select {} } // handle notification of reset -func handleReset(execCtx *rapidContext, watchdog *core.Watchdog, reset *interop.Reset) { - log.Warnf("Reset initiated: %s", reset.Reason) - if execCtx.initCachingEnabled { - execCtx.credentialsService.UnblockService() +func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { + log.Warnf("Reset initiated: %s", resetEvent.Reason) + + // Only send RuntimeDone event if we get a reset during an Invoke + if resetEvent.Reason == "failure" || resetEvent.Reason == "timeout" { + runtimeDoneEventData := telemetry.InvokeRuntimeDoneData{ + Status: resetEvent.Reason, + InternalMetrics: invokeResponseMetrics, + Metrics: telemetry.GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, metering.Monotime()), + Tracing: telemetry.BuildTracingCtx(model.XRayTracingType, resetEvent.TraceID, resetEvent.LambdaSegmentID), + Spans: telemetry.GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics), + } + if err := execCtx.eventsAPI.SendRuntimeDone(runtimeDoneEventData); err != nil { + log.Errorf("Failed to send RUNDONE: %s", err) + } } - if err := execCtx.eventsAPI.SendRuntimeDone(reset.Reason); err != nil { - log.Errorf("Failed to send RUNDONE: %s", err) + extensionsResetMs, resetTimeout, _ := execCtx.shutdownContext.shutdown(execCtx, resetEvent.DeadlineNs, resetEvent.Reason) + + log.Info("Starting runtime domain") + err := execCtx.supervisor.Start(&supvmodel.StartRequest{ + Domain: RuntimeDomain, + }) + if err != nil { + log.WithError(err).Panic("Failed booting runtime domain") } + execCtx.runtimeDomainGeneration++ - profiler := metering.ExtensionsResetDurationProfiler{} - gracefulShutdown(execCtx, watchdog, &profiler, reset.DeadlineNs, execCtx.standaloneMode, reset.Reason) + // Only used by standalone for more indepth assertions. + var fatalErrorType fatalerror.ErrorType - extensionsResetMs, resetTimeout := profiler.CalculateExtensionsResetMs() + if execCtx.standaloneMode { + fatalErrorType, _ = appctx.LoadFirstFatalError(execCtx.appCtx) + } - meta := interop.DoneMetadata{ - ExtensionsResetMs: extensionsResetMs, + var responseMetrics interop.ResponseMetrics + if resetEvent.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(resetEvent.InvokeResponseMetrics) { + responseMetrics.RuntimeTimeThrottledMs = resetEvent.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) + responseMetrics.RuntimeProducedBytes = resetEvent.InvokeResponseMetrics.ProducedBytes + responseMetrics.RuntimeOutboundThroughputBps = resetEvent.InvokeResponseMetrics.OutboundThroughputBps } - if !execCtx.standaloneMode { - // GIRP interopServer implementation sends GIRP RSTFAIL/RSTDONE - if resetTimeout { - // TODO: DoneFail must contain a reset timeout ErrorType for rapid local to distinguish errors - doneFail := &interop.DoneFail{CorrelationID: reset.CorrelationID, Meta: meta} - if err := execCtx.interopServer.SendDoneFail(doneFail); err != nil { - log.Panicf("Failed to SendDoneFail: %s", err) - } - } else { - done := &interop.Done{CorrelationID: reset.CorrelationID, Meta: meta} - if err := execCtx.interopServer.SendDone(done); err != nil { - log.Panicf("Failed to SendDone: %s", err) - } + if resetTimeout { + return interop.ResetSuccess{}, &interop.ResetFailure{ + ExtensionsResetMs: extensionsResetMs, + ErrorType: fatalErrorType, + ResponseMetrics: responseMetrics, } - - os.Exit(0) } - reinitialize(execCtx, watchdog) + return interop.ResetSuccess{ + ExtensionsResetMs: extensionsResetMs, + ErrorType: fatalErrorType, + ResponseMetrics: responseMetrics, + }, nil +} + +// handle notification of shutdown +func handleShutdown(execCtx *rapidContext, shutdownEvent *interop.Shutdown, reason string) interop.ShutdownSuccess { + log.Warnf("Shutdown initiated: %s", reason) + // TODO Handle shutdown error + _, _, _ = execCtx.shutdownContext.shutdown(execCtx, shutdownEvent.DeadlineNs, reason) - fatalErrorType, _ := appctx.LoadFirstFatalError(execCtx.appCtx) + // Only used by standalone for more indepth assertions. + var fatalErrorType fatalerror.ErrorType - if resetTimeout { - doneFail := &interop.DoneFail{CorrelationID: reset.CorrelationID, ErrorType: fatalErrorType, Meta: meta} - if err := execCtx.interopServer.SendDoneFail(doneFail); err != nil { - log.Panicf("Failed to SendDoneFail: %s", err) - } - } else { - done := &interop.Done{CorrelationID: reset.CorrelationID, ErrorType: fatalErrorType, Meta: meta} - if err := execCtx.interopServer.SendDone(done); err != nil { - log.Panicf("Failed to SendDone: %s", err) - } + if execCtx.standaloneMode { + fatalErrorType, _ = appctx.LoadFirstFatalError(execCtx.appCtx) } + + return interop.ShutdownSuccess{ErrorType: fatalErrorType} } -// handle notification of shutdown -func handleShutdown(execCtx *rapidContext, watchdog *core.Watchdog, shutdown *interop.Shutdown, reason string) { - log.Warnf("Shutdown initiated") +func handleRestore(execCtx *rapidContext, restore *interop.Restore) error { + err := execCtx.credentialsService.UpdateCredentials(restore.AwsKey, restore.AwsSecret, restore.AwsSession, restore.CredentialsExpiry) + restoreStatus := telemetry.RuntimeDoneSuccess + + defer func() { + sendRestoreRuntimeDoneLogEvent(execCtx, restoreStatus) + }() + + if err != nil { + return fmt.Errorf("error when updating credentials: %s", err) + } + renderer := rendering.NewRestoreRenderer() + execCtx.renderingService.SetRenderer(renderer) + + registrationService := execCtx.registrationService + runtime := registrationService.GetRuntime() + + // If runtime has not called /restore/next then just return + // instead of releasing the Runtime since there is no need to release. + // Then the runtime should be released only during Invoke + if runtime.GetState() != runtime.RuntimeRestoreReadyState { + restoreStatus = telemetry.RuntimeDoneSuccess + log.Infof("Runtime is in state: %s just returning", runtime.GetState().Name()) + return nil + } - gracefulShutdown(execCtx, watchdog, &metering.ExtensionsResetDurationProfiler{}, shutdown.DeadlineNs, true, reason) + runtime.Release() - fatalErrorType, _ := appctx.LoadFirstFatalError(execCtx.appCtx) + initFlow := execCtx.initFlow + err = initFlow.AwaitRuntimeReady() - if err := execCtx.interopServer.SendDone(&interop.Done{CorrelationID: shutdown.CorrelationID, ErrorType: fatalErrorType}); err != nil { - log.Panicf("Failed to SendDone: %s", err) + if err != nil { + restoreStatus = telemetry.RuntimeDoneFailure + } else { + restoreStatus = telemetry.RuntimeDoneSuccess } - // Shutdown induces a terminal state and no further messages will be processed - blockForever() + return err } func start(signalCtx context.Context, execCtx *rapidContext) { - watchdog := core.NewWatchdog(execCtx.registrationService.InitFlow(), execCtx.invokeFlow, execCtx.exitPidChan, execCtx.appCtx) - - interopServer := execCtx.interopServer - // Start Runtime API Server err := execCtx.server.Listen() if err != nil { @@ -670,30 +840,40 @@ func start(signalCtx context.Context, execCtx *rapidContext) { // Note, most of initialization code should run before blocking to receive START, // code before START runs in parallel with code downloads. +} - go func() { - for { - reset := <-interopServer.ResetChan() - // In the event of a Reset during init/invoke, CancelFlows cancels execution - // flows and return with the errResetReceived err - this error is special-cased - // and not handled by the init/invoke (unexpected) error handling functions - watchdog.CancelFlows(errResetReceived) - execCtx.resetChan <- reset - } - }() +func sendRestoreRuntimeDoneLogEvent(execCtx *rapidContext, status string) { + if err := execCtx.eventsAPI.SendRestoreRuntimeDone(status); err != nil { + log.Errorf("Failed to send RESTRD: %s", err) + } +} + +func sendInitRuntimeDoneLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, status string) { + initSource := interop.InferTelemetryInitSource(execCtx.initCachingEnabled, sandboxType) + + runtimeDoneData := &telemetry.InitRuntimeDoneData{ + InitSource: initSource, + Status: status, + } - for { - select { - case start := <-interopServer.StartChan(): - handleStart(signalCtx, execCtx, watchdog, start) - case invoke := <-interopServer.InvokeChan(): - handleInvoke(signalCtx, execCtx, watchdog, invoke) - case err := <-interopServer.TransportErrorChan(): - log.Panicf("Transport error emitted by interop server: %s", err) - case reset := <-execCtx.resetChan: - handleReset(execCtx, watchdog, reset) - case shutdown := <-interopServer.ShutdownChan(): // only in standalone mode - handleShutdown(execCtx, watchdog, shutdown, standaloneShutdownReason) + if err := execCtx.eventsAPI.SendInitRuntimeDone(runtimeDoneData); err != nil { + log.Errorf("Failed to send INITRD: %s", err) + } +} + +// This function will log a line if AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, or AWS_SESSION_TOKEN is missing +// This is expected to happen in cases when credentials provider is not needed +func checkCredentials(execCtx *rapidContext, bootstrapEnv map[string]string) { + credentialsKeys := []string{"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"} + missingCreds := []string{} + + for _, credEnvVar := range credentialsKeys { + if val, keyExists := bootstrapEnv[credEnvVar]; !keyExists || val == "" { + missingCreds = append(missingCreds, credEnvVar) } } + + if len(missingCreds) > 0 { + log.Infof("Starting runtime without %s , Expected?: %t", strings.Join(missingCreds[:], ", "), execCtx.initCachingEnabled) + } } diff --git a/lambda/rapid/start_test.go b/lambda/rapid/start_test.go index 2363705..ffb446f 100644 --- a/lambda/rapid/start_test.go +++ b/lambda/rapid/start_test.go @@ -7,7 +7,7 @@ import ( "context" "fmt" "go.amzn.com/lambda/core" - "io/ioutil" + "io" "net/http" "regexp" "strconv" @@ -142,7 +142,7 @@ func TestListen(t *testing.T) { ctx := context.Background() telemetryAPIEnabled := true - server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.LogsSubscriptionAPI, false, flowTest.CredentialsService) + server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.TelemetrySubscription, flowTest.TelemetrySubscription, flowTest.CredentialsService, flowTest.EventsAPI) err := server.Listen() assert.NoError(t, err) @@ -161,7 +161,7 @@ func TestListen(t *testing.T) { resp, err1 := http.Get(fmt.Sprintf("http://%s:%d/2018-06-01/runtime/invocation/next", server.Host(), server.Port())) assert.Nil(t, err1) - body, err2 := ioutil.ReadAll(resp.Body) + body, err2 := io.ReadAll(resp.Body) assert.Nil(t, err2) assert.Equal(t, "MyTest", string(body)) @@ -171,3 +171,31 @@ func TestListen(t *testing.T) { <-done } + +func TestInferSandboxInitTypeOnDemand(t *testing.T) { + initCachingEnabled := false + sandboxType := interop.SandboxClassic + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "on-demand", initSource) +} + +func TestInferSandboxInitTypeProvisionedConcurrency(t *testing.T) { + initCachingEnabled := false + sandboxType := interop.SandboxPreWarmed + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "provisioned-concurrency", initSource) +} + +func TestInferSandboxInitTypeInitCaching(t *testing.T) { + initCachingEnabled := true + sandboxType := interop.SandboxClassic + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "snap-start", initSource) +} + +func TestInferSandboxInitTypeInitCachingWithPC(t *testing.T) { + initCachingEnabled := true + sandboxType := interop.SandboxPreWarmed + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "snap-start", initSource) +} diff --git a/lambda/rapidcore/bootstrap.go b/lambda/rapidcore/bootstrap.go index 9faf518..165f532 100644 --- a/lambda/rapidcore/bootstrap.go +++ b/lambda/rapidcore/bootstrap.go @@ -6,11 +6,12 @@ package rapidcore import ( "fmt" "os" + "path" "path/filepath" + "strings" "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/logging" - "go.amzn.com/lambda/rapid" + "go.amzn.com/lambda/interop" log "github.com/sirupsen/logrus" ) @@ -21,6 +22,7 @@ type BootstrapError func() (fatalerror.ErrorType, LogFormatter) // Bootstrap represents a list of executable bootstrap // candidates in order of priority and exec metadata type Bootstrap struct { + runtimeDomainRoot string orderedLookupPaths []string validCmd []string workingDir string @@ -29,8 +31,11 @@ type Bootstrap struct { bootstrapError BootstrapError } +// Validate interface compliance +var _ interop.Bootstrap = (*Bootstrap)(nil) + // NewBootstrap returns an instance of bootstrap defined by given params -func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string) *Bootstrap { +func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string, runtimeDomainRoot string) *Bootstrap { var orderedLookupBootstrapPaths []string for _, args := range cmdCandidates { // Empty args is an error, but we want to detect it later (in Cmd() call) when we are able to report a descriptive error @@ -44,23 +49,32 @@ func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string) *Bootstrap currentWorkingDir = "/" } + if runtimeDomainRoot == "" { + runtimeDomainRoot = "/" + } + return &Bootstrap{ orderedLookupPaths: orderedLookupBootstrapPaths, workingDir: currentWorkingDir, cmdCandidates: cmdCandidates, + runtimeDomainRoot: runtimeDomainRoot, } } -func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string) *Bootstrap { +func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string, runtimeDomainRoot string) *Bootstrap { if currentWorkingDir == "" { // use the root directory as the default working directory currentWorkingDir = "/" } + if runtimeDomainRoot == "" { + runtimeDomainRoot = "/" + } // a single candidate command makes it automatically valid return &Bootstrap{ - validCmd: cmd, - workingDir: currentWorkingDir, + validCmd: cmd, + workingDir: currentWorkingDir, + runtimeDomainRoot: runtimeDomainRoot, } } @@ -68,16 +82,28 @@ func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string) *Bootstrap { // actual bootstrap, given a list of possible files func (b *Bootstrap) locateBootstrap() error { for i, bootstrapCandidate := range b.orderedLookupPaths { - if file, err := os.Stat(bootstrapCandidate); !os.IsNotExist(err) && !file.IsDir() { - b.validCmd = b.cmdCandidates[i] - return nil + // validate path relatively to the domain's root + candidatPath := path.Join(b.runtimeDomainRoot, bootstrapCandidate) + file, err := os.Stat(candidatPath) + if err != nil { + if !os.IsNotExist(err) { + log.WithError(err).Warnf("Could not validate %s. Ignoring it.", bootstrapCandidate) + } + continue } + if file.IsDir() { + log.Warnf("%s is a directory. Ignoring it", bootstrapCandidate) + continue + } + b.validCmd = b.cmdCandidates[i] + return nil } log.WithField("bootstrapPathsChecked", b.orderedLookupPaths).Warn("Couldn't find valid bootstrap(s)") return fmt.Errorf("Couldn't find valid bootstrap(s): %s", b.orderedLookupPaths) } -// Cmd returns the args of bootstrap, where args[0] +// Cmd returns the args of bootstrap, relative to the +// chroot idenfied by `root`, where args[0] // is the path to executable func (b *Bootstrap) Cmd() ([]string, error) { if len(b.validCmd) > 0 { @@ -94,16 +120,21 @@ func (b *Bootstrap) Cmd() ([]string, error) { // Env returns the environment variables available to // the bootstrap process -func (b *Bootstrap) Env(e rapid.EnvironmentVariables) []string { +func (b *Bootstrap) Env(e interop.EnvironmentVariables) map[string]string { return e.RuntimeExecEnv() } // Cwd returns the working directory of the bootstrap process +// The path is validated against the chroot identified by `root` func (b *Bootstrap) Cwd() (string, error) { if !filepath.IsAbs(b.workingDir) { return "", fmt.Errorf("the working directory '%s' is invalid, it needs to be an absolute path", b.workingDir) - } else if _, err := os.Stat(b.workingDir); os.IsNotExist(err) { - return "", fmt.Errorf("the working directory doesn't exist: %s", b.workingDir) + } + + // evaluate the path relatively to the domain's mnt namespace root + domainPath := path.Join(b.runtimeDomainRoot, b.workingDir) + if _, err := os.Stat(domainPath); os.IsNotExist(err) { + return "", fmt.Errorf("the working directory doesn't exist: %s", domainPath) } return b.workingDir, nil @@ -140,19 +171,35 @@ func (b *Bootstrap) SetCachedFatalError(bootstrapErrFn BootstrapError) { // BootstrapErrInvalidLCISTaskConfig represents an error while parsing LCIS task config func BootstrapErrInvalidLCISTaskConfig(err error) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidTaskConfig, logging.SupernovaInvalidTaskConfigRepr(err) + return fatalerror.InvalidTaskConfig, SupernovaInvalidTaskConfigRepr(err) } } // BootstrapErrInvalidLCISEntrypoint represents an invalid LCIS entrypoint error func BootstrapErrInvalidLCISEntrypoint(entrypoint []string, cmd []string, workingdir string) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidEntrypoint, logging.SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) + return fatalerror.InvalidEntrypoint, SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) } } func BootstrapErrInvalidLCISWorkingDir(entrypoint []string, cmd []string, workingdir string) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidWorkingDir, logging.SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) + return fatalerror.InvalidWorkingDir, SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) + } +} + +func SupernovaInvalidTaskConfigRepr(err error) func(error) string { + return func(unused error) string { + return fmt.Sprintf("IMAGE\tInvalid task config: %s", err) + } +} + +func SupernovaLaunchErrorRepr(entrypoint []string, cmd []string, workingDir string) func(error) string { + return func(err error) string { + return fmt.Sprintf("IMAGE\tLaunch error: %s\tEntrypoint: [%s]\tCmd: [%s]\tWorkingDir: [%s]", + err, + strings.Join(entrypoint, ","), + strings.Join(cmd, ","), + workingDir) } } diff --git a/lambda/rapidcore/bootstrap_test.go b/lambda/rapidcore/bootstrap_test.go index 4700130..b43520d 100644 --- a/lambda/rapidcore/bootstrap_test.go +++ b/lambda/rapidcore/bootstrap_test.go @@ -4,8 +4,10 @@ package rapidcore import ( - "io/ioutil" "os" + "path" + "path/filepath" + "reflect" "testing" "go.amzn.com/lambda/rapidcore/env" @@ -14,11 +16,11 @@ import ( ) func TestBootstrap(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") assert.NoError(t, err) defer os.RemoveAll(tmpDir) - tmpFile, err := ioutil.TempFile("", "lcis-test-bootstrap") + tmpFile, err := os.CreateTemp("", "lcis-test-bootstrap") assert.NoError(t, err) defer os.Remove(tmpFile.Name()) @@ -38,11 +40,57 @@ func TestBootstrap(t *testing.T) { environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") // Test - b := NewBootstrap(cmdCandidates, cwd) + b := NewBootstrap(cmdCandidates, cwd, "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, cwd, bCwd) - assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) + + cmd, err := b.Cmd() + assert.NoError(t, err) + assert.Equal(t, file, cmd) +} + +// When running bootstraps in separate mount namespaces +// we want to verify and discover paths relative to +// a root different from "/" +func TestBootstrapChroot(t *testing.T) { + tmpRoot, err := os.MkdirTemp(os.TempDir(), "domain-root") + assert.NoError(t, err) + defer os.RemoveAll(tmpRoot) + tmpDir, err := os.MkdirTemp(tmpRoot, "lcis-test-invalid-bootstrap") + assert.NoError(t, err) + defer os.RemoveAll(tmpDir) + + tmpFile, err := os.CreateTemp(tmpRoot, "lcis-test-bootstrap") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // Setup cmd candidates + nonExistent := []string{"/foo/bar/baz"} + baseName := filepath.Base(tmpDir) + dir := []string{"/" + baseName, "--arg1", "foo"} + baseName = filepath.Base(tmpFile.Name()) + file := []string{"/" + baseName, "--arg1 s", "foo"} + cmdCandidates := [][]string{nonExistent, dir, file} + + // Setup working dir + cwd, err := os.MkdirTemp(tmpRoot, "cwd") + assert.NoError(t, err) + defer os.RemoveAll(cwd) + + // Setup environment + environment := env.NewEnvironment() + environment.StoreRuntimeAPIEnvironmentVariable("host:port") + environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") + + // Test + baseName = filepath.Base(cwd) + b := NewBootstrap(cmdCandidates, "/"+baseName, tmpRoot) + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, cwd, path.Join(tmpRoot, bCwd)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) cmd, err := b.Cmd() assert.NoError(t, err) @@ -53,17 +101,24 @@ func TestBootstrapEmptyCandidate(t *testing.T) { // we expect newBootstrap to succeed and bootstrap.Cmd() to fail. // We want to postpone the failure to be able to propagate error description to slicer and write it to customer log invalidBootstrapCandidate := []string{} - bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/") + bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/", "") + _, err := bs.Cmd() + assert.Error(t, err) +} + +func TestBootstrapChrootNonExistingRoot(t *testing.T) { + invalidBootstrapCandidate := []string{"/bin/bash", "-c"} + bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/", "/does_not_exist") _, err := bs.Cmd() assert.Error(t, err) } func TestBootstrapSingleCmd(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") assert.NoError(t, err) defer os.RemoveAll(tmpDir) - tmpFile, err := ioutil.TempFile("", "lcis-test-bootstrap") + tmpFile, err := os.CreateTemp("", "lcis-test-bootstrap") assert.NoError(t, err) defer os.Remove(tmpFile.Name()) @@ -81,11 +136,11 @@ func TestBootstrapSingleCmd(t *testing.T) { environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") // Test - b := NewBootstrapSingleCmd(cmdCandidate, cwd) + b := NewBootstrapSingleCmd(cmdCandidate, cwd, "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, cwd, bCwd) - assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) cmd, err := b.Cmd() assert.NoError(t, err) @@ -93,7 +148,7 @@ func TestBootstrapSingleCmd(t *testing.T) { } func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") assert.NoError(t, err) defer os.RemoveAll(tmpDir) @@ -111,11 +166,11 @@ func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") // Test - b := NewBootstrapSingleCmd(cmdCandidate, cwd) + b := NewBootstrapSingleCmd(cmdCandidate, cwd, "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, cwd, bCwd) - assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) // No validations run against single candidates cmd, err := b.Cmd() @@ -125,100 +180,100 @@ func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { // Test our ability to locate bootstrap files in the file system func TestFindCustomRuntimeIfExists(t *testing.T) { - tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp-") + tmpFile, err := os.CreateTemp(os.TempDir(), "tmp-") if err != nil { t.Fatal("Cannot create temporary file", err) } defer os.Remove(tmpFile.Name()) - tmpFile2, err := ioutil.TempFile(os.TempDir(), "tmp-") + tmpFile2, err := os.CreateTemp(os.TempDir(), "tmp-") if err != nil { t.Fatal("Cannot create temporary file", err) } defer os.Remove(tmpFile2.Name()) // one bootstrap argument was given and it exists - bootstrap := NewBootstrap([][]string{[]string{tmpFile.Name()}}, "/") + bootstrap := NewBootstrap([][]string{{tmpFile.Name()}}, "/", "") cmd, err := bootstrap.Cmd() assert.NoError(t, err) assert.Equal(t, []string{tmpFile.Name()}, cmd) assert.Nil(t, err) // two bootstrap arguments given, both exist but first one is returned - bootstrap = NewBootstrap([][]string{[]string{tmpFile.Name()}, []string{tmpFile2.Name()}}, "/") + bootstrap = NewBootstrap([][]string{{tmpFile.Name()}, {tmpFile2.Name()}}, "/", "") cmd, err = bootstrap.Cmd() assert.NoError(t, err) assert.Equal(t, []string{tmpFile.Name()}, cmd) assert.Nil(t, err) // two bootstrap arguments given, first one does not exist, second exists and is returned - bootstrap = NewBootstrap([][]string{[]string{"mk"}, []string{tmpFile2.Name()}}, "/") + bootstrap = NewBootstrap([][]string{{"mk"}, {tmpFile2.Name()}}, "/", "") cmd, err = bootstrap.Cmd() assert.NoError(t, err) assert.Equal(t, []string{tmpFile2.Name()}, cmd) assert.Nil(t, err) // two bootstrap arguments given, none exists - bootstrap = NewBootstrap([][]string{[]string{"mk"}, []string{"mk2"}}, "/") + bootstrap = NewBootstrap([][]string{{"mk"}, {"mk2"}}, "/", "") cmd, err = bootstrap.Cmd() assert.EqualError(t, err, "Couldn't find valid bootstrap(s): [mk mk2]") assert.Equal(t, []string{}, cmd) } func TestCwdIsAbsolute(t *testing.T) { - tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp-") + tmpFile, err := os.CreateTemp(os.TempDir(), "tmp-") if err != nil { t.Fatal("Cannot create temporary file", err) } defer os.Remove(tmpFile.Name()) - cmdCandidates := [][]string{[]string{tmpFile.Name()}} + cmdCandidates := [][]string{{tmpFile.Name()}} // no errors when currentWorkingDir is absolute - bootstrap := NewBootstrap(cmdCandidates, "/tmp") + bootstrap := NewBootstrap(cmdCandidates, "/tmp", "") cwd, err := bootstrap.Cwd() assert.Nil(t, err) assert.Equal(t, "/tmp", cwd) - bootstrap = NewBootstrap(cmdCandidates, "tmp") + bootstrap = NewBootstrap(cmdCandidates, "tmp", "") _, err = bootstrap.Cwd() assert.EqualError(t, err, "the working directory 'tmp' is invalid, it needs to be an absolute path") - bootstrap = NewBootstrap(cmdCandidates, "./") + bootstrap = NewBootstrap(cmdCandidates, "./", "") _, err = bootstrap.Cwd() assert.EqualError(t, err, "the working directory './' is invalid, it needs to be an absolute path") } func TestBootstrapMissingWorkingDirectory(t *testing.T) { - tmpFile, err := ioutil.TempFile(os.TempDir(), "cwd-test-bootstrap") + tmpFile, err := os.CreateTemp(os.TempDir(), "cwd-test-bootstrap") assert.NoError(t, err) defer os.Remove(tmpFile.Name()) - tmpDir, err := ioutil.TempDir("", "cwd-test") + tmpDir, err := os.MkdirTemp("", "cwd-test") assert.NoError(t, err) defer os.RemoveAll(tmpDir) // cwd argument exists - bootstrap := NewBootstrap([][]string{[]string{tmpFile.Name()}}, tmpDir) + bootstrap := NewBootstrap([][]string{{tmpFile.Name()}}, tmpDir, "") cwd, err := bootstrap.Cwd() assert.Equal(t, cwd, tmpDir) assert.NoError(t, err) // cwd argument doesn't exist - bootstrap = NewBootstrap([][]string{[]string{tmpFile.Name()}}, "/foo") + bootstrap = NewBootstrap([][]string{{tmpFile.Name()}}, "/foo", "") _, err = bootstrap.Cwd() assert.EqualError(t, err, "the working directory doesn't exist: /foo") } func TestDefaultWorkeringDirectory(t *testing.T) { - bootstrap := NewBootstrap([][]string{[]string{}}, "") + bootstrap := NewBootstrap([][]string{{}}, "", "") cwd, err := bootstrap.Cwd() assert.NoError(t, err) assert.Equal(t, "/", cwd) } func TestBootstrapSingleCmdDefaultWorkingDir(t *testing.T) { - b := NewBootstrapSingleCmd([]string{}, "") + b := NewBootstrapSingleCmd([]string{}, "", "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, "/", bCwd) diff --git a/lambda/rapidcore/env/environment.go b/lambda/rapidcore/env/environment.go index 699abda..be0584c 100644 --- a/lambda/rapidcore/env/environment.go +++ b/lambda/rapidcore/env/environment.go @@ -149,26 +149,24 @@ func (e *Environment) mergeCustomerEnvironmentVariables(envVars map[string]strin // RuntimeExecEnv returns the key=value strings of all environment variables // passed to runtime process on exec() -func (e *Environment) RuntimeExecEnv() []string { +func (e *Environment) RuntimeExecEnv() map[string]string { if !e.initEnvVarsSet || !e.runtimeAPISet { log.Fatal("credentials, customer and runtime API address must be set") } - return asEnviron(mapUnion(e.Customer, e.PlatformUnreserved, e.Credentials, e.Runtime, e.Platform)) + return mapUnion(e.Customer, e.PlatformUnreserved, e.Credentials, e.Runtime, e.Platform) } // AgentExecEnv returns the key=value strings of all environment variables // passed to agent process on exec() -func (e *Environment) AgentExecEnv() []string { +func (e *Environment) AgentExecEnv() map[string]string { if !e.initEnvVarsSet || !e.runtimeAPISet { log.Fatal("credentials, customer and runtime API address must be set") } excludedKeys := extensionExcludedKeys() excludeCondition := func(key string) bool { return excludedKeys[key] || strings.HasPrefix(key, "_") } - environ := asEnviron(mapExclude(mapUnion(e.Customer, e.Credentials, e.Platform), excludeCondition)) - - return environ + return mapExclude(mapUnion(e.Customer, e.Credentials, e.Platform), excludeCondition) } // RAPIDInternalConfig returns the rapid config parsed from environment vars @@ -249,14 +247,6 @@ func mapUnion(maps ...map[string]string) map[string]string { return union } -func asEnviron(m map[string]string) []string { - keySepValArray := []string{} - for key, val := range m { - keySepValArray = append(keySepValArray, key+"="+val) - } - return keySepValArray -} - func mapExclude(m map[string]string, excludeCondition func(string) bool) map[string]string { res := map[string]string{} for key, val := range m { diff --git a/lambda/rapidcore/env/environment_test.go b/lambda/rapidcore/env/environment_test.go index cdfef24..ed3043c 100644 --- a/lambda/rapidcore/env/environment_test.go +++ b/lambda/rapidcore/env/environment_test.go @@ -6,12 +6,21 @@ package env import ( "fmt" "os" - "strings" "testing" "github.com/stretchr/testify/assert" ) +func envToSlice(env map[string]string) []string { + ret := make([]string, len(env)) + i := 0 + for key, val := range env { + ret[i] = key + "=" + val + i++ + } + return ret +} + func TestRAPIDInternalConfig(t *testing.T) { os.Clearenv() os.Setenv("_LAMBDA_SB_ID", "sbid") @@ -121,34 +130,35 @@ func TestRuntimeExecEnvironmentVariables(t *testing.T) { rapidEnvVars := env.RuntimeExecEnv() var rapidEnvKeys []string - for _, keyval := range rapidEnvVars { - key := strings.Split(keyval, "=")[0] + for key := range rapidEnvVars { rapidEnvKeys = append(rapidEnvKeys, key) } + rapidEnvVarsSlice := envToSlice(rapidEnvVars) + for key := range env.RAPID { assert.NotContains(t, rapidEnvKeys, key) } for key, val := range env.Runtime { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.Platform { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.PlatformUnreserved { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) assert.NotContains(t, env.Customer, key) } for key, val := range env.Credentials { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.Customer { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) assert.NotContains(t, env.PlatformUnreserved, key) } } @@ -191,7 +201,7 @@ func TestRuntimeExecEnvironmentVariablesPriority(t *testing.T) { assert.Equal(t, len(predefinedInternalEnvVarKeys()), len(env.RAPID)) assert.Equal(t, len(predefinedRuntimeEnvVarKeys()), len(env.Runtime)) - rapidEnvVars := env.RuntimeExecEnv() + rapidEnvVars := envToSlice(env.RuntimeExecEnv()) // Customer env vars cannot override platform/internal ones assert.NotContains(t, rapidEnvVars, conflictPlatformKeyFromInit+"="+customerEnvVal) @@ -224,7 +234,7 @@ func TestCustomerEnvironmentVariablesFromInitCanOverrideEnvironmentVariablesFrom assert.Equal(t, env.Customer["LCIS_ARG1"], lcisCLIArgEnvVal) assert.Equal(t, env.Customer["MY_FOOBAR_ENV_1"], customerEnvVal) - rapidEnvVars := env.RuntimeExecEnv() + rapidEnvVars := envToSlice(env.RuntimeExecEnv()) assert.Contains(t, rapidEnvVars, "LCIS_ARG1="+lcisCLIArgEnvVal) assert.Contains(t, rapidEnvVars, "MY_FOOBAR_ENV_1="+customerEnvVal) @@ -250,17 +260,18 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { agentEnvVars := env.AgentExecEnv() var agentEnvKeys []string - for _, keyval := range agentEnvVars { - key := strings.Split(keyval, "=")[0] + for key := range agentEnvVars { agentEnvKeys = append(agentEnvKeys, key) } + agentEnvVarsSlice := envToSlice(agentEnvVars) + for key := range env.RAPID { assert.NotContains(t, agentEnvKeys, key) } for key, val := range env.Runtime { - assert.NotContains(t, agentEnvKeys, key+"="+val) + assert.NotContains(t, agentEnvVarsSlice, key+"="+val) } for key := range env.Platform { @@ -272,10 +283,10 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { } for key, val := range env.Credentials { - assert.Contains(t, agentEnvVars, key+"="+val) + assert.Contains(t, agentEnvVarsSlice, key+"="+val) } - assert.Contains(t, agentEnvVars, runtimeAPIAddressKey+"="+env.Platform[runtimeAPIAddressKey]) + assert.Contains(t, agentEnvVarsSlice, runtimeAPIAddressKey+"="+env.Platform[runtimeAPIAddressKey]) } func TestStoreEnvironmentVariablesFromInitCaching(t *testing.T) { diff --git a/lambda/rapidcore/errors.go b/lambda/rapidcore/errors.go index 06a4830..7f35ca8 100644 --- a/lambda/rapidcore/errors.go +++ b/lambda/rapidcore/errors.go @@ -5,9 +5,9 @@ package rapidcore import "errors" -var ErrInitAlreadyDone = errors.New("InitAlreadyDone") var ErrInitDoneFailed = errors.New("InitDoneFailed") -var ErrInitError = errors.New("InitError") +var ErrInitNotStarted = errors.New("InitNotStarted") +var ErrInitResetReceived = errors.New("InitResetReceived") var ErrNotReserved = errors.New("NotReserved") var ErrAlreadyReserved = errors.New("AlreadyReserved") @@ -23,5 +23,3 @@ var ErrReleaseReservationDone = errors.New("ReleaseReservationDone") var ErrInternalServerError = errors.New("InternalServerError") var ErrInvokeTimeout = errors.New("InvokeTimeout") - -var ErrTerminated = errors.New("SandboxTerminated") // sent to signal a process exit diff --git a/lambda/rapidcore/sandbox.go b/lambda/rapidcore/sandbox.go deleted file mode 100644 index 7d5a8a9..0000000 --- a/lambda/rapidcore/sandbox.go +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "context" - "io" - "io/ioutil" - "net/http" - "os" - "os/signal" - "syscall" - - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/logging" - "go.amzn.com/lambda/rapid" - "go.amzn.com/lambda/rapidcore/env" - "go.amzn.com/lambda/telemetry" - - log "github.com/sirupsen/logrus" -) - -const ( - defaultSigtermResetTimeoutMs = int64(2000) -) - -type Sandbox interface { - Init(i *interop.Init, invokeTimeoutMs int64) - Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error - InteropServer() InteropServer -} - -type ReserveResponse struct { - Token interop.Token - InternalState *statejson.InternalStateDescription -} - -type InteropServer interface { - FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error - Reserve(id string, traceID, lambdaSegmentID string) (*ReserveResponse, error) - Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) - AwaitRelease() (*statejson.InternalStateDescription, error) - Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription - InternalState() (*statejson.InternalStateDescription, error) - CurrentToken() *interop.Token -} - -type SandboxBuilder struct { - sandbox *rapid.Sandbox - defaultInteropServer *Server - useCustomInteropServer bool - shutdownFuncs []context.CancelFunc - debugTailLogWriter io.Writer - platformLogWriter io.Writer -} - -type logSink int - -const ( - RuntimeLogSink logSink = iota - ExtensionLogSink -) - -func NewSandboxBuilder(bootstrap *Bootstrap) *SandboxBuilder { - defaultInteropServer := NewServer(context.Background()) - signalCtx, cancelSignalCtx := context.WithCancel(context.Background()) - logsEgressAPI := &telemetry.NoOpLogsEgressAPI{} - runtimeStdoutWriter, runtimeStderrWriter, _ := logsEgressAPI.GetRuntimeSockets() - - b := &SandboxBuilder{ - sandbox: &rapid.Sandbox{ - Bootstrap: bootstrap, - PreLoadTimeNs: 0, // TODO - StandaloneMode: true, - RuntimeStdoutWriter: runtimeStdoutWriter, - RuntimeStderrWriter: runtimeStderrWriter, - LogsEgressAPI: logsEgressAPI, - EnableTelemetryAPI: false, - Environment: env.NewEnvironment(), - Tracer: telemetry.NewNoOpTracer(), - SignalCtx: signalCtx, - EventsAPI: &telemetry.NoOpEventsAPI{}, - InitCachingEnabled: false, - }, - defaultInteropServer: defaultInteropServer, - shutdownFuncs: []context.CancelFunc{}, - debugTailLogWriter: ioutil.Discard, - platformLogWriter: ioutil.Discard, - } - - b.AddShutdownFunc(context.CancelFunc(func() { - log.Info("Shutting down...") - defaultInteropServer.Reset("SandboxTerminated", defaultSigtermResetTimeoutMs) - cancelSignalCtx() - })) - - return b -} - -func (b *SandboxBuilder) SetInteropServer(interopServer interop.Server) *SandboxBuilder { - b.sandbox.InteropServer = interopServer - b.useCustomInteropServer = true - return b -} - -func (b *SandboxBuilder) SetEventsAPI(eventsAPI telemetry.EventsAPI) *SandboxBuilder { - b.sandbox.EventsAPI = eventsAPI - return b -} - -func (b *SandboxBuilder) SetTracer(tracer telemetry.Tracer) *SandboxBuilder { - b.sandbox.Tracer = tracer - return b -} - -func (b *SandboxBuilder) DisableStandaloneMode() *SandboxBuilder { - b.sandbox.StandaloneMode = false - return b -} - -func (b *SandboxBuilder) SetExtensionsFlag(extensionsEnabled bool) *SandboxBuilder { - if extensionsEnabled { - extensions.Enable() - } else { - extensions.Disable() - } - return b -} - -func (b *SandboxBuilder) SetInitCachingFlag(initCachingEnabled bool) *SandboxBuilder { - b.sandbox.InitCachingEnabled = initCachingEnabled - return b -} - -func (b *SandboxBuilder) SetPreLoadTimeNs(preLoadTimeNs int64) *SandboxBuilder { - b.sandbox.PreLoadTimeNs = preLoadTimeNs - return b -} - -func (b *SandboxBuilder) SetEnvironmentVariables(environment *env.Environment) *SandboxBuilder { - b.sandbox.Environment = environment - return b -} - -func (b *SandboxBuilder) SetPlatformLogOutput(w io.Writer) *SandboxBuilder { - b.platformLogWriter = w - return b -} - -func (b *SandboxBuilder) SetTailLogOutput(w io.Writer) *SandboxBuilder { - b.debugTailLogWriter = w - return b -} - -func (b *SandboxBuilder) SetLogsSubscriptionAPI(logsSubscriptionAPI telemetry.LogsSubscriptionAPI) *SandboxBuilder { - b.sandbox.EnableTelemetryAPI = true - b.sandbox.LogsSubscriptionAPI = logsSubscriptionAPI - return b -} - -func (b *SandboxBuilder) SetLogsEgressAPI(logsEgressAPI telemetry.LogsEgressAPI) *SandboxBuilder { - runtimeStdoutWriter, runtimeStderrWriter, err := logsEgressAPI.GetRuntimeSockets() - - if err != nil { - log.WithError(err).Fatal("failed to get the Runtime sockets from the logs egress API") - } - - b.sandbox.LogsEgressAPI = logsEgressAPI - b.sandbox.RuntimeStdoutWriter = runtimeStdoutWriter - b.sandbox.RuntimeStderrWriter = runtimeStderrWriter - return b -} - -func (b *SandboxBuilder) SetHandler(handler string) *SandboxBuilder { - b.sandbox.Handler = handler - return b -} - -func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc context.CancelFunc) *SandboxBuilder { - b.shutdownFuncs = append(b.shutdownFuncs, shutdownFunc) - return b -} - -func (b *SandboxBuilder) setupLoggingWithDebugLogs() { - // Compose debug log writer with all log sinks. Debug log writer w - // will not write logs when disabled by invoke parameter - b.sandbox.DebugTailLogger = logging.NewTailLogWriter(b.debugTailLogWriter) - b.sandbox.PlatformLogger = logging.NewPlatformLogger(b.platformLogWriter, b.sandbox.DebugTailLogger) - b.sandbox.RuntimeStdoutWriter = io.MultiWriter(b.sandbox.DebugTailLogger, b.sandbox.RuntimeStdoutWriter) - b.sandbox.RuntimeStderrWriter = io.MultiWriter(b.sandbox.DebugTailLogger, b.sandbox.RuntimeStderrWriter) -} - -func (b *SandboxBuilder) Create() { - if len(b.sandbox.Handler) > 0 { - b.sandbox.Environment.SetHandler(b.sandbox.Handler) - } - - if !b.useCustomInteropServer { - b.sandbox.InteropServer = b.defaultInteropServer - } - - b.setupLoggingWithDebugLogs() - - go signalHandler(b.shutdownFuncs) - - rapid.Start(b.sandbox) -} - -func (b *SandboxBuilder) Init(i *interop.Init, timeoutMs int64) { - b.sandbox.InteropServer.Init(&interop.Start{ - Handler: i.Handler, - CorrelationID: i.CorrelationID, - AwsKey: i.AwsKey, - AwsSecret: i.AwsSecret, - AwsSession: i.AwsSession, - XRayDaemonAddress: i.XRayDaemonAddress, - FunctionName: i.FunctionName, - FunctionVersion: i.FunctionVersion, - CustomerEnvironmentVariables: i.CustomerEnvironmentVariables, - }, timeoutMs) -} - -func (b *SandboxBuilder) Invoke(w http.ResponseWriter, i *interop.Invoke) error { - return b.sandbox.InteropServer.Invoke(w, i) -} - -func (b *SandboxBuilder) InteropServer() InteropServer { - return b.defaultInteropServer -} - -// SetLogLevel sets the log level for internal logging. Needs to be called very -// early during startup to configure logs emitted during initialization -func SetLogLevel(logLevel string) { - level, err := log.ParseLevel(logLevel) - if err != nil { - log.WithError(err).Fatal("Failed to set log level. Valid log levels are:", log.AllLevels) - } - - log.SetLevel(level) - log.SetFormatter(&logging.InternalFormatter{}) -} - -func SetInternalLogOutput(w io.Writer) { - logging.SetOutput(w) -} - -// Trap SIGINT and SIGTERM signals and call shutdown function -func signalHandler(shutdownFuncs []context.CancelFunc) { - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) - sigReceived := <-sig - log.WithField("signal", sigReceived.String()).Info("Received signal") - for _, shutdownFunc := range shutdownFuncs { - shutdownFunc() - } -} diff --git a/lambda/rapidcore/sandbox_api.go b/lambda/rapidcore/sandbox_api.go new file mode 100644 index 0000000..0c7052e --- /dev/null +++ b/lambda/rapidcore/sandbox_api.go @@ -0,0 +1,147 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "go.amzn.com/lambda/interop" +) + +// SandboxContext and other structs form the implementation of the SandboxAPI +// interface defined in interop/sandbox_model.go, using the implementation of +// Init, Invoke and Reset handlers in rapid/sandbox.go +type SandboxContext struct { + rapidCtx interop.RapidContext + handler string + runtimeAPIAddress string + + InvokeReceivedTime int64 + InvokeResponseMetrics *interop.InvokeResponseMetrics +} + +type initContext struct { + initSuccessChan chan interop.InitSuccess + initFailureChan chan interop.InitFailure + rapidCtx interop.RapidContext + sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke +} + +type invokeContext struct { + rapidCtx interop.RapidContext + invokeRequestChan chan *interop.Invoke + invokeSuccessChan chan interop.InvokeSuccess + invokeFailureChan chan interop.InvokeFailure +} + +// Validate interface compliance +var _ interop.SandboxContext = (*SandboxContext)(nil) +var _ interop.InitContext = (*initContext)(nil) +var _ interop.InvokeContext = (*invokeContext)(nil) + +func (s SandboxContext) Init(init *interop.Init, timeoutMs int64) (interop.InitStarted, interop.InitContext) { + initStartedResponseChan := make(chan interop.InitStarted) + initSuccessResponseChan := make(chan interop.InitSuccess) + initFailureResponseChan := make(chan interop.InitFailure) + + if len(s.handler) > 0 { + init.EnvironmentVariables.SetHandler(s.handler) + } + + init.EnvironmentVariables.StoreRuntimeAPIEnvironmentVariable(s.runtimeAPIAddress) + + go s.rapidCtx.HandleInit(init, initStartedResponseChan, initSuccessResponseChan, initFailureResponseChan) + initStarted := <-initStartedResponseChan + + sbMetadata := interop.SandboxInfoFromInit{ + EnvironmentVariables: init.EnvironmentVariables, + SandboxType: init.SandboxType, + RuntimeBootstrap: init.Bootstrap, + } + return initStarted, newInitContext(s.rapidCtx, sbMetadata, initSuccessResponseChan, initFailureResponseChan) +} + +func (s SandboxContext) Reset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { + defer s.rapidCtx.Clear() + return s.rapidCtx.HandleReset(reset, s.InvokeReceivedTime, s.InvokeResponseMetrics) +} + +func (s SandboxContext) Shutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { + return s.rapidCtx.HandleShutdown(shutdown) +} + +func (s SandboxContext) Restore(restore *interop.Restore) error { + return s.rapidCtx.HandleRestore(restore) +} + +func (s *SandboxContext) SetInvokeReceivedTime(invokeReceivedTime int64) { + s.InvokeReceivedTime = invokeReceivedTime +} + +func (s *SandboxContext) SetInvokeResponseMetrics(metrics *interop.InvokeResponseMetrics) { + s.InvokeResponseMetrics = metrics +} + +func newInitContext(r interop.RapidContext, sbMetadata interop.SandboxInfoFromInit, + initSuccessChan chan interop.InitSuccess, initFailureChan chan interop.InitFailure) initContext { + return initContext{ + initSuccessChan: initSuccessChan, + initFailureChan: initFailureChan, + rapidCtx: r, + sbInfoFromInit: sbMetadata, + } +} + +func (i initContext) Wait() (interop.InitSuccess, *interop.InitFailure) { + select { + case initSuccess, isOpen := <-i.initSuccessChan: + if !isOpen { + // If init has already suceeded, we return quickly + return interop.InitSuccess{}, nil + } + return initSuccess, nil + case initFailure, isOpen := <-i.initFailureChan: + if !isOpen { + // If init has already failed, we return quickly for init to be suppressed + return interop.InitSuccess{}, &initFailure + } + return interop.InitSuccess{}, &initFailure + } +} + +func (i initContext) Reserve() interop.InvokeContext { + + invokeRequestChan := make(chan *interop.Invoke) + invokeSuccessChan := make(chan interop.InvokeSuccess) + invokeFailureChan := make(chan interop.InvokeFailure) + + go func() { + invoke := <-invokeRequestChan + // For suppressed inits, invoke needs the runtime and agent env vars + invokeSuccess, invokeFailure := i.rapidCtx.HandleInvoke(invoke, i.sbInfoFromInit) + if invokeFailure != nil { + invokeFailureChan <- *invokeFailure + } else { + invokeSuccessChan <- invokeSuccess + } + }() + + return invokeContext{ + rapidCtx: i.rapidCtx, + invokeRequestChan: invokeRequestChan, + invokeSuccessChan: invokeSuccessChan, + invokeFailureChan: invokeFailureChan, + } +} + +func (invCtx invokeContext) SendRequest(i *interop.Invoke) { + invCtx.invokeRequestChan <- i +} + +func (invCtx invokeContext) Wait() (interop.InvokeSuccess, *interop.InvokeFailure) { + select { + case invokeSuccess := <-invCtx.invokeSuccessChan: + return invokeSuccess, nil + case invokeFailure := <-invCtx.invokeFailureChan: + return interop.InvokeSuccess{}, &invokeFailure + } +} diff --git a/lambda/rapidcore/sandbox_builder.go b/lambda/rapidcore/sandbox_builder.go new file mode 100644 index 0000000..ce016a0 --- /dev/null +++ b/lambda/rapidcore/sandbox_builder.go @@ -0,0 +1,217 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "context" + "io" + "net" + "os" + "os/signal" + "strconv" + "syscall" + + "go.amzn.com/lambda/extensions" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/logging" + "go.amzn.com/lambda/rapid" + "go.amzn.com/lambda/supervisor" + supvmodel "go.amzn.com/lambda/supervisor/model" + "go.amzn.com/lambda/telemetry" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultSigtermResetTimeoutMs = int64(2000) +) + +type SandboxBuilder struct { + sandbox *rapid.Sandbox + sandboxContext interop.SandboxContext + lambdaInvokeAPI LambdaInvokeAPI + defaultInteropServer *Server + useCustomInteropServer bool + shutdownFuncs []context.CancelFunc + handler string +} + +type logSink int + +const ( + RuntimeLogSink logSink = iota + ExtensionLogSink +) + +func NewSandboxBuilder() *SandboxBuilder { + defaultInteropServer := NewServer(context.Background()) + signalCtx, cancelSignalCtx := context.WithCancel(context.Background()) + + b := &SandboxBuilder{ + sandbox: &rapid.Sandbox{ + PreLoadTimeNs: 0, // TODO + StandaloneMode: true, + LogsEgressAPI: &telemetry.NoOpLogsEgressAPI{}, + EnableTelemetryAPI: false, + Tracer: telemetry.NewNoOpTracer(), + SignalCtx: signalCtx, + EventsAPI: &telemetry.NoOpEventsAPI{}, + InitCachingEnabled: false, + Supervisor: supervisor.NewLocalSupervisor(), + RuntimeAPIHost: "127.0.0.1", + RuntimeAPIPort: 9001, + }, + defaultInteropServer: defaultInteropServer, + shutdownFuncs: []context.CancelFunc{}, + lambdaInvokeAPI: NewEmulatorAPI(defaultInteropServer), + } + + b.AddShutdownFunc(context.CancelFunc(func() { + log.Info("Shutting down...") + defaultInteropServer.Reset("SandboxTerminated", defaultSigtermResetTimeoutMs) + cancelSignalCtx() + })) + + return b +} + +func (b *SandboxBuilder) SetSupervisor(supervisor supvmodel.Supervisor) *SandboxBuilder { + b.sandbox.Supervisor = supervisor + return b +} + +func (b *SandboxBuilder) SetRuntimeAPIAddress(runtimeAPIAddress string) *SandboxBuilder { + host, port, err := net.SplitHostPort(runtimeAPIAddress) + if err != nil { + log.WithError(err).Warnf("Failed to parse RuntimeApiAddress: %s:", runtimeAPIAddress) + return b + } + + portInt, err := strconv.Atoi(port) + if err != nil { + log.WithError(err).Warnf("Failed to parse RuntimeApiPort: %s:", port) + return b + } + + b.sandbox.RuntimeAPIHost = host + b.sandbox.RuntimeAPIPort = portInt + return b +} + +func (b *SandboxBuilder) SetInteropServer(interopServer interop.Server) *SandboxBuilder { + b.sandbox.InteropServer = interopServer + b.useCustomInteropServer = true + return b +} + +func (b *SandboxBuilder) SetEventsAPI(eventsAPI telemetry.EventsAPI) *SandboxBuilder { + b.sandbox.EventsAPI = eventsAPI + return b +} + +func (b *SandboxBuilder) SetTracer(tracer telemetry.Tracer) *SandboxBuilder { + b.sandbox.Tracer = tracer + return b +} + +func (b *SandboxBuilder) DisableStandaloneMode() *SandboxBuilder { + b.sandbox.StandaloneMode = false + return b +} + +func (b *SandboxBuilder) SetExtensionsFlag(extensionsEnabled bool) *SandboxBuilder { + if extensionsEnabled { + extensions.Enable() + } else { + extensions.Disable() + } + return b +} + +func (b *SandboxBuilder) SetInitCachingFlag(initCachingEnabled bool) *SandboxBuilder { + b.sandbox.InitCachingEnabled = initCachingEnabled + return b +} + +func (b *SandboxBuilder) SetPreLoadTimeNs(preLoadTimeNs int64) *SandboxBuilder { + b.sandbox.PreLoadTimeNs = preLoadTimeNs + return b +} + +func (b *SandboxBuilder) SetTelemetrySubscription(logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI) *SandboxBuilder { + b.sandbox.EnableTelemetryAPI = true + b.sandbox.LogsSubscriptionAPI = logsSubscriptionAPI + b.sandbox.TelemetrySubscriptionAPI = telemetrySubscriptionAPI + return b +} + +func (b *SandboxBuilder) SetLogsEgressAPI(logsEgressAPI telemetry.StdLogsEgressAPI) *SandboxBuilder { + b.sandbox.LogsEgressAPI = logsEgressAPI + return b +} + +func (b *SandboxBuilder) SetHandler(handler string) *SandboxBuilder { + b.handler = handler + return b +} + +func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc context.CancelFunc) *SandboxBuilder { + b.shutdownFuncs = append(b.shutdownFuncs, shutdownFunc) + return b +} + +func (b *SandboxBuilder) Create() (interop.SandboxContext, interop.InternalStateGetter) { + if !b.useCustomInteropServer { + b.sandbox.InteropServer = b.defaultInteropServer + } + + go signalHandler(b.shutdownFuncs) + + rapidCtx, internalStateFn, runtimeAPIAddr := rapid.Start(b.sandbox) + + b.sandboxContext = &SandboxContext{ + rapidCtx: rapidCtx, + handler: b.handler, + runtimeAPIAddress: runtimeAPIAddr, + InvokeReceivedTime: int64(0), + InvokeResponseMetrics: nil, + } + + return b.sandboxContext, internalStateFn +} + +func (b *SandboxBuilder) DefaultInteropServer() *Server { + return b.defaultInteropServer +} + +func (b *SandboxBuilder) LambdaInvokeAPI() LambdaInvokeAPI { + return b.lambdaInvokeAPI +} + +// SetLogLevel sets the log level for internal logging. Needs to be called very +// early during startup to configure logs emitted during initialization +func SetLogLevel(logLevel string) { + level, err := log.ParseLevel(logLevel) + if err != nil { + log.WithError(err).Fatal("Failed to set log level. Valid log levels are:", log.AllLevels) + } + + log.SetLevel(level) + log.SetFormatter(&logging.InternalFormatter{}) +} + +func SetInternalLogOutput(w io.Writer) { + logging.SetOutput(w) +} + +// Trap SIGINT and SIGTERM signals and call shutdown function +func signalHandler(shutdownFuncs []context.CancelFunc) { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + sigReceived := <-sig + log.WithField("signal", sigReceived.String()).Info("Received signal") + for _, shutdownFunc := range shutdownFuncs { + shutdownFunc() + } +} diff --git a/lambda/rapidcore/sandbox_emulator_api.go b/lambda/rapidcore/sandbox_emulator_api.go new file mode 100644 index 0000000..6737631 --- /dev/null +++ b/lambda/rapidcore/sandbox_emulator_api.go @@ -0,0 +1,52 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "go.amzn.com/lambda/interop" + + "net/http" +) + +// LambdaInvokeAPI are the methods used by the Runtime Interface Emulator +type LambdaInvokeAPI interface { + Init(i *interop.Init, invokeTimeoutMs int64) + Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error +} + +// EmulatorAPI wraps the standalone interop server to provide a convenient interface +// for Rapid Standalone +type EmulatorAPI struct { + server *Server +} + +// Validate interface compliance +var _ LambdaInvokeAPI = (*EmulatorAPI)(nil) + +func NewEmulatorAPI(s *Server) *EmulatorAPI { + return &EmulatorAPI{s} +} + +// Init method is only used by the Runtime interface emulator +func (l *EmulatorAPI) Init(i *interop.Init, timeoutMs int64) { + l.server.Init(&interop.Init{ + Handler: i.Handler, + AwsKey: i.AwsKey, + AwsSecret: i.AwsSecret, + AwsSession: i.AwsSession, + XRayDaemonAddress: i.XRayDaemonAddress, + FunctionName: i.FunctionName, + FunctionVersion: i.FunctionVersion, + CustomerEnvironmentVariables: i.CustomerEnvironmentVariables, + RuntimeInfo: i.RuntimeInfo, + SandboxType: i.SandboxType, + Bootstrap: i.Bootstrap, + EnvironmentVariables: i.EnvironmentVariables, + }, timeoutMs) +} + +// Invoke method is only used by the Runtime interface emulator +func (l *EmulatorAPI) Invoke(w http.ResponseWriter, i *interop.Invoke) error { + return l.server.Invoke(w, i) +} diff --git a/lambda/rapidcore/server.go b/lambda/rapidcore/server.go index e3e01b6..e652130 100644 --- a/lambda/rapidcore/server.go +++ b/lambda/rapidcore/server.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "math" "net/http" "sync" @@ -17,6 +16,7 @@ import ( "go.amzn.com/lambda/core/directinvoke" "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" @@ -33,6 +33,12 @@ const ( resetDefaultTimeoutMs = 2000 ) +const ( + contentTypeHeader = "Content-Type" + errorTypeHeader = "Error-Type" + functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" +) + type rapidPhase int const ( @@ -46,7 +52,6 @@ type runtimeState int const ( runtimeNotStarted = iota - runtimeInitStarted runtimeInitError runtimeInitComplete runtimeInitFailed @@ -76,14 +81,11 @@ type InvokeContext struct { type Server struct { InternalStateGetter interop.InternalStateGetter - invokeChanOut chan *interop.Invoke - startChanOut chan *interop.Start - resetChanOut chan *interop.Reset - shutdownChanOut chan *interop.Shutdown - errorChanOut chan error + initChanOut chan *interop.Init + interruptedResponseChan chan *interop.Reset - sendRunningChan chan *interop.Running - sendResponseChan chan struct{} + sendRunningChan chan *interop.InitStarted + sendResponseChan chan *interop.InvokeResponseMetrics doneChan chan *interop.Done InitDoneChan chan DoneWithState @@ -100,12 +102,17 @@ type Server struct { rapidPhase rapidPhase runtimeState runtimeState -} -func (s *Server) StartAcceptingDirectInvokes() error { - return nil + sandboxContext interop.SandboxContext + initContext interop.InitContext + invoker interop.InvokeContext + initFailures chan interop.InitFailure + cachedInitErrorResponse *interop.ErrorResponse } +// Validate interface compliance +var _ interop.Server = (*Server)(nil) + func (s *Server) setRapidPhase(phase rapidPhase) { s.mutex.Lock() defer s.mutex.Unlock() @@ -185,6 +192,11 @@ func (s *Server) setNewInvokeContext(invokeID string, traceID, lambdaSegmentID s return resp, nil } +type ReserveResponse struct { + Token interop.Token + InternalState *statejson.InternalStateDescription +} + // Reserve allocates invoke context func (s *Server) Reserve(id string, traceID, lambdaSegmentID string) (*ReserveResponse, error) { invokeID := uuid.New().String() @@ -196,37 +208,28 @@ func (s *Server) Reserve(id string, traceID, lambdaSegmentID string) (*ReserveRe return nil, err } - resp.InternalState, err = s.waitInit() + // The two errors reserve returns in standalone mode are INIT timeout + // and INIT failure (two types of failure: runtime exit, /init/error). Both require suppressed + // initialization, so we succeed the reservation. + invCtx := s.initContext.Reserve() + s.invoker = invCtx + resp.InternalState, err = s.InternalState() + return resp, err } -func (s *Server) waitInit() (*statejson.InternalStateDescription, error) { - for { - select { - - case doneWithState, chanOpen := <-s.InitDoneChan: - if !chanOpen { - // init only happens once - return nil, ErrInitAlreadyDone - } - - close(s.InitDoneChan) // this was first call to reserve - - if s.getRuntimeState() == runtimeInitFailed { - return &doneWithState.State, ErrInitError - } - - if len(doneWithState.ErrorType) > 0 { - log.Errorf("INIT DONE failed: %s", doneWithState.ErrorType) - return &doneWithState.State, ErrInitDoneFailed - } - - return &doneWithState.State, nil - - case <-s.reservationContext.Done(): - return nil, ErrReserveReservationDone - } +func (s *Server) awaitInitCompletion() { + initSuccess, initFailure := s.initContext.Wait() + if initFailure != nil { + // In standalone, we don't have to block rapid start() goroutine until init failure is consumed + // because there is no channel back to the invoker until an invoke arrives via a Reserve() + initFailure.Ack <- struct{}{} + s.initFailures <- *initFailure + } else { + initSuccess.Ack <- struct{}{} } + // always closing the channel makes this method idempotent + close(s.initFailures) } func (s *Server) setReplyStream(w http.ResponseWriter, direct bool) (string, error) { @@ -263,6 +266,8 @@ func (s *Server) Release() error { s.reservationCancel() } + s.sandboxContext.SetInvokeReceivedTime(0) + s.sandboxContext.SetInvokeResponseMetrics(nil) s.invokeCtx = nil return nil } @@ -279,37 +284,18 @@ func (s *Server) GetCurrentInvokeID() string { return s.invokeCtx.Token.InvokeID } +// SetSandboxContext is used to set the sandbox context after intiialization of interop server. +// After refactoring all messages, this needs to be removed and made an struct parameter on initialization. +func (s *Server) SetSandboxContext(sbCtx interop.SandboxContext) { + s.sandboxContext = sbCtx +} + // SetInternalStateGetter is used to set callback which returnes internal state for /test/internalState request func (s *Server) SetInternalStateGetter(cb interop.InternalStateGetter) { s.InternalStateGetter = cb } -// StartChan returns Start emitter -func (s *Server) StartChan() <-chan *interop.Start { - return s.startChanOut -} - -// InvokeChan returns Invoke emitter -func (s *Server) InvokeChan() <-chan *interop.Invoke { - return s.invokeChanOut -} - -// ResetChan returns Reset emitter -func (s *Server) ResetChan() <-chan *interop.Reset { - return s.resetChanOut -} - -// ShutdownChan returns Shutdown emitter -func (s *Server) ShutdownChan() <-chan *interop.Shutdown { - return s.shutdownChanOut -} - -// InvalidMessageChan emits errors if there was something we could not parse -func (s *Server) TransportErrorChan() <-chan error { - return s.errorChanOut -} - -func (s *Server) sendResponseUnsafe(invokeID string, contentType string, status int, payload io.Reader) error { +func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[string]string, status int, payload io.Reader, trailers http.Header, request *interop.CancellableRequest, runtimeCalledResponse bool) error { if s.invokeCtx == nil || invokeID != s.invokeCtx.Token.InvokeID { return interop.ErrInvalidInvokeID } @@ -322,26 +308,15 @@ func (s *Server) sendResponseUnsafe(invokeID string, contentType string, status return fmt.Errorf("ReplyStream not available") } - // TODO: earlier, status was set to 500 if runtime called /invocation/error. However, the integration - // tests do not differentiate between /invocation/error and /invocation/response, but they check the error type: - // To identify user-errors, we should also allowlist custom errortypes and propagate them via headers. - - // s.invokeCtx.ReplyStream.WriteHeader(status) - var reportedErr error if s.invokeCtx.Direct { - if err := directinvoke.SendDirectInvokeResponse(map[string]string{"Content-Type": contentType}, payload, s.invokeCtx.ReplyStream); err != nil { + if err := directinvoke.SendDirectInvokeResponse(additionalHeaders, payload, trailers, s.invokeCtx.ReplyStream, s.interruptedResponseChan, s.sendResponseChan, request, runtimeCalledResponse); err != nil { // TODO: Do we need to drain the reader in case of a large payload and connection reuse? log.Errorf("Failed to write response to %s: %s", invokeID, err) - flusher, ok := s.invokeCtx.ReplyStream.(http.Flusher) - if !ok { - log.Error("Failed to flush response") - } - flusher.Flush() reportedErr = err } } else { - data, err := ioutil.ReadAll(payload) + data, err := io.ReadAll(payload) if err != nil { return fmt.Errorf("Failed to read response on %s: %s", invokeID, err) } @@ -352,73 +327,103 @@ func (s *Server) sendResponseUnsafe(invokeID string, contentType string, status } } - s.invokeCtx.ReplyStream.Header().Add("Content-Type", contentType) - if _, err := s.invokeCtx.ReplyStream.Write(data); err != nil { + startReadingResponseMonoTimeMs := metering.Monotime() + s.invokeCtx.ReplyStream.Header().Add(contentTypeHeader, additionalHeaders[contentTypeHeader]) + written, err := s.invokeCtx.ReplyStream.Write(data) + if err != nil { return fmt.Errorf("Failed to write response to %s: %s", invokeID, err) } + + s.sendResponseChan <- &interop.InvokeResponseMetrics{ + ProducedBytes: int64(written), + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: metering.Monotime(), + TimeShapedNs: int64(-1), + OutboundThroughputBps: int64(-1), + // FIXME: + // The runtime tells whether the function response mode is streaming or not. + // Ideally, we would want to use that value here. Since I'm just rebasing, I will leave + // as-is, but we should use that instead of relying on our memory to set this here + // because we "know" it's a streaming code path. + FunctionResponseMode: interop.FunctionResponseModeBuffered, + RuntimeCalledResponse: runtimeCalledResponse, + } } - s.sendResponseChan <- struct{}{} s.invokeCtx.ReplySent = true s.invokeCtx.Direct = false return reportedErr } -func (s *Server) SendResponse(invokeID string, contentType string, reader io.Reader) error { +func (s *Server) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { s.setRuntimeState(runtimeInvokeResponseSent) s.mutex.Lock() defer s.mutex.Unlock() - return s.sendResponseUnsafe(invokeID, contentType, http.StatusOK, reader) -} - -func (s *Server) CommitResponse() error { return nil } - -func (s *Server) SendRunning(run *interop.Running) error { - s.setRuntimeState(runtimeInitStarted) - s.sendRunningChan <- run - return nil + runtimeCalledResponse := true + return s.sendResponseUnsafe(invokeID, headers, http.StatusOK, reader, trailers, request, runtimeCalledResponse) } -func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorResponse) error { - switch s.getRapidPhase() { - case phaseInitializing: - s.setRuntimeState(runtimeInitError) - return nil - case phaseInvoking: - // This branch can also occur during a suppressed init error, which is reported as invoke error - s.setRuntimeState(runtimeInvokeError) - s.mutex.Lock() - defer s.mutex.Unlock() - return s.sendResponseUnsafe(invokeID, resp.ContentType, http.StatusInternalServerError, bytes.NewReader(resp.Payload)) - default: - panic("received unexpected error response outside invoke or init phases") +func (s *Server) SendInitErrorResponse(invokeID string, resp *interop.ErrorResponse) error { + log.Debugf("Sending Init Error Response: %s", resp.ErrorType) + if s.getRapidPhase() == phaseInvoking { + // This branch occurs during suppressed init + return s.SendErrorResponse(invokeID, resp) } -} -func (s *Server) SendDone(done *interop.Done) error { - s.doneChan <- done + // Handle an /init/error outside of the invoke phase + s.setCachedInitErrorResponse(resp) + s.setRuntimeState(runtimeInitError) return nil } -func (s *Server) SendDoneFail(doneFail *interop.DoneFail) error { - s.doneChan <- &interop.Done{ - ErrorType: doneFail.ErrorType, - CorrelationID: doneFail.CorrelationID, // filipovi: correlationID is required to dispatch message into correct channel - Meta: doneFail.Meta, +func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorResponse) error { + log.Debugf("Sending Error Response: %s", resp.ErrorType) + s.setRuntimeState(runtimeInvokeError) + s.mutex.Lock() + defer s.mutex.Unlock() + additionalHeaders := map[string]string{contentTypeHeader: resp.ContentType, errorTypeHeader: resp.ErrorType} + if functionResponseMode := resp.FunctionResponseMode; functionResponseMode != "" { + additionalHeaders[functionResponseModeHeader] = functionResponseMode } - return nil + runtimeCalledResponse := false // we are sending an error here, so runtime called /error or crashed/timeout + return s.sendResponseUnsafe(invokeID, additionalHeaders, http.StatusInternalServerError, bytes.NewReader(resp.Payload), nil, nil, runtimeCalledResponse) } func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) { // pass reset to rapid - s.resetChanOut <- &interop.Reset{ - Reason: reason, - DeadlineNs: deadlineNsFromTimeoutMs(timeoutMs), - CorrelationID: "resetCorrelationID", + reset := &interop.Reset{ + Reason: reason, + DeadlineNs: deadlineNsFromTimeoutMs(timeoutMs), } + go func() { + select { + case s.interruptedResponseChan <- reset: + <-s.interruptedResponseChan // wait for response streaming metrics being added to reset struct + s.sandboxContext.SetInvokeResponseMetrics(reset.InvokeResponseMetrics) + default: + } + + resetSuccess, resetFailure := s.sandboxContext.Reset(reset) + s.Clear() // clear server state to prepare for new invokes + s.setRapidPhase(phaseIdle) + s.setRuntimeState(runtimeNotStarted) + + var meta interop.DoneMetadata + if reset.InvokeResponseMetrics != nil { + meta.RuntimeTimeThrottledMs = reset.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) + meta.RuntimeProducedBytes = reset.InvokeResponseMetrics.ProducedBytes + meta.RuntimeOutboundThroughputBps = reset.InvokeResponseMetrics.OutboundThroughputBps + } + + if resetFailure != nil { + meta.ExtensionsResetMs = resetFailure.ExtensionsResetMs + s.ResetDoneChan <- &interop.Done{ErrorType: resetFailure.ErrorType, Meta: meta} + } else { + meta.ExtensionsResetMs = resetSuccess.ExtensionsResetMs + s.ResetDoneChan <- &interop.Done{ErrorType: resetSuccess.ErrorType, Meta: meta} + } + }() - // TODO do not block on reset, instead consume ResetDoneChan in waitForRelease handler, - // this will get us more aligned on async reset notification handling. done := <-s.ResetDoneChan s.Release() @@ -431,14 +436,11 @@ func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescript func NewServer(ctx context.Context) *Server { s := &Server{ - startChanOut: make(chan *interop.Start), - invokeChanOut: make(chan *interop.Invoke), - errorChanOut: make(chan error), - resetChanOut: make(chan *interop.Reset), - shutdownChanOut: make(chan *interop.Shutdown), - - sendRunningChan: make(chan *interop.Running), - sendResponseChan: make(chan struct{}), + initChanOut: make(chan *interop.Init), + interruptedResponseChan: make(chan *interop.Reset), + + sendRunningChan: make(chan *interop.InitStarted), + sendResponseChan: make(chan *interop.InvokeResponseMetrics), doneChan: make(chan *interop.Done), // These two channels are buffered, because they are depleted asynchronously (by reserve and waitUntilRelease) and we don't want to block in SendDone until they are called @@ -449,47 +451,9 @@ func NewServer(ctx context.Context) *Server { ShutdownDoneChan: make(chan *interop.Done), } - go s.dispatchDone() - return s } -func (s *Server) setInitDoneRuntimeState(done *interop.Done) { - if len(done.ErrorType) > 0 { - s.setRuntimeState(runtimeInitFailed) // donefail - } else { - s.setRuntimeState(runtimeInitComplete) // done - } -} - -// Note, the dispatch loop below has potential to block, when -// channel is not drained. E.g. if test assumes sandbox init -// completion before dispatching reset, then reset will block -// until init channel is drained. -func (s *Server) dispatchDone() { - for { - done := <-s.doneChan - log.Debug("Dispatching DONE:", done.CorrelationID) - internalState := s.InternalStateGetter() - s.setRapidPhase(phaseIdle) - if done.CorrelationID == "initCorrelationID" { - s.setInitDoneRuntimeState(done) - s.InitDoneChan <- DoneWithState{Done: done, State: internalState} - } else if done.CorrelationID == "invokeCorrelationID" { - s.setRuntimeState(runtimeInvokeComplete) - s.InvokeDoneChan <- DoneWithState{Done: done, State: internalState} - } else if done.CorrelationID == "resetCorrelationID" { - s.setRuntimeState(runtimeNotStarted) - s.ResetDoneChan <- done - } else if done.CorrelationID == "shutdownCorrelationID" { - s.setRuntimeState(runtimeNotStarted) - s.ShutdownDoneChan <- done - } else { - panic("Received DONE without correlation ID") - } - } -} - func drainChannel(c chan DoneWithState) { for { select { @@ -509,10 +473,6 @@ func (s *Server) Clear() { s.Release() } -func (s *Server) IsResponseSent() bool { - panic("unexpected call to unimplemented method in rapidcore: IsResponseSent()") -} - func (s *Server) SendRuntimeReady() error { // only called when extensions are enabled s.setRuntimeState(runtimeReady) @@ -524,16 +484,34 @@ func deadlineNsFromTimeoutMs(timeoutMs int64) int64 { return mono + timeoutMs*1000*1000 } -func (s *Server) Init(i *interop.Start, invokeTimeoutMs int64) { - s.SetInvokeTimeout(time.Duration(invokeTimeoutMs) * time.Millisecond) +func (s *Server) setInitFailuresChan() { + s.mutex.Lock() + defer s.mutex.Unlock() + s.initFailures = make(chan interop.InitFailure) +} + +func (s *Server) getInitFailuresChan() chan interop.InitFailure { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.initFailures +} - s.startChanOut <- i +func (s *Server) Init(i *interop.Init, invokeTimeoutMs int64) error { + s.SetInvokeTimeout(time.Duration(invokeTimeoutMs) * time.Millisecond) s.setRapidPhase(phaseInitializing) - <-s.sendRunningChan - log.Debug("Received RUNNING") + s.setInitFailuresChan() + initStarted, initCtx := s.sandboxContext.Init(i, invokeTimeoutMs) + initStarted.Ack <- struct{}{} + + s.initContext = initCtx + go s.awaitInitCompletion() + + log.Debugf("Received RUNNING: %v", initStarted) + return nil } func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error { + s.sandboxContext.SetInvokeReceivedTime(i.InvokeReceivedTime) invokeID, err := s.setReplyStream(w, direct) if err != nil { return err @@ -544,16 +522,55 @@ func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct boo i.ID = invokeID select { - case s.invokeChanOut <- i: - break case <-s.sendResponseChan: // we didn't pass invoke to rapid yet, but rapid already has written some response // It can happend if runtime/agent crashed even before we passed invoke to it return ErrInvokeResponseAlreadyWritten + default: } + go func() { + if s.invoker == nil { + // Reset occurred, do not send invoke request + s.InvokeDoneChan <- DoneWithState{State: s.InternalStateGetter()} + s.setRuntimeState(runtimeInvokeComplete) + return + } + s.invoker.SendRequest(i) + invokeSuccess, invokeFailure := s.invoker.Wait() + if invokeFailure != nil { + if invokeFailure.ResetReceived { + return + } + + // Rapid constructs a response body itself when invoke fails, with error type. + // These are on the handleInvokeError path, may occur during timeout resets, + // failure reset (proc exit). It is expected to be non-nil on all invoke failures. + if invokeFailure.DefaultErrorResponse == nil { + log.Panicf("default error response was nil for invoke failure, %v", invokeFailure) + } + + if cachedInitError := s.getCachedInitErrorResponse(); cachedInitError != nil { + // /init/error was called + s.trySendDefaultErrorResponse(cachedInitError) + } else { + // sent only if /error and /response not called + s.trySendDefaultErrorResponse(invokeFailure.DefaultErrorResponse) + } + doneFail := doneFailFromInvokeFailure(invokeFailure) + s.InvokeDoneChan <- DoneWithState{ + Done: &interop.Done{ErrorType: doneFail.ErrorType, Meta: doneFail.Meta}, + State: s.InternalStateGetter(), + } + } else { + done := doneFromInvokeSuccess(invokeSuccess) + s.InvokeDoneChan <- DoneWithState{Done: done, State: s.InternalStateGetter()} + } + }() + select { - case <-s.sendResponseChan: + case i.InvokeResponseMetrics = <-s.sendResponseChan: + s.sandboxContext.SetInvokeResponseMetrics(i.InvokeResponseMetrics) break case <-s.reservationContext.Done(): return ErrInvokeReservationDone @@ -562,6 +579,26 @@ func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct boo return nil } +func (s *Server) setCachedInitErrorResponse(errResp *interop.ErrorResponse) { + s.mutex.Lock() + defer s.mutex.Unlock() + s.cachedInitErrorResponse = errResp +} + +func (s *Server) getCachedInitErrorResponse() *interop.ErrorResponse { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.cachedInitErrorResponse +} + +func (s *Server) trySendDefaultErrorResponse(resp *interop.ErrorResponse) { + if err := s.SendErrorResponse(s.GetCurrentInvokeID(), resp); err != nil { + if err != interop.ErrResponseSent { + log.Panicf("Failed to send default error response: %s", err) + } + } +} + func (s *Server) CurrentToken() *interop.Token { s.mutex.Lock() defer s.mutex.Unlock() @@ -582,77 +619,158 @@ func (s *Server) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invo go func() { select { case <-time.After(s.GetInvokeTimeout()): + log.Debug("Invoke() timeout") timeoutChan <- ErrInvokeTimeout - s.Reset(autoresetReasonTimeout, resetDefaultTimeoutMs) case <-resetCtx.Done(): log.Debugf("execute finished, autoreset cancelled") } }() - reserveResp, err := s.Reserve(invoke.ID, "", "") - if err != nil { - switch err { - case ErrInitError: - // Simulate 'Suppressed Init' scenario - s.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) - reserveResp, err = s.Reserve("", "", "") - if err == ErrInitAlreadyDone { - break - } - return err - case ErrInitDoneFailed, ErrTerminated: - s.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) - return err - - case ErrInitAlreadyDone: - // This is a valid response (e.g. for 2nd and 3rd invokes) - // TODO: switch on ReserveResponse status instead of err, - // since these are valid values - if s.InternalStateGetter == nil { - responseWriter.Write([]byte("error: internal state callback not set")) - return ErrInternalServerError - } + initFailures := s.getInitFailuresChan() + if initFailures == nil { + return ErrInitNotStarted + } - default: - return err + releaseErrChan := make(chan error) + releaseSuccessChan := make(chan struct{}) + go func() { + // This thread can block in one of two method calls Reserve() & AwaitRelease(), + // corresponding to Init and Invoke phase. + // FastInvoke is intended to be 'async' response stream copying. + // When a timeout occurs, we send a 'Reset' with the timeout reason + // When a Reset is sent, the reset handler in rapid lib cancels existing flows, + // including init/invoke. This causes either initFailure/invokeFailure, and then + // the Reset is handled and processed. + // TODO: however, ideally Reserve() does not block on init, but FastInvoke does + // The logic would be almost identical, except that init failures could manifest + // through return values of FastInvoke and not Reserve() + + reserveResp, err := s.Reserve("", "", "") + if err != nil { + log.Infof("ReserveFailed: %s", err) } - } - invoke.DeadlineNs = fmt.Sprintf("%d", metering.Monotime()+reserveResp.Token.FunctionTimeout.Nanoseconds()) + invoke.DeadlineNs = fmt.Sprintf("%d", metering.Monotime()+reserveResp.Token.FunctionTimeout.Nanoseconds()) + go func() { + if initCompletionResp, err := s.awaitInitialized(); err != nil { + switch err { + case ErrInitResetReceived, ErrInitDoneFailed: + // For init failures, cache the response so they can be checked later + // We check if they have not already been set by a call to /init/error by runtime + if s.getCachedInitErrorResponse() == nil { + errType, errMsg := string(initCompletionResp.InitErrorType), initCompletionResp.InitErrorMessage.Error() + s.setCachedInitErrorResponse(&interop.ErrorResponse{ErrorType: errType, ErrorMessage: errMsg}) + } + } + } - invokeChan := make(chan error) - go func() { - if err := s.FastInvoke(responseWriter, invoke, false); err != nil { - invokeChan <- err + if err := s.FastInvoke(responseWriter, invoke, false); err != nil { + log.Debugf("FastInvoke() error: %s", err) + } + }() + + _, err = s.AwaitRelease() + if err != nil && err != ErrReleaseReservationDone { + log.Debugf("AwaitRelease() error: %s", err) + switch err { + case ErrReleaseReservationDone: // not an error, expected return value when Reset is called + if s.getCachedInitErrorResponse() != nil { + // For Init failures, AwaitRelease returns ErrReleaseReservationDone + // because the Reset calls Release & cancels the release context + // We rename the error to ErrInitDoneFailed + releaseErrChan <- ErrInitDoneFailed + } + case ErrInitDoneFailed, ErrInvokeDoneFailed: + // Reset when either init or invoke failrues occur, i.e. + // init/error, invocation/error, Runtime.ExitError, Extension.ExitError + s.Reset(autoresetReasonReleaseFail, resetDefaultTimeoutMs) + releaseErrChan <- err + default: + releaseErrChan <- err + } + return } - }() - releaseChan := make(chan error) - go func() { - _, err := s.AwaitRelease() - releaseChan <- err + releaseSuccessChan <- struct{}{} }() - // TODO: verify the order of channel receives. When timeouts happen, Reset() - // is called first, which also does Release() => this may signal a type - // Err<*>ReservationDone error to the non-timeout channels. This is currently - // handled by the http handler, which returns GatewayTimeout for reservation errors - // too. However, Timeouts should ideally be only represented by ErrInvokeTimeout. + var err error select { - case err = <-invokeChan: - case err = <-timeoutChan: - case err = <-releaseChan: - if err != nil { - s.Reset(autoresetReasonReleaseFail, resetDefaultTimeoutMs) + case timeoutErr := <-timeoutChan: + s.Reset(autoresetReasonTimeout, resetDefaultTimeoutMs) + select { + case releaseErr := <-releaseErrChan: // when AwaitRelease() has errors + log.Debugf("Invoke() release error on Execute() timeout: %s", releaseErr) + case <-releaseSuccessChan: // when AwaitRelease() finishes cleanly } + err = timeoutErr + case err = <-releaseErrChan: + log.Debug("Invoke() release error") + case <-releaseSuccessChan: + s.Release() + log.Debug("Invoke() success") } return err } +type initCompletionResponse struct { + InitErrorType fatalerror.ErrorType + InitErrorMessage error +} + +func (s *Server) awaitInitialized() (initCompletionResponse, error) { + initFailure, awaitingInitStatus := <-s.getInitFailuresChan() + resp := initCompletionResponse{} + + if initFailure.ResetReceived { + // Resets during Init are only received in standalone + // during an invoke timeout + s.setRuntimeState(runtimeInitFailed) + resp.InitErrorType = initFailure.ErrorType + resp.InitErrorMessage = initFailure.ErrorMessage + return resp, ErrInitResetReceived + } + + if awaitingInitStatus { + // channel not closed, received init failure + // Sandbox can be reserved even if init failed (due to function errors) + s.setRuntimeState(runtimeInitFailed) + resp.InitErrorType = initFailure.ErrorType + resp.InitErrorMessage = initFailure.ErrorMessage + return resp, ErrInitDoneFailed + } + + // not awaiting init status (channel closed) + return resp, nil +} + +// AwaitInitialized waits until init is complete. It must be idempotent, +// since it can be called twice when a caller wants to wait until init is complete +func (s *Server) AwaitInitialized() error { + if _, err := s.awaitInitialized(); err != nil { + if releaseErr := s.Release(); err != nil { + log.Infof("Error releasing after init failure %s: %s", err, releaseErr) + } + s.setRuntimeState(runtimeInitFailed) + return err + } + s.setRuntimeState(runtimeInitComplete) + return nil +} + func (s *Server) AwaitRelease() (*statejson.InternalStateDescription, error) { + defer func() { + s.setRapidPhase(phaseIdle) + s.setRuntimeState(runtimeInvokeComplete) + }() + select { case doneWithState := <-s.InvokeDoneChan: + if len(doneWithState.ErrorType) > 0 && string(doneWithState.ErrorType) == ErrInitDoneFailed.Error() { + return nil, ErrInitDoneFailed + } + if len(doneWithState.ErrorType) > 0 { log.Errorf("Invoke DONE failed: %s", doneWithState.ErrorType) return nil, ErrInvokeDoneFailed @@ -667,8 +785,13 @@ func (s *Server) AwaitRelease() (*statejson.InternalStateDescription, error) { } func (s *Server) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { - s.shutdownChanOut <- shutdown - <-s.ShutdownDoneChan + shutdownSuccess := s.sandboxContext.Shutdown(shutdown) + if len(shutdownSuccess.ErrorType) > 0 { + log.Errorf("Shutdown first fatal error: %s", shutdownSuccess.ErrorType) + } + + s.setRapidPhase(phaseIdle) + s.setRuntimeState(runtimeNotStarted) state := s.InternalStateGetter() return &state @@ -682,3 +805,49 @@ func (s *Server) InternalState() (*statejson.InternalStateDescription, error) { state := s.InternalStateGetter() return &state, nil } + +func (s *Server) Restore(restore *interop.Restore) error { + return s.sandboxContext.Restore(restore) +} + +func doneFromInvokeSuccess(successMsg interop.InvokeSuccess) *interop.Done { + return &interop.Done{ + Meta: interop.DoneMetadata{ + RuntimeRelease: successMsg.RuntimeRelease, + NumActiveExtensions: successMsg.NumActiveExtensions, + ExtensionNames: successMsg.ExtensionNames, + InvokeRequestReadTimeNs: successMsg.InvokeMetrics.InvokeRequestReadTimeNs, + InvokeRequestSizeBytes: successMsg.InvokeMetrics.InvokeRequestSizeBytes, + RuntimeReadyTime: successMsg.InvokeMetrics.RuntimeReadyTime, + + InvokeCompletionTimeNs: successMsg.InvokeCompletionTimeNs, + InvokeReceivedTime: successMsg.InvokeReceivedTime, + RuntimeTimeThrottledMs: successMsg.ResponseMetrics.RuntimeTimeThrottledMs, + RuntimeProducedBytes: successMsg.ResponseMetrics.RuntimeProducedBytes, + RuntimeOutboundThroughputBps: successMsg.ResponseMetrics.RuntimeOutboundThroughputBps, + LogsAPIMetrics: successMsg.LogsAPIMetrics, + }, + } +} + +func doneFailFromInvokeFailure(failureMsg *interop.InvokeFailure) *interop.DoneFail { + return &interop.DoneFail{ + ErrorType: failureMsg.ErrorType, + Meta: interop.DoneMetadata{ + RuntimeRelease: failureMsg.RuntimeRelease, + NumActiveExtensions: failureMsg.NumActiveExtensions, + InvokeReceivedTime: failureMsg.InvokeReceivedTime, + + RuntimeTimeThrottledMs: failureMsg.ResponseMetrics.RuntimeTimeThrottledMs, + RuntimeProducedBytes: failureMsg.ResponseMetrics.RuntimeProducedBytes, + RuntimeOutboundThroughputBps: failureMsg.ResponseMetrics.RuntimeOutboundThroughputBps, + + InvokeRequestReadTimeNs: failureMsg.InvokeMetrics.InvokeRequestReadTimeNs, + InvokeRequestSizeBytes: failureMsg.InvokeMetrics.InvokeRequestSizeBytes, + RuntimeReadyTime: failureMsg.InvokeMetrics.RuntimeReadyTime, + + ExtensionNames: failureMsg.ExtensionNames, + LogsAPIMetrics: failureMsg.LogsAPIMetrics, + }, + } +} diff --git a/lambda/rapidcore/server_test.go b/lambda/rapidcore/server_test.go index 416304c..88eea3f 100644 --- a/lambda/rapidcore/server_test.go +++ b/lambda/rapidcore/server_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore/env" ) func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { @@ -26,15 +27,68 @@ func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { } } +func sendInitStartedResponse(responseChannel chan<- interop.InitStarted, msg interop.InitStarted) { + msg.Ack = make(chan struct{}) + responseChannel <- msg + <-msg.Ack +} + +func sendInitSuccessResponse(responseChannel chan<- interop.InitSuccess, msg interop.InitSuccess) { + msg.Ack = make(chan struct{}) + responseChannel <- msg + <-msg.Ack +} + +func sendInitFailureResponse(responseChannel chan<- interop.InitFailure, msg interop.InitFailure) { + msg.Ack = make(chan struct{}) + responseChannel <- msg + <-msg.Ack +} + +type mockRapidCtx struct { + initHandler func(start chan<- interop.InitStarted, success chan<- interop.InitSuccess, fail chan<- interop.InitFailure) + invokeHandler func() (interop.InvokeSuccess, *interop.InvokeFailure) + resetHandler func() (interop.ResetSuccess, *interop.ResetFailure) +} + +func (r *mockRapidCtx) HandleInit(init *interop.Init, startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + r.initHandler(startResp, successResp, failureResp) +} + +func (r *mockRapidCtx) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { + return r.invokeHandler() +} + +func (r *mockRapidCtx) HandleReset(reset *interop.Reset, invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { + return r.resetHandler() +} + +func (r *mockRapidCtx) HandleShutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { + return interop.ShutdownSuccess{} +} + +func (r *mockRapidCtx) HandleRestore(restore *interop.Restore) error { + return nil +} + +func (r *mockRapidCtx) Clear() {} + func TestReserveDoesNotDeadlockWhenCalledMultipleTimes(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { <-srv.StartChan() }() - go srv.SendRunning(&interop.Running{}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - go srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"}) _, err := srv.Reserve("", "", "") // reserve successfully require.NoError(t, err) @@ -61,89 +115,120 @@ func TestInitSuccess(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) - }() + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) _, err := srv.Reserve("", "", "") require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) } func TestInitErrorBeforeReserve(t *testing.T) { + // Rapid thread sending init failure should not be blocked even if reserve hasn't arrived srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) initErrorResponseSent := make(chan error) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) initErrorResponseSent <- errors.New("initErrorResponseSent") - }() + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) if msg := waitForChanWithTimeout(initErrorResponseSent, 1*time.Second); msg == nil { require.Fail(t, "Timed out waiting for init error response to be sent") } resp, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) + require.NoError(t, err) require.True(t, len(resp.Token.InvokeID) > 0) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) + + _, err = srv.AwaitRelease() + require.Error(t, err, ErrReleaseReservationDone) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) } func TestInitErrorDuringReserve(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) - }() + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) resp, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) + require.NoError(t, err) require.True(t, len(resp.Token.InvokeID) > 0) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) + + _, err = srv.AwaitRelease() + require.Error(t, err, ErrReleaseReservationDone) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) } func TestInvokeSuccess(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) + releaseRuntimeInit := make(chan struct{}) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + <-releaseRuntimeInit + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } - <-srv.InvokeChan() - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "application/json", bytes.NewReader([]byte("response")))) + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader([]byte("response")), nil, nil)) require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + return interop.InvokeSuccess{}, nil + } - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) + releaseRuntimeInit <- struct{}{} _, err := srv.Reserve("", "", "") require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) // Reserve does not wait for init completion + + awaitInitErr := srv.AwaitInitialized() + require.NoError(t, awaitInitErr) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "response", responseRecorder.Body.String()) require.Equal(t, "application/json", responseRecorder.Result().Header.Get("Content-Type")) @@ -157,28 +242,35 @@ func TestInvokeError(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) - - <-srv.InvokeChan() + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }"), ContentType: "application/json"})) require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + return interop.InvokeSuccess{}, nil + } - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } + + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) _, err := srv.Reserve("", "", "") require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + awaitInitErr := srv.AwaitInitialized() + require.NoError(t, awaitInitErr) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) require.Equal(t, "application/json", responseRecorder.Result().Header.Get("Content-Type")) @@ -203,43 +295,49 @@ func TestInvokeWithSuppressedInitSuccess(t *testing.T) { srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) initErrorCompleted := make(chan error) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) initErrorCompleted <- errors.New("initErrorSequenceCompleted") + } - <-srv.ResetChan() - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"})) + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), nil, bytes.NewReader([]byte("response")), nil, nil)) + return interop.InvokeSuccess{}, nil + } - <-srv.InvokeChan() // run only after FastInvoke is called - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "", bytes.NewReader([]byte("response")))) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } + + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) if msg := waitForChanWithTimeout(initErrorCompleted, 1*time.Second); msg == nil { require.Fail(t, "Timed out waiting for init error sequence to be called") } - _, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + resp, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.True(t, len(resp.Token.InvokeID) > 0) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for suppressed init require.NoError(t, err) _, err = srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.NoError(t, err) responseRecorder := httptest.NewRecorder() successChan := make(chan error) go func() { directInvoke := false - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, directInvoke) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, directInvoke) require.NoError(t, invokeErr) successChan <- errors.New("invokeResponseWritten") }() @@ -261,39 +359,45 @@ func TestInvokeWithSuppressedInitErrorDueToInitError(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) + } + releaseChan := make(chan error) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "A.B"})) + releaseChan <- nil + return interop.InvokeSuccess{}, &interop.InvokeFailure{ErrorType: "A.B", RequestReset: true, DefaultErrorResponse: &interop.ErrorResponse{}} + } - <-srv.ResetChan() - srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"}) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } - <-srv.InvokeChan() - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - releaseChan <- nil - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "invokeCorrelationID", ErrorType: "A.B"})) - }() + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - _, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + resp, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.True(t, len(resp.Token.InvokeID) > 0) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for invoke with suppressed init require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) _, err = srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) require.Equal(t, phaseInvoking, srv.getRapidPhase()) @@ -310,39 +414,43 @@ func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "A.B"})) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) + } + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) + require.NoError(t, srv.SendRuntimeReady()) + return interop.InvokeSuccess{}, nil + } - <-srv.ResetChan() - srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"}) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } - <-srv.InvokeChan() - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) - require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) - _, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + resp, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.True(t, len(resp.Token.InvokeID) > 0) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for invoke with suppressed init require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) _, err = srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "{ 'errorType': 'B.C' }", responseRecorder.Body.String()) @@ -356,39 +464,43 @@ func TestMultipleInvokeSuccess(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) - }() - - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - - invokeFunc := func(i int) { - <-srv.InvokeChan() - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "", bytes.NewReader([]byte("response-"+fmt.Sprint(i))))) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + i := 0 + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), nil, bytes.NewReader([]byte("response-"+fmt.Sprint(i))), nil, nil)) require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) + i++ + return interop.InvokeSuccess{}, nil + } + + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil } - go func() { - for i := 0; i < 3; i++ { - invokeFunc(i) - } - }() + + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) for i := 0; i < 3; i++ { _, err := srv.Reserve("", "", "") - require.Contains(t, []error{nil, ErrInitAlreadyDone}, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) + require.NoError(t, err) + + awaitInitErr := srv.AwaitInitialized() + require.NoError(t, awaitInitErr) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "response-"+fmt.Sprint(i), responseRecorder.Body.String()) + require.Equal(t, phaseInvoking, srv.getRapidPhase()) _, err = srv.AwaitRelease() require.NoError(t, err) + require.Equal(t, phaseIdle, srv.getRapidPhase()) require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) } } diff --git a/lambda/rapidcore/standalone/directInvokeHandler.go b/lambda/rapidcore/standalone/directInvokeHandler.go index a485deb..1c7e7cb 100644 --- a/lambda/rapidcore/standalone/directInvokeHandler.go +++ b/lambda/rapidcore/standalone/directInvokeHandler.go @@ -4,13 +4,15 @@ package standalone import ( - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/core/directinvoke" "go.amzn.com/lambda/rapidcore" + "net/http" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/core/directinvoke" ) -func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { tok := s.CurrentToken() if tok == nil { log.Errorf("Attempt to call directInvoke without Reserve") @@ -24,6 +26,14 @@ func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.Int return } + if err := s.AwaitInitialized(); err != nil { + w.WriteHeader(DoneFailedHTTPCode) + if state, err := s.InternalState(); err == nil { + w.Write(state.AsJSON()) + } + return + } + if err := s.FastInvoke(w, invoke, true); err != nil { switch err { case rapidcore.ErrNotReserved: diff --git a/lambda/rapidcore/standalone/executeHandler.go b/lambda/rapidcore/standalone/executeHandler.go index 36c257a..9bac400 100644 --- a/lambda/rapidcore/standalone/executeHandler.go +++ b/lambda/rapidcore/standalone/executeHandler.go @@ -8,16 +8,17 @@ import ( log "github.com/sirupsen/logrus" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapidcore" ) -func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) { +func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.LambdaInvokeAPI) { invokePayload := &interop.Invoke{ - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: r.Body, - CorrelationID: "invokeCorrelationID", + TraceID: r.Header.Get("X-Amzn-Trace-Id"), + LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), + Payload: r.Body, + InvokeReceivedTime: metering.Monotime(), } // If we write to 'w' directly and waitUntilRelease fails, we won't be able to propagate error anymore @@ -38,17 +39,17 @@ func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) case rapidcore.ErrInvokeResponseAlreadyWritten: return - case rapidcore.ErrInvokeTimeout: + case rapidcore.ErrInvokeTimeout, rapidcore.ErrInitResetReceived: w.WriteHeader(http.StatusGatewayTimeout) // DONE failures: - case rapidcore.ErrTerminated, rapidcore.ErrInitDoneFailed, rapidcore.ErrInvokeDoneFailed: + case rapidcore.ErrInvokeDoneFailed: copyHeaders(invokeResp, w) w.WriteHeader(DoneFailedHTTPCode) w.Write(invokeResp.Body) return // Reservation canceled errors - case rapidcore.ErrReserveReservationDone, rapidcore.ErrInvokeReservationDone, rapidcore.ErrReleaseReservationDone: + case rapidcore.ErrReserveReservationDone, rapidcore.ErrInvokeReservationDone, rapidcore.ErrReleaseReservationDone, rapidcore.ErrInitNotStarted: w.WriteHeader(http.StatusGatewayTimeout) } diff --git a/lambda/rapidcore/standalone/initHandler.go b/lambda/rapidcore/standalone/initHandler.go index d006b81..d60ec6f 100644 --- a/lambda/rapidcore/standalone/initHandler.go +++ b/lambda/rapidcore/standalone/initHandler.go @@ -7,21 +7,33 @@ import ( "fmt" "net/http" "os" + "time" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" "go.amzn.com/lambda/rapidcore/env" ) +type RuntimeInfo struct { + ImageJSON string `json:"runtimeImageJSON,omitempty"` + Arn string `json:"runtimeArn,omitempty"` + Version string `json:"runtimeVersion,omitempty"` +} + // TODO: introduce suppress init flag type InitBody struct { - Handler string `json:"handler"` - FunctionName string `json:"functionName"` - FunctionVersion string `json:"functionVersion"` - InvokeTimeoutMs int64 `json:"invokeTimeoutMs"` + Handler string `json:"handler"` + FunctionName string `json:"functionName"` + FunctionVersion string `json:"functionVersion"` + InvokeTimeoutMs int64 `json:"invokeTimeoutMs"` + RuntimeInfo RuntimeInfo `json:"runtimeInfo"` Customer struct { Environment map[string]string `json:"environment"` } `json:"customer"` + AwsKey *string `json:"awskey"` + AwsSecret *string `json:"awssecret"` + AwsSession *string `json:"awssession"` + CredentialsExpiry time.Time `json:"credentialsExpiry"` + Throttled bool `json:"throttled"` } type InitRequest struct { @@ -44,7 +56,7 @@ func (c *InitBody) Validate() error { return nil } -func InitHandler(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) { +func InitHandler(w http.ResponseWriter, r *http.Request, sandbox InteropServer, bs interop.Bootstrap) { init := InitBody{} if lerr := readBodyAndUnmarshalJSON(r, &init); lerr != nil { lerr.Send(w, r) @@ -61,20 +73,54 @@ func InitHandler(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandb // logic consistent across standalone-mode and girp-mode os.Setenv(envKey, envVal) } - // TODO generate CorrelationID + + awsKey, awsSecret, awsSession := getCredentials(init) + + sandboxType := interop.SandboxClassic + + if init.Throttled { + sandboxType = interop.SandboxPreWarmed + } // pass to rapid sandbox.Init(&interop.Init{ Handler: init.Handler, - CorrelationID: "initCorrelationID", - AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), - AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), - AwsSession: os.Getenv("AWS_SESSION_TOKEN"), + AwsKey: awsKey, + AwsSecret: awsSecret, + AwsSession: awsSession, + CredentialsExpiry: init.CredentialsExpiry, XRayDaemonAddress: "0.0.0.0:0", // TODO FunctionName: init.FunctionName, FunctionVersion: init.FunctionVersion, - + RuntimeInfo: interop.RuntimeInfo{ + ImageJSON: init.RuntimeInfo.ImageJSON, + Arn: init.RuntimeInfo.Arn, + Version: init.RuntimeInfo.Version}, CustomerEnvironmentVariables: env.CustomerEnvironmentVariables(), + SandboxType: sandboxType, + Bootstrap: bs, + EnvironmentVariables: env.NewEnvironment(), }, init.InvokeTimeoutMs) +} + +func getCredentials(init InitBody) (string, string, string) { + // ToDo(guvfatih): I think instead of passing and getting these credentials values via environment variables + // we need to make StandaloneTests passing these via the Init request to be compliant with the existing protocol. + awsKey := os.Getenv("AWS_ACCESS_KEY_ID") + awsSecret := os.Getenv("AWS_SECRET_ACCESS_KEY") + awsSession := os.Getenv("AWS_SESSION_TOKEN") + + if init.AwsKey != nil { + awsKey = *init.AwsKey + } + + if init.AwsSecret != nil { + awsSecret = *init.AwsSecret + } + + if init.AwsSession != nil { + awsSession = *init.AwsSession + } + return awsKey, awsSecret, awsSession } diff --git a/lambda/rapidcore/standalone/internalStateHandler.go b/lambda/rapidcore/standalone/internalStateHandler.go index ff1335a..cb40c1c 100644 --- a/lambda/rapidcore/standalone/internalStateHandler.go +++ b/lambda/rapidcore/standalone/internalStateHandler.go @@ -5,11 +5,9 @@ package standalone import ( "net/http" - - "go.amzn.com/lambda/rapidcore" ) -func InternalStateHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func InternalStateHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { state, err := s.InternalState() if err != nil { http.Error(w, "internal state callback not set", http.StatusInternalServerError) diff --git a/lambda/rapidcore/standalone/invokeHandler.go b/lambda/rapidcore/standalone/invokeHandler.go index 0d89f1c..3e9768c 100644 --- a/lambda/rapidcore/standalone/invokeHandler.go +++ b/lambda/rapidcore/standalone/invokeHandler.go @@ -14,7 +14,7 @@ import ( log "github.com/sirupsen/logrus" ) -func InvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func InvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { tok := s.CurrentToken() if tok == nil { log.Errorf("Attempt to call directInvoke without Reserve") @@ -22,28 +22,20 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropSe return } - isResyncReceivedFlag := false - - awsKey := r.Header.Get("ResyncAwsKey") - awsSecret := r.Header.Get("ResyncAwsSecret") - awsSession := r.Header.Get("ResyncAwsSession") - - if len(awsKey) > 0 && len(awsSecret) > 0 && len(awsSession) > 0 { - isResyncReceivedFlag = true + invokePayload := &interop.Invoke{ + TraceID: r.Header.Get("X-Amzn-Trace-Id"), + LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), + Payload: r.Body, + DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), + InvokeReceivedTime: metering.Monotime(), } - invokePayload := &interop.Invoke{ - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: r.Body, - CorrelationID: "invokeCorrelationID", - DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), - ResyncState: interop.Resync{ - IsResyncReceived: isResyncReceivedFlag, - AwsKey: awsKey, - AwsSecret: awsSecret, - AwsSession: awsSession, - }, + if err := s.AwaitInitialized(); err != nil { + w.WriteHeader(DoneFailedHTTPCode) + if state, err := s.InternalState(); err == nil { + w.Write(state.AsJSON()) + } + return } if err := s.FastInvoke(w, invokePayload, false); err != nil { diff --git a/lambda/rapidcore/standalone/pingHandler.go b/lambda/rapidcore/standalone/pingHandler.go new file mode 100644 index 0000000..c6cb021 --- /dev/null +++ b/lambda/rapidcore/standalone/pingHandler.go @@ -0,0 +1,12 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package standalone + +import ( + "net/http" +) + +func PingHandler(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("pong")) +} diff --git a/lambda/rapidcore/standalone/reserveHandler.go b/lambda/rapidcore/standalone/reserveHandler.go index d3e0b9f..52b51cd 100644 --- a/lambda/rapidcore/standalone/reserveHandler.go +++ b/lambda/rapidcore/standalone/reserveHandler.go @@ -24,24 +24,14 @@ func tokenToHeaders(w http.ResponseWriter, token interop.Token) { w.Header().Set(directinvoke.VersionIDHeader, token.VersionID) } -func ReserveHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func ReserveHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { reserveResp, err := s.Reserve("", r.Header.Get("X-Amzn-Trace-Id"), r.Header.Get("X-Amzn-Segment-Id")) if err != nil { switch err { - case rapidcore.ErrInitAlreadyDone: - // init already happened before, just provide internal state and return - tokenToHeaders(w, reserveResp.Token) - InternalStateHandler(w, r, s) case rapidcore.ErrReserveReservationDone: // TODO use http.StatusBadGateway w.WriteHeader(http.StatusGatewayTimeout) - case rapidcore.ErrInitDoneFailed, rapidcore.ErrInitError: - w.WriteHeader(DoneFailedHTTPCode) - w.Write(reserveResp.InternalState.AsJSON()) - case rapidcore.ErrTerminated: - w.WriteHeader(DoneFailedHTTPCode) - w.Write(reserveResp.InternalState.AsJSON()) default: log.Errorf("Failed to reserve: %s", err) w.WriteHeader(400) diff --git a/lambda/rapidcore/standalone/resetHandler.go b/lambda/rapidcore/standalone/resetHandler.go index 1a719ff..4f2ca2e 100644 --- a/lambda/rapidcore/standalone/resetHandler.go +++ b/lambda/rapidcore/standalone/resetHandler.go @@ -5,8 +5,6 @@ package standalone import ( "net/http" - - "go.amzn.com/lambda/rapidcore" ) type resetAPIRequest struct { @@ -14,7 +12,7 @@ type resetAPIRequest struct { TimeoutMs int64 `json:"timeoutMs"` } -func ResetHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func ResetHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { reset := resetAPIRequest{} if lerr := readBodyAndUnmarshalJSON(r, &reset); lerr != nil { lerr.Send(w, r) diff --git a/lambda/rapidcore/standalone/restoreHandler.go b/lambda/rapidcore/standalone/restoreHandler.go new file mode 100644 index 0000000..190b6d8 --- /dev/null +++ b/lambda/rapidcore/standalone/restoreHandler.go @@ -0,0 +1,41 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package standalone + +import ( + "net/http" + "time" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" +) + +type RestoreBody struct { + AwsKey string `json:"awskey"` + AwsSecret string `json:"awssecret"` + AwsSession string `json:"awssession"` + CredentialsExpiry time.Time `json:"credentialsExpiry"` +} + +func RestoreHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { + restoreRequest := RestoreBody{} + if lerr := readBodyAndUnmarshalJSON(r, &restoreRequest); lerr != nil { + lerr.Send(w, r) + return + } + + restore := &interop.Restore{ + AwsKey: restoreRequest.AwsKey, + AwsSecret: restoreRequest.AwsSecret, + AwsSession: restoreRequest.AwsSession, + CredentialsExpiry: restoreRequest.CredentialsExpiry, + } + + err := s.Restore(restore) + + if err != nil { + log.Errorf("Failed to restore: %s", err) + w.WriteHeader(http.StatusBadGateway) + } +} diff --git a/lambda/rapidcore/standalone/router.go b/lambda/rapidcore/standalone/router.go index 5a4ae7c..f1712ea 100644 --- a/lambda/rapidcore/standalone/router.go +++ b/lambda/rapidcore/standalone/router.go @@ -7,18 +7,35 @@ import ( "context" "net/http" + "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" "go.amzn.com/lambda/rapidcore/telemetry" "github.com/go-chi/chi" ) -func NewHTTPRouter(sandbox rapidcore.Sandbox, eventLog *telemetry.EventLog, shutdownFunc context.CancelFunc) *chi.Mux { - ipcSrv := sandbox.InteropServer() +type InteropServer interface { + Init(i *interop.Init, invokeTimeoutMs int64) error + AwaitInitialized() error + FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error + Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) + Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) + AwaitRelease() (*statejson.InternalStateDescription, error) + Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription + InternalState() (*statejson.InternalStateDescription, error) + CurrentToken() *interop.Token + Restore(restore *interop.Restore) error +} + +func NewHTTPRouter(ipcSrv InteropServer, lambdaInvokeAPI rapidcore.LambdaInvokeAPI, eventLog *telemetry.EventLog, shutdownFunc context.CancelFunc, bs interop.Bootstrap) *chi.Mux { r := chi.NewRouter() r.Use(standaloneAccessLogDecorator) - r.Post("/2015-03-31/functions/*/invocations", func(w http.ResponseWriter, r *http.Request) { Execute(w, r, sandbox) }) - r.Post("/test/init", func(w http.ResponseWriter, r *http.Request) { InitHandler(w, r, sandbox) }) + + r.Post("/2015-03-31/functions/*/invocations", func(w http.ResponseWriter, r *http.Request) { Execute(w, r, lambdaInvokeAPI) }) + r.Get("/test/ping", func(w http.ResponseWriter, r *http.Request) { PingHandler(w, r) }) + r.Post("/test/init", func(w http.ResponseWriter, r *http.Request) { InitHandler(w, r, ipcSrv, bs) }) + r.Post("/test/waitUntilInitialized", func(w http.ResponseWriter, r *http.Request) { WaitUntilInitializedHandler(w, r, ipcSrv) }) r.Post("/test/reserve", func(w http.ResponseWriter, r *http.Request) { ReserveHandler(w, r, ipcSrv) }) r.Post("/test/invoke", func(w http.ResponseWriter, r *http.Request) { InvokeHandler(w, r, ipcSrv) }) r.Post("/test/waitUntilRelease", func(w http.ResponseWriter, r *http.Request) { WaitUntilReleaseHandler(w, r, ipcSrv) }) @@ -27,6 +44,6 @@ func NewHTTPRouter(sandbox rapidcore.Sandbox, eventLog *telemetry.EventLog, shut r.Post("/test/directInvoke/{reservationtoken}", func(w http.ResponseWriter, r *http.Request) { DirectInvokeHandler(w, r, ipcSrv) }) r.Get("/test/internalState", func(w http.ResponseWriter, r *http.Request) { InternalStateHandler(w, r, ipcSrv) }) r.Get("/test/eventLog", func(w http.ResponseWriter, r *http.Request) { EventLogHandler(w, r, eventLog) }) - + r.Post("/test/restore", func(w http.ResponseWriter, r *http.Request) { RestoreHandler(w, r, ipcSrv) }) return r } diff --git a/lambda/rapidcore/standalone/shutdownHandler.go b/lambda/rapidcore/standalone/shutdownHandler.go index ee91277..8085541 100644 --- a/lambda/rapidcore/standalone/shutdownHandler.go +++ b/lambda/rapidcore/standalone/shutdownHandler.go @@ -9,14 +9,13 @@ import ( "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapidcore" ) type shutdownAPIRequest struct { TimeoutMs int64 `json:"timeoutMs"` } -func ShutdownHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer, shutdownFunc context.CancelFunc) { +func ShutdownHandler(w http.ResponseWriter, r *http.Request, s InteropServer, shutdownFunc context.CancelFunc) { shutdown := shutdownAPIRequest{} if lerr := readBodyAndUnmarshalJSON(r, &shutdown); lerr != nil { lerr.Send(w, r) @@ -24,8 +23,7 @@ func ShutdownHandler(w http.ResponseWriter, r *http.Request, s rapidcore.Interop } internalState := s.Shutdown(&interop.Shutdown{ - DeadlineNs: metering.Monotime() + int64(shutdown.TimeoutMs*1000*1000), - CorrelationID: "shutdownCorrelationID", + DeadlineNs: metering.Monotime() + int64(shutdown.TimeoutMs*1000*1000), }) w.Write(internalState.AsJSON()) diff --git a/lambda/rapidcore/standalone/util.go b/lambda/rapidcore/standalone/util.go index 21ee08f..7ba7420 100644 --- a/lambda/rapidcore/standalone/util.go +++ b/lambda/rapidcore/standalone/util.go @@ -6,7 +6,7 @@ package standalone import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" log "github.com/sirupsen/logrus" @@ -58,7 +58,7 @@ func (w *ResponseWriterProxy) IsError() bool { } func readBodyAndUnmarshalJSON(r *http.Request, dst interface{}) *ErrorReply { - bodyBytes, err := ioutil.ReadAll(r.Body) + bodyBytes, err := io.ReadAll(r.Body) if err != nil { return newErrorReply(ClientInvalidRequest, fmt.Sprintf("Failed to read full body: %s", err)) } diff --git a/lambda/rapidcore/standalone/waitUntilInitializedHandler.go b/lambda/rapidcore/standalone/waitUntilInitializedHandler.go new file mode 100644 index 0000000..95d64ac --- /dev/null +++ b/lambda/rapidcore/standalone/waitUntilInitializedHandler.go @@ -0,0 +1,23 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package standalone + +import ( + "net/http" + + "go.amzn.com/lambda/rapidcore" +) + +func WaitUntilInitializedHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { + err := s.AwaitInitialized() + if err != nil { + switch err { + case rapidcore.ErrInitDoneFailed: + w.WriteHeader(DoneFailedHTTPCode) + case rapidcore.ErrInitResetReceived: + w.WriteHeader(DoneFailedHTTPCode) + } + } + w.WriteHeader(http.StatusOK) +} diff --git a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go index 9aab644..0a756dd 100644 --- a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go +++ b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go @@ -9,7 +9,7 @@ import ( "go.amzn.com/lambda/rapidcore" ) -func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { internalState, err := s.AwaitRelease() if err != nil { switch err { @@ -20,7 +20,7 @@ func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s rapidcore // TODO use http.StatusOK w.WriteHeader(http.StatusGatewayTimeout) return - case rapidcore.ErrTerminated: + case rapidcore.ErrInitDoneFailed: w.WriteHeader(DoneFailedHTTPCode) w.Write(internalState.AsJSON()) return diff --git a/lambda/rapidcore/telemetry/eventLog.go b/lambda/rapidcore/telemetry/eventLog.go index c66672c..2f809fa 100644 --- a/lambda/rapidcore/telemetry/eventLog.go +++ b/lambda/rapidcore/telemetry/eventLog.go @@ -9,6 +9,8 @@ import ( "time" ) +// TODO: Refactor to represent event structs below as a form of Events API entity + type XrayEvent struct { Msg string `json:"msg"` TraceID string `json:"traceID"` @@ -32,20 +34,13 @@ type FunctionLogEvent struct{} type ExtensionLogEvent struct{} type EventLog struct { + Events []SandboxEvent `json:"events,omitempty"` // populated by the StandaloneEventLog object Xray []XrayEvent `json:"xray,omitempty"` PlatformLog []PlatformLogEvent `json:"platformLogs,omitempty"` Logs []string `json:"rawLogs,omitempty"` mutex sync.Mutex } -func (p *EventLog) LogXrayEvent(msg string, traceID string, segmentName string, segmentID string) { - p.Xray = append(p.Xray, XrayEvent{Msg: msg, TraceID: traceID, SegmentName: segmentName, SegmentID: segmentID, Timestamp: time.Now().UnixNano() / int64(time.Millisecond)}) -} - -func (p *EventLog) LogExtensionInitEvent(agentName string, state string, subscriptions string, errorType string) { - p.PlatformLog = append(p.PlatformLog, PlatformLogEvent{agentName, state, errorType, strings.Split(subscriptions, ",")}) -} - func parseLogString(s string) []string { elems := strings.Split(s, "\t")[1:] for i, e := range elems { @@ -62,19 +57,7 @@ func (p *EventLog) dispatchLogEvent(logStr string) { if strings.HasPrefix(logStr, "XRAY") { // format: 'XRAY\tMessage: %s\tTraceID: %s\tSegmentName: %s\tSegmentID: %s' msg, traceID, segmentName, segmentID := elems[0], elems[1], elems[2], elems[3] - p.LogXrayEvent(msg, traceID, segmentName, segmentID) - } - - if strings.HasPrefix(logStr, "EXTENSION") && strings.Contains(logStr, "Error Type") { - // format: 'EXTENSION\tName: %s\tState: %s\tEvents: [%s]\tError Type: %s' - agentName, state, subscriptions, errorType := elems[0], elems[1], elems[2], elems[3] - p.LogExtensionInitEvent(agentName, state, subscriptions, errorType) - } - - if strings.HasPrefix(logStr, "EXTENSION") && !strings.Contains(logStr, "Error Type") { - // format: 'EXTENSION\tName: %s\tState: %s\tEvents: [%s]' - agentName, state, subscriptions, errorType := elems[0], elems[1], elems[2], "" - p.LogExtensionInitEvent(agentName, state, subscriptions, errorType) + p.Xray = append(p.Xray, XrayEvent{Msg: msg, TraceID: traceID, SegmentName: segmentName, SegmentID: segmentID, Timestamp: time.Now().UnixNano() / int64(time.Millisecond)}) } } diff --git a/lambda/rapidcore/telemetry/events_api.go b/lambda/rapidcore/telemetry/events_api.go new file mode 100644 index 0000000..7a882fd --- /dev/null +++ b/lambda/rapidcore/telemetry/events_api.go @@ -0,0 +1,97 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "sort" + "time" + + "go.amzn.com/lambda/telemetry" +) + +// EventType indicates the type of SandboxEvent. See full list: +type EventType = string + +const ( + PlatformInitRuntimeDone = EventType("platform.initRuntimeDone") + PlatformRestoreRuntimeDone = EventType("platform.restoreRuntimeDone") + PlatformRuntimeDone = EventType("platform.runtimeDone") + PlatformExtension = EventType("platform.extension") +) + +/* + SandboxEvent represents a generic sandbox event. For example: + {'time': '2021-03-16T13:10:42.358Z', + 'type': 'platform.extension', + 'record': { "name": "foo bar", "state": "Ready", "events": ["INVOKE", "SHUTDOWN"]}} +*/ +type SandboxEvent struct { + Time string `json:"time"` + Type EventType `json:"type"` + Record map[string]interface{} `json:"record"` +} + +type StandaloneEventLog struct { + requestID string + eventLog *EventLog +} + +func (s *StandaloneEventLog) SetCurrentRequestID(requestID string) { + s.requestID = requestID +} + +func (s *StandaloneEventLog) SendInitRuntimeDone(data *telemetry.InitRuntimeDoneData) error { + record := map[string]interface{}{"initializationType": data.InitSource, "status": data.Status} + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformInitRuntimeDone, record}) + return nil +} + +func (s *StandaloneEventLog) SendRestoreRuntimeDone(status string) error { + record := map[string]interface{}{"status": status} + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformRestoreRuntimeDone, record}) + return nil +} + +func (s *StandaloneEventLog) SendRuntimeDone(data telemetry.InvokeRuntimeDoneData) error { + // e.g. 'record': {'requestId': '1506eb3053d148f3bb7ec0fabe6f8d91','status': 'success', 'metrics': {...}, 'tracing': {...}} + record := map[string]interface{}{ + "requestId": s.requestID, + "status": data.Status, + "metrics": data.Metrics, + "internalMetrics": data.InternalMetrics, + "spans": data.Spans, + } + + if data.Tracing != nil { + record["tracing"] = map[string]string{ + "spanId": data.Tracing.SpanID, + "type": string(data.Tracing.Type), + "value": data.Tracing.Value, + } + } + + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformRuntimeDone, record}) + return nil +} + +func (s *StandaloneEventLog) SendExtensionInit(agentName, state, errorType string, subscriptions []string) error { + // e.g. 'record': { "name": "", "state": "", errorType: "", events: [""] } + sort.Strings(subscriptions) + record := map[string]interface{}{"name": agentName, "state": state, "events": subscriptions} + if len(errorType) > 0 { + record["errorType"] = errorType + } + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformExtension, record}) + return nil +} + +func (s *StandaloneEventLog) SendImageErrorLog(logline string) { + // Called on bootstrap exec errors for OCI error modes, e.g. InvalidEntrypoint etc. +} + +func NewStandaloneEventLog(eventLog *EventLog) *StandaloneEventLog { + return &StandaloneEventLog{ + eventLog: eventLog, + } +} diff --git a/lambda/runtimecmd/runtime_command.go b/lambda/runtimecmd/runtime_command.go deleted file mode 100644 index adf7886..0000000 --- a/lambda/runtimecmd/runtime_command.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package runtimecmd - -import ( - "context" - "fmt" - "io" - "os" - "os/exec" - "path" - "syscall" -) - -// CustomRuntimeCmd wraps exec.Cmd -type CustomRuntimeCmd struct { - *exec.Cmd -} - -// NewCustomRuntimeCmd returns a new CustomRuntimeCmd -func NewCustomRuntimeCmd(ctx context.Context, bootstrapCmd []string, dir string, env []string, stdoutWriter io.Writer, stderrWriter io.Writer, extraFiles []*os.File) *CustomRuntimeCmd { - cmd := exec.CommandContext(ctx, bootstrapCmd[0], bootstrapCmd[1:]...) - cmd.Dir = dir - - cmd.Stdout = stdoutWriter - cmd.Stderr = stderrWriter - - cmd.Env = env - - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - - if len(extraFiles) > 0 { - cmd.ExtraFiles = extraFiles - } - - return &CustomRuntimeCmd{cmd} -} - -// Name returnes runtime executable name -func (cmd *CustomRuntimeCmd) Name() string { - return path.Base(cmd.Path) -} - -// Pid returns the pid of a started runtime process -func (cmd *CustomRuntimeCmd) Pid() int { - return cmd.Process.Pid -} - -// Wait waits for the started customer runtime process to exit -func (cmd *CustomRuntimeCmd) Wait() error { - if err := cmd.Cmd.Wait(); err != nil { - return fmt.Errorf("Runtime exited with error: %v", err) - } - - return fmt.Errorf("Runtime exited without providing a reason") -} diff --git a/lambda/runtimecmd/runtime_command_test.go b/lambda/runtimecmd/runtime_command_test.go deleted file mode 100644 index f99599d..0000000 --- a/lambda/runtimecmd/runtime_command_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package runtimecmd - -import ( - "context" - "errors" - "io/ioutil" - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestRuntimeCommandSetsEnvironmentVariables(t *testing.T) { - envVars := []string{"foo=1", "bar=2", "baz=3"} - - currentDir, err := os.Getwd() - assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) - - execCmdArgs := []string{"foobar"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) - - assert.ElementsMatch(t, envVars, runtimeCmd.Env) - assert.Equal(t, execCmdArgs, runtimeCmd.Args) -} - -func TestRuntimeCommandSetsCurrentWorkingDir(t *testing.T) { - envVars := []string{} - - currentDir, err := os.Getwd() - assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) - - execCmdArgs := []string{"foobar"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) - - assert.Equal(t, currentDir, runtimeCmd.Dir) -} - -func TestRuntimeCommandSetsMultipleArgs(t *testing.T) { - envVars := []string{} - - currentDir, err := os.Getwd() - assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) - - execCmdArgs := []string{"foobar", "--baz", "22"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) - - assert.Equal(t, execCmdArgs, runtimeCmd.Args) -} diff --git a/lambda/supervisor/local_supervisor.go b/lambda/supervisor/local_supervisor.go new file mode 100644 index 0000000..1174089 --- /dev/null +++ b/lambda/supervisor/local_supervisor.go @@ -0,0 +1,302 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package supervisor + +import ( + "errors" + "fmt" + "os/exec" + "sync" + "syscall" + "time" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/supervisor/model" +) + +// typecheck interface compliance +var _ model.SupervisorClient = (*LocalSupervisor)(nil) + +type process struct { + // pid of the running process + pid int + // channel that can be use to block + // while waiting on process termination. + termination chan struct{} +} + +type LocalSupervisor struct { + events chan model.Event + processMapLock sync.Mutex + processMap map[string]process +} + +func NewLocalSupervisor() model.Supervisor { + return model.Supervisor{ + SupervisorClient: &LocalSupervisor{ + events: make(chan model.Event), + processMap: make(map[string]process), + }, + OperatorConfig: model.DomainConfig{ + RootPath: "/", + }, + RuntimeConfig: model.DomainConfig{ + RootPath: "/", + }, + } +} + +func (*LocalSupervisor) Start(req *model.StartRequest) error { + return nil +} +func (*LocalSupervisor) Configure(req *model.ConfigureRequest) error { + return nil +} +func (s *LocalSupervisor) Exec(req *model.ExecRequest) error { + if req.Domain != "runtime" { + log.Debug("Exec is a no op if domain != runtime") + return nil + } + command := exec.Command(req.Path, req.Args...) + + if req.Env != nil { + envStrings := make([]string, 0, len(*req.Env)) + for key, value := range *req.Env { + envStrings = append(envStrings, key+"="+value) + } + command.Env = envStrings + } + + if req.Cwd != nil && *req.Cwd != "" { + command.Dir = *req.Cwd + } + + if req.ExtraFiles != nil { + command.ExtraFiles = *req.ExtraFiles + } + + command.Stdout = req.StdoutWriter + command.Stderr = req.StderrWriter + + command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + err := command.Start() + + if err != nil { + return err + // TODO Use supevisor specific error + } + + pid := command.Process.Pid + termination := make(chan struct{}) + s.processMapLock.Lock() + s.processMap[req.Name] = process{ + pid: pid, + termination: termination, + } + s.processMapLock.Unlock() + + go func() { + err = command.Wait() + // close the termination channel to unblock whoever's blocked on + // it (used to implement kill's blocking behaviour) + close(termination) + + var cell int32 + var exitStatus *int32 + var signo *int32 + var exitErr *exec.ExitError + + if err == nil { + exitStatus = &cell + } else if errors.As(err, &exitErr) { + if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { + if code := status.ExitStatus(); code >= 0 { + cell = int32(code) + exitStatus = &cell + } else { + cell = int32(status.Signal()) + signo = &cell + } + } + } + + if signo == nil && exitStatus == nil { + log.Error("Cannot convert process exit status to unix WaitStatus. This is unexpected. Assuming ExitStatus 1") + cell = 1 + exitStatus = &cell + } + s.events <- model.Event{ + Time: uint64(time.Now().UnixMilli()), + Event: model.EventData{ + Domain: &req.Domain, + Name: &req.Name, + Signo: signo, + ExitStatus: exitStatus, + }, + } + }() + + return nil +} + +func kill(p process, name string, timeout *time.Duration) error { + // kill should report success if the process terminated by the time + //supervisor receives the request. + select { + // ifthis case is selected, the channel is closed, + // which means the process is terminated + case <-p.termination: + log.Debugf("Process %s already terminated.", name) + return nil + default: + log.Infof("Sending SIGKILL to %s(%d).", name, p.pid) + } + + if timeout != nil && *timeout <= 0 { + return fmt.Errorf("Timed out while trying to SIGKILL %s", name) + } + + pgid, err := syscall.Getpgid(p.pid) + + if err == nil { + // Negative pid sends signal to all in process group + syscall.Kill(-pgid, syscall.SIGKILL) + } else { + syscall.Kill(p.pid, syscall.SIGKILL) + } + + // the nil channel blocks forever + var timer <-chan time.Time + if timeout != nil { + timer = time.After(*timeout) + } + + // block until the (main) process exits + // or the timeout fires + select { + case <-p.termination: + return nil + case <-timer: + return fmt.Errorf("Timed out while trying to SIGKILL %s", name) + } +} + +func (s *LocalSupervisor) Kill(req *model.KillRequest) error { + if req.Domain != "runtime" { + log.Debug("Kill is a no op if domain != runtime") + return nil + } + s.processMapLock.Lock() + process, ok := s.processMap[req.Name] + s.processMapLock.Unlock() + if !ok { + msg := "Unknown process" + return &model.SupervisorError{ + Kind: model.NoSuchEntity, + Message: &msg, + } + } + timeout := convertTimeout(req.Timeout) + + return kill(process, req.Name, timeout) +} + +func (s *LocalSupervisor) Terminate(req *model.TerminateRequest) error { + if req.Domain != "runtime" { + log.Debug("Terminate is no op if domain != runtime") + return nil + } + s.processMapLock.Lock() + process, ok := s.processMap[req.Name] + pid := process.pid + s.processMapLock.Unlock() + if !ok { + msg := "Unknown process" + err := &model.SupervisorError{ + Kind: model.NoSuchEntity, + Message: &msg, + } + log.WithError(err).Errorf("Process %s not found in local supervisor map", req.Name) + return err + } + + pgid, err := syscall.Getpgid(pid) + + if err == nil { + // Negative pid sends signal to all in process group + // best effort, ignore errors + _ = syscall.Kill(-pgid, syscall.SIGTERM) + } else { + _ = syscall.Kill(pid, syscall.SIGTERM) + } + + return nil +} + +func (s *LocalSupervisor) Stop(req *model.StopRequest) error { + if req.Domain != "runtime" { + log.Debug("Shutdown is no op if domain != runtime") + return nil + } + timeout := convertTimeout(req.Timeout) + + // shut down kills all the processes in the map + s.processMapLock.Lock() + defer s.processMapLock.Unlock() + + nprocs := len(s.processMap) + + successes := make(chan struct{}) + errors := make(chan error) + for name, proc := range s.processMap { + go func(n string, p process) { + log.Debugf("Killing %s", n) + err := kill(p, n, timeout) + if err != nil { + errors <- err + } else { + successes <- struct{}{} + } + + }(name, proc) + } + + var err error + for i := 0; i < nprocs; i++ { + select { + case <-successes: + case e := <-errors: + if err == nil { + err = fmt.Errorf("Shutdown failed: %s", e.Error()) + } + } + + } + + s.processMap = make(map[string]process) + return err +} +func (*LocalSupervisor) Freeze(req *model.FreezeRequest) error { + return nil +} +func (*LocalSupervisor) Thaw(req *model.ThawRequest) error { + return nil +} +func (s *LocalSupervisor) Ping() error { + return nil +} + +func (s *LocalSupervisor) Events() (<-chan model.Event, error) { + return s.events, nil +} + +func convertTimeout(millis *uint64) *time.Duration { + var timeout *time.Duration + if millis != nil { + t := time.Duration(*millis) * time.Millisecond + timeout = &t + } + return timeout +} diff --git a/lambda/supervisor/local_supervisor_test.go b/lambda/supervisor/local_supervisor_test.go new file mode 100644 index 0000000..8b3336b --- /dev/null +++ b/lambda/supervisor/local_supervisor_test.go @@ -0,0 +1,215 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package supervisor + +import ( + "errors" + "fmt" + "syscall" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/supervisor/model" +) + +func TestRuntimeDomainExec(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + + assert.Nil(t, err) +} + +func TestInvalidRuntimeDomainExec(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/none", + }) + + require.Error(t, err) +} + +func TestEvents(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + sync := make(chan struct{}) + go func() { + evt, ok := <-client.events + require.True(t, ok) + termination := evt.Event.ProcessTerminated() + require.NotNil(t, termination) + assert.Equal(t, "runtime", *termination.Domain) + assert.Equal(t, "agent", *termination.Name) + sync <- struct{}{} + }() + + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + <-sync +} + +func TestTerminate(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = supv.Terminate(&model.TerminateRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err) + // wait for process exit notification + ev := <-client.events + require.NotNil(t, ev.Event.ProcessTerminated()) + term := *ev.Event.ProcessTerminated() + require.Nil(t, term.Exited()) + require.NotNil(t, term.Signaled()) + require.EqualValues(t, syscall.SIGTERM, *term.Signo) +} + +// Termiante should not fail if the message is not delivered +func TestTerminateExited(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + // wait a short bit for bash to exit + time.Sleep(100 * time.Millisecond) + err = supv.Terminate(&model.TerminateRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err) +} + +func TestKill(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + err = supv.Kill(&model.KillRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err) + timer := time.NewTimer(50 * time.Millisecond) + select { + case _, ok := <-client.events: + assert.True(t, ok) + case <-timer.C: + require.Fail(t, "Process should have exited by the time kill returns") + } +} + +func TestKillExited(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + //wait for natural exit event + <-client.events + err = supv.Kill(&model.KillRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err, "Kill should succeed for exited processes") +} + +func TestKillUnknown(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Kill(&model.KillRequest{ + Domain: "runtime", + Name: "unknown", + }) + require.Error(t, err) + var supvError *model.SupervisorError + assert.True(t, errors.As(err, &supvError)) + assert.Equal(t, supvError.Kind, model.NoSuchEntity) +} + +func TestShutdown(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + log.Debug("hello") + // start a bunch of processes, some short running, some longer running + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent-0", + Path: "/bin/bash", + Args: []string{"-c", "sleep 1s"}, + }) + require.NoError(t, err) + + err = supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent-1", + Path: "/bin/bash", + }) + require.NoError(t, err) + + err = supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent-2", + Path: "/bin/bash", + Args: []string{"-c", "sleep 2s"}, + }) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = supv.Stop(&model.StopRequest{ + Domain: "runtime", + }) + require.NoError(t, err) + // Shutdown is expected to block untill all processes have exited + expected := map[string]struct{}{ + "agent-0": {}, + "agent-1": {}, + "agent-2": {}, + } + done := false + timer := time.NewTimer(200 * time.Millisecond) + for !done { + select { + case ev := <-client.events: + data := ev.Event.ProcessTerminated() + assert.NotNil(t, data) + _, ok := expected[*data.Name] + assert.True(t, ok) + delete(expected, *data.Name) + case <-timer.C: + fmt.Print(expected) + assert.Equal(t, 0, len(expected), "All process should terminate at shutdown") + done = true + } + } +} diff --git a/lambda/supervisor/model/model.go b/lambda/supervisor/model/model.go new file mode 100644 index 0000000..384726d --- /dev/null +++ b/lambda/supervisor/model/model.go @@ -0,0 +1,269 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "fmt" + "io" + "os" + "syscall" +) + +type Supervisor struct { + SupervisorClient + OperatorConfig DomainConfig + RuntimeConfig DomainConfig +} + +type DomainConfig struct { + // path to the root of the domain within the root mnt namespace + RootPath string +} + +type SupervisorClient interface { + Start(req *StartRequest) error + Configure(req *ConfigureRequest) error + Exec(req *ExecRequest) error + Terminate(req *TerminateRequest) error + Kill(req *KillRequest) error + Stop(req *StopRequest) error + Freeze(req *FreezeRequest) error + Thaw(req *ThawRequest) error + Ping() error + Events() (<-chan Event, error) +} + +type StartRequest struct { + Domain string `json:"domain"` + // name of the cgroup profile to start the domain in + CgroupProfile *string `json:"cgroup_profile,omitempty"` +} + +// Mount in lockhard::mnt is a Rust enum, an algebraic type, where each case has different set of fields. +// This models only the Mount::Drive case, the only one we need for now. +type DriveMount struct { + Source string `json:"source,omitempty"` + Destination string `json:"destination,omitempty"` + FsType string `json:"fs_type,omitempty"` + Options []string `json:"options,omitempty"` + Chowner []uint32 `json:"chowner,omitempty"` // array of two integers representing a tuple + Chmode uint32 `json:"chmode,omitempty"` + // Lockhard also expects a "type" field here, which in our case is constant, so we provide it upon serialization below +} + +// Adds the "type": "drive" to json +func (m *DriveMount) MarshalJSON() ([]byte, error) { + type driveMountAlias DriveMount + + return json.Marshal(&struct { + Type string `json:"type,omitempty"` + *driveMountAlias + }{ + Type: "drive", + driveMountAlias: (*driveMountAlias)(m), + }) +} + +type Capabilities struct { + Ambient []string `json:"ambient,omitempty"` + Bounding []string `json:"bounding,omitempty"` + Effective []string `json:"effective,omitempty"` + Inheritable []string `json:"inheritable,omitempty"` + Permitted []string `json:"permitted,omitempty"` +} + +type CgroupProfile struct { + Name string `json:"name"` + CPUPct *float64 `json:"cpu_pct,omitempty"` + MemMaxBytes *uint64 `json:"mem_max,omitempty"` +} + +type ExecUser struct { + UID *uint32 `json:"uid"` + GID *uint32 `json:"gid"` +} + +type ConfigureRequest struct { + // domain to configure + Domain string `json:"domain"` + Mounts []DriveMount `json:"mounts,omitempty"` + Capabilities *Capabilities `json:"capabilities,omitempty"` + SeccompFilters []string `json:"seccomp_filters,omitempty"` + // list of cgroup profiles available for the domain + // cgroup profiles are set on boot or thaw requests + CgroupProfiles []CgroupProfile `json:"cgroup_profiles,omitempty"` + // uid and gid of the user the spawned process runs as (w.r.t. the domain user namespace). + // If nil, Supervisor will use the ExecUser specified in the domain configuration file + ExecUser *ExecUser `json:"exec_user,omitempty"` + // additional hooks to execute on domain start + AdditionalStartHooks []Hook `json:"additional_start_hooks,omitempty"` +} + +type Event struct { + Time uint64 `json:"timestamp_millis"` + Event EventData `json:"event"` +} + +// EventData is a union type tagged by the "EventType" +// and "Cause" strings. +// you can use ProcessTermination() or EventLoss() to access +// the correct type of Event. +type EventData struct { + EvType string `json:"type"` + Domain *string `json:"domain"` + Name *string `json:"name"` + Cause *string `json:"cause"` + Signo *int32 `json:"signo"` + ExitStatus *int32 `json:"exit_status"` + Size *uint64 `json:"size"` +} + +// returns nil if the event is not a EventLoss event +// otherwise returns how many events were lost due to +// backpressure (slow reader) +func (d EventData) EventLoss() *uint64 { + return d.Size +} + +// Returns a ProcessTermination struct that describe the process +// which terminated. Use Signaled() or Exited() to check whether +// the process terminated because of a signal or exited on its own +func (d EventData) ProcessTerminated() *ProcessTermination { + if d.Signo != nil || d.ExitStatus != nil { + return &ProcessTermination{ + Domain: d.Domain, + Name: d.Name, + Signo: d.Signo, + ExitStatus: d.ExitStatus, + } + } + return nil +} + +// Event signalling that a process exited +type ProcessTermination struct { + Domain *string + Name *string + Signo *int32 + ExitStatus *int32 +} + +// If not nil, the process was terminated by an unhandled signal. +// The returned value is the number of the signal that terminated the process +func (t ProcessTermination) Signaled() *int32 { + return t.Signo +} + +// It not nil, the process exited (as opposed to killed by a signal). +// The returned value is the exit_status returned by the process +func (t ProcessTermination) Exited() *int32 { + return t.ExitStatus +} + +func (t ProcessTermination) Success() bool { + return t.ExitStatus != nil && *t.ExitStatus == 0 +} + +// Transform the process termination status in a string that +// is equal to what would be returned by golang exec.ExitError.Error() +// We used to rely on this format to report errors to customer (sigh) +// so we keep this for backwards compatibility +func (t ProcessTermination) String() string { + if t.ExitStatus != nil { + return fmt.Sprintf("exit status %d", *t.ExitStatus) + } + sig := syscall.Signal(*t.Signo) + return fmt.Sprintf("signal: %s", sig.String()) +} + +type Hook struct { + // Unique name identifying the hook + Name string `json:"name"` + // Path in the parent domain mount namespace that locates + // the executable to run as the hook + Path string `json:"path"` + // Args for the hook + Args []string `json:"args,omitempty"` + // Map of ENV variables to set when running the hook + Env *map[string]string `json:"envs,omitempty"` + // Maximum time for the hook to run. The hook will be considered failed + // if it takes more than this value (default 10_000) + TimeoutMillis *uint64 `json:"timeout_millis,omitempty"` +} + +type ExecRequest struct { + // Identifier that Supervisor will assign to the spawned process. + // The tuple (Domain,Name) must be unique. It is the caller's responsibility + // to generate the unique name + Name string `json:"name"` + Domain string `json:"domain"` + // Path pointing to the exectuable file within the domain's root filesystem + Path string `json:"path"` + Args []string `json:"args,omitempty"` + // If nil, root of the domain + Cwd *string `json:"cwd,omitempty"` + Env *map[string]string `json:"env,omitempty"` + // If not nil, points to the socket that Supervisor + // uses to get the processes stdout and stderr. + LogsSock *string `json:"logs_sock,omitempty"` + StdoutWriter io.Writer `json:"-"` + StderrWriter io.Writer `json:"-"` + ExtraFiles *[]*os.File `json:"-"` +} + +type ErrorKind string + +const ( + // operation on an unkown entity (e.g., domain process) + NoSuchEntity ErrorKind = "no_such_entity" + // operation not allowed in the current state (e.g., tried to exec a proces in a domain which is not booted) + InvalidState ErrorKind = "invalid_state" + // Serialization or derserialization issue in the communication + Serde ErrorKind = "serde" + // Unhandled Supervisor server error + Failure ErrorKind = "failure" +) + +type SupervisorError struct { + Kind ErrorKind `json:"error_kind"` + Message *string `json:"message"` +} + +func (e *SupervisorError) Error() string { + return string(e.Kind) +} + +// Send SIGETERM asynchrnously to a process +type TerminateRequest struct { + Name string `json:"name"` + Domain string `json:"domain"` +} + +// Force terminate a process (SIGKILL) +// Block until process is exited or timeout +// If timeout is 0 or nil, block forever +type KillRequest struct { + Name string `json:"name"` + Domain string `json:"domain"` + Timeout *uint64 `json:",omitempty"` +} + +// Stop the domain. Supervisor will first try to +// cleanly terminate the domain's init process. If unsuccessful, +// within Timeout seconds, it will send SIGKILL. +type StopRequest struct { + Domain string `json:"domain"` + Timeout *uint64 `json:",omitempty"` +} + +type FreezeRequest struct { + Domain string `json:"domain"` +} + +type ThawRequest struct { + Domain string `json:"domain"` + // if not nil, changes the cgroup profile of the domain upon thawing. + CgroupProfile *string `json:"cgroup_profile,omitempty"` +} diff --git a/lambda/telemetry/events_api.go b/lambda/telemetry/events_api.go index 132977e..e7c5c36 100644 --- a/lambda/telemetry/events_api.go +++ b/lambda/telemetry/events_api.go @@ -3,12 +3,136 @@ package telemetry +import ( + "time" + + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" + "go.amzn.com/lambda/rapi/model" +) + +type RuntimeDoneInvokeMetrics struct { + ProducedBytes int64 + DurationMs float64 +} + +func GetRuntimeDoneInvokeMetrics(invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics, runtimeDoneTime int64) *RuntimeDoneInvokeMetrics { + if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && invokeReceivedTime != 0 { + return &RuntimeDoneInvokeMetrics{ + ProducedBytes: invokeResponseMetrics.ProducedBytes, + // time taken from sending the invoke to the sandbox until the runtime calls GET /next + DurationMs: float64((runtimeDoneTime - invokeReceivedTime) / int64(time.Millisecond)), + } + } + + // when we get a reset before runtime called /response + if invokeReceivedTime != 0 { + return &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64((runtimeDoneTime - invokeReceivedTime) / int64(time.Millisecond)), + } + } + + // We didn't have time to register the invokeReceiveTime, which means we crash/reset very early, + // too early for the runtime to actual run. In such case, the runtimeDone event shouldn't be sent + // Not returning Nil even in this improbable case guarantees that we will always have some metrics to send to FluxPump + return &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(0), + } +} + +type InitRuntimeDoneData struct { + InitSource string + Status string +} + +type InvokeRuntimeDoneData struct { + Status string + Metrics *RuntimeDoneInvokeMetrics + InternalMetrics *interop.InvokeResponseMetrics + Tracing *TracingCtx + Spans []Span +} + +type Span struct { + Name string + Start string + DurationMs float64 +} + +func GetRuntimeDoneSpans(invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) []Span { + if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && invokeReceivedTime != 0 { + // time span from when the invoke is received in the sandbox to the moment the runtime calls PUT /response + responseLatencyMsSpan := Span{ + Name: "responseLatency", + Start: getEpochTimeInISO8601FormatFromMonotime(invokeReceivedTime), + DurationMs: float64((invokeResponseMetrics.StartReadingResponseMonoTimeMs - invokeReceivedTime) / int64(time.Millisecond)), + } + + // time span from when the runtime called PUT /response to the moment the body of the response is fully sent + responseDurationMsSpan := Span{ + Name: "responseDuration", + Start: getEpochTimeInISO8601FormatFromMonotime(invokeResponseMetrics.StartReadingResponseMonoTimeMs), + DurationMs: float64((invokeResponseMetrics.FinishReadingResponseMonoTimeMs - invokeResponseMetrics.StartReadingResponseMonoTimeMs) / int64(time.Millisecond)), + } + return []Span{responseLatencyMsSpan, responseDurationMsSpan} + } + + return []Span{} +} + +func getEpochTimeInISO8601FormatFromMonotime(monotime int64) string { + return time.Unix(0, metering.MonoToEpoch(monotime)).Format("2006-01-02T15:04:05.000Z") +} + +type TracingCtx struct { + SpanID string + Type model.TracingType + Value string +} + +func BuildTracingCtx(tracingType model.TracingType, traceID string, lambdaSegmentID string) *TracingCtx { + // it takes current tracing context and change its parent value with the provided lambda segment id + root, currentParent, sample := ParseTraceID(traceID) + if root == "" || sample != model.XRaySampled { + return nil + } + + return &TracingCtx{ + SpanID: currentParent, + Type: tracingType, + Value: BuildFullTraceID(root, lambdaSegmentID, sample), + } +} + +const ( + RuntimeDoneSuccess = "success" + RuntimeDoneFailure = "failure" +) + type EventsAPI interface { SetCurrentRequestID(requestID string) - SendRuntimeDone(status string) error + SendInitRuntimeDone(data *InitRuntimeDoneData) error + SendRestoreRuntimeDone(status string) error + SendRuntimeDone(data InvokeRuntimeDoneData) error + SendExtensionInit(agentName, state, errorType string, subscriptions []string) error + SendImageErrorLog(logline string) } type NoOpEventsAPI struct{} func (s *NoOpEventsAPI) SetCurrentRequestID(requestID string) {} -func (s *NoOpEventsAPI) SendRuntimeDone(status string) error { return nil } +func (s *NoOpEventsAPI) SendInitRuntimeDone(data *InitRuntimeDoneData) error { + return nil +} +func (s *NoOpEventsAPI) SendRestoreRuntimeDone(status string) error { + return nil +} +func (s *NoOpEventsAPI) SendRuntimeDone(data InvokeRuntimeDoneData) error { + return nil +} +func (s *NoOpEventsAPI) SendExtensionInit(agentName, state, errorType string, subscriptions []string) error { + return nil +} +func (s *NoOpEventsAPI) SendImageErrorLog(logline string) {} diff --git a/lambda/telemetry/events_api_test.go b/lambda/telemetry/events_api_test.go new file mode 100644 index 0000000..b943be9 --- /dev/null +++ b/lambda/telemetry/events_api_test.go @@ -0,0 +1,139 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" +) + +func TestGetRuntimeDoneInvokeMetrics(t *testing.T) { + now := metering.Monotime() + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + ProducedBytes: int64(100), + RuntimeCalledResponse: true, + } + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(10), + } + + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, runtimeDoneTime)) +} + +func TestGetRuntimeDoneInvokeMetricsWhenRuntimeCalledError(t *testing.T) { + now := metering.Monotime() + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + ProducedBytes: int64(100), + RuntimeCalledResponse: false, + } + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(10), + } + + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, runtimeDoneTime)) +} + +func TestGetRuntimeDoneInvokeMetricsWhenInvokeReceivedTimeIsZero(t *testing.T) { + now := int64(0) // January 1st, 1970 at 00:00:00 UTC + invokeReceivedTime := now + + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(0), + } + actual := GetRuntimeDoneInvokeMetrics(invokeReceivedTime, nil, runtimeDoneTime) + assert.Equal(t, expected, actual) +} + +func TestGetRuntimeDoneInvokeMetricsWhenInvokeResponseMetricsIsNil(t *testing.T) { + now := metering.Monotime() + invokeReceivedTime := now + + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(10), + } + + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, nil, runtimeDoneTime)) +} + +func TestGetRuntimeDoneSpans(t *testing.T) { + now := metering.Monotime() + startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) + finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, + RuntimeCalledResponse: true, + } + + expectedResponseLatencyMsStartTime := getEpochTimeInISO8601FormatFromMonotime(now) + expectedResponseDurationMsStartTime := getEpochTimeInISO8601FormatFromMonotime(startReadingResponseMonoTimeMs) + expected := []Span{ + Span{ + Name: "responseLatency", + Start: expectedResponseLatencyMsStartTime, + DurationMs: 5, + }, + Span{ + Name: "responseDuration", + Start: expectedResponseDurationMsStartTime, + DurationMs: 2, + }, + } + + assert.Equal(t, expected, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) +} + +func TestGetRuntimeDoneSpansWhenRuntimeCalledError(t *testing.T) { + now := metering.Monotime() + startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) + finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, + RuntimeCalledResponse: false, + } + + assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) +} + +func TestGetRuntimeDoneSpansWhenInvokeResponseMetricsNil(t *testing.T) { + invokeReceivedTime := metering.Monotime() + + assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, nil)) +} + +func TestGetRuntimeDoneSpansWhenInvokeReceivedTimeIsZero(t *testing.T) { + now := int64(0) // January 1st, 1970 at 00:00:00 UTC + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + StartReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(5)), + FinishReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(7)), + } + + assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) +} diff --git a/lambda/telemetry/logs_egress_api.go b/lambda/telemetry/logs_egress_api.go index ac9a754..7e84fe2 100644 --- a/lambda/telemetry/logs_egress_api.go +++ b/lambda/telemetry/logs_egress_api.go @@ -8,7 +8,12 @@ import ( "os" ) -type LogsEgressAPI interface { +// StdLogsEgressAPI is the interface that wraps the basic methods required to setup +// logs channels for Runtime's stdout/stderr and Extension's stdout/stderr. +// +// Implementation should return a Writer implementor for stdout and another for +// stderr on success and an error on failure. +type StdLogsEgressAPI interface { GetExtensionSockets() (io.Writer, io.Writer, error) GetRuntimeSockets() (io.Writer, io.Writer, error) } diff --git a/lambda/telemetry/logs_subscription_api.go b/lambda/telemetry/logs_subscription_api.go index 3ea7a20..6ee9490 100644 --- a/lambda/telemetry/logs_subscription_api.go +++ b/lambda/telemetry/logs_subscription_api.go @@ -10,28 +10,37 @@ import ( "go.amzn.com/lambda/interop" ) -// LogsSubscriptionAPI represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible -type LogsSubscriptionAPI interface { +// SubscriptionAPI represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible +type SubscriptionAPI interface { Subscribe(agentName string, body io.Reader, headers map[string][]string) (resp []byte, status int, respHeaders map[string][]string, err error) RecordCounterMetric(metricName string, count int) - FlushMetrics() interop.LogsAPIMetrics + FlushMetrics() interop.TelemetrySubscriptionMetrics Clear() TurnOff() + GetEndpointURL() string + GetServiceClosedErrorMessage() string + GetServiceClosedErrorType() string } -type NoOpLogsSubscriptionAPI struct{} +type NoOpSubscriptionAPI struct{} // Subscribe writes response to a shared memory -func (m *NoOpLogsSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { +func (m *NoOpSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { return []byte(`{}`), http.StatusOK, map[string][]string{}, nil } -func (m *NoOpLogsSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} +func (m *NoOpSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} -func (m *NoOpLogsSubscriptionAPI) FlushMetrics() interop.LogsAPIMetrics { - return interop.LogsAPIMetrics(map[string]int{}) +func (m *NoOpSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { + return interop.TelemetrySubscriptionMetrics(map[string]int{}) } -func (m *NoOpLogsSubscriptionAPI) Clear() {} +func (m *NoOpSubscriptionAPI) Clear() {} -func (m *NoOpLogsSubscriptionAPI) TurnOff() {} +func (m *NoOpSubscriptionAPI) TurnOff() {} + +func (m *NoOpSubscriptionAPI) GetEndpointURL() string { return "" } + +func (m *NoOpSubscriptionAPI) GetServiceClosedErrorMessage() string { return "" } + +func (m *NoOpSubscriptionAPI) GetServiceClosedErrorType() string { return "" } diff --git a/lambda/telemetry/tracer.go b/lambda/telemetry/tracer.go index 1ac8325..affca60 100644 --- a/lambda/telemetry/tracer.go +++ b/lambda/telemetry/tracer.go @@ -11,6 +11,7 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi/model" ) type traceContextKey int @@ -129,3 +130,24 @@ func ParseTraceID(fullTraceID string) (rootID, parentID, sample string) { } return } + +// BuildFullTraceID takes individual components of X-Ray trace header +// and puts them together into a formatted trace header. +// If root is empty, returns an empty string. +func BuildFullTraceID(root, parent, sample string) string { + if root == "" { + return "" + } + + parts := make([]string, 0, 3) + parts = append(parts, "Root="+root) + if parent != "" { + parts = append(parts, "Parent="+parent) + } + if sample == "" { + sample = model.XRayNonSampled + } + parts = append(parts, "Sampled="+sample) + + return strings.Join(parts, ";") +} diff --git a/lambda/telemetry/tracer_test.go b/lambda/telemetry/tracer_test.go index 9ac1260..c31653f 100644 --- a/lambda/telemetry/tracer_test.go +++ b/lambda/telemetry/tracer_test.go @@ -5,6 +5,8 @@ package telemetry import ( "testing" + + "go.amzn.com/lambda/rapi/model" ) var parserTests = []struct { @@ -35,3 +37,47 @@ func TestParseTraceID(t *testing.T) { }) } } + +func TestBuildFullTraceID(t *testing.T) { + specs := map[string]struct { + root string + parent string + sample string + expectedTraceID string + }{ + "all non-empty components, sampled": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + parent: "c88d77b0aef840e9", + sample: model.XRaySampled, + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", + }, + "all non-empty components, non-sampled": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + parent: "c88d77b0aef840e9", + sample: model.XRayNonSampled, + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=0", + }, + "root is non-empty, parent and sample are empty": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Sampled=0", + }, + "root is empty": { + parent: "c88d77b0aef840e9", + expectedTraceID: "", + }, + "sample is empty": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + parent: "c88d77b0aef840e9", + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=0", + }, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + actual := BuildFullTraceID(spec.root, spec.parent, spec.sample) + if actual != spec.expectedTraceID { + t.Errorf("got %q, wanted %q", actual, spec.expectedTraceID) + } + }) + } +} diff --git a/lambda/testdata/agents/bash_stderr.sh b/lambda/testdata/agents/bash_stderr.sh deleted file mode 100755 index 65c0ff1..0000000 --- a/lambda/testdata/agents/bash_stderr.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash - -printf "stderr line 1\n" >&2 -printf "stderr line 2\n" >&2 -printf "stderr line 3\n" >&2 diff --git a/lambda/testdata/agents/bash_stdout.sh b/lambda/testdata/agents/bash_stdout.sh deleted file mode 100755 index d0cb893..0000000 --- a/lambda/testdata/agents/bash_stdout.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash - -printf "stdout line 1\n" -printf "stdout line 2\n" -printf "stdout line 3\n" diff --git a/lambda/testdata/agents/bash_stdout_and_stderr.sh b/lambda/testdata/agents/bash_stdout_and_stderr.sh deleted file mode 100755 index cf87e60..0000000 --- a/lambda/testdata/agents/bash_stdout_and_stderr.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env bash - -printf "stdout line 1\n" -printf "stderr line 1\n" >&2 -printf "stdout line 2\n" -printf "stderr line 2\n" >&2 -printf "stdout line 3\n" -printf "stderr line 3\n" >&2 diff --git a/lambda/testdata/flowtesting.go b/lambda/testdata/flowtesting.go index ee163bb..c028d7c 100644 --- a/lambda/testdata/flowtesting.go +++ b/lambda/testdata/flowtesting.go @@ -8,28 +8,31 @@ import ( "io" "io/ioutil" "net/http" + "time" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" - "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/rendering" "go.amzn.com/lambda/telemetry" "go.amzn.com/lambda/testdata/mockthread" ) +const ( + contentTypeHeader = "Content-Type" + functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" +) + type MockInteropServer struct { - Response []byte - ErrorResponse *interop.ErrorResponse - ResponseContentType string - ActiveInvokeID string + Response []byte + ErrorResponse *interop.ErrorResponse + ResponseContentType string + FunctionResponseMode string + ActiveInvokeID string } -// StartAcceptingDirectInvokes -func (i *MockInteropServer) StartAcceptingDirectInvokes() error { return nil } - // SendResponse writes response to a shared memory. -func (i *MockInteropServer) SendResponse(invokeID string, contentType string, reader io.Reader) error { +func (i *MockInteropServer) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { bytes, err := ioutil.ReadAll(reader) if err != nil { return err @@ -41,7 +44,8 @@ func (i *MockInteropServer) SendResponse(invokeID string, contentType string, re } } i.Response = bytes - i.ResponseContentType = contentType + i.ResponseContentType = headers[contentTypeHeader] + i.FunctionResponseMode = headers[functionResponseModeHeader] return nil } @@ -49,68 +53,35 @@ func (i *MockInteropServer) SendResponse(invokeID string, contentType string, re func (i *MockInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error { i.ErrorResponse = response i.ResponseContentType = response.ContentType + i.FunctionResponseMode = response.FunctionResponseMode return nil } -func (i *MockInteropServer) GetCurrentInvokeID() string { - return i.ActiveInvokeID +// SendInitErrorResponse writes error response during init to a shared memory and sends GIRD FAULT. +func (i *MockInteropServer) SendInitErrorResponse(invokeID string, response *interop.ErrorResponse) error { + i.ErrorResponse = response + i.ResponseContentType = response.ContentType + return nil } -func (i *MockInteropServer) CommitResponse() error { return nil } - -// SendRunning sends GIRD RUNNING. -func (i *MockInteropServer) SendRunning(*interop.Running) error { return nil } - -// SendDone sends GIRD DONE. -func (i *MockInteropServer) SendDone(*interop.Done) error { return nil } - -// SendDoneFail sends GIRD DONEFAIL. -func (i *MockInteropServer) SendDoneFail(*interop.DoneFail) error { return nil } - -// StartChan returns Start emitter -func (i *MockInteropServer) StartChan() <-chan *interop.Start { return nil } - -// InvokeChan returns Invoke emitter -func (i *MockInteropServer) InvokeChan() <-chan *interop.Invoke { return nil } - -// ResetChan returns Reset emitter -func (i *MockInteropServer) ResetChan() <-chan *interop.Reset { return nil } - -// ShutdownChan returns Shutdown emitter -func (i *MockInteropServer) ShutdownChan() <-chan *interop.Shutdown { return nil } - -// TransportErrorChan emits errors if there was parsing/connection issue -func (i *MockInteropServer) TransportErrorChan() <-chan error { return nil } - -func (i *MockInteropServer) Clear() {} - -func (i *MockInteropServer) IsResponseSent() bool { - return !(i.Response == nil && i.ErrorResponse == nil) +func (i *MockInteropServer) GetCurrentInvokeID() string { + return i.ActiveInvokeID } func (i *MockInteropServer) SendRuntimeReady() error { return nil } -func (i *MockInteropServer) SetInternalStateGetter(isd interop.InternalStateGetter) {} - -func (m *MockInteropServer) Init(i *interop.Start, invokeTimeoutMs int64) {} - -func (m *MockInteropServer) Invoke(w http.ResponseWriter, i *interop.Invoke) error { return nil } - -func (m *MockInteropServer) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { - return nil -} - // FlowTest provides configuration for tests that involve synchronization flows. type FlowTest struct { - AppCtx appctx.ApplicationContext - InitFlow core.InitFlowSynchronization - InvokeFlow core.InvokeFlowSynchronization - RegistrationService core.RegistrationService - RenderingService *rendering.EventRenderingService - Runtime *core.Runtime - InteropServer *MockInteropServer - LogsSubscriptionAPI *telemetry.NoOpLogsSubscriptionAPI - CredentialsService core.CredentialsService + AppCtx appctx.ApplicationContext + InitFlow core.InitFlowSynchronization + InvokeFlow core.InvokeFlowSynchronization + RegistrationService core.RegistrationService + RenderingService *rendering.EventRenderingService + Runtime *core.Runtime + InteropServer *MockInteropServer + TelemetrySubscription *telemetry.NoOpSubscriptionAPI + CredentialsService core.CredentialsService + EventsAPI telemetry.EventsAPI } // ConfigureForInit initialize synchronization gates and states for init. @@ -125,13 +96,13 @@ func (s *FlowTest) ConfigureForInvoke(ctx context.Context, invoke *interop.Invok s.RenderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invoke, telemetry.GetCustomerTracingHeader)) } -func (s *FlowTest) ConfigureForInitCaching(token, awsKey, awsSecret, awsSession string) { - s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession) +func (s *FlowTest) ConfigureForRestore() { + s.RenderingService.SetRenderer(rendering.NewRestoreRenderer()) } -func (s *FlowTest) ConfigureForBlockedInitCaching(token, awsKey, awsSecret, awsSession string) { - s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession) - s.CredentialsService.BlockService() +func (s *FlowTest) ConfigureForInitCaching(token, awsKey, awsSecret, awsSession string) { + credentialsExpiration := time.Now().Add(30 * time.Minute) + s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession, credentialsExpiration) } // NewFlowTest returns new FlowTest configuration. @@ -145,16 +116,18 @@ func NewFlowTest() *FlowTest { runtime := core.NewRuntime(initFlow, invokeFlow) runtime.ManagedThread = &mockthread.MockManagedThread{} interopServer := &MockInteropServer{} + eventsAPI := telemetry.NoOpEventsAPI{} appctx.StoreInteropServer(appCtx, interopServer) return &FlowTest{ - AppCtx: appCtx, - InitFlow: initFlow, - InvokeFlow: invokeFlow, - RegistrationService: registrationService, - RenderingService: renderingService, - LogsSubscriptionAPI: &telemetry.NoOpLogsSubscriptionAPI{}, - Runtime: runtime, - InteropServer: interopServer, - CredentialsService: credentialsService, + AppCtx: appCtx, + InitFlow: initFlow, + InvokeFlow: invokeFlow, + RegistrationService: registrationService, + RenderingService: renderingService, + TelemetrySubscription: &telemetry.NoOpSubscriptionAPI{}, + Runtime: runtime, + InteropServer: interopServer, + CredentialsService: credentialsService, + EventsAPI: &eventsAPI, } } diff --git a/test/integration/local_lambda/test_end_to_end.py b/test/integration/local_lambda/test_end_to_end.py index a85abce..c5c3e63 100644 --- a/test/integration/local_lambda/test_end_to_end.py +++ b/test/integration/local_lambda/test_end_to_end.py @@ -53,6 +53,7 @@ def tearDownClass(cls): "remaining_time_in_default_deadline", "pre-runtime-api", "assert-overwritten", + "port_override" ] for image in images_to_delete: @@ -264,6 +265,23 @@ def test_function_name_is_overriden(self, arch, port): ) self.assertEqual(b'"My lambda ran succesfully"', r.content) + @parameterized.expand([("x86_64", "8011"), ("arm64", "9011"), ("", "9061")]) + def test_port_override(self, arch, port): + image, rie, image_name = self.tagged_name("port_override", arch) + + # Use port 8081 inside the container instead of 8080 + cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8081 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler --runtime-interface-emulator-address 0.0.0.0:8081" + + Popen(cmd.split(" ")).communicate() + + # sleep 1s to give enough time for the endpoint to be up to curl + time.sleep(SLEEP_TIME) + + r = requests.post( + f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} + ) + self.assertEqual(b'"My lambda ran succesfully"', r.content) + if __name__ == "__main__": main()