From 2171dc3b91d8aded820a568a8f814027fabc610f Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 20 Aug 2024 14:33:25 -0400 Subject: [PATCH 1/2] Move ONRSet into the one internal package in which its used and simplify --- internal/datasets/subjectsetbytype.go | 2 +- .../developmentmembership/foundsubject.go | 4 +- .../foundsubject_test.go | 3 +- internal/developmentmembership/onrset.go | 94 ++++++++++++ internal/developmentmembership/onrset_test.go | 144 ++++++++++++++++++ .../trackingsubjectset.go | 4 +- .../trackingsubjectset_test.go | 3 +- pkg/tuple/onrset.go | 111 -------------- pkg/tuple/onrset_test.go | 143 ----------------- 9 files changed, 244 insertions(+), 264 deletions(-) create mode 100644 internal/developmentmembership/onrset.go create mode 100644 internal/developmentmembership/onrset_test.go delete mode 100644 pkg/tuple/onrset.go delete mode 100644 pkg/tuple/onrset_test.go diff --git a/internal/datasets/subjectsetbytype.go b/internal/datasets/subjectsetbytype.go index 4ece06ecd6..0ba66cd432 100644 --- a/internal/datasets/subjectsetbytype.go +++ b/internal/datasets/subjectsetbytype.go @@ -53,7 +53,7 @@ func (s *SubjectByTypeSet) ForEachType(handler func(rr *core.RelationReference, } } -// Map runs the mapper function over each type of object in the set, returning a new ONRByTypeSet with +// Map runs the mapper function over each type of object in the set, returning a new SubjectByTypeSet with // the object type replaced by that returned by the mapper function. func (s *SubjectByTypeSet) Map(mapper func(rr *core.RelationReference) (*core.RelationReference, error)) (*SubjectByTypeSet, error) { mapped := NewSubjectByTypeSet() diff --git a/internal/developmentmembership/foundsubject.go b/internal/developmentmembership/foundsubject.go index 9b4509aa01..520ee72ac6 100644 --- a/internal/developmentmembership/foundsubject.go +++ b/internal/developmentmembership/foundsubject.go @@ -11,7 +11,7 @@ import ( // NewFoundSubject creates a new FoundSubject for a subject and a set of its resources. func NewFoundSubject(subject *core.DirectSubject, resources ...*core.ObjectAndRelation) FoundSubject { - return FoundSubject{subject.Subject, nil, subject.CaveatExpression, tuple.NewONRSet(resources...)} + return FoundSubject{subject.Subject, nil, subject.CaveatExpression, NewONRSet(resources...)} } // FoundSubject contains a single found subject and all the relationships in which that subject @@ -28,7 +28,7 @@ type FoundSubject struct { // relations are the relations under which the subject lives that informed the locating // of this subject for the root ONR. - relationships *tuple.ONRSet + relationships ONRSet } // GetSubjectId is named to match the Subject interface for the BaseSubjectSet. diff --git a/internal/developmentmembership/foundsubject_test.go b/internal/developmentmembership/foundsubject_test.go index 587997fe3d..2d907fb46c 100644 --- a/internal/developmentmembership/foundsubject_test.go +++ b/internal/developmentmembership/foundsubject_test.go @@ -7,7 +7,6 @@ import ( "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/caveats" - "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/spicedb/pkg/validationfile/blocks" ) @@ -20,7 +19,7 @@ func cfs(subjectType string, subjectID string, subjectRel string, excludedSubjec return FoundSubject{ subject: ONR(subjectType, subjectID, subjectRel), excludedSubjects: excludedSubjects, - relationships: tuple.NewONRSet(), + relationships: NewONRSet(), caveatExpression: caveats.CaveatExprForTesting(caveatName), } } diff --git a/internal/developmentmembership/onrset.go b/internal/developmentmembership/onrset.go new file mode 100644 index 0000000000..caeea58f03 --- /dev/null +++ b/internal/developmentmembership/onrset.go @@ -0,0 +1,94 @@ +package developmentmembership + +import ( + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// onrStruct is a struct holding a namespace and relation. +type onrStruct struct { + Namespace string + ObjectID string + Relation string +} + +// ONRSet is a set of ObjectAndRelation's. +type ONRSet struct { + onrs *mapz.Set[onrStruct] +} + +// NewONRSet creates a new set. +func NewONRSet(onrs ...*core.ObjectAndRelation) ONRSet { + created := ONRSet{ + onrs: mapz.NewSet[onrStruct](), + } + created.Update(onrs) + return created +} + +// Length returns the size of the set. +func (ons ONRSet) Length() uint64 { + return uint64(ons.onrs.Len()) +} + +// IsEmpty returns whether the set is empty. +func (ons ONRSet) IsEmpty() bool { + return ons.onrs.IsEmpty() +} + +// Has returns true if the set contains the given ONR. +func (ons ONRSet) Has(onr *core.ObjectAndRelation) bool { + key := onrStruct{onr.Namespace, onr.ObjectId, onr.Relation} + return ons.onrs.Has(key) +} + +// Add adds the given ONR to the set. Returns true if the object was not in the set before this +// call and false otherwise. +func (ons ONRSet) Add(onr *core.ObjectAndRelation) bool { + key := onrStruct{onr.Namespace, onr.ObjectId, onr.Relation} + return ons.onrs.Add(key) +} + +// Update updates the set by adding the given ONRs to it. +func (ons ONRSet) Update(onrs []*core.ObjectAndRelation) { + for _, onr := range onrs { + ons.Add(onr) + } +} + +// UpdateFrom updates the set by adding the ONRs found in the other set to it. +func (ons ONRSet) UpdateFrom(otherSet ONRSet) { + if otherSet.onrs == nil { + return + } + ons.onrs.Merge(otherSet.onrs) +} + +// Intersect returns an intersection between this ONR set and the other set provided. +func (ons ONRSet) Intersect(otherSet ONRSet) ONRSet { + return ONRSet{ons.onrs.Intersect(otherSet.onrs)} +} + +// Subtract returns a subtraction from this ONR set of the other set provided. +func (ons ONRSet) Subtract(otherSet ONRSet) ONRSet { + return ONRSet{ons.onrs.Subtract(otherSet.onrs)} +} + +// Union returns a copy of this ONR set with the other set's elements added in. +func (ons ONRSet) Union(otherSet ONRSet) ONRSet { + return ONRSet{ons.onrs.Union(otherSet.onrs)} +} + +// AsSlice returns the ONRs found in the set as a slice. +func (ons ONRSet) AsSlice() []*core.ObjectAndRelation { + slice := make([]*core.ObjectAndRelation, 0, ons.Length()) + _ = ons.onrs.ForEach(func(rr onrStruct) error { + slice = append(slice, &core.ObjectAndRelation{ + Namespace: rr.Namespace, + ObjectId: rr.ObjectID, + Relation: rr.Relation, + }) + return nil + }) + return slice +} diff --git a/internal/developmentmembership/onrset_test.go b/internal/developmentmembership/onrset_test.go new file mode 100644 index 0000000000..d6384d9b70 --- /dev/null +++ b/internal/developmentmembership/onrset_test.go @@ -0,0 +1,144 @@ +package developmentmembership + +import ( + "testing" + + "github.com/stretchr/testify/require" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +func TestONRSet(t *testing.T) { + set := NewONRSet() + require.True(t, set.IsEmpty()) + require.Equal(t, uint64(0), set.Length()) + + require.True(t, set.Add(tuple.ParseONR("resource:1#viewer"))) + require.False(t, set.IsEmpty()) + require.Equal(t, uint64(1), set.Length()) + + require.True(t, set.Add(tuple.ParseONR("resource:2#viewer"))) + require.True(t, set.Add(tuple.ParseONR("resource:3#viewer"))) + require.Equal(t, uint64(3), set.Length()) + + require.False(t, set.Add(tuple.ParseONR("resource:1#viewer"))) + require.True(t, set.Add(tuple.ParseONR("resource:1#editor"))) + + require.True(t, set.Has(tuple.ParseONR("resource:1#viewer"))) + require.True(t, set.Has(tuple.ParseONR("resource:1#editor"))) + require.False(t, set.Has(tuple.ParseONR("resource:1#owner"))) + require.False(t, set.Has(tuple.ParseONR("resource:1#admin"))) + require.False(t, set.Has(tuple.ParseONR("resource:1#reader"))) + + require.True(t, set.Has(tuple.ParseONR("resource:2#viewer"))) +} + +func TestONRSetUpdate(t *testing.T) { + set := NewONRSet() + set.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + require.Equal(t, uint64(3), set.Length()) + + set.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:1#editor"), + tuple.ParseONR("resource:1#owner"), + tuple.ParseONR("resource:1#admin"), + tuple.ParseONR("resource:1#reader"), + }) + require.Equal(t, uint64(7), set.Length()) +} + +func TestONRSetIntersect(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + + set2 := NewONRSet() + set2.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:1#editor"), + tuple.ParseONR("resource:1#owner"), + tuple.ParseONR("resource:1#admin"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:1#reader"), + }) + + require.Equal(t, uint64(2), set1.Intersect(set2).Length()) + require.Equal(t, uint64(2), set2.Intersect(set1).Length()) +} + +func TestONRSetSubtract(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + + set2 := NewONRSet() + set2.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:1#editor"), + tuple.ParseONR("resource:1#owner"), + tuple.ParseONR("resource:1#admin"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:1#reader"), + }) + + require.Equal(t, uint64(1), set1.Subtract(set2).Length()) + require.Equal(t, uint64(4), set2.Subtract(set1).Length()) +} + +func TestONRSetUnion(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + + set2 := NewONRSet() + set2.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:1#editor"), + tuple.ParseONR("resource:1#owner"), + tuple.ParseONR("resource:1#admin"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:1#reader"), + }) + + require.Equal(t, uint64(7), set1.Union(set2).Length()) + require.Equal(t, uint64(7), set2.Union(set1).Length()) +} + +func TestONRSetWith(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + + added := set1.Union(NewONRSet(tuple.ParseONR("resource:1#editor"))) + require.Equal(t, uint64(3), set1.Length()) + require.Equal(t, uint64(4), added.Length()) +} + +func TestONRSetAsSlice(t *testing.T) { + set := NewONRSet() + set.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + + require.Equal(t, 3, len(set.AsSlice())) +} diff --git a/internal/developmentmembership/trackingsubjectset.go b/internal/developmentmembership/trackingsubjectset.go index 6cdd8ee462..b66a08940d 100644 --- a/internal/developmentmembership/trackingsubjectset.go +++ b/internal/developmentmembership/trackingsubjectset.go @@ -100,9 +100,7 @@ func (tss *TrackingSubjectSet) getSetForKey(key string) datasets.BaseSubjectSet[ fs.excludedSubjects = excludedSubjects fs.caveatExpression = caveatExpression for _, source := range sources { - if source.relationships != nil { - fs.relationships.UpdateFrom(source.relationships) - } + fs.relationships.UpdateFrom(source.relationships) } return fs }, diff --git a/internal/developmentmembership/trackingsubjectset_test.go b/internal/developmentmembership/trackingsubjectset_test.go index cf79b6968b..700662288e 100644 --- a/internal/developmentmembership/trackingsubjectset_test.go +++ b/internal/developmentmembership/trackingsubjectset_test.go @@ -7,7 +7,6 @@ import ( "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" - "github.com/authzed/spicedb/pkg/tuple" ) func set(subjects ...*core.DirectSubject) *TrackingSubjectSet { @@ -52,7 +51,7 @@ func fs(subjectType string, subjectID string, subjectRel string, excludedSubject return FoundSubject{ subject: ONR(subjectType, subjectID, subjectRel), excludedSubjects: excludedSubjects, - relationships: tuple.NewONRSet(), + relationships: NewONRSet(), } } diff --git a/pkg/tuple/onrset.go b/pkg/tuple/onrset.go deleted file mode 100644 index 04a4ad66af..0000000000 --- a/pkg/tuple/onrset.go +++ /dev/null @@ -1,111 +0,0 @@ -package tuple - -import ( - "maps" - - expmaps "golang.org/x/exp/maps" - - core "github.com/authzed/spicedb/pkg/proto/core/v1" -) - -// ONRSet is a set of ObjectAndRelation's. -type ONRSet struct { - onrs map[string]*core.ObjectAndRelation -} - -// NewONRSet creates a new set. -func NewONRSet(onrs ...*core.ObjectAndRelation) *ONRSet { - created := &ONRSet{ - onrs: map[string]*core.ObjectAndRelation{}, - } - created.Update(onrs) - return created -} - -// Length returns the size of the set. -func (ons *ONRSet) Length() uint64 { - return uint64(len(ons.onrs)) -} - -// IsEmpty returns whether the set is empty. -func (ons *ONRSet) IsEmpty() bool { - return len(ons.onrs) == 0 -} - -// Has returns true if the set contains the given ONR. -func (ons *ONRSet) Has(onr *core.ObjectAndRelation) bool { - _, ok := ons.onrs[StringONR(onr)] - return ok -} - -// Add adds the given ONR to the set. Returns true if the object was not in the set before this -// call and false otherwise. -func (ons *ONRSet) Add(onr *core.ObjectAndRelation) bool { - if _, ok := ons.onrs[StringONR(onr)]; ok { - return false - } - - ons.onrs[StringONR(onr)] = onr - return true -} - -// Update updates the set by adding the given ONRs to it. -func (ons *ONRSet) Update(onrs []*core.ObjectAndRelation) { - for _, onr := range onrs { - ons.Add(onr) - } -} - -// UpdateFrom updates the set by adding the ONRs found in the other set to it. -func (ons *ONRSet) UpdateFrom(otherSet *ONRSet) { - for _, onr := range otherSet.onrs { - ons.Add(onr) - } -} - -// Intersect returns an intersection between this ONR set and the other set provided. -func (ons *ONRSet) Intersect(otherSet *ONRSet) *ONRSet { - updated := NewONRSet() - for _, onr := range ons.onrs { - if otherSet.Has(onr) { - updated.Add(onr) - } - } - return updated -} - -// Subtract returns a subtraction from this ONR set of the other set provided. -func (ons *ONRSet) Subtract(otherSet *ONRSet) *ONRSet { - updated := NewONRSet() - for _, onr := range ons.onrs { - if !otherSet.Has(onr) { - updated.Add(onr) - } - } - return updated -} - -// With returns a copy of this ONR set with the given element added. -func (ons *ONRSet) With(onr *core.ObjectAndRelation) *ONRSet { - updated := &ONRSet{ - onrs: maps.Clone(ons.onrs), - } - updated.Add(onr) - return updated -} - -// Union returns a copy of this ONR set with the other set's elements added in. -func (ons *ONRSet) Union(otherSet *ONRSet) *ONRSet { - updated := &ONRSet{ - onrs: maps.Clone(ons.onrs), - } - for _, current := range otherSet.onrs { - updated.Add(current) - } - return updated -} - -// AsSlice returns the ONRs found in the set as a slice. -func (ons *ONRSet) AsSlice() []*core.ObjectAndRelation { - return expmaps.Values(ons.onrs) -} diff --git a/pkg/tuple/onrset_test.go b/pkg/tuple/onrset_test.go deleted file mode 100644 index e102ca2d8e..0000000000 --- a/pkg/tuple/onrset_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package tuple - -import ( - "testing" - - "github.com/stretchr/testify/require" - - core "github.com/authzed/spicedb/pkg/proto/core/v1" -) - -func TestONRSet(t *testing.T) { - set := NewONRSet() - require.True(t, set.IsEmpty()) - require.Equal(t, uint64(0), set.Length()) - - require.True(t, set.Add(ParseONR("resource:1#viewer"))) - require.False(t, set.IsEmpty()) - require.Equal(t, uint64(1), set.Length()) - - require.True(t, set.Add(ParseONR("resource:2#viewer"))) - require.True(t, set.Add(ParseONR("resource:3#viewer"))) - require.Equal(t, uint64(3), set.Length()) - - require.False(t, set.Add(ParseONR("resource:1#viewer"))) - require.True(t, set.Add(ParseONR("resource:1#editor"))) - - require.True(t, set.Has(ParseONR("resource:1#viewer"))) - require.True(t, set.Has(ParseONR("resource:1#editor"))) - require.False(t, set.Has(ParseONR("resource:1#owner"))) - require.False(t, set.Has(ParseONR("resource:1#admin"))) - require.False(t, set.Has(ParseONR("resource:1#reader"))) - - require.True(t, set.Has(ParseONR("resource:2#viewer"))) -} - -func TestONRSetUpdate(t *testing.T) { - set := NewONRSet() - set.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - require.Equal(t, uint64(3), set.Length()) - - set.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:1#editor"), - ParseONR("resource:1#owner"), - ParseONR("resource:1#admin"), - ParseONR("resource:1#reader"), - }) - require.Equal(t, uint64(7), set.Length()) -} - -func TestONRSetIntersect(t *testing.T) { - set1 := NewONRSet() - set1.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - - set2 := NewONRSet() - set2.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:1#editor"), - ParseONR("resource:1#owner"), - ParseONR("resource:1#admin"), - ParseONR("resource:2#viewer"), - ParseONR("resource:1#reader"), - }) - - require.Equal(t, uint64(2), set1.Intersect(set2).Length()) - require.Equal(t, uint64(2), set2.Intersect(set1).Length()) -} - -func TestONRSetSubtract(t *testing.T) { - set1 := NewONRSet() - set1.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - - set2 := NewONRSet() - set2.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:1#editor"), - ParseONR("resource:1#owner"), - ParseONR("resource:1#admin"), - ParseONR("resource:2#viewer"), - ParseONR("resource:1#reader"), - }) - - require.Equal(t, uint64(1), set1.Subtract(set2).Length()) - require.Equal(t, uint64(4), set2.Subtract(set1).Length()) -} - -func TestONRSetUnion(t *testing.T) { - set1 := NewONRSet() - set1.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - - set2 := NewONRSet() - set2.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:1#editor"), - ParseONR("resource:1#owner"), - ParseONR("resource:1#admin"), - ParseONR("resource:2#viewer"), - ParseONR("resource:1#reader"), - }) - - require.Equal(t, uint64(7), set1.Union(set2).Length()) - require.Equal(t, uint64(7), set2.Union(set1).Length()) -} - -func TestONRSetWith(t *testing.T) { - set1 := NewONRSet() - set1.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - - added := set1.With(ParseONR("resource:1#editor")) - require.Equal(t, uint64(3), set1.Length()) - require.Equal(t, uint64(4), added.Length()) -} - -func TestONRSetAsSlice(t *testing.T) { - set := NewONRSet() - set.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - - require.Equal(t, 3, len(set.AsSlice())) -} From bb10dfeb4212e2aeeb9df72ff91d0e204b682f04 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 20 Aug 2024 16:17:16 -0400 Subject: [PATCH 2/2] Remove the ONRTypeSet and create a new combined CheckDispatchSet This new set does all the tracking and mapping previously handled by the ONRTypeSet and a custom multimap, in a more compact and better tested implementation This also allows us to avoid requiring all results for those redispatches that do not have caveats, even if one of the subject types does --- internal/graph/check.go | 167 +++------- internal/graph/checkdispatchset.go | 166 ++++++++++ internal/graph/checkdispatchset_test.go | 390 ++++++++++++++++++++++++ internal/graph/membershipset.go | 12 +- pkg/tuple/onrbytypeset.go | 82 ----- pkg/tuple/onrbytypeset_test.go | 61 ---- 6 files changed, 616 insertions(+), 262 deletions(-) create mode 100644 internal/graph/checkdispatchset.go create mode 100644 internal/graph/checkdispatchset_test.go delete mode 100644 pkg/tuple/onrbytypeset.go delete mode 100644 pkg/tuple/onrbytypeset_test.go diff --git a/internal/graph/check.go b/internal/graph/check.go index b9880acda8..77f498ed29 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -19,7 +19,6 @@ import ( "github.com/authzed/spicedb/internal/taskrunner" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/genutil/mapz" - "github.com/authzed/spicedb/pkg/genutil/slicez" nspkg "github.com/authzed/spicedb/pkg/namespace" core "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" @@ -296,11 +295,6 @@ func combineWithCheckHints(result CheckResult, req ValidatedCheckRequest) CheckR return result } -type directDispatch struct { - resourceType *core.RelationReference - resourceIds []string -} - func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequestContext, relation *core.Relation) CheckResult { ctx, span := tracer.Start(ctx, "checkDirect") defer span.End() @@ -447,56 +441,30 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest defer it.Close() queryCount += 1.0 - // Find the subjects over which to dispatch. - subjectsToDispatch := tuple.NewONRByTypeSet() - relationshipsBySubjectONR := mapz.NewMultiMap[string, *core.RelationTuple]() - hasCaveats := false - + // Build the set of subjects over which to dispatch, along with metadata for + // mapping over caveats (if any). + checksToDispatch := newCheckDispatchSet() for tpl := it.Next(); tpl != nil; tpl = it.Next() { if it.Err() != nil { return checkResultError(NewCheckFailureErr(it.Err()), emptyMetadata) } - - spiceerrors.DebugAssert(func() bool { return tpl.Subject.Relation != Ellipsis }, "got a terminal for a non-terminal query") - - // Add the subject as an object over which to dispatch. - subjectsToDispatch.Add(tpl.Subject) - relationshipsBySubjectONR.Add(tuple.StringONR(tpl.Subject), tpl) - if tpl.Caveat != nil && tpl.Caveat.CaveatName != "" { - hasCaveats = true - } + checksToDispatch.addForRelationship(tpl) } it.Close() - // Convert the subjects into batched requests. - // To simplify the logic, +1 is added to account for the situation where - // the number of elements is less than the chunk size, and spare us some annoying code. - expectedNumberOfChunks := subjectsToDispatch.ValueLen()/int(crc.dispatchChunkSize) + 1 - toDispatch := make([]directDispatch, 0, expectedNumberOfChunks) - subjectsToDispatch.ForEachType(func(rr *core.RelationReference, resourceIds []string) { - chunkCount := 0.0 - slicez.ForEachChunk(resourceIds, crc.dispatchChunkSize, func(resourceIdChunk []string) { - chunkCount++ - toDispatch = append(toDispatch, directDispatch{ - resourceType: rr, - resourceIds: resourceIdChunk, - }) - }) - dispatchChunkCountHistogram.Observe(chunkCount) - }) - - // If there are caveats on the incoming relationships, then we must require all results to be - // found, as we need to ensure that all caveats are used for building the final expression. - resultsSetting := crc.resultsSetting - if hasCaveats { - resultsSetting = v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS - } - // Dispatch and map to the associated resource ID(s). - result := union(ctx, crc, toDispatch, func(ctx context.Context, crc currentRequestContext, dd directDispatch) CheckResult { + toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize) + result := union(ctx, crc, toDispatch, func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) CheckResult { + // If there are caveats on any of the incoming relationships for the subjects to dispatch, then we must require all + // results to be found, as we need to ensure that all caveats are used for building the final expression. + resultsSetting := crc.resultsSetting + if dd.hasIncomingCaveats { + resultsSetting = v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS + } + childResult := cc.dispatch(ctx, crc, ValidatedCheckRequest{ &v1.DispatchCheckRequest{ - ResourceRelation: dd.resourceType, + ResourceRelation: tuple.RelationReference(dd.resourceType.namespace, dd.resourceType.relation), ResourceIds: dd.resourceIds, Subject: crc.parentReq.Subject, ResultsSetting: resultsSetting, @@ -513,21 +481,24 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest return childResult } - return mapFoundResources(childResult, dd.resourceType, relationshipsBySubjectONR) + return mapFoundResources(childResult, dd.resourceType, checksToDispatch) }, cc.concurrencyLimit) return combineResultWithFoundResources(result, foundResources) } -func mapFoundResources(result CheckResult, resourceType *core.RelationReference, relationshipsBySubjectONR *mapz.MultiMap[string, *core.RelationTuple]) CheckResult { +func mapFoundResources(result CheckResult, resourceType relationRef, checksToDispatch *checkDispatchSet) CheckResult { // Map any resources found to the parent resource IDs. membershipSet := NewMembershipSet() for foundResourceID, result := range result.Resp.ResultsByResourceId { - subjectKey := tuple.StringONRStrings(resourceType.Namespace, foundResourceID, resourceType.Relation) + resourceIDAndCaveats := checksToDispatch.mappingsForSubject(resourceType.namespace, foundResourceID, resourceType.relation) + + spiceerrors.DebugAssert(func() bool { + return len(resourceIDAndCaveats) > 0 + }, "found resource ID without associated caveats") - tuples, _ := relationshipsBySubjectONR.Get(subjectKey) - for _, relationTuple := range tuples { - membershipSet.AddMemberViaRelationship(relationTuple.ResourceAndRelation.ObjectId, result.Expression, relationTuple) + for _, riac := range resourceIDAndCaveats { + membershipSet.AddMemberWithParentCaveat(riac.resourceID, result.Expression, riac.caveat) } } @@ -595,7 +566,7 @@ func (cc *ConcurrentChecker) runSetOperation(ctx context.Context, crc currentReq } } -func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc currentRequestContext, cu *core.ComputedUserset, rr *core.RelationReference, resourceIds []string) CheckResult { +func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc currentRequestContext, cu *core.ComputedUserset, rr *relationRef, resourceIds []string) CheckResult { ctx, span := tracer.Start(ctx, cu.Relation) defer span.End() @@ -606,7 +577,7 @@ func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc curre return checkResultError(spiceerrors.MustBugf("computed userset for tupleset without tuples"), emptyMetadata) } - startNamespace = rr.Namespace + startNamespace = rr.namespace targetResourceIds = resourceIds } else if cu.Object == core.ComputedUserset_TUPLE_OBJECT { if rr != nil { @@ -693,7 +664,7 @@ type ttu[T relation] interface { type checkResultWithType struct { CheckResult - relationType *core.RelationReference + relationType relationRef } func checkIntersectionTupleToUserset( @@ -719,38 +690,21 @@ func checkIntersectionTupleToUserset( } defer it.Close() - subjectsToDispatch := tuple.NewONRByTypeSet() - relationshipsBySubjectONR := mapz.NewMultiMap[string, *core.RelationTuple]() + checksToDispatch := newCheckDispatchSet() subjectsByResourceID := mapz.NewMultiMap[string, *core.ObjectAndRelation]() for tpl := it.Next(); tpl != nil; tpl = it.Next() { if it.Err() != nil { return checkResultError(NewCheckFailureErr(it.Err()), emptyMetadata) } - subjectsToDispatch.Add(tpl.Subject) - relationshipsBySubjectONR.Add(tuple.StringONR(tpl.Subject), tpl) + checksToDispatch.addForRelationship(tpl) subjectsByResourceID.Add(tpl.ResourceAndRelation.ObjectId, tpl.Subject) } it.Close() // Convert the subjects into batched requests. - // To simplify the logic, +1 is added to account for the situation where - // the number of elements is less than the chunk size, and spare us some annoying code. - expectedNumberOfChunks := uint16(subjectsToDispatch.ValueLen())/crc.dispatchChunkSize + 1 - toDispatch := make([]directDispatch, 0, expectedNumberOfChunks) - subjectsToDispatch.ForEachType(func(rr *core.RelationReference, resourceIds []string) { - chunkCount := 0.0 - slicez.ForEachChunk(resourceIds, crc.dispatchChunkSize, func(resourceIdChunk []string) { - chunkCount++ - toDispatch = append(toDispatch, directDispatch{ - resourceType: rr, - resourceIds: resourceIdChunk, - }) - }) - dispatchChunkCountHistogram.Observe(chunkCount) - }) - - if subjectsToDispatch.IsEmpty() { + toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize) + if len(toDispatch) == 0 { return noMembers() } @@ -766,8 +720,9 @@ func checkIntersectionTupleToUserset( dispatchChunkSize: crc.dispatchChunkSize, }, toDispatch, - func(ctx context.Context, crc currentRequestContext, dd directDispatch) checkResultWithType { - childResult := cc.checkComputedUserset(ctx, crc, ttu.GetComputedUserset(), dd.resourceType, dd.resourceIds) + func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) checkResultWithType { + resourceType := dd.resourceType + childResult := cc.checkComputedUserset(ctx, crc, ttu.GetComputedUserset(), &resourceType, dd.resourceIds) return checkResultWithType{ CheckResult: childResult, relationType: dd.resourceType, @@ -780,19 +735,18 @@ func checkIntersectionTupleToUserset( } // Create a membership set per-subject-type, representing the membership for each of the dispatched subjects. - resultsByDispatchedSubject := map[string]*MembershipSet{} + resultsByDispatchedSubject := map[relationRef]*MembershipSet{} combinedMetadata := emptyMetadata for _, result := range chunkResults { if result.Err != nil { return checkResultError(result.Err, emptyMetadata) } - typeKey := tuple.StringRR(result.relationType) - if _, ok := resultsByDispatchedSubject[typeKey]; !ok { - resultsByDispatchedSubject[typeKey] = NewMembershipSet() + if _, ok := resultsByDispatchedSubject[result.relationType]; !ok { + resultsByDispatchedSubject[result.relationType] = NewMembershipSet() } - resultsByDispatchedSubject[typeKey].UnionWith(result.Resp.ResultsByResourceId) + resultsByDispatchedSubject[result.relationType].UnionWith(result.Resp.ResultsByResourceId) combinedMetadata = combineResponseMetadata(combinedMetadata, result.Resp.Metadata) } @@ -813,11 +767,7 @@ func checkIntersectionTupleToUserset( // was found for each. If any are not found, then the resource ID is not a member. // We also collect up the caveats for each subject, as they will be added to the final result. for _, subject := range subjects { - subjectTypeKey := tuple.StringRR(&core.RelationReference{ - Namespace: subject.Namespace, - Relation: subject.Relation, - }) - + subjectTypeKey := relationRef{subject.Namespace, subject.Relation} results, ok := resultsByDispatchedSubject[subjectTypeKey] if !ok { hasAllSubjects = false @@ -835,11 +785,10 @@ func checkIntersectionTupleToUserset( } // Add any caveats on the subject from the starting relationship(s) as well. - subjectKey := tuple.StringONR(subject) - tuples, _ := relationshipsBySubjectONR.Get(subjectKey) - for _, relationTuple := range tuples { - if relationTuple.Caveat != nil { - caveats = append(caveats, wrapCaveat(relationTuple.Caveat)) + resourceIDAndCaveats := checksToDispatch.mappingsForSubject(subject.Namespace, subject.ObjectId, subject.Relation) + for _, riac := range resourceIDAndCaveats { + if riac.caveat != nil { + caveats = append(caveats, wrapCaveat(riac.caveat)) } } } @@ -904,46 +853,28 @@ func checkTupleToUserset[T relation]( } defer it.Close() - subjectsToDispatch := tuple.NewONRByTypeSet() - relationshipsBySubjectONR := mapz.NewMultiMap[string, *core.RelationTuple]() + checksToDispatch := newCheckDispatchSet() for tpl := it.Next(); tpl != nil; tpl = it.Next() { if it.Err() != nil { return checkResultError(NewCheckFailureErr(it.Err()), emptyMetadata) } - - subjectsToDispatch.Add(tpl.Subject) - relationshipsBySubjectONR.Add(tuple.StringONR(tpl.Subject), tpl) + checksToDispatch.addForRelationship(tpl) } it.Close() - // Convert the subjects into batched requests. - // To simplify the logic, +1 is added to account for the situation where - // the number of elements is less than the chunk size, and spare us some annoying code. - expectedNumberOfChunks := uint16(subjectsToDispatch.ValueLen())/crc.dispatchChunkSize + 1 - toDispatch := make([]directDispatch, 0, expectedNumberOfChunks) - subjectsToDispatch.ForEachType(func(rr *core.RelationReference, resourceIds []string) { - chunkCount := 0.0 - slicez.ForEachChunk(resourceIds, crc.dispatchChunkSize, func(resourceIdChunk []string) { - chunkCount++ - toDispatch = append(toDispatch, directDispatch{ - resourceType: rr, - resourceIds: resourceIdChunk, - }) - }) - dispatchChunkCountHistogram.Observe(chunkCount) - }) - + toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize) return combineWithComputedHints(union( ctx, crc, toDispatch, - func(ctx context.Context, crc currentRequestContext, dd directDispatch) CheckResult { - childResult := cc.checkComputedUserset(ctx, crc, ttu.GetComputedUserset(), dd.resourceType, dd.resourceIds) + func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) CheckResult { + resourceType := dd.resourceType + childResult := cc.checkComputedUserset(ctx, crc, ttu.GetComputedUserset(), &resourceType, dd.resourceIds) if childResult.Err != nil { return childResult } - return mapFoundResources(childResult, dd.resourceType, relationshipsBySubjectONR) + return mapFoundResources(childResult, dd.resourceType, checksToDispatch) }, cc.concurrencyLimit, ), hintsToReturn) diff --git a/internal/graph/checkdispatchset.go b/internal/graph/checkdispatchset.go new file mode 100644 index 0000000000..331157e580 --- /dev/null +++ b/internal/graph/checkdispatchset.go @@ -0,0 +1,166 @@ +package graph + +import ( + "sort" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/genutil/slicez" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// checkDispatchSet is the set of subjects over which check will need to dispatch +// as subproblems in order to answer the parent problem. +type checkDispatchSet struct { + // bySubjectType is a map from the type of subject to the set of subjects of that type + // over which to dispatch, along with information indicating whether caveats are present + // for that chunk. + bySubjectType map[relationRef]map[string]bool + + // bySubject is a map from the subject to the set of resources for which the subject + // has a relationship, along with the caveats that apply to that relationship. + bySubject *mapz.MultiMap[subjectRef, resourceIDAndCaveat] +} + +// checkDispatchChunk is a chunk of subjects over which to dispatch a check operation. +type checkDispatchChunk struct { + // resourceType is the type of the subjects in this chunk. + resourceType relationRef + + // resourceIds is the set of subjects in this chunk. + resourceIds []string + + // hasIncomingCaveats is true if any of the subjects in this chunk have incoming caveats. + // This is used to determine whether the check operation should be dispatched requiring + // all results. + hasIncomingCaveats bool +} + +// subjectIDAndHasCaveat is a tuple of a subject ID and whether it has a caveat. +type subjectIDAndHasCaveat struct { + // objectID is the ID of the subject. + objectID string + + // hasIncomingCaveats is true if the subject has a caveat. + hasIncomingCaveats bool +} + +// resourceIDAndCaveat is a tuple of a resource ID and a caveat. +type resourceIDAndCaveat struct { + // resourceID is the ID of the resource. + resourceID string + + // caveat is the caveat that applies to the relationship between the subject and the resource. + // May be nil. + caveat *core.ContextualizedCaveat +} + +// relationRef is a tuple of a namespace and a relation. +type relationRef struct { + namespace string + relation string +} + +// subjectRef is a tuple of a namespace, an object ID, and a relation. +type subjectRef struct { + namespace string + objectID string + relation string +} + +// newCheckDispatchSet creates and returns a new checkDispatchSet. +func newCheckDispatchSet() *checkDispatchSet { + return &checkDispatchSet{ + bySubjectType: map[relationRef]map[string]bool{}, + bySubject: mapz.NewMultiMap[subjectRef, resourceIDAndCaveat](), + } +} + +// Add adds the specified ObjectAndRelation to the set. +func (s *checkDispatchSet) addForRelationship(tpl *core.RelationTuple) { + // Add an entry for the subject pointing to the resource ID and caveat for the subject. + riac := resourceIDAndCaveat{ + resourceID: tpl.ResourceAndRelation.ObjectId, + caveat: tpl.Caveat, + } + subjectRef := subjectRef{ + namespace: tpl.Subject.Namespace, + objectID: tpl.Subject.ObjectId, + relation: tpl.Subject.Relation, + } + s.bySubject.Add(subjectRef, riac) + + // Add the subject ID to the map of subjects for the type of subject. + siac := subjectIDAndHasCaveat{ + objectID: tpl.Subject.ObjectId, + hasIncomingCaveats: tpl.Caveat != nil && tpl.Caveat.CaveatName != "", + } + subjectTypeRef := relationRef{namespace: tpl.Subject.Namespace, relation: tpl.Subject.Relation} + + subjectIDsForType, ok := s.bySubjectType[subjectTypeRef] + if !ok { + subjectIDsForType = make(map[string]bool) + s.bySubjectType[subjectTypeRef] = subjectIDsForType + } + + // If a caveat exists for the subject ID in any branch, the whole branch is considered caveated. + subjectIDsForType[tpl.Subject.ObjectId] = siac.hasIncomingCaveats || subjectIDsForType[tpl.Subject.ObjectId] +} + +func (s *checkDispatchSet) dispatchChunks(dispatchChunkSize uint16) []checkDispatchChunk { + // Start with an estimate of one chunk per type, plus one for the remainder. + expectedNumberOfChunks := len(s.bySubjectType) + 1 + toDispatch := make([]checkDispatchChunk, 0, expectedNumberOfChunks) + + // For each type of subject, create chunks of the IDs over which to dispatch. + for subjectType, subjectIDsAndHasCaveats := range s.bySubjectType { + entries := make([]subjectIDAndHasCaveat, 0, len(subjectIDsAndHasCaveats)) + for objectID, hasIncomingCaveats := range subjectIDsAndHasCaveats { + entries = append(entries, subjectIDAndHasCaveat{objectID: objectID, hasIncomingCaveats: hasIncomingCaveats}) + } + + // Sort the list of subject IDs by whether they have caveats and then the ID itself. + sort.Slice(entries, func(i, j int) bool { + iHasCaveat := entries[i].hasIncomingCaveats + jHasCaveat := entries[j].hasIncomingCaveats + if iHasCaveat == jHasCaveat { + return entries[i].objectID < entries[j].objectID + } + return iHasCaveat && !jHasCaveat + }) + + chunkCount := 0.0 + slicez.ForEachChunk(entries, dispatchChunkSize, func(subjectIdChunk []subjectIDAndHasCaveat) { + chunkCount++ + + subjectIDsToDispatch := make([]string, 0, len(subjectIdChunk)) + hasIncomingCaveats := false + for _, entry := range subjectIdChunk { + subjectIDsToDispatch = append(subjectIDsToDispatch, entry.objectID) + hasIncomingCaveats = hasIncomingCaveats || entry.hasIncomingCaveats + } + + toDispatch = append(toDispatch, checkDispatchChunk{ + resourceType: subjectType, + resourceIds: subjectIDsToDispatch, + hasIncomingCaveats: hasIncomingCaveats, + }) + }) + dispatchChunkCountHistogram.Observe(chunkCount) + } + + return toDispatch +} + +// mappingsForSubject returns the mappings that apply to the relationship between the specified +// subject and any of its resources. The returned caveats include the resource ID of the resource +// that the subject has a relationship with. +func (s *checkDispatchSet) mappingsForSubject(subjectType string, subjectObjectID string, subjectRelation string) []resourceIDAndCaveat { + results, ok := s.bySubject.Get(subjectRef{ + namespace: subjectType, + objectID: subjectObjectID, + relation: subjectRelation, + }) + spiceerrors.DebugAssert(func() bool { return ok }, "no caveats found for subject %s:%s:%s", subjectType, subjectObjectID, subjectRelation) + return results +} diff --git a/internal/graph/checkdispatchset_test.go b/internal/graph/checkdispatchset_test.go new file mode 100644 index 0000000000..cefbd17aed --- /dev/null +++ b/internal/graph/checkdispatchset_test.go @@ -0,0 +1,390 @@ +package graph + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/internal/caveats" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +var caveatForTesting = caveats.CaveatForTesting + +func TestCheckDispatchSet(t *testing.T) { + tcs := []struct { + name string + relationships []*core.RelationTuple + dispatchChunkSize uint16 + expectedChunks []checkDispatchChunk + expectedMappings map[string][]resourceIDAndCaveat + }{ + { + "basic", + []*core.RelationTuple{ + tuple.MustParse("document:somedoc#viewer@group:1#member"), + tuple.MustParse("document:somedoc#viewer@group:2#member"), + tuple.MustParse("document:somedoc#viewer@group:3#member"), + }, + 100, + []checkDispatchChunk{ + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"1", "2", "3"}, + hasIncomingCaveats: false, + }, + }, + map[string][]resourceIDAndCaveat{ + "group:1#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:2#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:3#member": { + {resourceID: "somedoc", caveat: nil}, + }, + }, + }, + { + "basic chunking", + []*core.RelationTuple{ + tuple.MustParse("document:somedoc#viewer@group:1#member"), + tuple.MustParse("document:somedoc#viewer@group:2#member"), + tuple.MustParse("document:somedoc#viewer@group:3#member"), + }, + 2, + []checkDispatchChunk{ + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"1", "2"}, + hasIncomingCaveats: false, + }, + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"3"}, + hasIncomingCaveats: false, + }, + }, + map[string][]resourceIDAndCaveat{ + "group:1#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:2#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:3#member": { + {resourceID: "somedoc", caveat: nil}, + }, + }, + }, + { + "different subject types", + []*core.RelationTuple{ + tuple.MustParse("document:somedoc#viewer@group:1#member"), + tuple.MustParse("document:somedoc#viewer@group:2#member"), + tuple.MustParse("document:somedoc#viewer@group:3#member"), + tuple.MustParse("document:somedoc#viewer@anothertype:1#member"), + tuple.MustParse("document:somedoc#viewer@anothertype:2#member"), + tuple.MustParse("document:somedoc#viewer@anothertype:3#member"), + }, + 100, + []checkDispatchChunk{ + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"1", "2", "3"}, + hasIncomingCaveats: false, + }, + { + resourceType: relationRef{namespace: "anothertype", relation: "member"}, + resourceIds: []string{"1", "2", "3"}, + hasIncomingCaveats: false, + }, + }, + map[string][]resourceIDAndCaveat{ + "group:1#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:2#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:3#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "anothertype:1#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "anothertype:2#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "anothertype:3#member": { + {resourceID: "somedoc", caveat: nil}, + }, + }, + }, + { + "different subject types mixed", + []*core.RelationTuple{ + tuple.MustParse("document:somedoc#viewer@group:1#member"), + tuple.MustParse("document:somedoc#viewer@anothertype:1#member"), + tuple.MustParse("document:somedoc#viewer@anothertype:2#member"), + tuple.MustParse("document:somedoc#viewer@group:2#member"), + tuple.MustParse("document:somedoc#viewer@group:3#member"), + tuple.MustParse("document:somedoc#viewer@anothertype:3#member"), + }, + 100, + []checkDispatchChunk{ + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"1", "2", "3"}, + hasIncomingCaveats: false, + }, + { + resourceType: relationRef{namespace: "anothertype", relation: "member"}, + resourceIds: []string{"1", "2", "3"}, + hasIncomingCaveats: false, + }, + }, + map[string][]resourceIDAndCaveat{ + "group:1#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:2#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:3#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "anothertype:1#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "anothertype:2#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "anothertype:3#member": { + {resourceID: "somedoc", caveat: nil}, + }, + }, + }, + { + "different subject types with chunking", + []*core.RelationTuple{ + tuple.MustParse("document:somedoc#viewer@group:1#member"), + tuple.MustParse("document:somedoc#viewer@group:2#member"), + tuple.MustParse("document:somedoc#viewer@group:3#member"), + tuple.MustParse("document:somedoc#viewer@anothertype:1#member"), + tuple.MustParse("document:somedoc#viewer@anothertype:2#member"), + tuple.MustParse("document:somedoc#viewer@anothertype:3#member"), + }, + 2, + []checkDispatchChunk{ + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"1", "2"}, + hasIncomingCaveats: false, + }, + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"3"}, + hasIncomingCaveats: false, + }, + { + resourceType: relationRef{namespace: "anothertype", relation: "member"}, + resourceIds: []string{"1", "2"}, + hasIncomingCaveats: false, + }, + { + resourceType: relationRef{namespace: "anothertype", relation: "member"}, + resourceIds: []string{"3"}, + hasIncomingCaveats: false, + }, + }, + map[string][]resourceIDAndCaveat{ + "group:1#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:2#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:3#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "anothertype:1#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "anothertype:2#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "anothertype:3#member": { + {resourceID: "somedoc", caveat: nil}, + }, + }, + }, + { + "some caveated members", + []*core.RelationTuple{ + tuple.MustParse("document:somedoc#viewer@group:1#member[somecaveat]"), + tuple.MustParse("document:somedoc#viewer@group:2#member"), + tuple.MustParse("document:somedoc#viewer@group:3#member"), + }, + 100, + []checkDispatchChunk{ + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"1", "2", "3"}, + hasIncomingCaveats: true, + }, + }, + map[string][]resourceIDAndCaveat{ + "group:1#member": { + {resourceID: "somedoc", caveat: caveatForTesting("somecaveat")}, + }, + "group:2#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:3#member": { + {resourceID: "somedoc", caveat: nil}, + }, + }, + }, + { + "caveated members combined when chunking", + []*core.RelationTuple{ + tuple.MustParse("document:somedoc#viewer@group:1#member[somecaveat]"), + tuple.MustParse("document:somedoc#viewer@group:2#member"), + tuple.MustParse("document:somedoc#viewer@group:3#member"), + tuple.MustParse("document:somedoc#viewer@group:4#member[somecaveat]"), + }, + 2, + []checkDispatchChunk{ + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"2", "3"}, + hasIncomingCaveats: false, + }, + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"1", "4"}, + hasIncomingCaveats: true, + }, + }, + map[string][]resourceIDAndCaveat{ + "group:1#member": { + {resourceID: "somedoc", caveat: caveatForTesting("somecaveat")}, + }, + "group:2#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:3#member": { + {resourceID: "somedoc", caveat: nil}, + }, + "group:4#member": { + {resourceID: "somedoc", caveat: caveatForTesting("somecaveat")}, + }, + }, + }, + { + "different resources leading to the same subject", + []*core.RelationTuple{ + tuple.MustParse("document:somedoc#viewer@group:1#member"), + tuple.MustParse("document:anotherdoc#viewer@group:1#member"), + tuple.MustParse("document:somedoc#viewer@group:2#member"), + tuple.MustParse("document:anotherdoc#viewer@group:2#member"), + }, + 2, + []checkDispatchChunk{ + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"1", "2"}, + hasIncomingCaveats: false, + }, + }, + map[string][]resourceIDAndCaveat{ + "group:1#member": { + {resourceID: "somedoc", caveat: nil}, + {resourceID: "anotherdoc", caveat: nil}, + }, + "group:2#member": { + {resourceID: "somedoc", caveat: nil}, + {resourceID: "anotherdoc", caveat: nil}, + }, + }, + }, + { + "different resources leading to the same subject with caveats", + []*core.RelationTuple{ + tuple.MustParse("document:somedoc#viewer@group:1#member[somecaveat]"), + tuple.MustParse("document:anotherdoc#viewer@group:1#member"), + tuple.MustParse("document:somedoc#viewer@group:2#member"), + tuple.MustParse("document:anotherdoc#viewer@group:2#member[somecaveat]"), + }, + 2, + []checkDispatchChunk{ + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"1", "2"}, + hasIncomingCaveats: true, + }, + }, + map[string][]resourceIDAndCaveat{ + "group:1#member": { + {resourceID: "somedoc", caveat: caveatForTesting("somecaveat")}, + {resourceID: "anotherdoc", caveat: nil}, + }, + "group:2#member": { + {resourceID: "somedoc", caveat: nil}, + {resourceID: "anotherdoc", caveat: caveatForTesting("somecaveat")}, + }, + }, + }, + { + "different resource leading to the same subject with caveats", + []*core.RelationTuple{ + tuple.MustParse("document:anotherdoc#viewer@group:1#member"), + tuple.MustParse("document:thirddoc#viewer@group:1#member"), + tuple.MustParse("document:somedoc#viewer@group:1#member[somecaveat]"), + }, + 2, + []checkDispatchChunk{ + { + resourceType: relationRef{namespace: "group", relation: "member"}, + resourceIds: []string{"1"}, + hasIncomingCaveats: true, + }, + }, + map[string][]resourceIDAndCaveat{ + "group:1#member": { + {resourceID: "somedoc", caveat: caveatForTesting("somecaveat")}, + {resourceID: "anotherdoc", caveat: nil}, + {resourceID: "thirddoc", caveat: nil}, + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + set := newCheckDispatchSet() + for _, rel := range tc.relationships { + set.addForRelationship(rel) + } + + chunks := set.dispatchChunks(tc.dispatchChunkSize) + for _, c := range chunks { + sort.Strings(c.resourceIds) + } + + require.ElementsMatch(t, tc.expectedChunks, chunks, "difference in expected chunks. found: %v", chunks) + + for subjectString, expectedMappings := range tc.expectedMappings { + parsed := tuple.ParseSubjectONR(subjectString) + require.NotNil(t, parsed) + + mappings := set.mappingsForSubject(parsed.Namespace, parsed.ObjectId, parsed.Relation) + require.ElementsMatch(t, expectedMappings, mappings) + } + }) + } +} diff --git a/internal/graph/membershipset.go b/internal/graph/membershipset.go index 1eacf1fd86..6c9157a329 100644 --- a/internal/graph/membershipset.go +++ b/internal/graph/membershipset.go @@ -58,7 +58,17 @@ func (ms *MembershipSet) AddMemberViaRelationship( resourceCaveatExpression *core.CaveatExpression, parentRelationship *core.RelationTuple, ) { - intersection := caveatAnd(wrapCaveat(parentRelationship.Caveat), resourceCaveatExpression) + ms.AddMemberWithParentCaveat(resourceID, resourceCaveatExpression, parentRelationship.Caveat) +} + +// AddMemberWithParentCaveat adds the given resource ID as a member with the parent caveat +// combined via intersection with the resource's caveat. The parent caveat may be nil. +func (ms *MembershipSet) AddMemberWithParentCaveat( + resourceID string, + resourceCaveatExpression *core.CaveatExpression, + parentCaveat *core.ContextualizedCaveat, +) { + intersection := caveatAnd(wrapCaveat(parentCaveat), resourceCaveatExpression) ms.addMember(resourceID, intersection) } diff --git a/pkg/tuple/onrbytypeset.go b/pkg/tuple/onrbytypeset.go deleted file mode 100644 index 56da96ddbc..0000000000 --- a/pkg/tuple/onrbytypeset.go +++ /dev/null @@ -1,82 +0,0 @@ -package tuple - -import ( - "github.com/samber/lo" - - core "github.com/authzed/spicedb/pkg/proto/core/v1" -) - -// ONRByTypeSet is a set of ObjectAndRelation's, grouped by namespace+relation. -type ONRByTypeSet struct { - byType map[string][]string -} - -// NewONRByTypeSet creates and returns a new ONRByTypeSet. -func NewONRByTypeSet() *ONRByTypeSet { - return &ONRByTypeSet{ - byType: map[string][]string{}, - } -} - -// Add adds the specified ObjectAndRelation to the set. -func (s *ONRByTypeSet) Add(onr *core.ObjectAndRelation) { - key := JoinRelRef(onr.Namespace, onr.Relation) - if _, ok := s.byType[key]; !ok { - s.byType[key] = []string{} - } - - s.byType[key] = append(s.byType[key], onr.ObjectId) -} - -// ForEachType invokes the handler for each type of ObjectAndRelation found in the set, along -// with all IDs of objects of that type. -func (s *ONRByTypeSet) ForEachType(handler func(rr *core.RelationReference, objectIds []string)) { - for key, objectIds := range s.byType { - ns, rel := MustSplitRelRef(key) - handler(&core.RelationReference{ - Namespace: ns, - Relation: rel, - }, lo.Uniq(objectIds)) - } -} - -// Map runs the mapper function over each type of object in the set, returning a new ONRByTypeSet with -// the object type replaced by that returned by the mapper function. -func (s *ONRByTypeSet) Map(mapper func(rr *core.RelationReference) (*core.RelationReference, error)) (*ONRByTypeSet, error) { - mapped := NewONRByTypeSet() - for key, objectIds := range s.byType { - ns, rel := MustSplitRelRef(key) - updatedType, err := mapper(&core.RelationReference{ - Namespace: ns, - Relation: rel, - }) - if err != nil { - return nil, err - } - if updatedType == nil { - continue - } - mapped.byType[JoinRelRef(updatedType.Namespace, updatedType.Relation)] = lo.Uniq(objectIds) - } - return mapped, nil -} - -// IsEmpty returns true if the set is empty. -func (s *ONRByTypeSet) IsEmpty() bool { - return len(s.byType) == 0 -} - -// KeyLen returns the number of keys in the set. -func (s *ONRByTypeSet) KeyLen() int { - return len(s.byType) -} - -// ValueLen returns the number of values in the set. -func (s *ONRByTypeSet) ValueLen() int { - var total int - for _, vals := range s.byType { - total += len(vals) - } - - return total -} diff --git a/pkg/tuple/onrbytypeset_test.go b/pkg/tuple/onrbytypeset_test.go deleted file mode 100644 index 00254e0330..0000000000 --- a/pkg/tuple/onrbytypeset_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package tuple - -import ( - "sort" - "testing" - - core "github.com/authzed/spicedb/pkg/proto/core/v1" - - "github.com/stretchr/testify/require" -) - -func RR(namespaceName string, relationName string) *core.RelationReference { - return &core.RelationReference{ - Namespace: namespaceName, - Relation: relationName, - } -} - -func TestONRByTypeSet(t *testing.T) { - assertHasObjectIds := func(s *ONRByTypeSet, rr *core.RelationReference, expected []string) { - wasFound := false - s.ForEachType(func(foundRR *core.RelationReference, objectIds []string) { - if rr.EqualVT(foundRR) { - sort.Strings(objectIds) - require.Equal(t, expected, objectIds) - wasFound = true - } - }) - require.True(t, wasFound) - } - - set := NewONRByTypeSet() - require.True(t, set.IsEmpty()) - - // Add some ONRs. - set.Add(ParseONR("document:foo#viewer")) - set.Add(ParseONR("document:bar#viewer")) - set.Add(ParseONR("team:something#member")) - set.Add(ParseONR("team:other#member")) - set.Add(ParseONR("team:other#manager")) - require.False(t, set.IsEmpty()) - - // Run for each type over the set - assertHasObjectIds(set, RR("document", "viewer"), []string{"bar", "foo"}) - assertHasObjectIds(set, RR("team", "member"), []string{"other", "something"}) - assertHasObjectIds(set, RR("team", "manager"), []string{"other"}) - - // Map - mapped, err := set.Map(func(rr *core.RelationReference) (*core.RelationReference, error) { - if rr.Namespace == "document" { - return RR("doc", rr.Relation), nil - } - - return rr, nil - }) - require.NoError(t, err) - - assertHasObjectIds(mapped, RR("doc", "viewer"), []string{"bar", "foo"}) - assertHasObjectIds(mapped, RR("team", "member"), []string{"other", "something"}) - assertHasObjectIds(mapped, RR("team", "manager"), []string{"other"}) -}