diff --git a/internal/services/v1/bulkcheck.go b/internal/services/v1/bulkcheck.go index d5b49bd679..3b07363f72 100644 --- a/internal/services/v1/bulkcheck.go +++ b/internal/services/v1/bulkcheck.go @@ -35,6 +35,8 @@ type bulkChecker struct { dispatchChunkSize uint16 } +const maxBulkCheckCount = 10000 + func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBulkPermissionsRequest) (*v1.CheckBulkPermissionsResponse, error) { atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) if err != nil { diff --git a/internal/services/v1/experimental.go b/internal/services/v1/experimental.go index 0c596a7458..ae2ce67336 100644 --- a/internal/services/v1/experimental.go +++ b/internal/services/v1/experimental.go @@ -3,6 +3,7 @@ package v1 import ( "context" "errors" + "io" "slices" "sort" "strings" @@ -21,6 +22,7 @@ import ( "github.com/authzed/spicedb/internal/middleware/handwrittenvalidation" "github.com/authzed/spicedb/internal/middleware/streamtimeout" "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/relationships" "github.com/authzed/spicedb/internal/services/shared" "github.com/authzed/spicedb/internal/services/v1/options" "github.com/authzed/spicedb/pkg/cursor" @@ -100,6 +102,7 @@ func NewExperimentalServer(dispatch dispatch.Dispatcher, permServerConfig Permis streamtimeout.MustStreamServerInterceptor(config.StreamReadTimeout), ), }, + maxBatchSize: uint64(config.MaxExportBatchSize), bulkChecker: &bulkChecker{ maxAPIDepth: permServerConfig.MaximumAPIDepth, maxCaveatContextSize: permServerConfig.MaxCaveatContextSize, @@ -107,10 +110,6 @@ func NewExperimentalServer(dispatch dispatch.Dispatcher, permServerConfig Permis dispatch: dispatch, dispatchChunkSize: chunkSize, }, - bulkImporter: &bulkImporter{}, - bulkExporter: &bulkExporter{ - maxBatchSize: uint64(config.MaxExportBatchSize), - }, } } @@ -118,9 +117,68 @@ type experimentalServer struct { v1.UnimplementedExperimentalServiceServer shared.WithServiceSpecificInterceptors + maxBatchSize uint64 + bulkChecker *bulkChecker - bulkImporter *bulkImporter - bulkExporter *bulkExporter +} + +type bulkLoadAdapter struct { + stream v1.ExperimentalService_BulkImportRelationshipsServer + referencedNamespaceMap map[string]*typesystem.TypeSystem + referencedCaveatMap map[string]*core.CaveatDefinition + current core.RelationTuple + caveat core.ContextualizedCaveat + + awaitingNamespaces []string + awaitingCaveats []string + + currentBatch []*v1.Relationship + numSent int + err error +} + +func (a *bulkLoadAdapter) Next(_ context.Context) (*core.RelationTuple, error) { + for a.err == nil && a.numSent == len(a.currentBatch) { + // Load a new batch + batch, err := a.stream.Recv() + if err != nil { + a.err = err + if errors.Is(a.err, io.EOF) { + return nil, nil + } + return nil, a.err + } + + a.currentBatch = batch.Relationships + a.numSent = 0 + + a.awaitingNamespaces, a.awaitingCaveats = extractBatchNewReferencedNamespacesAndCaveats( + a.currentBatch, + a.referencedNamespaceMap, + a.referencedCaveatMap, + ) + } + + if len(a.awaitingNamespaces) > 0 || len(a.awaitingCaveats) > 0 { + // Shut down the stream to give our caller a chance to fill in this information + return nil, nil + } + + a.current.Caveat = &a.caveat + a.current.Integrity = nil + tuple.CopyRelationshipToRelationTuple(a.currentBatch[a.numSent], &a.current) + + if err := relationships.ValidateOneRelationship( + a.referencedNamespaceMap, + a.referencedCaveatMap, + &a.current, + relationships.ValidateRelationshipForCreateOrTouch, + ); err != nil { + return nil, err + } + + a.numSent++ + return &a.current, nil } func extractBatchNewReferencedNamespacesAndCaveats( @@ -147,165 +205,8 @@ func extractBatchNewReferencedNamespacesAndCaveats( return lo.Keys(newNamespaces), lo.Keys(newCaveats) } -func (es *experimentalServer) BulkImportRelationships(stream grpc.ClientStreamingServer[v1.BulkImportRelationshipsRequest, v1.BulkImportRelationshipsResponse]) error { -} - -func (es *experimentalServer) BulkExportRelationships( - req *v1.BulkExportRelationshipsRequest, - resp grpc.ServerStreamingServer[v1.BulkExportRelationshipsResponse], -) error { - ctx := resp.Context() - atRevision, _, err := consistency.RevisionFromContext(ctx) - if err != nil { - return shared.RewriteErrorWithoutConfig(ctx, err) - } - - return BulkExport(ctx, datastoremw.MustFromContext(ctx), es.bulkExporter.maxBatchSize, req, atRevision, resp.Send) -} - -// BulkExport implements the BulkExportRelationships API functionality. Given a datastore.Datastore, it will -// export stream via the sender all relationships matched by the incoming request. -// If no cursor is provided, it will fallback to the provided revision. -func BulkExport(ctx context.Context, ds datastore.Datastore, batchSize uint64, req *v1.ExportBulkRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.ExportBulkRelationshipsResponse) error) error { - if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize { - return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize)) - } - - atRevision := fallbackRevision - var curNamespace string - var cur dsoptions.Cursor - if req.OptionalCursor != nil { - var err error - atRevision, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor) - if err != nil { - return shared.RewriteErrorWithoutConfig(ctx, err) - } - } - - reader := ds.SnapshotReader(atRevision) - - namespaces, err := reader.ListAllNamespaces(ctx) - if err != nil { - return shared.RewriteErrorWithoutConfig(ctx, err) - } - - // Make sure the namespaces are always in a stable order - slices.SortFunc(namespaces, func( - lhs datastore.RevisionedDefinition[*core.NamespaceDefinition], - rhs datastore.RevisionedDefinition[*core.NamespaceDefinition], - ) int { - return strings.Compare(lhs.Definition.Name, rhs.Definition.Name) - }) - - // Skip the namespaces that are already fully returned - for cur != nil && len(namespaces) > 0 && namespaces[0].Definition.Name < curNamespace { - namespaces = namespaces[1:] - } - - limit := batchSize - if req.OptionalLimit > 0 { - limit = uint64(req.OptionalLimit) - } - - // Pre-allocate all of the relationships that we might need in order to - // make export easier and faster for the garbage collector. - relsArray := make([]v1.Relationship, limit) - objArray := make([]v1.ObjectReference, limit) - subArray := make([]v1.SubjectReference, limit) - subObjArray := make([]v1.ObjectReference, limit) - caveatArray := make([]v1.ContextualizedCaveat, limit) - for i := range relsArray { - relsArray[i].Resource = &objArray[i] - relsArray[i].Subject = &subArray[i] - relsArray[i].Subject.Object = &subObjArray[i] - } - - emptyRels := make([]*v1.Relationship, limit) - for _, ns := range namespaces { - rels := emptyRels - - // Reset the cursor between namespaces. - if ns.Definition.Name != curNamespace { - cur = nil - } - - // Skip this namespace if a resource type filter was specified. - if req.OptionalRelationshipFilter != nil && req.OptionalRelationshipFilter.ResourceType != "" { - if ns.Definition.Name != req.OptionalRelationshipFilter.ResourceType { - continue - } - } - - // Setup the filter to use for the relationships. - relationshipFilter := datastore.RelationshipsFilter{OptionalResourceType: ns.Definition.Name} - if req.OptionalRelationshipFilter != nil { - rf, err := datastore.RelationshipsFilterFromPublicFilter(req.OptionalRelationshipFilter) - if err != nil { - return shared.RewriteErrorWithoutConfig(ctx, err) - } - - // Overload the namespace name with the one from the request, because each iteration is for a different namespace. - rf.OptionalResourceType = ns.Definition.Name - relationshipFilter = rf - } - - // We want to keep iterating as long as we're sending full batches. - // To bootstrap this loop, we enter the first time with a full rels - // slice of dummy rels that were never sent. - for uint64(len(rels)) == limit { - // Lop off any rels we've already sent - rels = rels[:0] - - tplFn := func(tpl *core.RelationTuple) { - offset := len(rels) - rels = append(rels, &relsArray[offset]) // nozero - tuple.CopyRelationTupleToRelationship(tpl, &relsArray[offset], &caveatArray[offset]) - } - - cur, err = queryForEach( - ctx, - reader, - relationshipFilter, - tplFn, - dsoptions.WithLimit(&limit), - dsoptions.WithAfter(cur), - dsoptions.WithSort(dsoptions.ByResource), - ) - if err != nil { - return shared.RewriteErrorWithoutConfig(ctx, err) - } - - if len(rels) == 0 { - continue - } - - encoded, err := cursor.Encode(&implv1.DecodedCursor{ - VersionOneof: &implv1.DecodedCursor_V1{ - V1: &implv1.V1Cursor{ - Revision: atRevision.String(), - Sections: []string{ - ns.Definition.Name, - tuple.MustString(cur), - }, - }, - }, - }) - if err != nil { - return shared.RewriteErrorWithoutConfig(ctx, err) - } - - if err := sender(&v1.BulkExportRelationshipsResponse{ - AfterResultCursor: encoded, - Relationships: rels, - }); err != nil { - return shared.RewriteErrorWithoutConfig(ctx, err) - } - } - } - return nil -} - -func (es *experimentalServer) ImportBulkRelationships(stream v1.PermissionsService_ImportBulkRelationshipsServer) error { +// TODO: this is now duplicate code with ImportBulkRelationships +func (es *experimentalServer) BulkImportRelationships(stream v1.ExperimentalService_BulkImportRelationshipsServer) error { ds := datastoremw.MustFromContext(stream.Context()) var numWritten uint64 @@ -377,6 +278,20 @@ func (es *experimentalServer) ImportBulkRelationships(stream v1.PermissionsServi }) } +// TODO: this is now duplicate code with ExportBulkRelationships +func (es *experimentalServer) BulkExportRelationships( + req *v1.BulkExportRelationshipsRequest, + resp grpc.ServerStreamingServer[v1.BulkExportRelationshipsResponse], +) error { + ctx := resp.Context() + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + return BulkExport(ctx, datastoremw.MustFromContext(ctx), es.maxBatchSize, req, atRevision, resp.Send) +} + // BulkExport implements the BulkExportRelationships API functionality. Given a datastore.Datastore, it will // export stream via the sender all relationships matched by the incoming request. // If no cursor is provided, it will fallback to the provided revision. @@ -519,8 +434,6 @@ func BulkExport(ctx context.Context, ds datastore.Datastore, batchSize uint64, r return nil } -const maxBulkCheckCount = 10000 - func (es *experimentalServer) BulkCheckPermission(ctx context.Context, req *v1.BulkCheckPermissionRequest) (*v1.BulkCheckPermissionResponse, error) { convertedReq := toCheckBulkPermissionsRequest(req) res, err := es.bulkChecker.checkBulkPermissions(ctx, convertedReq) diff --git a/internal/services/v1/permissions.go b/internal/services/v1/permissions.go index 0fd7362a82..72866dca7f 100644 --- a/internal/services/v1/permissions.go +++ b/internal/services/v1/permissions.go @@ -2,7 +2,11 @@ package v1 import ( "context" + "errors" "fmt" + "io" + "slices" + "strings" "github.com/authzed/authzed-go/pkg/requestmeta" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" @@ -21,14 +25,18 @@ import ( datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/internal/middleware/usagemetrics" "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/relationships" "github.com/authzed/spicedb/internal/services/shared" "github.com/authzed/spicedb/pkg/cursor" "github.com/authzed/spicedb/pkg/datastore" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/middleware/consistency" core "github.com/authzed/spicedb/pkg/proto/core/v1" dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + implv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/typesystem" ) func (ps *permissionServer) rewriteError(ctx context.Context, err error) error { @@ -865,13 +873,288 @@ func GetCaveatContext(ctx context.Context, caveatCtx *structpb.Struct, maxCaveat return caveatContext, nil } -func (ps *permissionServer) ImportBulkRelationships (stream grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error { - adapter := &bulkLoadAdapter{ - stream: stream, +type loadBulkAdapter struct { + stream grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse] + referencedNamespaceMap map[string]*typesystem.TypeSystem + referencedCaveatMap map[string]*core.CaveatDefinition + current core.RelationTuple + caveat core.ContextualizedCaveat + + awaitingNamespaces []string + awaitingCaveats []string + + currentBatch []*v1.Relationship + numSent int + err error +} + +func (a *loadBulkAdapter) Next(_ context.Context) (*core.RelationTuple, error) { + for a.err == nil && a.numSent == len(a.currentBatch) { + // Load a new batch + batch, err := a.stream.Recv() + if err != nil { + a.err = err + if errors.Is(a.err, io.EOF) { + return nil, nil + } + return nil, a.err + } + + a.currentBatch = batch.Relationships + a.numSent = 0 + + a.awaitingNamespaces, a.awaitingCaveats = extractBatchNewReferencedNamespacesAndCaveats( + a.currentBatch, + a.referencedNamespaceMap, + a.referencedCaveatMap, + ) } - response, err := ps.bulkImporter.bulkImportRelationships(stream.Context(), adapter) - if err != nil { + + if len(a.awaitingNamespaces) > 0 || len(a.awaitingCaveats) > 0 { + // Shut down the stream to give our caller a chance to fill in this information + return nil, nil + } + + a.current.Caveat = &a.caveat + a.current.Integrity = nil + tuple.CopyRelationshipToRelationTuple(a.currentBatch[a.numSent], &a.current) + + if err := relationships.ValidateOneRelationship( + a.referencedNamespaceMap, + a.referencedCaveatMap, + &a.current, + relationships.ValidateRelationshipForCreateOrTouch, + ); err != nil { + return nil, err + } + + a.numSent++ + return &a.current, nil +} + +func (ps *permissionServer) ImportBulkRelationships(stream grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error { + ds := datastoremw.MustFromContext(stream.Context()) + + var numWritten uint64 + if _, err := ds.ReadWriteTx(stream.Context(), func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + loadedNamespaces := make(map[string]*typesystem.TypeSystem, 2) + loadedCaveats := make(map[string]*core.CaveatDefinition, 0) + + adapter := &loadBulkAdapter{ + stream: stream, + referencedNamespaceMap: loadedNamespaces, + referencedCaveatMap: loadedCaveats, + current: core.RelationTuple{ + ResourceAndRelation: &core.ObjectAndRelation{}, + Subject: &core.ObjectAndRelation{}, + }, + caveat: core.ContextualizedCaveat{}, + } + resolver := typesystem.ResolverForDatastoreReader(rwt) + + var streamWritten uint64 + var err error + for ; adapter.err == nil && err == nil; streamWritten, err = rwt.BulkLoad(stream.Context(), adapter) { + numWritten += streamWritten + + // The stream has terminated because we're awaiting namespace and/or caveat information + if len(adapter.awaitingNamespaces) > 0 { + nsDefs, err := rwt.LookupNamespacesWithNames(stream.Context(), adapter.awaitingNamespaces) + if err != nil { + return err + } + + for _, nsDef := range nsDefs { + nts, err := typesystem.NewNamespaceTypeSystem(nsDef.Definition, resolver) + if err != nil { + return err + } + + loadedNamespaces[nsDef.Definition.Name] = nts + } + adapter.awaitingNamespaces = nil + } + + if len(adapter.awaitingCaveats) > 0 { + caveats, err := rwt.LookupCaveatsWithNames(stream.Context(), adapter.awaitingCaveats) + if err != nil { + return err + } + + for _, caveat := range caveats { + loadedCaveats[caveat.Definition.Name] = caveat.Definition + } + adapter.awaitingCaveats = nil + } + } + numWritten += streamWritten + return err + }, dsoptions.WithDisableRetries(true)); err != nil { + return shared.RewriteErrorWithoutConfig(stream.Context(), err) + } + + usagemetrics.SetInContext(stream.Context(), &dispatch.ResponseMeta{ + // One request for the whole load + DispatchCount: 1, + }) + + return stream.SendAndClose(&v1.ImportBulkRelationshipsResponse{ + NumLoaded: numWritten, + }) +} + +func (ps *permissionServer) ExportBulkRelationships( + req *v1.ExportBulkRelationshipsRequest, + resp grpc.ServerStreamingServer[v1.ExportBulkRelationshipsResponse], +) error { + ctx := resp.Context() + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + return ExportBulk(ctx, datastoremw.MustFromContext(ctx), uint64(ps.config.MaxBulkExportRelationshipsLimit), req, atRevision, resp.Send) +} + +// ExportBulk implements the ExportBulkRelationships API functionality. Given a datastore.Datastore, it will +// export stream via the sender all relationships matched by the incoming request. +// If no cursor is provided, it will fallback to the provided revision. +func ExportBulk(ctx context.Context, ds datastore.Datastore, batchSize uint64, req *v1.ExportBulkRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.ExportBulkRelationshipsResponse) error) error { + if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize { + return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize)) + } + + atRevision := fallbackRevision + var curNamespace string + var cur dsoptions.Cursor + if req.OptionalCursor != nil { + var err error + atRevision, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + } + + reader := ds.SnapshotReader(atRevision) + + namespaces, err := reader.ListAllNamespaces(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + // Make sure the namespaces are always in a stable order + slices.SortFunc(namespaces, func( + lhs datastore.RevisionedDefinition[*core.NamespaceDefinition], + rhs datastore.RevisionedDefinition[*core.NamespaceDefinition], + ) int { + return strings.Compare(lhs.Definition.Name, rhs.Definition.Name) + }) + + // Skip the namespaces that are already fully returned + for cur != nil && len(namespaces) > 0 && namespaces[0].Definition.Name < curNamespace { + namespaces = namespaces[1:] + } + + limit := batchSize + if req.OptionalLimit > 0 { + limit = uint64(req.OptionalLimit) + } + + // Pre-allocate all of the relationships that we might need in order to + // make export easier and faster for the garbage collector. + relsArray := make([]v1.Relationship, limit) + objArray := make([]v1.ObjectReference, limit) + subArray := make([]v1.SubjectReference, limit) + subObjArray := make([]v1.ObjectReference, limit) + caveatArray := make([]v1.ContextualizedCaveat, limit) + for i := range relsArray { + relsArray[i].Resource = &objArray[i] + relsArray[i].Subject = &subArray[i] + relsArray[i].Subject.Object = &subObjArray[i] + } + + emptyRels := make([]*v1.Relationship, limit) + for _, ns := range namespaces { + rels := emptyRels + + // Reset the cursor between namespaces. + if ns.Definition.Name != curNamespace { + cur = nil + } + + // Skip this namespace if a resource type filter was specified. + if req.OptionalRelationshipFilter != nil && req.OptionalRelationshipFilter.ResourceType != "" { + if ns.Definition.Name != req.OptionalRelationshipFilter.ResourceType { + continue + } + } + + // Setup the filter to use for the relationships. + relationshipFilter := datastore.RelationshipsFilter{OptionalResourceType: ns.Definition.Name} + if req.OptionalRelationshipFilter != nil { + rf, err := datastore.RelationshipsFilterFromPublicFilter(req.OptionalRelationshipFilter) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + // Overload the namespace name with the one from the request, because each iteration is for a different namespace. + rf.OptionalResourceType = ns.Definition.Name + relationshipFilter = rf + } + + // We want to keep iterating as long as we're sending full batches. + // To bootstrap this loop, we enter the first time with a full rels + // slice of dummy rels that were never sent. + for uint64(len(rels)) == limit { + // Lop off any rels we've already sent + rels = rels[:0] + + tplFn := func(tpl *core.RelationTuple) { + offset := len(rels) + rels = append(rels, &relsArray[offset]) // nozero + tuple.CopyRelationTupleToRelationship(tpl, &relsArray[offset], &caveatArray[offset]) + } + + cur, err = queryForEach( + ctx, + reader, + relationshipFilter, + tplFn, + dsoptions.WithLimit(&limit), + dsoptions.WithAfter(cur), + dsoptions.WithSort(dsoptions.ByResource), + ) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + if len(rels) == 0 { + continue + } + + encoded, err := cursor.Encode(&implv1.DecodedCursor{ + VersionOneof: &implv1.DecodedCursor_V1{ + V1: &implv1.V1Cursor{ + Revision: atRevision.String(), + Sections: []string{ + ns.Definition.Name, + tuple.MustString(cur), + }, + }, + }, + }) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + if err := sender(&v1.ExportBulkRelationshipsResponse{ + AfterResultCursor: encoded, + Relationships: rels, + }); err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + } } - stream.SendAndClose(response) + return nil } diff --git a/internal/services/v1/permissions_test.go b/internal/services/v1/permissions_test.go index 4bdd6a1698..fc3b76de7c 100644 --- a/internal/services/v1/permissions_test.go +++ b/internal/services/v1/permissions_test.go @@ -6,8 +6,10 @@ import ( "errors" "fmt" "io" + "math" "math/rand" "slices" + "strconv" "strings" "testing" "time" @@ -17,6 +19,7 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/grpcutil" "github.com/ccoveille/go-safecast" + "github.com/scylladb/go-set" "github.com/stretchr/testify/require" "go.uber.org/goleak" "google.golang.org/genproto/googleapis/rpc/errdetails" @@ -2043,3 +2046,361 @@ func relToCheckBulkRequestItem(rel string) *v1.CheckBulkPermissionsRequestItem { } return item } + +func TestImportBulkRelationships(t *testing.T) { + testCases := []struct { + name string + batchSize func() uint64 + numBatches int + }{ + {"one small batch", constBatch(1), 1}, + {"one big batch", constBatch(10_000), 1}, + {"many small batches", constBatch(5), 1_000}, + {"one empty batch", constBatch(0), 1}, + {"small random batches", randomBatch(1, 10), 100}, + {"big random batches", randomBatch(1_000, 3_000), 50}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + + conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema) + client := v1.NewPermissionsServiceClient(conn) + t.Cleanup(cleanup) + + ctx := context.Background() + + writer, err := client.ImportBulkRelationships(ctx) + require.NoError(err) + + var expectedTotal uint64 + for batchNum := 0; batchNum < tc.numBatches; batchNum++ { + batchSize := tc.batchSize() + batch := make([]*v1.Relationship, 0, batchSize) + + for i := uint64(0); i < batchSize; i++ { + batch = append(batch, rel( + tf.DocumentNS.Name, + strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10), + "viewer", + tf.UserNS.Name, + strconv.FormatUint(i, 10), + "", + )) + } + + err := writer.Send(&v1.ImportBulkRelationshipsRequest{ + Relationships: batch, + }) + require.NoError(err) + + expectedTotal += batchSize + } + + resp, err := writer.CloseAndRecv() + require.NoError(err) + require.Equal(expectedTotal, resp.NumLoaded) + + readerClient := v1.NewPermissionsServiceClient(conn) + stream, err := readerClient.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{ + RelationshipFilter: &v1.RelationshipFilter{ + ResourceType: tf.DocumentNS.Name, + }, + Consistency: &v1.Consistency{ + Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true}, + }, + }) + require.NoError(err) + + var readBack uint64 + for _, err = stream.Recv(); err == nil; _, err = stream.Recv() { + readBack++ + } + require.ErrorIs(err, io.EOF) + require.Equal(expectedTotal, readBack) + }) + } +} + +func TestExportBulkRelationshipsBeyondAllowedLimit(t *testing.T) { + require := require.New(t) + conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithData) + client := v1.NewPermissionsServiceClient(conn) + t.Cleanup(cleanup) + + resp, err := client.ExportBulkRelationships(context.Background(), &v1.ExportBulkRelationshipsRequest{ + OptionalLimit: 10000005, + }) + require.NoError(err) + + _, err = resp.Recv() + require.Error(err) + require.Contains(err.Error(), "provided limit 10000005 is greater than maximum allowed of 100000") +} + +func TestExportBulkRelationships(t *testing.T) { + conn, cleanup, _, _ := testserver.NewTestServer(require.New(t), 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema) + client := v1.NewPermissionsServiceClient(conn) + t.Cleanup(cleanup) + + nsAndRels := []struct { + namespace string + relation string + }{ + {tf.DocumentNS.Name, "viewer"}, + {tf.FolderNS.Name, "viewer"}, + {tf.DocumentNS.Name, "owner"}, + {tf.FolderNS.Name, "owner"}, + {tf.DocumentNS.Name, "editor"}, + {tf.FolderNS.Name, "editor"}, + } + + totalToWrite := 1_000 + expectedRels := set.NewStringSetWithSize(totalToWrite) + batch := make([]*v1.Relationship, totalToWrite) + for i := range batch { + nsAndRel := nsAndRels[i%len(nsAndRels)] + rel := rel( + nsAndRel.namespace, + strconv.Itoa(i), + nsAndRel.relation, + tf.UserNS.Name, + strconv.Itoa(i), + "", + ) + batch[i] = rel + expectedRels.Add(tuple.MustStringRelationship(rel)) + } + + ctx := context.Background() + writer, err := client.ImportBulkRelationships(ctx) + require.NoError(t, err) + + require.NoError(t, writer.Send(&v1.ImportBulkRelationshipsRequest{ + Relationships: batch, + })) + + resp, err := writer.CloseAndRecv() + require.NoError(t, err) + numLoaded, err := safecast.ToInt(resp.NumLoaded) + require.NoError(t, err) + require.Equal(t, totalToWrite, numLoaded) + + testCases := []struct { + batchSize uint32 + paginateEveryN int + }{ + {1_000, math.MaxInt}, + {10, math.MaxInt}, + {1_000, 1}, + {100, 5}, + {97, 7}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%d-%d", tc.batchSize, tc.paginateEveryN), func(t *testing.T) { + require := require.New(t) + + var totalRead int + remainingRels := expectedRels.Copy() + require.Equal(totalToWrite, expectedRels.Size()) + var cursor *v1.Cursor + + var done bool + for !done { + streamCtx, cancel := context.WithCancel(ctx) + + stream, err := client.ExportBulkRelationships(streamCtx, &v1.ExportBulkRelationshipsRequest{ + OptionalLimit: tc.batchSize, + OptionalCursor: cursor, + }) + require.NoError(err) + + for i := 0; i < tc.paginateEveryN; i++ { + batch, err := stream.Recv() + if errors.Is(err, io.EOF) { + done = true + break + } + + require.NoError(err) + require.LessOrEqual(uint64(len(batch.Relationships)), uint64(tc.batchSize)) + require.NotNil(batch.AfterResultCursor) + require.NotEmpty(batch.AfterResultCursor.Token) + + cursor = batch.AfterResultCursor + totalRead += len(batch.Relationships) + + for _, rel := range batch.Relationships { + remainingRels.Remove(tuple.MustStringRelationship(rel)) + } + } + + cancel() + } + + require.Equal(totalToWrite, totalRead) + require.True(remainingRels.IsEmpty(), "rels were not exported %#v", remainingRels.List()) + }) + } +} + +func TestExportBulkRelationshipsWithFilter(t *testing.T) { + testCases := []struct { + name string + filter *v1.RelationshipFilter + expectedCount int + }{ + { + "basic filter", + &v1.RelationshipFilter{ + ResourceType: tf.DocumentNS.Name, + }, + 500, + }, + { + "filter by resource ID", + &v1.RelationshipFilter{ + OptionalResourceId: "12", + }, + 1, + }, + { + "filter by resource ID prefix", + &v1.RelationshipFilter{ + OptionalResourceIdPrefix: "1", + }, + 111, + }, + { + "filter by resource ID prefix and resource type", + &v1.RelationshipFilter{ + ResourceType: tf.DocumentNS.Name, + OptionalResourceIdPrefix: "1", + }, + 55, + }, + { + "filter by invalid resource type", + &v1.RelationshipFilter{ + ResourceType: "invalid", + }, + 0, + }, + } + + batchSize := uint32(14) + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + + conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema) + client := v1.NewPermissionsServiceClient(conn) + t.Cleanup(cleanup) + + nsAndRels := []struct { + namespace string + relation string + }{ + {tf.DocumentNS.Name, "viewer"}, + {tf.FolderNS.Name, "viewer"}, + {tf.DocumentNS.Name, "owner"}, + {tf.FolderNS.Name, "owner"}, + {tf.DocumentNS.Name, "editor"}, + {tf.FolderNS.Name, "editor"}, + } + + expectedRels := set.NewStringSetWithSize(1000) + batch := make([]*v1.Relationship, 1000) + for i := range batch { + nsAndRel := nsAndRels[i%len(nsAndRels)] + rel := rel( + nsAndRel.namespace, + strconv.Itoa(i), + nsAndRel.relation, + tf.UserNS.Name, + strconv.Itoa(i), + "", + ) + batch[i] = rel + + if tc.filter != nil { + filter, err := datastore.RelationshipsFilterFromPublicFilter(tc.filter) + require.NoError(err) + if !filter.Test(tuple.MustFromRelationship(rel)) { + continue + } + } + + expectedRels.Add(tuple.MustStringRelationship(rel)) + } + + require.Equal(tc.expectedCount, expectedRels.Size()) + + ctx := context.Background() + writer, err := client.ImportBulkRelationships(ctx) + require.NoError(err) + + require.NoError(writer.Send(&v1.ImportBulkRelationshipsRequest{ + Relationships: batch, + })) + + _, err = writer.CloseAndRecv() + require.NoError(err) + + var totalRead uint64 + remainingRels := expectedRels.Copy() + var cursor *v1.Cursor + + foundRels := mapz.NewSet[string]() + for { + streamCtx, cancel := context.WithCancel(ctx) + + stream, err := client.ExportBulkRelationships(streamCtx, &v1.ExportBulkRelationshipsRequest{ + OptionalRelationshipFilter: tc.filter, + OptionalLimit: batchSize, + OptionalCursor: cursor, + }) + require.NoError(err) + + batch, err := stream.Recv() + if errors.Is(err, io.EOF) { + cancel() + break + } + + require.NoError(err) + relLength, err := safecast.ToUint32(len(batch.Relationships)) + require.NoError(err) + require.LessOrEqual(relLength, batchSize) + require.NotNil(batch.AfterResultCursor) + require.NotEmpty(batch.AfterResultCursor.Token) + + cursor = batch.AfterResultCursor + totalRead += uint64(len(batch.Relationships)) + + for _, rel := range batch.Relationships { + if tc.filter != nil { + filter, err := datastore.RelationshipsFilterFromPublicFilter(tc.filter) + require.NoError(err) + require.True(filter.Test(tuple.MustFromRelationship(rel)), "relationship did not match filter: %s", rel) + } + + require.True(remainingRels.Has(tuple.MustStringRelationship(rel)), "relationship was not expected or was repeated: %s", rel) + remainingRels.Remove(tuple.MustStringRelationship(rel)) + foundRels.Add(tuple.MustStringRelationship(rel)) + } + + cancel() + } + + // These are statically defined. + expectedCount, _ := safecast.ToUint64(tc.expectedCount) + require.Equal(expectedCount, totalRead, "found: %v", foundRels.AsSlice()) + require.True(remainingRels.IsEmpty(), "rels were not exported %#v", remainingRels.List()) + }) + } +} diff --git a/internal/services/v1/relationships.go b/internal/services/v1/relationships.go index ff99cd65e1..b8feb62b74 100644 --- a/internal/services/v1/relationships.go +++ b/internal/services/v1/relationships.go @@ -141,8 +141,6 @@ func NewPermissionsServer( dispatch: dispatch, dispatchChunkSize: configWithDefaults.DispatchChunkSize, }, - bulkImporter: &bulkImporter{}, - bulkExporter: &bulkExporter{}, } } @@ -154,8 +152,6 @@ type permissionServer struct { config PermissionsServerConfig bulkChecker *bulkChecker - bulkImporter *bulkImporter - bulkExporter *bulkExporter } func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest, resp v1.PermissionsService_ReadRelationshipsServer) error {