Skip to content

Commit

Permalink
Merge pull request #2006 from josephschorr/version-middleware-serve-test
Browse files Browse the repository at this point in the history
Add server version middleware to serve-testing
  • Loading branch information
josephschorr authored Aug 5, 2024
2 parents 4429644 + 4ef7f38 commit 54cace8
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 60 deletions.
116 changes: 84 additions & 32 deletions pkg/cmd/server/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,75 @@ const (
DefaultInternalMiddlewareServerSpecific = "servicespecific"
)

//go:generate go run github.com/ecordell/optgen -output zz_generated.middlewareoption.go . MiddlewareOption
type MiddlewareOption struct {
logger zerolog.Logger
authFunc grpcauth.AuthFunc
enableVersionResponse bool
dispatcher dispatch.Dispatcher
ds datastore.Datastore
enableRequestLog bool
enableResponseLog bool
disableGRPCHistogram bool
Logger zerolog.Logger `debugmap:"hidden"`
AuthFunc grpcauth.AuthFunc `debugmap:"hidden"`
EnableVersionResponse bool `debugmap:"visible"`
DispatcherForMiddleware dispatch.Dispatcher `debugmap:"hidden"`
EnableRequestLog bool `debugmap:"visible"`
EnableResponseLog bool `debugmap:"visible"`
DisableGRPCHistogram bool `debugmap:"visible"`

unaryDatastoreMiddleware *ReferenceableMiddleware[grpc.UnaryServerInterceptor] `debugmap:"hidden"`
streamDatastoreMiddleware *ReferenceableMiddleware[grpc.StreamServerInterceptor] `debugmap:"hidden"`
}

type Middleware interface {
UnaryServerInterceptor() grpc.UnaryServerInterceptor
StreamServerInterceptor() grpc.StreamServerInterceptor
}

func (m MiddlewareOption) WithDatastoreMiddleware(middleware Middleware) MiddlewareOption {
unary := NewUnaryMiddleware().
WithName(DefaultInternalMiddlewareDatastore).
WithInternal(true).
WithInterceptor(middleware.UnaryServerInterceptor()).
Done()

stream := NewStreamMiddleware().
WithName(DefaultInternalMiddlewareDatastore).
WithInternal(true).
WithInterceptor(middleware.StreamServerInterceptor()).
Done()

return MiddlewareOption{
Logger: m.Logger,
AuthFunc: m.AuthFunc,
EnableVersionResponse: m.EnableVersionResponse,
DispatcherForMiddleware: m.DispatcherForMiddleware,
EnableRequestLog: m.EnableRequestLog,
EnableResponseLog: m.EnableResponseLog,
DisableGRPCHistogram: m.DisableGRPCHistogram,
unaryDatastoreMiddleware: &unary,
streamDatastoreMiddleware: &stream,
}
}

func (m MiddlewareOption) WithDatastore(ds datastore.Datastore) MiddlewareOption {
unary := NewUnaryMiddleware().
WithName(DefaultInternalMiddlewareDatastore).
WithInternal(true).
WithInterceptor(datastoremw.UnaryServerInterceptor(ds)).
Done()

stream := NewStreamMiddleware().
WithName(DefaultInternalMiddlewareDatastore).
WithInternal(true).
WithInterceptor(datastoremw.StreamServerInterceptor(ds)).
Done()

return MiddlewareOption{
Logger: m.Logger,
AuthFunc: m.AuthFunc,
EnableVersionResponse: m.EnableVersionResponse,
DispatcherForMiddleware: m.DispatcherForMiddleware,
EnableRequestLog: m.EnableRequestLog,
EnableResponseLog: m.EnableResponseLog,
DisableGRPCHistogram: m.DisableGRPCHistogram,
unaryDatastoreMiddleware: &unary,
streamDatastoreMiddleware: &stream,
}
}

// gRPCMetricsUnaryInterceptor creates the default prometheus metrics interceptor for unary gRPCs
Expand Down Expand Up @@ -212,7 +272,7 @@ func doesNotMatchRoute(route string) func(_ context.Context, c interceptors.Call

// DefaultUnaryMiddleware generates the default middleware chain used for the public SpiceDB Unary gRPC methods
func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryServerInterceptor], error) {
grpcMetricsUnaryInterceptor, _ := GRPCMetrics(opts.disableGRPCHistogram)
grpcMetricsUnaryInterceptor, _ := GRPCMetrics(opts.DisableGRPCHistogram)
chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
NewUnaryMiddleware().
WithName(DefaultMiddlewareRequestID).
Expand All @@ -232,15 +292,15 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS
NewUnaryMiddleware().
WithName(DefaultMiddlewareGRPCLog + "-debug").
WithInterceptor(selector.UnaryServerInterceptor(
grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption),
grpclog.UnaryServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption),
selector.MatchFunc(matchesRoute(healthCheckRoute)))).
EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
Done(),

NewUnaryMiddleware().
WithName(DefaultMiddlewareGRPCLog).
WithInterceptor(selector.UnaryServerInterceptor(
grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption),
grpclog.UnaryServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption),
selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))).
EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
Done(),
Expand All @@ -252,26 +312,22 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS

NewUnaryMiddleware().
WithName(DefaultMiddlewareGRPCAuth).
WithInterceptor(grpcauth.UnaryServerInterceptor(opts.authFunc)).
WithInterceptor(grpcauth.UnaryServerInterceptor(opts.AuthFunc)).
EnsureAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures
Done(),

NewUnaryMiddleware().
WithName(DefaultMiddlewareServerVersion).
WithInterceptor(serverversion.UnaryServerInterceptor(opts.enableVersionResponse)).
WithInterceptor(serverversion.UnaryServerInterceptor(opts.EnableVersionResponse)).
Done(),

NewUnaryMiddleware().
WithName(DefaultInternalMiddlewareDispatch).
WithInternal(true).
WithInterceptor(dispatchmw.UnaryServerInterceptor(opts.dispatcher)).
WithInterceptor(dispatchmw.UnaryServerInterceptor(opts.DispatcherForMiddleware)).
Done(),

NewUnaryMiddleware().
WithName(DefaultInternalMiddlewareDatastore).
WithInternal(true).
WithInterceptor(datastoremw.UnaryServerInterceptor(opts.ds)).
Done(),
*opts.unaryDatastoreMiddleware,

NewUnaryMiddleware().
WithName(DefaultInternalMiddlewareConsistency).
Expand All @@ -290,7 +346,7 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS

// DefaultStreamingMiddleware generates the default middleware chain used for the public SpiceDB Streaming gRPC methods
func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.StreamServerInterceptor], error) {
_, grpcMetricsStreamingInterceptor := GRPCMetrics(opts.disableGRPCHistogram)
_, grpcMetricsStreamingInterceptor := GRPCMetrics(opts.DisableGRPCHistogram)
chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.StreamServerInterceptor]{
NewStreamMiddleware().
WithName(DefaultMiddlewareRequestID).
Expand All @@ -310,15 +366,15 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St
NewStreamMiddleware().
WithName(DefaultMiddlewareGRPCLog + "-debug").
WithInterceptor(selector.StreamServerInterceptor(
grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption),
grpclog.StreamServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption),
selector.MatchFunc(matchesRoute(healthCheckRoute)))).
EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
Done(),

NewStreamMiddleware().
WithName(DefaultMiddlewareGRPCLog).
WithInterceptor(selector.StreamServerInterceptor(
grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption),
grpclog.StreamServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption),
selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))).
EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
Done(),
Expand All @@ -330,26 +386,22 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St

NewStreamMiddleware().
WithName(DefaultMiddlewareGRPCAuth).
WithInterceptor(grpcauth.StreamServerInterceptor(opts.authFunc)).
WithInterceptor(grpcauth.StreamServerInterceptor(opts.AuthFunc)).
EnsureInterceptorAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures
Done(),

NewStreamMiddleware().
WithName(DefaultMiddlewareServerVersion).
WithInterceptor(serverversion.StreamServerInterceptor(opts.enableVersionResponse)).
WithInterceptor(serverversion.StreamServerInterceptor(opts.EnableVersionResponse)).
Done(),

NewStreamMiddleware().
WithName(DefaultInternalMiddlewareDispatch).
WithInternal(true).
WithInterceptor(dispatchmw.StreamServerInterceptor(opts.dispatcher)).
WithInterceptor(dispatchmw.StreamServerInterceptor(opts.DispatcherForMiddleware)).
Done(),

NewStreamMiddleware().
WithName(DefaultInternalMiddlewareDatastore).
WithInternal(true).
WithInterceptor(datastoremw.StreamServerInterceptor(opts.ds)).
Done(),
*opts.streamDatastoreMiddleware,

NewStreamMiddleware().
WithName(DefaultInternalMiddlewareConsistency).
Expand All @@ -368,11 +420,11 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St

func determineEventsToLog(opts MiddlewareOption) grpclog.Option {
eventsToLog := []grpclog.LoggableEvent{grpclog.FinishCall}
if opts.enableRequestLog {
if opts.EnableRequestLog {
eventsToLog = append(eventsToLog, grpclog.PayloadReceived)
}

if opts.enableResponseLog {
if opts.EnableResponseLog {
eventsToLog = append(eventsToLog, grpclog.PayloadSent)
}

Expand Down
5 changes: 4 additions & 1 deletion pkg/cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,14 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) {
c.GRPCAuthFunc,
!c.DisableVersionResponse,
dispatcher,
ds,
c.EnableRequestLogs,
c.EnableResponseLogs,
c.DisableGRPCLatencyHistogram,
nil,
nil,
}
opts = opts.WithDatastore(ds)

defaultUnaryMiddlewareChain, err := DefaultUnaryMiddleware(opts)
if err != nil {
return nil, fmt.Errorf("error building default middlewares: %w", err)
Expand Down
8 changes: 6 additions & 2 deletions pkg/cmd/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ func TestModifyUnaryMiddleware(t *testing.T) {
},
}}

opt := MiddlewareOption{logging.Logger, nil, false, nil, nil, false, false, false}
opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, nil, nil}
opt = opt.WithDatastore(nil)

defaultMw, err := DefaultUnaryMiddleware(opt)
require.NoError(t, err)

Expand All @@ -257,7 +259,9 @@ func TestModifyStreamingMiddleware(t *testing.T) {
},
}}

opt := MiddlewareOption{logging.Logger, nil, false, nil, nil, false, false, false}
opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, nil, nil}
opt = opt.WithDatastore(nil)

defaultMw, err := DefaultStreamingMiddleware(opt)
require.NoError(t, err)

Expand Down
121 changes: 121 additions & 0 deletions pkg/cmd/server/zz_generated.middlewareoption.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 54cace8

Please sign in to comment.