diff --git a/pkg/cmd/server/defaults.go b/pkg/cmd/server/defaults.go index 1fd91091cf..d16297c0b2 100644 --- a/pkg/cmd/server/defaults.go +++ b/pkg/cmd/server/defaults.go @@ -168,15 +168,75 @@ const ( DefaultInternalMiddlewareServerSpecific = "servicespecific" ) +//go:generate go run github.com/ecordell/optgen -output zz_generated.middlewareoption.go . MiddlewareOption type MiddlewareOption struct { - logger zerolog.Logger - authFunc grpcauth.AuthFunc - enableVersionResponse bool - dispatcher dispatch.Dispatcher - ds datastore.Datastore - enableRequestLog bool - enableResponseLog bool - disableGRPCHistogram bool + Logger zerolog.Logger `debugmap:"hidden"` + AuthFunc grpcauth.AuthFunc `debugmap:"hidden"` + EnableVersionResponse bool `debugmap:"visible"` + DispatcherForMiddleware dispatch.Dispatcher `debugmap:"hidden"` + EnableRequestLog bool `debugmap:"visible"` + EnableResponseLog bool `debugmap:"visible"` + DisableGRPCHistogram bool `debugmap:"visible"` + + unaryDatastoreMiddleware *ReferenceableMiddleware[grpc.UnaryServerInterceptor] `debugmap:"hidden"` + streamDatastoreMiddleware *ReferenceableMiddleware[grpc.StreamServerInterceptor] `debugmap:"hidden"` +} + +type Middleware interface { + UnaryServerInterceptor() grpc.UnaryServerInterceptor + StreamServerInterceptor() grpc.StreamServerInterceptor +} + +func (m MiddlewareOption) WithDatastoreMiddleware(middleware Middleware) MiddlewareOption { + unary := NewUnaryMiddleware(). + WithName(DefaultInternalMiddlewareDatastore). + WithInternal(true). + WithInterceptor(middleware.UnaryServerInterceptor()). + Done() + + stream := NewStreamMiddleware(). + WithName(DefaultInternalMiddlewareDatastore). + WithInternal(true). + WithInterceptor(middleware.StreamServerInterceptor()). + Done() + + return MiddlewareOption{ + Logger: m.Logger, + AuthFunc: m.AuthFunc, + EnableVersionResponse: m.EnableVersionResponse, + DispatcherForMiddleware: m.DispatcherForMiddleware, + EnableRequestLog: m.EnableRequestLog, + EnableResponseLog: m.EnableResponseLog, + DisableGRPCHistogram: m.DisableGRPCHistogram, + unaryDatastoreMiddleware: &unary, + streamDatastoreMiddleware: &stream, + } +} + +func (m MiddlewareOption) WithDatastore(ds datastore.Datastore) MiddlewareOption { + unary := NewUnaryMiddleware(). + WithName(DefaultInternalMiddlewareDatastore). + WithInternal(true). + WithInterceptor(datastoremw.UnaryServerInterceptor(ds)). + Done() + + stream := NewStreamMiddleware(). + WithName(DefaultInternalMiddlewareDatastore). + WithInternal(true). + WithInterceptor(datastoremw.StreamServerInterceptor(ds)). + Done() + + return MiddlewareOption{ + Logger: m.Logger, + AuthFunc: m.AuthFunc, + EnableVersionResponse: m.EnableVersionResponse, + DispatcherForMiddleware: m.DispatcherForMiddleware, + EnableRequestLog: m.EnableRequestLog, + EnableResponseLog: m.EnableResponseLog, + DisableGRPCHistogram: m.DisableGRPCHistogram, + unaryDatastoreMiddleware: &unary, + streamDatastoreMiddleware: &stream, + } } // gRPCMetricsUnaryInterceptor creates the default prometheus metrics interceptor for unary gRPCs @@ -212,7 +272,7 @@ func doesNotMatchRoute(route string) func(_ context.Context, c interceptors.Call // DefaultUnaryMiddleware generates the default middleware chain used for the public SpiceDB Unary gRPC methods func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryServerInterceptor], error) { - grpcMetricsUnaryInterceptor, _ := GRPCMetrics(opts.disableGRPCHistogram) + grpcMetricsUnaryInterceptor, _ := GRPCMetrics(opts.DisableGRPCHistogram) chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.UnaryServerInterceptor]{ NewUnaryMiddleware(). WithName(DefaultMiddlewareRequestID). @@ -232,7 +292,7 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS NewUnaryMiddleware(). WithName(DefaultMiddlewareGRPCLog + "-debug"). WithInterceptor(selector.UnaryServerInterceptor( - grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption), + grpclog.UnaryServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption), selector.MatchFunc(matchesRoute(healthCheckRoute)))). EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs), Done(), @@ -240,7 +300,7 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS NewUnaryMiddleware(). WithName(DefaultMiddlewareGRPCLog). WithInterceptor(selector.UnaryServerInterceptor( - grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption), + grpclog.UnaryServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption), selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))). EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs), Done(), @@ -252,26 +312,22 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS NewUnaryMiddleware(). WithName(DefaultMiddlewareGRPCAuth). - WithInterceptor(grpcauth.UnaryServerInterceptor(opts.authFunc)). + WithInterceptor(grpcauth.UnaryServerInterceptor(opts.AuthFunc)). EnsureAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures Done(), NewUnaryMiddleware(). WithName(DefaultMiddlewareServerVersion). - WithInterceptor(serverversion.UnaryServerInterceptor(opts.enableVersionResponse)). + WithInterceptor(serverversion.UnaryServerInterceptor(opts.EnableVersionResponse)). Done(), NewUnaryMiddleware(). WithName(DefaultInternalMiddlewareDispatch). WithInternal(true). - WithInterceptor(dispatchmw.UnaryServerInterceptor(opts.dispatcher)). + WithInterceptor(dispatchmw.UnaryServerInterceptor(opts.DispatcherForMiddleware)). Done(), - NewUnaryMiddleware(). - WithName(DefaultInternalMiddlewareDatastore). - WithInternal(true). - WithInterceptor(datastoremw.UnaryServerInterceptor(opts.ds)). - Done(), + *opts.unaryDatastoreMiddleware, NewUnaryMiddleware(). WithName(DefaultInternalMiddlewareConsistency). @@ -290,7 +346,7 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS // DefaultStreamingMiddleware generates the default middleware chain used for the public SpiceDB Streaming gRPC methods func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.StreamServerInterceptor], error) { - _, grpcMetricsStreamingInterceptor := GRPCMetrics(opts.disableGRPCHistogram) + _, grpcMetricsStreamingInterceptor := GRPCMetrics(opts.DisableGRPCHistogram) chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.StreamServerInterceptor]{ NewStreamMiddleware(). WithName(DefaultMiddlewareRequestID). @@ -310,7 +366,7 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St NewStreamMiddleware(). WithName(DefaultMiddlewareGRPCLog + "-debug"). WithInterceptor(selector.StreamServerInterceptor( - grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption), + grpclog.StreamServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption), selector.MatchFunc(matchesRoute(healthCheckRoute)))). EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs), Done(), @@ -318,7 +374,7 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St NewStreamMiddleware(). WithName(DefaultMiddlewareGRPCLog). WithInterceptor(selector.StreamServerInterceptor( - grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption), + grpclog.StreamServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption), selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))). EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs), Done(), @@ -330,26 +386,22 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St NewStreamMiddleware(). WithName(DefaultMiddlewareGRPCAuth). - WithInterceptor(grpcauth.StreamServerInterceptor(opts.authFunc)). + WithInterceptor(grpcauth.StreamServerInterceptor(opts.AuthFunc)). EnsureInterceptorAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures Done(), NewStreamMiddleware(). WithName(DefaultMiddlewareServerVersion). - WithInterceptor(serverversion.StreamServerInterceptor(opts.enableVersionResponse)). + WithInterceptor(serverversion.StreamServerInterceptor(opts.EnableVersionResponse)). Done(), NewStreamMiddleware(). WithName(DefaultInternalMiddlewareDispatch). WithInternal(true). - WithInterceptor(dispatchmw.StreamServerInterceptor(opts.dispatcher)). + WithInterceptor(dispatchmw.StreamServerInterceptor(opts.DispatcherForMiddleware)). Done(), - NewStreamMiddleware(). - WithName(DefaultInternalMiddlewareDatastore). - WithInternal(true). - WithInterceptor(datastoremw.StreamServerInterceptor(opts.ds)). - Done(), + *opts.streamDatastoreMiddleware, NewStreamMiddleware(). WithName(DefaultInternalMiddlewareConsistency). @@ -368,11 +420,11 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St func determineEventsToLog(opts MiddlewareOption) grpclog.Option { eventsToLog := []grpclog.LoggableEvent{grpclog.FinishCall} - if opts.enableRequestLog { + if opts.EnableRequestLog { eventsToLog = append(eventsToLog, grpclog.PayloadReceived) } - if opts.enableResponseLog { + if opts.EnableResponseLog { eventsToLog = append(eventsToLog, grpclog.PayloadSent) } diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 6444898883..16d2449913 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -381,11 +381,14 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { c.GRPCAuthFunc, !c.DisableVersionResponse, dispatcher, - ds, c.EnableRequestLogs, c.EnableResponseLogs, c.DisableGRPCLatencyHistogram, + nil, + nil, } + opts = opts.WithDatastore(ds) + defaultUnaryMiddlewareChain, err := DefaultUnaryMiddleware(opts) if err != nil { return nil, fmt.Errorf("error building default middlewares: %w", err) diff --git a/pkg/cmd/server/server_test.go b/pkg/cmd/server/server_test.go index 1a0b82ffa5..405eccc18c 100644 --- a/pkg/cmd/server/server_test.go +++ b/pkg/cmd/server/server_test.go @@ -231,7 +231,9 @@ func TestModifyUnaryMiddleware(t *testing.T) { }, }} - opt := MiddlewareOption{logging.Logger, nil, false, nil, nil, false, false, false} + opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, nil, nil} + opt = opt.WithDatastore(nil) + defaultMw, err := DefaultUnaryMiddleware(opt) require.NoError(t, err) @@ -257,7 +259,9 @@ func TestModifyStreamingMiddleware(t *testing.T) { }, }} - opt := MiddlewareOption{logging.Logger, nil, false, nil, nil, false, false, false} + opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, nil, nil} + opt = opt.WithDatastore(nil) + defaultMw, err := DefaultStreamingMiddleware(opt) require.NoError(t, err) diff --git a/pkg/cmd/server/zz_generated.middlewareoption.go b/pkg/cmd/server/zz_generated.middlewareoption.go new file mode 100644 index 0000000000..ba75dbc5ff --- /dev/null +++ b/pkg/cmd/server/zz_generated.middlewareoption.go @@ -0,0 +1,121 @@ +// Code generated by github.com/ecordell/optgen. DO NOT EDIT. +package server + +import ( + dispatch "github.com/authzed/spicedb/internal/dispatch" + defaults "github.com/creasty/defaults" + helpers "github.com/ecordell/optgen/helpers" + auth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" + zerolog "github.com/rs/zerolog" +) + +type MiddlewareOptionOption func(m *MiddlewareOption) + +// NewMiddlewareOptionWithOptions creates a new MiddlewareOption with the passed in options set +func NewMiddlewareOptionWithOptions(opts ...MiddlewareOptionOption) *MiddlewareOption { + m := &MiddlewareOption{} + for _, o := range opts { + o(m) + } + return m +} + +// NewMiddlewareOptionWithOptionsAndDefaults creates a new MiddlewareOption with the passed in options set starting from the defaults +func NewMiddlewareOptionWithOptionsAndDefaults(opts ...MiddlewareOptionOption) *MiddlewareOption { + m := &MiddlewareOption{} + defaults.MustSet(m) + for _, o := range opts { + o(m) + } + return m +} + +// ToOption returns a new MiddlewareOptionOption that sets the values from the passed in MiddlewareOption +func (m *MiddlewareOption) ToOption() MiddlewareOptionOption { + return func(to *MiddlewareOption) { + to.Logger = m.Logger + to.AuthFunc = m.AuthFunc + to.EnableVersionResponse = m.EnableVersionResponse + to.DispatcherForMiddleware = m.DispatcherForMiddleware + to.EnableRequestLog = m.EnableRequestLog + to.EnableResponseLog = m.EnableResponseLog + to.DisableGRPCHistogram = m.DisableGRPCHistogram + to.unaryDatastoreMiddleware = m.unaryDatastoreMiddleware + to.streamDatastoreMiddleware = m.streamDatastoreMiddleware + } +} + +// DebugMap returns a map form of MiddlewareOption for debugging +func (m MiddlewareOption) DebugMap() map[string]any { + debugMap := map[string]any{} + debugMap["EnableVersionResponse"] = helpers.DebugValue(m.EnableVersionResponse, false) + debugMap["EnableRequestLog"] = helpers.DebugValue(m.EnableRequestLog, false) + debugMap["EnableResponseLog"] = helpers.DebugValue(m.EnableResponseLog, false) + debugMap["DisableGRPCHistogram"] = helpers.DebugValue(m.DisableGRPCHistogram, false) + return debugMap +} + +// MiddlewareOptionWithOptions configures an existing MiddlewareOption with the passed in options set +func MiddlewareOptionWithOptions(m *MiddlewareOption, opts ...MiddlewareOptionOption) *MiddlewareOption { + for _, o := range opts { + o(m) + } + return m +} + +// WithOptions configures the receiver MiddlewareOption with the passed in options set +func (m *MiddlewareOption) WithOptions(opts ...MiddlewareOptionOption) *MiddlewareOption { + for _, o := range opts { + o(m) + } + return m +} + +// WithLogger returns an option that can set Logger on a MiddlewareOption +func WithLogger(logger zerolog.Logger) MiddlewareOptionOption { + return func(m *MiddlewareOption) { + m.Logger = logger + } +} + +// WithAuthFunc returns an option that can set AuthFunc on a MiddlewareOption +func WithAuthFunc(authFunc auth.AuthFunc) MiddlewareOptionOption { + return func(m *MiddlewareOption) { + m.AuthFunc = authFunc + } +} + +// WithEnableVersionResponse returns an option that can set EnableVersionResponse on a MiddlewareOption +func WithEnableVersionResponse(enableVersionResponse bool) MiddlewareOptionOption { + return func(m *MiddlewareOption) { + m.EnableVersionResponse = enableVersionResponse + } +} + +// WithDispatcherForMiddleware returns an option that can set DispatcherForMiddleware on a MiddlewareOption +func WithDispatcherForMiddleware(dispatcherForMiddleware dispatch.Dispatcher) MiddlewareOptionOption { + return func(m *MiddlewareOption) { + m.DispatcherForMiddleware = dispatcherForMiddleware + } +} + +// WithEnableRequestLog returns an option that can set EnableRequestLog on a MiddlewareOption +func WithEnableRequestLog(enableRequestLog bool) MiddlewareOptionOption { + return func(m *MiddlewareOption) { + m.EnableRequestLog = enableRequestLog + } +} + +// WithEnableResponseLog returns an option that can set EnableResponseLog on a MiddlewareOption +func WithEnableResponseLog(enableResponseLog bool) MiddlewareOptionOption { + return func(m *MiddlewareOption) { + m.EnableResponseLog = enableResponseLog + } +} + +// WithDisableGRPCHistogram returns an option that can set DisableGRPCHistogram on a MiddlewareOption +func WithDisableGRPCHistogram(disableGRPCHistogram bool) MiddlewareOptionOption { + return func(m *MiddlewareOption) { + m.DisableGRPCHistogram = disableGRPCHistogram + } +} diff --git a/pkg/cmd/testserver/testserver.go b/pkg/cmd/testserver/testserver.go index 78895f2b3a..2e81548bd1 100644 --- a/pkg/cmd/testserver/testserver.go +++ b/pkg/cmd/testserver/testserver.go @@ -12,14 +12,12 @@ import ( "github.com/authzed/spicedb/internal/dispatch/graph" "github.com/authzed/spicedb/internal/gateway" log "github.com/authzed/spicedb/internal/logging" - consistencymw "github.com/authzed/spicedb/internal/middleware/consistency" - dispatchmw "github.com/authzed/spicedb/internal/middleware/dispatcher" "github.com/authzed/spicedb/internal/middleware/pertoken" "github.com/authzed/spicedb/internal/middleware/readonly" - "github.com/authzed/spicedb/internal/middleware/servicespecific" "github.com/authzed/spicedb/internal/services" "github.com/authzed/spicedb/internal/services/health" v1svc "github.com/authzed/spicedb/internal/services/v1" + "github.com/authzed/spicedb/pkg/cmd/server" "github.com/authzed/spicedb/pkg/cmd/util" "github.com/authzed/spicedb/pkg/datastore" ) @@ -87,19 +85,26 @@ func (c *Config) Complete() (RunnableTestServer, error) { 1*time.Second, ) } + + opts := *server.NewMiddlewareOptionWithOptions(server.WithAuthFunc(func(ctx context.Context) (context.Context, error) { + // Turn off the default auth system. + return ctx, nil + })) + opts = opts.WithDatastoreMiddleware(datastoreMiddleware) + + unaryMiddleware, err := server.DefaultUnaryMiddleware(opts) + if err != nil { + return nil, err + } + + streamMiddleware, err := server.DefaultStreamingMiddleware(opts) + if err != nil { + return nil, err + } + gRPCSrv, err := c.GRPCServer.Complete(zerolog.InfoLevel, registerServices, - grpc.ChainUnaryInterceptor( - datastoreMiddleware.UnaryServerInterceptor(), - dispatchmw.UnaryServerInterceptor(dispatcher), - consistencymw.UnaryServerInterceptor(), - servicespecific.UnaryServerInterceptor, - ), - grpc.ChainStreamInterceptor( - datastoreMiddleware.StreamServerInterceptor(), - dispatchmw.StreamServerInterceptor(dispatcher), - consistencymw.StreamServerInterceptor(), - servicespecific.StreamServerInterceptor, - ), + grpc.ChainUnaryInterceptor(unaryMiddleware.ToGRPCInterceptors()...), + grpc.ChainStreamInterceptor(streamMiddleware.ToGRPCInterceptors()...), ) if err != nil { return nil, err @@ -107,18 +112,10 @@ func (c *Config) Complete() (RunnableTestServer, error) { readOnlyGRPCSrv, err := c.ReadOnlyGRPCServer.Complete(zerolog.InfoLevel, registerServices, grpc.ChainUnaryInterceptor( - datastoreMiddleware.UnaryServerInterceptor(), - readonly.UnaryServerInterceptor(), - dispatchmw.UnaryServerInterceptor(dispatcher), - consistencymw.UnaryServerInterceptor(), - servicespecific.UnaryServerInterceptor, + append(unaryMiddleware.ToGRPCInterceptors(), readonly.UnaryServerInterceptor())..., ), grpc.ChainStreamInterceptor( - datastoreMiddleware.StreamServerInterceptor(), - readonly.StreamServerInterceptor(), - dispatchmw.StreamServerInterceptor(dispatcher), - consistencymw.StreamServerInterceptor(), - servicespecific.StreamServerInterceptor, + append(streamMiddleware.ToGRPCInterceptors(), readonly.StreamServerInterceptor())..., ), ) if err != nil {