diff --git a/internal/datastore/common/changes.go b/internal/datastore/common/changes.go index df558eaefb..956576543a 100644 --- a/internal/datastore/common/changes.go +++ b/internal/datastore/common/changes.go @@ -21,9 +21,11 @@ const ( // Changes represents a set of datastore mutations that are kept self-consistent // across one or more transaction revisions. type Changes[R datastore.Revision, K comparable] struct { - records map[K]changeRecord[R] - keyFunc func(R) K - content datastore.WatchContent + records map[K]changeRecord[R] + keyFunc func(R) K + content datastore.WatchContent + maxByteSize uint64 + currentByteSize int64 } type changeRecord[R datastore.Revision] struct { @@ -36,11 +38,13 @@ type changeRecord[R datastore.Revision] struct { } // NewChanges creates a new Changes object for change tracking and de-duplication. -func NewChanges[R datastore.Revision, K comparable](keyFunc func(R) K, content datastore.WatchContent) *Changes[R, K] { +func NewChanges[R datastore.Revision, K comparable](keyFunc func(R) K, content datastore.WatchContent, maxByteSize uint64) *Changes[R, K] { return &Changes[R, K]{ - records: make(map[K]changeRecord[R], 0), - keyFunc: keyFunc, - content: content, + records: make(map[K]changeRecord[R], 0), + keyFunc: keyFunc, + content: content, + maxByteSize: maxByteSize, + currentByteSize: 0, } } @@ -60,20 +64,38 @@ func (ch *Changes[R, K]) AddRelationshipChange( return nil } - record := ch.recordForRevision(rev) + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + tplKey := tuple.StringWithoutCaveat(tpl) switch op { case core.RelationTupleUpdate_TOUCH: // If there was a delete for the same tuple at the same revision, drop it - delete(record.tupleDeletes, tplKey) + existing, ok := record.tupleDeletes[tplKey] + if ok { + delete(record.tupleDeletes, tplKey) + if err := ch.adjustByteSize(existing, -1); err != nil { + return err + } + } + record.tupleTouches[tplKey] = tpl + if err := ch.adjustByteSize(tpl, 1); err != nil { + return err + } case core.RelationTupleUpdate_DELETE: _, alreadyTouched := record.tupleTouches[tplKey] if !alreadyTouched { record.tupleDeletes[tplKey] = tpl + if err := ch.adjustByteSize(tpl, 1); err != nil { + return err + } } + default: log.Ctx(ctx).Warn().Stringer("operation", op).Msg("unknown change operation") return spiceerrors.MustBugf("unknown change operation") @@ -81,7 +103,29 @@ func (ch *Changes[R, K]) AddRelationshipChange( return nil } -func (ch *Changes[R, K]) recordForRevision(rev R) changeRecord[R] { +type sized interface { + SizeVT() int +} + +func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error { + if ch.maxByteSize == 0 { + return nil + } + + size := item.SizeVT() * delta + ch.currentByteSize += int64(size) + if ch.currentByteSize < 0 { + return spiceerrors.MustBugf("byte size underflow") + } + + if ch.currentByteSize > int64(ch.maxByteSize) { + return NewMaximumChangesSizeExceededError(ch.maxByteSize) + } + + return nil +} + +func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) { k := ch.keyFunc(rev) revisionChanges, ok := ch.records[k] if !ok { @@ -96,7 +140,7 @@ func (ch *Changes[R, K]) recordForRevision(rev R) changeRecord[R] { ch.records[k] = revisionChanges } - return revisionChanges + return revisionChanges, nil } // AddDeletedNamespace adds a change indicating that the namespace with the name was deleted. @@ -104,15 +148,20 @@ func (ch *Changes[R, K]) AddDeletedNamespace( _ context.Context, rev R, namespaceName string, -) { +) error { if ch.content&datastore.WatchSchema != datastore.WatchSchema { - return + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err } - record := ch.recordForRevision(rev) delete(record.definitionsChanged, nsPrefix+namespaceName) record.namespacesDeleted[namespaceName] = struct{}{} + return nil } // AddDeletedCaveat adds a change indicating that the caveat with the name was deleted. @@ -120,15 +169,20 @@ func (ch *Changes[R, K]) AddDeletedCaveat( _ context.Context, rev R, caveatName string, -) { +) error { if ch.content&datastore.WatchSchema != datastore.WatchSchema { - return + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err } - record := ch.recordForRevision(rev) delete(record.definitionsChanged, caveatPrefix+caveatName) record.caveatsDeleted[caveatName] = struct{}{} + return nil } // AddChangedDefinition adds a change indicating that the schema definition (namespace or caveat) @@ -137,24 +191,52 @@ func (ch *Changes[R, K]) AddChangedDefinition( ctx context.Context, rev R, def datastore.SchemaDefinition, -) { +) error { if ch.content&datastore.WatchSchema != datastore.WatchSchema { - return + return nil } - record := ch.recordForRevision(rev) + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } switch t := def.(type) { case *core.NamespaceDefinition: delete(record.namespacesDeleted, t.Name) + + if existing, ok := record.definitionsChanged[nsPrefix+t.Name]; ok { + if err := ch.adjustByteSize(existing, -1); err != nil { + return err + } + } + record.definitionsChanged[nsPrefix+t.Name] = t + if err := ch.adjustByteSize(t, 1); err != nil { + return err + } + case *core.CaveatDefinition: delete(record.caveatsDeleted, t.Name) + + if existing, ok := record.definitionsChanged[nsPrefix+t.Name]; ok { + if err := ch.adjustByteSize(existing, -1); err != nil { + return err + } + } + record.definitionsChanged[caveatPrefix+t.Name] = t + + if err := ch.adjustByteSize(t, 1); err != nil { + return err + } + default: log.Ctx(ctx).Fatal().Msg("unknown schema definition kind") } + + return nil } // AsRevisionChanges returns the list of changes processed so far as a datastore watch diff --git a/internal/datastore/common/changes_test.go b/internal/datastore/common/changes_test.go index 62cb2e37a8..db496a4d77 100644 --- a/internal/datastore/common/changes_test.go +++ b/internal/datastore/common/changes_test.go @@ -306,7 +306,7 @@ func TestChanges(t *testing.T) { require := require.New(t) ctx := context.Background() - ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema) + ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) for _, step := range tc.script { if step.relationship != "" { rel := tuple.MustParse(step.relationship) @@ -315,15 +315,18 @@ func TestChanges(t *testing.T) { } for _, changed := range step.changedDefinitions { - ch.AddChangedDefinition(ctx, revisions.NewForTransactionID(step.revision), changed) + err := ch.AddChangedDefinition(ctx, revisions.NewForTransactionID(step.revision), changed) + require.NoError(err) } for _, ns := range step.deletedNamespaces { - ch.AddDeletedNamespace(ctx, revisions.NewForTransactionID(step.revision), ns) + err := ch.AddDeletedNamespace(ctx, revisions.NewForTransactionID(step.revision), ns) + require.NoError(err) } for _, c := range step.deletedCaveats { - ch.AddDeletedCaveat(ctx, revisions.NewForTransactionID(step.revision), c) + err := ch.AddDeletedCaveat(ctx, revisions.NewForTransactionID(step.revision), c) + require.NoError(err) } } @@ -337,7 +340,7 @@ func TestChanges(t *testing.T) { func TestFilteredSchemaChanges(t *testing.T) { ctx := context.Background() - ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchSchema) + ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchSchema, 0) require.True(t, ch.IsEmpty()) require.NoError(t, ch.AddRelationshipChange(ctx, rev1, tuple.MustParse("document:firstdoc#viewer@user:tom"), core.RelationTupleUpdate_TOUCH)) @@ -346,22 +349,28 @@ func TestFilteredSchemaChanges(t *testing.T) { func TestFilteredRelationshipChanges(t *testing.T) { ctx := context.Background() - ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships) + ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships, 0) require.True(t, ch.IsEmpty()) - ch.AddDeletedNamespace(ctx, rev3, "deletedns3") + err := ch.AddDeletedNamespace(ctx, rev3, "deletedns3") + require.NoError(t, err) require.True(t, ch.IsEmpty()) } func TestFilterAndRemoveRevisionChanges(t *testing.T) { ctx := context.Background() - ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema) + ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) require.True(t, ch.IsEmpty()) - ch.AddDeletedNamespace(ctx, rev1, "deletedns1") - ch.AddDeletedNamespace(ctx, rev2, "deletedns2") - ch.AddDeletedNamespace(ctx, rev3, "deletedns3") + err := ch.AddDeletedNamespace(ctx, rev1, "deletedns1") + require.NoError(t, err) + + err = ch.AddDeletedNamespace(ctx, rev2, "deletedns2") + require.NoError(t, err) + + err = ch.AddDeletedNamespace(ctx, rev3, "deletedns3") + require.NoError(t, err) require.False(t, ch.IsEmpty()) @@ -408,7 +417,7 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) { func TestHLCOrdering(t *testing.T) { ctx := context.Background() - ch := NewChanges(revisions.HLCKeyFunc, datastore.WatchRelationships|datastore.WatchSchema) + ch := NewChanges(revisions.HLCKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) require.True(t, ch.IsEmpty()) rev1, err := revisions.HLCRevisionFromString("1.0000000001") @@ -451,7 +460,7 @@ func TestHLCOrdering(t *testing.T) { func TestHLCSameRevision(t *testing.T) { ctx := context.Background() - ch := NewChanges(revisions.HLCKeyFunc, datastore.WatchRelationships|datastore.WatchSchema) + ch := NewChanges(revisions.HLCKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) require.True(t, ch.IsEmpty()) rev0, err := revisions.HLCRevisionFromString("1") @@ -496,6 +505,56 @@ func TestHLCSameRevision(t *testing.T) { }, remaining) } +func TestMaximumSize(t *testing.T) { + ctx := context.Background() + + ch := NewChanges(revisions.HLCKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 150) + require.True(t, ch.IsEmpty()) + + rev0, err := revisions.HLCRevisionFromString("1") + require.NoError(t, err) + + rev1, err := revisions.HLCRevisionFromString("2") + require.NoError(t, err) + + rev2, err := revisions.HLCRevisionFromString("3") + require.NoError(t, err) + + rev3, err := revisions.HLCRevisionFromString("4") + require.NoError(t, err) + + err = ch.AddRelationshipChange(ctx, rev0, tuple.MustParse("document:foo#viewer@user:tom"), core.RelationTupleUpdate_TOUCH) + require.NoError(t, err) + + err = ch.AddRelationshipChange(ctx, rev1, tuple.MustParse("document:foo#viewer@user:tom"), core.RelationTupleUpdate_TOUCH) + require.NoError(t, err) + + err = ch.AddRelationshipChange(ctx, rev2, tuple.MustParse("document:foo#viewer@user:tom"), core.RelationTupleUpdate_TOUCH) + require.NoError(t, err) + + err = ch.AddRelationshipChange(ctx, rev3, tuple.MustParse("document:foo#viewer@user:tom"), core.RelationTupleUpdate_TOUCH) + require.Error(t, err) + require.ErrorContains(t, err, "maximum changes byte size of 150 exceeded") +} + +func TestMaximumSizeReplacement(t *testing.T) { + ctx := context.Background() + + ch := NewChanges(revisions.HLCKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 43) + require.True(t, ch.IsEmpty()) + + rev0, err := revisions.HLCRevisionFromString("1") + require.NoError(t, err) + + err = ch.AddRelationshipChange(ctx, rev0, tuple.MustParse("document:foo#viewer@user:tom"), core.RelationTupleUpdate_TOUCH) + require.NoError(t, err) + require.Equal(t, int64(43), ch.currentByteSize) + + err = ch.AddRelationshipChange(ctx, rev0, tuple.MustParse("document:foo#viewer@user:tom"), core.RelationTupleUpdate_DELETE) + require.NoError(t, err) + require.Equal(t, int64(43), ch.currentByteSize) +} + func TestCanonicalize(t *testing.T) { testCases := []struct { name string diff --git a/internal/datastore/common/errors.go b/internal/datastore/common/errors.go index 81537bab84..b8a24dae5c 100644 --- a/internal/datastore/common/errors.go +++ b/internal/datastore/common/errors.go @@ -153,3 +153,14 @@ type RevisionUnavailableError struct { func NewRevisionUnavailableError(err error) error { return RevisionUnavailableError{err} } + +// MaximumChangesSizeExceededError is returned when the maximum size of changes is exceeded. +type MaximumChangesSizeExceededError struct { + error + maxSize uint64 +} + +// NewMaximumChangesSizeExceededError creates a new MaximumChangesSizeExceededError. +func NewMaximumChangesSizeExceededError(maxSize uint64) error { + return MaximumChangesSizeExceededError{fmt.Errorf("maximum changes byte size of %d exceeded", maxSize), maxSize} +} diff --git a/internal/datastore/crdb/watch.go b/internal/datastore/crdb/watch.go index 68383d8a55..2bc28ecdc7 100644 --- a/internal/datastore/crdb/watch.go +++ b/internal/datastore/crdb/watch.go @@ -184,7 +184,7 @@ func (cds *crdbDatastore) watch( // no return value so we're not really losing anything. defer func() { go changes.Close() }() - tracked := common.NewChanges(revisions.HLCKeyFunc, opts.Content) + tracked := common.NewChanges(revisions.HLCKeyFunc, opts.Content, opts.MaximumBufferedChangesByteSize) for changes.Next() { var tableNameBytes []byte @@ -311,9 +311,17 @@ func (cds *crdbDatastore) watch( sendError(fmt.Errorf("could not unmarshal namespace definition: %w", err)) return } - tracked.AddChangedDefinition(ctx, rev, namespaceDef) + err = tracked.AddChangedDefinition(ctx, rev, namespaceDef) + if err != nil { + sendError(err) + return + } } else { - tracked.AddDeletedNamespace(ctx, rev, definitionName) + err = tracked.AddDeletedNamespace(ctx, rev, definitionName) + if err != nil { + sendError(err) + return + } } case tableCaveat: @@ -342,9 +350,18 @@ func (cds *crdbDatastore) watch( sendError(fmt.Errorf("could not unmarshal caveat definition: %w", err)) return } - tracked.AddChangedDefinition(ctx, rev, caveatDef) + + err = tracked.AddChangedDefinition(ctx, rev, caveatDef) + if err != nil { + sendError(err) + return + } } else { - tracked.AddDeletedCaveat(ctx, rev, definitionName) + err = tracked.AddDeletedCaveat(ctx, rev, definitionName) + if err != nil { + sendError(err) + return + } } } } diff --git a/internal/datastore/memdb/memdb.go b/internal/datastore/memdb/memdb.go index e23a7f3560..d13ade7fdb 100644 --- a/internal/datastore/memdb/memdb.go +++ b/internal/datastore/memdb/memdb.go @@ -198,7 +198,7 @@ func (mdb *memdbDatastore) ReadWriteTx( mdb.Lock() defer mdb.Unlock() - tracked := common.NewChanges(revisions.TimestampIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema) + tracked := common.NewChanges(revisions.TimestampIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) if tx != nil { for _, change := range tx.Changes() { switch change.Table { @@ -231,9 +231,15 @@ func (mdb *memdbDatastore) ReadWriteTx( return datastore.NoRevision, err } - tracked.AddChangedDefinition(ctx, newRevision, loaded) + err := tracked.AddChangedDefinition(ctx, newRevision, loaded) + if err != nil { + return datastore.NoRevision, err + } } else if change.After == nil && change.Before != nil { - tracked.AddDeletedNamespace(ctx, newRevision, change.Before.(*namespace).name) + err := tracked.AddDeletedNamespace(ctx, newRevision, change.Before.(*namespace).name) + if err != nil { + return datastore.NoRevision, err + } } else { return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change") } @@ -244,9 +250,15 @@ func (mdb *memdbDatastore) ReadWriteTx( return datastore.NoRevision, err } - tracked.AddChangedDefinition(ctx, newRevision, loaded) + err := tracked.AddChangedDefinition(ctx, newRevision, loaded) + if err != nil { + return datastore.NoRevision, err + } } else if change.After == nil && change.Before != nil { - tracked.AddDeletedCaveat(ctx, newRevision, change.Before.(*caveat).name) + err := tracked.AddDeletedCaveat(ctx, newRevision, change.Before.(*caveat).name) + if err != nil { + return datastore.NoRevision, err + } } else { return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change") } diff --git a/internal/datastore/mysql/watch.go b/internal/datastore/mysql/watch.go index 170394a2a7..f3c2dffc84 100644 --- a/internal/datastore/mysql/watch.go +++ b/internal/datastore/mysql/watch.go @@ -148,7 +148,7 @@ func (mds *Datastore) loadChanges( } defer common.LogOnError(ctx, rows.Close) - stagedChanges := common.NewChanges(revisions.TransactionIDKeyFunc, options.Content) + stagedChanges := common.NewChanges(revisions.TransactionIDKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) for rows.Next() { nextTuple := &core.RelationTuple{ diff --git a/internal/datastore/postgres/watch.go b/internal/datastore/postgres/watch.go index 120a2d188c..2c7a2d3c97 100644 --- a/internal/datastore/postgres/watch.go +++ b/internal/datastore/postgres/watch.go @@ -238,7 +238,7 @@ func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []postgresRev txidToRevision[rev.optionalTxID.Uint64] = rev } - tracked := common.NewChanges(revisionKeyFunc, options.Content) + tracked := common.NewChanges(revisionKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) // Load relationship changes. if options.Content&datastore.WatchRelationships == datastore.WatchRelationships { @@ -384,10 +384,16 @@ func (pgd *pgDatastore) loadNamespaceChanges(ctx context.Context, xmin uint64, x } if _, found := filter[createdXID.Uint64]; found { - tracked.AddChangedDefinition(ctx, txidToRevision[deletedXID.Uint64], loaded) + err := tracked.AddChangedDefinition(ctx, txidToRevision[deletedXID.Uint64], loaded) + if err != nil { + return err + } } if _, found := filter[deletedXID.Uint64]; found { - tracked.AddDeletedNamespace(ctx, txidToRevision[deletedXID.Uint64], loaded.Name) + err := tracked.AddDeletedNamespace(ctx, txidToRevision[deletedXID.Uint64], loaded.Name) + if err != nil { + return err + } } } if changes.Err() != nil { @@ -437,10 +443,16 @@ func (pgd *pgDatastore) loadCaveatChanges(ctx context.Context, min uint64, max u } if _, found := filter[createdXID.Uint64]; found { - tracked.AddChangedDefinition(ctx, txidToRevision[deletedXID.Uint64], loaded) + err := tracked.AddChangedDefinition(ctx, txidToRevision[deletedXID.Uint64], loaded) + if err != nil { + return err + } } if _, found := filter[deletedXID.Uint64]; found { - tracked.AddDeletedCaveat(ctx, txidToRevision[deletedXID.Uint64], loaded.Name) + err := tracked.AddDeletedCaveat(ctx, txidToRevision[deletedXID.Uint64], loaded.Name) + if err != nil { + return err + } } } if changes.Err() != nil { diff --git a/internal/datastore/spanner/watch.go b/internal/datastore/spanner/watch.go index 0cc707eb3d..edd4eab37e 100644 --- a/internal/datastore/spanner/watch.go +++ b/internal/datastore/spanner/watch.go @@ -159,7 +159,7 @@ func (sd *spannerDatastore) watch( err = reader.Read(ctx, func(result *changestreams.ReadResult) error { // See: https://cloud.google.com/spanner/docs/change-streams/details for _, record := range result.ChangeRecords { - tracked := common.NewChanges(revisions.TimestampIDKeyFunc, opts.Content) + tracked := common.NewChanges(revisions.TimestampIDKeyFunc, opts.Content, opts.MaximumBufferedChangesByteSize) for _, dcr := range record.DataChangeRecords { changeRevision := revisions.NewForTime(dcr.CommitTimestamp) @@ -203,7 +203,10 @@ func (sd *spannerDatastore) watch( return spiceerrors.MustBugf("error converting namespace name: %v", primaryKeyColumnValues[colNamespaceName]) } - tracked.AddDeletedNamespace(ctx, changeRevision, namespaceName) + err := tracked.AddDeletedNamespace(ctx, changeRevision, namespaceName) + if err != nil { + return err + } case tableCaveat: caveatNameValue, ok := primaryKeyColumnValues[colNamespaceName] @@ -216,7 +219,10 @@ func (sd *spannerDatastore) watch( return spiceerrors.MustBugf("error converting caveat name: %v", primaryKeyColumnValues[colName]) } - tracked.AddDeletedCaveat(ctx, changeRevision, caveatName) + err := tracked.AddDeletedCaveat(ctx, changeRevision, caveatName) + if err != nil { + return err + } default: return spiceerrors.MustBugf("unknown table name %s in delete of change stream", dcr.TableName) @@ -274,7 +280,10 @@ func (sd *spannerDatastore) watch( return err } - tracked.AddChangedDefinition(ctx, changeRevision, ns) + err := tracked.AddChangedDefinition(ctx, changeRevision, ns) + if err != nil { + return err + } case tableCaveat: caveatDefValue, ok := newValues[colCaveatDefinition] @@ -287,7 +296,10 @@ func (sd *spannerDatastore) watch( return err } - tracked.AddChangedDefinition(ctx, changeRevision, caveat) + err := tracked.AddChangedDefinition(ctx, changeRevision, caveat) + if err != nil { + return err + } default: return spiceerrors.MustBugf("unknown table name %s in delete of change stream", dcr.TableName) diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index d709018010..875f2fa79c 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -387,6 +387,7 @@ func (sf SubjectsFilter) AsSelector() SubjectsSelector { // SchemaDefinition represents a namespace or caveat definition under a schema. type SchemaDefinition interface { GetName() string + SizeVT() int } // RevisionedDefinition holds a schema definition and its last updated revision. @@ -517,6 +518,11 @@ type WatchOptions struct { // If given the zero value, the datastore's default will be used. // May not be supported by the datastore. WatchConnectTimeout time.Duration + + // MaximumBufferedChangesByteSize is the maximum byte size of the buffered changes struct. + // If unspecified, no maximum will be enforced. If the maximum is reached before + // the changes can be sent, the watch will be closed with an error. + MaximumBufferedChangesByteSize uint64 } // WatchJustRelationships returns watch options for just relationships.