diff --git a/internal/datasets/subjectsetbytype.go b/internal/datasets/subjectsetbytype.go index 4ece06ecd6..0ba66cd432 100644 --- a/internal/datasets/subjectsetbytype.go +++ b/internal/datasets/subjectsetbytype.go @@ -53,7 +53,7 @@ func (s *SubjectByTypeSet) ForEachType(handler func(rr *core.RelationReference, } } -// Map runs the mapper function over each type of object in the set, returning a new ONRByTypeSet with +// Map runs the mapper function over each type of object in the set, returning a new SubjectByTypeSet with // the object type replaced by that returned by the mapper function. func (s *SubjectByTypeSet) Map(mapper func(rr *core.RelationReference) (*core.RelationReference, error)) (*SubjectByTypeSet, error) { mapped := NewSubjectByTypeSet() diff --git a/internal/developmentmembership/foundsubject.go b/internal/developmentmembership/foundsubject.go index 9b4509aa01..520ee72ac6 100644 --- a/internal/developmentmembership/foundsubject.go +++ b/internal/developmentmembership/foundsubject.go @@ -11,7 +11,7 @@ import ( // NewFoundSubject creates a new FoundSubject for a subject and a set of its resources. func NewFoundSubject(subject *core.DirectSubject, resources ...*core.ObjectAndRelation) FoundSubject { - return FoundSubject{subject.Subject, nil, subject.CaveatExpression, tuple.NewONRSet(resources...)} + return FoundSubject{subject.Subject, nil, subject.CaveatExpression, NewONRSet(resources...)} } // FoundSubject contains a single found subject and all the relationships in which that subject @@ -28,7 +28,7 @@ type FoundSubject struct { // relations are the relations under which the subject lives that informed the locating // of this subject for the root ONR. - relationships *tuple.ONRSet + relationships ONRSet } // GetSubjectId is named to match the Subject interface for the BaseSubjectSet. diff --git a/internal/developmentmembership/foundsubject_test.go b/internal/developmentmembership/foundsubject_test.go index 587997fe3d..2d907fb46c 100644 --- a/internal/developmentmembership/foundsubject_test.go +++ b/internal/developmentmembership/foundsubject_test.go @@ -7,7 +7,6 @@ import ( "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/caveats" - "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/spicedb/pkg/validationfile/blocks" ) @@ -20,7 +19,7 @@ func cfs(subjectType string, subjectID string, subjectRel string, excludedSubjec return FoundSubject{ subject: ONR(subjectType, subjectID, subjectRel), excludedSubjects: excludedSubjects, - relationships: tuple.NewONRSet(), + relationships: NewONRSet(), caveatExpression: caveats.CaveatExprForTesting(caveatName), } } diff --git a/internal/developmentmembership/onrset.go b/internal/developmentmembership/onrset.go new file mode 100644 index 0000000000..caeea58f03 --- /dev/null +++ b/internal/developmentmembership/onrset.go @@ -0,0 +1,94 @@ +package developmentmembership + +import ( + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// onrStruct is a struct holding a namespace and relation. +type onrStruct struct { + Namespace string + ObjectID string + Relation string +} + +// ONRSet is a set of ObjectAndRelation's. +type ONRSet struct { + onrs *mapz.Set[onrStruct] +} + +// NewONRSet creates a new set. +func NewONRSet(onrs ...*core.ObjectAndRelation) ONRSet { + created := ONRSet{ + onrs: mapz.NewSet[onrStruct](), + } + created.Update(onrs) + return created +} + +// Length returns the size of the set. +func (ons ONRSet) Length() uint64 { + return uint64(ons.onrs.Len()) +} + +// IsEmpty returns whether the set is empty. +func (ons ONRSet) IsEmpty() bool { + return ons.onrs.IsEmpty() +} + +// Has returns true if the set contains the given ONR. +func (ons ONRSet) Has(onr *core.ObjectAndRelation) bool { + key := onrStruct{onr.Namespace, onr.ObjectId, onr.Relation} + return ons.onrs.Has(key) +} + +// Add adds the given ONR to the set. Returns true if the object was not in the set before this +// call and false otherwise. +func (ons ONRSet) Add(onr *core.ObjectAndRelation) bool { + key := onrStruct{onr.Namespace, onr.ObjectId, onr.Relation} + return ons.onrs.Add(key) +} + +// Update updates the set by adding the given ONRs to it. +func (ons ONRSet) Update(onrs []*core.ObjectAndRelation) { + for _, onr := range onrs { + ons.Add(onr) + } +} + +// UpdateFrom updates the set by adding the ONRs found in the other set to it. +func (ons ONRSet) UpdateFrom(otherSet ONRSet) { + if otherSet.onrs == nil { + return + } + ons.onrs.Merge(otherSet.onrs) +} + +// Intersect returns an intersection between this ONR set and the other set provided. +func (ons ONRSet) Intersect(otherSet ONRSet) ONRSet { + return ONRSet{ons.onrs.Intersect(otherSet.onrs)} +} + +// Subtract returns a subtraction from this ONR set of the other set provided. +func (ons ONRSet) Subtract(otherSet ONRSet) ONRSet { + return ONRSet{ons.onrs.Subtract(otherSet.onrs)} +} + +// Union returns a copy of this ONR set with the other set's elements added in. +func (ons ONRSet) Union(otherSet ONRSet) ONRSet { + return ONRSet{ons.onrs.Union(otherSet.onrs)} +} + +// AsSlice returns the ONRs found in the set as a slice. +func (ons ONRSet) AsSlice() []*core.ObjectAndRelation { + slice := make([]*core.ObjectAndRelation, 0, ons.Length()) + _ = ons.onrs.ForEach(func(rr onrStruct) error { + slice = append(slice, &core.ObjectAndRelation{ + Namespace: rr.Namespace, + ObjectId: rr.ObjectID, + Relation: rr.Relation, + }) + return nil + }) + return slice +} diff --git a/internal/developmentmembership/onrset_test.go b/internal/developmentmembership/onrset_test.go new file mode 100644 index 0000000000..d6384d9b70 --- /dev/null +++ b/internal/developmentmembership/onrset_test.go @@ -0,0 +1,144 @@ +package developmentmembership + +import ( + "testing" + + "github.com/stretchr/testify/require" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +func TestONRSet(t *testing.T) { + set := NewONRSet() + require.True(t, set.IsEmpty()) + require.Equal(t, uint64(0), set.Length()) + + require.True(t, set.Add(tuple.ParseONR("resource:1#viewer"))) + require.False(t, set.IsEmpty()) + require.Equal(t, uint64(1), set.Length()) + + require.True(t, set.Add(tuple.ParseONR("resource:2#viewer"))) + require.True(t, set.Add(tuple.ParseONR("resource:3#viewer"))) + require.Equal(t, uint64(3), set.Length()) + + require.False(t, set.Add(tuple.ParseONR("resource:1#viewer"))) + require.True(t, set.Add(tuple.ParseONR("resource:1#editor"))) + + require.True(t, set.Has(tuple.ParseONR("resource:1#viewer"))) + require.True(t, set.Has(tuple.ParseONR("resource:1#editor"))) + require.False(t, set.Has(tuple.ParseONR("resource:1#owner"))) + require.False(t, set.Has(tuple.ParseONR("resource:1#admin"))) + require.False(t, set.Has(tuple.ParseONR("resource:1#reader"))) + + require.True(t, set.Has(tuple.ParseONR("resource:2#viewer"))) +} + +func TestONRSetUpdate(t *testing.T) { + set := NewONRSet() + set.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + require.Equal(t, uint64(3), set.Length()) + + set.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:1#editor"), + tuple.ParseONR("resource:1#owner"), + tuple.ParseONR("resource:1#admin"), + tuple.ParseONR("resource:1#reader"), + }) + require.Equal(t, uint64(7), set.Length()) +} + +func TestONRSetIntersect(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + + set2 := NewONRSet() + set2.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:1#editor"), + tuple.ParseONR("resource:1#owner"), + tuple.ParseONR("resource:1#admin"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:1#reader"), + }) + + require.Equal(t, uint64(2), set1.Intersect(set2).Length()) + require.Equal(t, uint64(2), set2.Intersect(set1).Length()) +} + +func TestONRSetSubtract(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + + set2 := NewONRSet() + set2.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:1#editor"), + tuple.ParseONR("resource:1#owner"), + tuple.ParseONR("resource:1#admin"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:1#reader"), + }) + + require.Equal(t, uint64(1), set1.Subtract(set2).Length()) + require.Equal(t, uint64(4), set2.Subtract(set1).Length()) +} + +func TestONRSetUnion(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + + set2 := NewONRSet() + set2.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:1#editor"), + tuple.ParseONR("resource:1#owner"), + tuple.ParseONR("resource:1#admin"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:1#reader"), + }) + + require.Equal(t, uint64(7), set1.Union(set2).Length()) + require.Equal(t, uint64(7), set2.Union(set1).Length()) +} + +func TestONRSetWith(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + + added := set1.Union(NewONRSet(tuple.ParseONR("resource:1#editor"))) + require.Equal(t, uint64(3), set1.Length()) + require.Equal(t, uint64(4), added.Length()) +} + +func TestONRSetAsSlice(t *testing.T) { + set := NewONRSet() + set.Update([]*core.ObjectAndRelation{ + tuple.ParseONR("resource:1#viewer"), + tuple.ParseONR("resource:2#viewer"), + tuple.ParseONR("resource:3#viewer"), + }) + + require.Equal(t, 3, len(set.AsSlice())) +} diff --git a/internal/developmentmembership/trackingsubjectset.go b/internal/developmentmembership/trackingsubjectset.go index 6cdd8ee462..b66a08940d 100644 --- a/internal/developmentmembership/trackingsubjectset.go +++ b/internal/developmentmembership/trackingsubjectset.go @@ -100,9 +100,7 @@ func (tss *TrackingSubjectSet) getSetForKey(key string) datasets.BaseSubjectSet[ fs.excludedSubjects = excludedSubjects fs.caveatExpression = caveatExpression for _, source := range sources { - if source.relationships != nil { - fs.relationships.UpdateFrom(source.relationships) - } + fs.relationships.UpdateFrom(source.relationships) } return fs }, diff --git a/internal/developmentmembership/trackingsubjectset_test.go b/internal/developmentmembership/trackingsubjectset_test.go index cf79b6968b..700662288e 100644 --- a/internal/developmentmembership/trackingsubjectset_test.go +++ b/internal/developmentmembership/trackingsubjectset_test.go @@ -7,7 +7,6 @@ import ( "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" - "github.com/authzed/spicedb/pkg/tuple" ) func set(subjects ...*core.DirectSubject) *TrackingSubjectSet { @@ -52,7 +51,7 @@ func fs(subjectType string, subjectID string, subjectRel string, excludedSubject return FoundSubject{ subject: ONR(subjectType, subjectID, subjectRel), excludedSubjects: excludedSubjects, - relationships: tuple.NewONRSet(), + relationships: NewONRSet(), } } diff --git a/pkg/tuple/onrset.go b/pkg/tuple/onrset.go deleted file mode 100644 index 04a4ad66af..0000000000 --- a/pkg/tuple/onrset.go +++ /dev/null @@ -1,111 +0,0 @@ -package tuple - -import ( - "maps" - - expmaps "golang.org/x/exp/maps" - - core "github.com/authzed/spicedb/pkg/proto/core/v1" -) - -// ONRSet is a set of ObjectAndRelation's. -type ONRSet struct { - onrs map[string]*core.ObjectAndRelation -} - -// NewONRSet creates a new set. -func NewONRSet(onrs ...*core.ObjectAndRelation) *ONRSet { - created := &ONRSet{ - onrs: map[string]*core.ObjectAndRelation{}, - } - created.Update(onrs) - return created -} - -// Length returns the size of the set. -func (ons *ONRSet) Length() uint64 { - return uint64(len(ons.onrs)) -} - -// IsEmpty returns whether the set is empty. -func (ons *ONRSet) IsEmpty() bool { - return len(ons.onrs) == 0 -} - -// Has returns true if the set contains the given ONR. -func (ons *ONRSet) Has(onr *core.ObjectAndRelation) bool { - _, ok := ons.onrs[StringONR(onr)] - return ok -} - -// Add adds the given ONR to the set. Returns true if the object was not in the set before this -// call and false otherwise. -func (ons *ONRSet) Add(onr *core.ObjectAndRelation) bool { - if _, ok := ons.onrs[StringONR(onr)]; ok { - return false - } - - ons.onrs[StringONR(onr)] = onr - return true -} - -// Update updates the set by adding the given ONRs to it. -func (ons *ONRSet) Update(onrs []*core.ObjectAndRelation) { - for _, onr := range onrs { - ons.Add(onr) - } -} - -// UpdateFrom updates the set by adding the ONRs found in the other set to it. -func (ons *ONRSet) UpdateFrom(otherSet *ONRSet) { - for _, onr := range otherSet.onrs { - ons.Add(onr) - } -} - -// Intersect returns an intersection between this ONR set and the other set provided. -func (ons *ONRSet) Intersect(otherSet *ONRSet) *ONRSet { - updated := NewONRSet() - for _, onr := range ons.onrs { - if otherSet.Has(onr) { - updated.Add(onr) - } - } - return updated -} - -// Subtract returns a subtraction from this ONR set of the other set provided. -func (ons *ONRSet) Subtract(otherSet *ONRSet) *ONRSet { - updated := NewONRSet() - for _, onr := range ons.onrs { - if !otherSet.Has(onr) { - updated.Add(onr) - } - } - return updated -} - -// With returns a copy of this ONR set with the given element added. -func (ons *ONRSet) With(onr *core.ObjectAndRelation) *ONRSet { - updated := &ONRSet{ - onrs: maps.Clone(ons.onrs), - } - updated.Add(onr) - return updated -} - -// Union returns a copy of this ONR set with the other set's elements added in. -func (ons *ONRSet) Union(otherSet *ONRSet) *ONRSet { - updated := &ONRSet{ - onrs: maps.Clone(ons.onrs), - } - for _, current := range otherSet.onrs { - updated.Add(current) - } - return updated -} - -// AsSlice returns the ONRs found in the set as a slice. -func (ons *ONRSet) AsSlice() []*core.ObjectAndRelation { - return expmaps.Values(ons.onrs) -} diff --git a/pkg/tuple/onrset_test.go b/pkg/tuple/onrset_test.go deleted file mode 100644 index e102ca2d8e..0000000000 --- a/pkg/tuple/onrset_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package tuple - -import ( - "testing" - - "github.com/stretchr/testify/require" - - core "github.com/authzed/spicedb/pkg/proto/core/v1" -) - -func TestONRSet(t *testing.T) { - set := NewONRSet() - require.True(t, set.IsEmpty()) - require.Equal(t, uint64(0), set.Length()) - - require.True(t, set.Add(ParseONR("resource:1#viewer"))) - require.False(t, set.IsEmpty()) - require.Equal(t, uint64(1), set.Length()) - - require.True(t, set.Add(ParseONR("resource:2#viewer"))) - require.True(t, set.Add(ParseONR("resource:3#viewer"))) - require.Equal(t, uint64(3), set.Length()) - - require.False(t, set.Add(ParseONR("resource:1#viewer"))) - require.True(t, set.Add(ParseONR("resource:1#editor"))) - - require.True(t, set.Has(ParseONR("resource:1#viewer"))) - require.True(t, set.Has(ParseONR("resource:1#editor"))) - require.False(t, set.Has(ParseONR("resource:1#owner"))) - require.False(t, set.Has(ParseONR("resource:1#admin"))) - require.False(t, set.Has(ParseONR("resource:1#reader"))) - - require.True(t, set.Has(ParseONR("resource:2#viewer"))) -} - -func TestONRSetUpdate(t *testing.T) { - set := NewONRSet() - set.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - require.Equal(t, uint64(3), set.Length()) - - set.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:1#editor"), - ParseONR("resource:1#owner"), - ParseONR("resource:1#admin"), - ParseONR("resource:1#reader"), - }) - require.Equal(t, uint64(7), set.Length()) -} - -func TestONRSetIntersect(t *testing.T) { - set1 := NewONRSet() - set1.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - - set2 := NewONRSet() - set2.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:1#editor"), - ParseONR("resource:1#owner"), - ParseONR("resource:1#admin"), - ParseONR("resource:2#viewer"), - ParseONR("resource:1#reader"), - }) - - require.Equal(t, uint64(2), set1.Intersect(set2).Length()) - require.Equal(t, uint64(2), set2.Intersect(set1).Length()) -} - -func TestONRSetSubtract(t *testing.T) { - set1 := NewONRSet() - set1.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - - set2 := NewONRSet() - set2.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:1#editor"), - ParseONR("resource:1#owner"), - ParseONR("resource:1#admin"), - ParseONR("resource:2#viewer"), - ParseONR("resource:1#reader"), - }) - - require.Equal(t, uint64(1), set1.Subtract(set2).Length()) - require.Equal(t, uint64(4), set2.Subtract(set1).Length()) -} - -func TestONRSetUnion(t *testing.T) { - set1 := NewONRSet() - set1.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - - set2 := NewONRSet() - set2.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:1#editor"), - ParseONR("resource:1#owner"), - ParseONR("resource:1#admin"), - ParseONR("resource:2#viewer"), - ParseONR("resource:1#reader"), - }) - - require.Equal(t, uint64(7), set1.Union(set2).Length()) - require.Equal(t, uint64(7), set2.Union(set1).Length()) -} - -func TestONRSetWith(t *testing.T) { - set1 := NewONRSet() - set1.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - - added := set1.With(ParseONR("resource:1#editor")) - require.Equal(t, uint64(3), set1.Length()) - require.Equal(t, uint64(4), added.Length()) -} - -func TestONRSetAsSlice(t *testing.T) { - set := NewONRSet() - set.Update([]*core.ObjectAndRelation{ - ParseONR("resource:1#viewer"), - ParseONR("resource:2#viewer"), - ParseONR("resource:3#viewer"), - }) - - require.Equal(t, 3, len(set.AsSlice())) -}