From c26efc5a951863ae6d7d4e5c400fda2a34dc49b3 Mon Sep 17 00:00:00 2001 From: Eliott Bouhana <47679741+eliottness@users.noreply.github.com> Date: Thu, 12 Sep 2024 11:27:35 +0200 Subject: [PATCH] feat: WAF Run Scope support for RASP Metrics (#109) - [x] Add the `Scope` system to classfiy Waf Runs - [x] Refactor Stats and Metrics without breaking the current API - [x] Remove the `_dd.appsec` to all metrics for the futur when we support telemetry metrics --------- Signed-off-by: Eliott Bouhana --- context.go | 68 ++++++++++++++++++++++---------- handle.go | 14 ++++++- metrics.go | 111 +++++++++++++++++++++++++++++++++++----------------- waf_test.go | 57 +++++++++++++++++++++------ 4 files changed, 179 insertions(+), 71 deletions(-) diff --git a/context.go b/context.go index a6dc7677..509f7f0c 100644 --- a/context.go +++ b/context.go @@ -7,14 +7,13 @@ package waf import ( "sync" + "sync/atomic" "time" "github.com/DataDog/go-libddwaf/v3/errors" "github.com/DataDog/go-libddwaf/v3/internal/bindings" "github.com/DataDog/go-libddwaf/v3/internal/unsafe" "github.com/DataDog/go-libddwaf/v3/timer" - - "sync/atomic" ) // Context is a WAF execution context. It allows running the WAF incrementally @@ -26,9 +25,10 @@ type Context struct { cgoRefs cgoRefPool // Used to retain go data referenced by WAF Objects the context holds cContext bindings.WafContext // The C ddwaf_context pointer - timeoutCount atomic.Uint64 // Cumulative timeout count for this context. + // timeoutCount count all calls which have timeout'ed by scope. Keys are fixed at creation time. + timeoutCount map[Scope]*atomic.Uint64 - // Mutex protecting the use of cContext which is not thread-safe and cgoRefs. + // mutex protecting the use of cContext which is not thread-safe and cgoRefs. mutex sync.Mutex // timer registers the time spent in the WAF and go-libddwaf @@ -39,7 +39,7 @@ type Context struct { // truncations provides details about truncations that occurred while // encoding address data for WAF execution. - truncations map[TruncationReason][]int + truncations map[Scope]map[TruncationReason][]int } // RunAddressData provides address data to the Context.Run method. If a given key is present in both @@ -51,6 +51,8 @@ type RunAddressData struct { // Ephemeral address data is scoped to a given Context.Run call and is not persisted across calls. This is used for // protocols such as gRPC client/server streaming or GraphQL, where a single request can incur multiple subrequests. Ephemeral map[string]any + // Scope is the way to classify the different runs in the same context in order to have different metrics + Scope Scope } func (d RunAddressData) isEmpty() bool { @@ -70,9 +72,13 @@ func (context *Context) Run(addressData RunAddressData) (res Result, err error) return } + if addressData.Scope == "" { + addressData.Scope = DefaultScope + } + defer func() { if err == errors.ErrTimeout { - context.timeoutCount.Add(1) + context.timeoutCount[addressData.Scope].Add(1) } }() @@ -94,13 +100,13 @@ func (context *Context) Run(addressData RunAddressData) (res Result, err error) runTimer.Start() defer func() { - context.metrics.add(wafRunTag, runTimer.Stop()) - context.metrics.merge(runTimer.Stats()) + context.metrics.add(addressData.Scope, wafRunTag, runTimer.Stop()) + context.metrics.merge(addressData.Scope, runTimer.Stats()) }() wafEncodeTimer := runTimer.MustLeaf(wafEncodeTag) wafEncodeTimer.Start() - persistentData, persistentEncoder, err := context.encodeOneAddressType(addressData.Persistent, wafEncodeTimer) + persistentData, persistentEncoder, err := context.encodeOneAddressType(addressData.Scope, addressData.Persistent, wafEncodeTimer) if err != nil { wafEncodeTimer.Stop() return res, err @@ -108,7 +114,7 @@ func (context *Context) Run(addressData RunAddressData) (res Result, err error) // The WAF releases ephemeral address data at the max of each run call, so we need not keep the Go values live beyond // that in the same way we need for persistent data. We hence use a separate encoder. - ephemeralData, ephemeralEncoder, err := context.encodeOneAddressType(addressData.Ephemeral, wafEncodeTimer) + ephemeralData, ephemeralEncoder, err := context.encodeOneAddressType(addressData.Scope, addressData.Ephemeral, wafEncodeTimer) if err != nil { wafEncodeTimer.Stop() return res, err @@ -180,7 +186,7 @@ func merge[K comparable, V any](a, b map[K][]V) (merged map[K][]V) { // is a nil map, but this behaviour is expected since either persistent or ephemeral addresses are allowed to be null // one at a time. In this case, Encode will return nil contrary to Encode which will return a nil wafObject, // which is what we need to send to ddwaf_run to signal that the address data is empty. -func (context *Context) encodeOneAddressType(addressData map[string]any, timer timer.Timer) (*bindings.WafObject, encoder, error) { +func (context *Context) encodeOneAddressType(scope Scope, addressData map[string]any, timer timer.Timer) (*bindings.WafObject, encoder, error) { encoder := newLimitedEncoder(timer) if addressData == nil { return nil, encoder, nil @@ -191,7 +197,7 @@ func (context *Context) encodeOneAddressType(addressData map[string]any, timer t context.mutex.Lock() defer context.mutex.Unlock() - context.truncations = merge(context.truncations, encoder.truncations) + context.truncations[scope] = merge(context.truncations[scope], encoder.truncations) } if timer.Exhausted() { @@ -269,14 +275,15 @@ func (context *Context) Close() { // TotalRuntime returns the cumulated WAF runtime across various run calls within the same WAF context. // Returned time is in nanoseconds. -// Deprecated: use Timings instead +// Deprecated: use Stats instead func (context *Context) TotalRuntime() (uint64, uint64) { - return uint64(context.metrics.get(wafRunTag)), uint64(context.metrics.get(wafDurationTag)) + return uint64(context.metrics.get(DefaultScope, wafRunTag)), uint64(context.metrics.get(DefaultScope, wafDurationTag)) } // TotalTimeouts returns the cumulated amount of WAF timeouts across various run calls within the same WAF context. +// Deprecated: use Stats instead func (context *Context) TotalTimeouts() uint64 { - return context.timeoutCount.Load() + return context.timeoutCount[DefaultScope].Load() } // Stats returns the cumulative time spent in various parts of the WAF, all in nanoseconds @@ -285,15 +292,36 @@ func (context *Context) Stats() Stats { context.mutex.Lock() defer context.mutex.Unlock() - truncations := make(map[TruncationReason][]int, len(context.truncations)) - for reason, counts := range context.truncations { + truncations := make(map[TruncationReason][]int, len(context.truncations[DefaultScope])) + for reason, counts := range context.truncations[DefaultScope] { truncations[reason] = make([]int, len(counts)) copy(truncations[reason], counts) } + raspTruncations := make(map[TruncationReason][]int, len(context.truncations[RASPScope])) + for reason, counts := range context.truncations[RASPScope] { + raspTruncations[reason] = make([]int, len(counts)) + copy(raspTruncations[reason], counts) + } + + var ( + timeoutDefault uint64 + timeoutRASP uint64 + ) + + if atomic, ok := context.timeoutCount[DefaultScope]; ok { + timeoutDefault = atomic.Load() + } + + if atomic, ok := context.timeoutCount[RASPScope]; ok { + timeoutRASP = atomic.Load() + } + return Stats{ - Timers: context.metrics.copy(), - TimeoutCount: context.timeoutCount.Load(), - Truncations: truncations, + Timers: context.metrics.timers(), + TimeoutCount: timeoutDefault, + TimeoutRASPCount: timeoutRASP, + Truncations: truncations, + TruncationsRASP: raspTruncations, } } diff --git a/handle.go b/handle.go index 18235d38..8c37f90e 100644 --- a/handle.go +++ b/handle.go @@ -69,7 +69,7 @@ func NewHandle(rules any, keyObfuscatorRegex string, valueObfuscatorRegex string cHandle := wafLib.WafInit(obj, config, diagnosticsWafObj) // Upon failure, the WAF may have produced some diagnostics to help signal what went wrong... var ( - diags *Diagnostics + diags = new(Diagnostics) diagsErr error ) if !diagnosticsWafObj.IsInvalid() { @@ -132,7 +132,17 @@ func (handle *Handle) NewContextWithBudget(budget time.Duration) (*Context, erro return nil, err } - return &Context{handle: handle, cContext: cContext, timer: timer, metrics: metricsStore{data: make(map[string]time.Duration, 5)}}, nil + return &Context{ + handle: handle, + cContext: cContext, + timer: timer, + metrics: metricsStore{data: make(map[metricKey]time.Duration, 5)}, + truncations: make(map[Scope]map[TruncationReason][]int, 2), + timeoutCount: map[Scope]*atomic.Uint64{ + DefaultScope: new(atomic.Uint64), + RASPScope: new(atomic.Uint64), + }, + }, nil } // Diagnostics returns the rules initialization metrics for the current WAF handle diff --git a/metrics.go b/metrics.go index 86130995..50973653 100644 --- a/metrics.go +++ b/metrics.go @@ -6,92 +6,131 @@ package waf import ( - "fmt" + "strings" "sync" "time" ) // Stats stores the metrics collected by the WAF. -type Stats struct { - // Timers returns a map of metrics and their durations. - Timers map[string]time.Duration +type ( + Stats struct { + // Timers returns a map of metrics and their durations. + Timers map[string]time.Duration - // Timeout - TimeoutCount uint64 + // TimeoutCount for the Default Scope i.e. "waf" + TimeoutCount uint64 - // Truncations provides details about truncations that occurred while - // encoding address data for WAF execution. - Truncations map[TruncationReason][]int -} + // TimeoutRASPCount for the RASP Scope i.e. "rasp" + TimeoutRASPCount uint64 + + // Truncations provides details about truncations that occurred while + // encoding address data for WAF execution. + Truncations map[TruncationReason][]int + + // TruncationsRASP provides details about truncations that occurred while + // encoding address data for RASP execution. + TruncationsRASP map[TruncationReason][]int + } + + // Scope is the way to classify the different runs in the same context in order to have different metrics + Scope string + + metricKey struct { + scope Scope + component string + } + + metricsStore struct { + data map[metricKey]time.Duration + mutex sync.RWMutex + } +) const ( - wafEncodeTag = "_dd.appsec.waf.encode" - wafRunTag = "_dd.appsec.waf.duration_ext" - wafDurationTag = "_dd.appsec.waf.duration" - wafDecodeTag = "_dd.appsec.waf.decode" - wafTimeoutTag = "_dd.appsec.waf.timeouts" - wafTruncationTag = "_dd.appsec.waf.truncations" + DefaultScope Scope = "waf" + RASPScope Scope = "rasp" ) -// Metrics transform the stats returned by the WAF into a map of key value metrics for datadog backend +const ( + wafEncodeTag = "encode" + wafRunTag = "duration_ext" + wafDurationTag = "duration" + wafDecodeTag = "decode" + wafTimeoutTag = "timeouts" + wafTruncationTag = "truncations" +) + +func dot(parts ...string) string { + return strings.Join(parts, ".") +} + +// Metrics transform the stats returned by the WAF into a map of key value metrics with values in microseconds. +// ex. {"waf.encode": 100, "waf.duration_ext": 300, "waf.duration": 200, "rasp.encode": 100, "rasp.duration_ext": 300, "rasp.duration": 200} func (stats Stats) Metrics() map[string]any { tags := make(map[string]any, len(stats.Timers)+len(stats.Truncations)+1) for k, v := range stats.Timers { tags[k] = float64(v.Nanoseconds()) / float64(time.Microsecond) // The metrics should be in microseconds } - tags[wafTimeoutTag] = stats.TimeoutCount + if stats.TimeoutCount > 0 { + tags[dot(string(DefaultScope), wafTimeoutTag)] = stats.TimeoutCount + } + + if stats.TimeoutRASPCount > 0 { + tags[dot(string(RASPScope), wafTimeoutTag)] = stats.TimeoutRASPCount + } + for reason, list := range stats.Truncations { - tags[fmt.Sprintf("%s.%s", wafTruncationTag, reason.String())] = list + tags[dot(string(DefaultScope), wafTruncationTag, reason.String())] = list } - return tags -} + for reason, list := range stats.TruncationsRASP { + tags[dot(string(RASPScope), wafTruncationTag, reason.String())] = list + } -type metricsStore struct { - data map[string]time.Duration - mutex sync.RWMutex + return tags } -func (metrics *metricsStore) add(key string, duration time.Duration) { +func (metrics *metricsStore) add(scope Scope, component string, duration time.Duration) { metrics.mutex.Lock() defer metrics.mutex.Unlock() if metrics.data == nil { - metrics.data = make(map[string]time.Duration, 5) + metrics.data = make(map[metricKey]time.Duration, 5) } - metrics.data[key] += duration + metrics.data[metricKey{scope, component}] += duration } -func (metrics *metricsStore) get(key string) time.Duration { +func (metrics *metricsStore) get(scope Scope, component string) time.Duration { metrics.mutex.RLock() defer metrics.mutex.RUnlock() - return metrics.data[key] + return metrics.data[metricKey{scope, component}] } -func (metrics *metricsStore) copy() map[string]time.Duration { +func (metrics *metricsStore) timers() map[string]time.Duration { metrics.mutex.Lock() defer metrics.mutex.Unlock() if metrics.data == nil { return nil } - copy := make(map[string]time.Duration, len(metrics.data)) + timers := make(map[string]time.Duration, len(metrics.data)) for k, v := range metrics.data { - copy[k] = v + timers[dot(string(k.scope), k.component)] = v } - return copy + return timers } // merge merges the current metrics with new ones -func (metrics *metricsStore) merge(other map[string]time.Duration) { +func (metrics *metricsStore) merge(scope Scope, other map[string]time.Duration) { metrics.mutex.Lock() defer metrics.mutex.Unlock() if metrics.data == nil { - metrics.data = make(map[string]time.Duration, 5) + metrics.data = make(map[metricKey]time.Duration, 5) } - for key, val := range other { + for component, val := range other { + key := metricKey{scope, component} prev, ok := metrics.data[key] if !ok { prev = 0 diff --git a/waf_test.go b/waf_test.go index ad38b64a..7d69a336 100644 --- a/waf_test.go +++ b/waf_test.go @@ -379,10 +379,10 @@ func TestTimeout(t *testing.T) { _, err = context.Run(RunAddressData{Persistent: normalValue, Ephemeral: normalValue}) require.NoError(t, err) require.NotEmpty(t, context.Stats()) - require.NotZero(t, context.Stats().Timers["_dd.appsec.waf.decode"]) - require.NotZero(t, context.Stats().Timers["_dd.appsec.waf.encode"]) - require.NotZero(t, context.Stats().Timers["_dd.appsec.waf.duration_ext"]) - require.NotZero(t, context.Stats().Timers["_dd.appsec.waf.duration"]) + require.NotZero(t, context.Stats().Timers["waf.decode"]) + require.NotZero(t, context.Stats().Timers["waf.encode"]) + require.NotZero(t, context.Stats().Timers["waf.duration_ext"]) + require.NotZero(t, context.Stats().Timers["waf.duration"]) }) t.Run("not-empty-metrics-no-match", func(t *testing.T) { @@ -394,10 +394,10 @@ func TestTimeout(t *testing.T) { _, err = context.Run(RunAddressData{Persistent: map[string]any{"my.input": "curl/7.88"}}) require.NoError(t, err) require.NotEmpty(t, context.Stats()) - require.NotZero(t, context.Stats().Timers["_dd.appsec.waf.decode"]) - require.NotZero(t, context.Stats().Timers["_dd.appsec.waf.encode"]) - require.NotZero(t, context.Stats().Timers["_dd.appsec.waf.duration_ext"]) - require.NotZero(t, context.Stats().Timers["_dd.appsec.waf.duration"]) + require.NotZero(t, context.Stats().Timers["waf.decode"]) + require.NotZero(t, context.Stats().Timers["waf.encode"]) + require.NotZero(t, context.Stats().Timers["waf.duration_ext"]) + require.NotZero(t, context.Stats().Timers["waf.duration"]) }) t.Run("timeout-persistent-encoder", func(t *testing.T) { @@ -408,8 +408,8 @@ func TestTimeout(t *testing.T) { _, err = context.Run(RunAddressData{Persistent: largeValue}) require.Equal(t, errors.ErrTimeout, err) - require.GreaterOrEqual(t, context.Stats().Timers["_dd.appsec.waf.duration_ext"], time.Millisecond) - require.GreaterOrEqual(t, context.Stats().Timers["_dd.appsec.waf.encode"], time.Millisecond) + require.GreaterOrEqual(t, context.Stats().Timers["waf.duration_ext"], time.Millisecond) + require.GreaterOrEqual(t, context.Stats().Timers["waf.encode"], time.Millisecond) }) t.Run("timeout-ephemeral-encoder", func(t *testing.T) { @@ -420,8 +420,8 @@ func TestTimeout(t *testing.T) { _, err = context.Run(RunAddressData{Ephemeral: largeValue}) require.Equal(t, errors.ErrTimeout, err) - require.GreaterOrEqual(t, context.Stats().Timers["_dd.appsec.waf.duration_ext"], time.Millisecond) - require.GreaterOrEqual(t, context.Stats().Timers["_dd.appsec.waf.encode"], time.Millisecond) + require.GreaterOrEqual(t, context.Stats().Timers["waf.duration_ext"], time.Millisecond) + require.GreaterOrEqual(t, context.Stats().Timers["waf.encode"], time.Millisecond) }) t.Run("many-runs", func(t *testing.T) { @@ -436,6 +436,37 @@ func TestTimeout(t *testing.T) { require.Equal(t, errors.ErrTimeout, err) }) + + t.Run("rasp-simple", func(t *testing.T) { + waf, err := newDefaultHandle(newArachniTestRule([]ruleInput{{Address: "my.input"}}, nil)) + require.NoError(t, err) + require.NotNil(t, waf) + + context, err := waf.NewContext() + require.NoError(t, err) + require.NotNil(t, context) + defer context.Close() + + _, err = context.Run(RunAddressData{Persistent: normalValue, Ephemeral: normalValue, Scope: RASPScope}) + require.NoError(t, err) + require.NotZero(t, context.Stats().Timers["rasp.duration_ext"]) + require.NotZero(t, context.Stats().Timers["rasp.duration"]) + require.NotZero(t, context.Stats().Timers["rasp.encode"]) + require.NotZero(t, context.Stats().Timers["rasp.decode"]) + }) + + t.Run("rasp-timeout", func(t *testing.T) { + context, err := waf.NewContextWithBudget(time.Millisecond) + require.NoError(t, err) + require.NotNil(t, context) + defer context.Close() + + _, err = context.Run(RunAddressData{Persistent: largeValue, Scope: RASPScope}) + require.Equal(t, errors.ErrTimeout, err) + require.GreaterOrEqual(t, context.Stats().Timers["rasp.duration_ext"], time.Millisecond) + require.GreaterOrEqual(t, context.Stats().Timers["rasp.encode"], time.Millisecond) + require.EqualValues(t, 1, context.Stats().TimeoutRASPCount) + }) } func TestMatching(t *testing.T) { @@ -1289,7 +1320,7 @@ func TestTruncationInformation(t *testing.T) { require.Equal(t, map[TruncationReason][]int{ StringTooLong: {bindings.WafMaxStringLength + extra + 2, bindings.WafMaxStringLength + extra}, ContainerTooLarge: {bindings.WafMaxContainerSize + extra + 2, bindings.WafMaxContainerSize + extra}, - }, ctx.truncations) + }, ctx.truncations[DefaultScope]) } func BenchmarkEncoder(b *testing.B) {