diff --git a/internal/datastore/common/changes.go b/internal/datastore/common/changes.go index 9449e47a5a..33c13f2905 100644 --- a/internal/datastore/common/changes.go +++ b/internal/datastore/common/changes.go @@ -5,6 +5,7 @@ import ( "sort" "golang.org/x/exp/maps" + "google.golang.org/protobuf/types/known/structpb" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" @@ -35,6 +36,7 @@ type changeRecord[R datastore.Revision] struct { definitionsChanged map[string]datastore.SchemaDefinition namespacesDeleted map[string]struct{} caveatsDeleted map[string]struct{} + metadata map[string]any } // NewChanges creates a new Changes object for change tracking and de-duplication. @@ -125,6 +127,25 @@ func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error { return nil } +// SetRevisionMetadata sets the metadata for the given revision. +func (ch *Changes[R, K]) SetRevisionMetadata(ctx context.Context, rev R, metadata map[string]any) error { + if len(metadata) == 0 { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + if len(record.metadata) > 0 { + return spiceerrors.MustBugf("metadata already set for revision") + } + + maps.Copy(record.metadata, metadata) + return nil +} + func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) { k := ch.keyFunc(rev) revisionChanges, ok := ch.records[k] @@ -136,6 +157,7 @@ func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) { make(map[string]datastore.SchemaDefinition), make(map[string]struct{}), make(map[string]struct{}), + make(map[string]any), } ch.records[k] = revisionChanges } @@ -241,21 +263,25 @@ func (ch *Changes[R, K]) AddChangedDefinition( // AsRevisionChanges returns the list of changes processed so far as a datastore watch // compatible, ordered, changelist. -func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) []datastore.RevisionChanges { +func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) ([]datastore.RevisionChanges, error) { return ch.revisionChanges(lessThanFunc, *new(R), false) } // FilterAndRemoveRevisionChanges filters a list of changes processed up to the bound revision from the changes list, removing them // and returning the filtered changes. -func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) []datastore.RevisionChanges { - changes := ch.revisionChanges(lessThanFunc, boundRev, true) +func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) ([]datastore.RevisionChanges, error) { + changes, err := ch.revisionChanges(lessThanFunc, boundRev, true) + if err != nil { + return nil, err + } + ch.removeAllChangesBefore(boundRev) - return changes + return changes, nil } -func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) []datastore.RevisionChanges { +func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) ([]datastore.RevisionChanges, error) { if ch.IsEmpty() { - return nil + return nil, nil } revisionsWithChanges := make([]K, 0, len(ch.records)) @@ -266,7 +292,7 @@ func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, bou } if len(revisionsWithChanges) == 0 { - return nil + return nil, nil } sort.Slice(revisionsWithChanges, func(i int, j int) bool { @@ -292,9 +318,18 @@ func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, bou changes[i].ChangedDefinitions = maps.Values(revisionChangeRecord.definitionsChanged) changes[i].DeletedNamespaces = maps.Keys(revisionChangeRecord.namespacesDeleted) changes[i].DeletedCaveats = maps.Keys(revisionChangeRecord.caveatsDeleted) + + if len(revisionChangeRecord.metadata) > 0 { + metadata, err := structpb.NewStruct(revisionChangeRecord.metadata) + if err != nil { + return nil, spiceerrors.MustBugf("failed to convert metadata to structpb: %v", err) + } + + changes[i].Metadata = metadata + } } - return changes + return changes, nil } func (ch *Changes[R, K]) removeAllChangesBefore(boundRev R) { diff --git a/internal/datastore/common/changes_test.go b/internal/datastore/common/changes_test.go index db496a4d77..7b16ab3e46 100644 --- a/internal/datastore/common/changes_test.go +++ b/internal/datastore/common/changes_test.go @@ -330,9 +330,12 @@ func TestChanges(t *testing.T) { } } + actual, err := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) + require.NoError(err) + require.Equal( canonicalize(tc.expected), - canonicalize(ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc)), + canonicalize(actual), ) }) } @@ -347,6 +350,23 @@ func TestFilteredSchemaChanges(t *testing.T) { require.True(t, ch.IsEmpty()) } +func TestSetMetadata(t *testing.T) { + ctx := context.Background() + ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) + require.True(t, ch.IsEmpty()) + + err := ch.SetRevisionMetadata(ctx, rev1, map[string]any{"foo": "bar"}) + require.NoError(t, err) + require.False(t, ch.IsEmpty()) + + results, err := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev2) + require.NoError(t, err) + require.Equal(t, 1, len(results)) + require.True(t, ch.IsEmpty()) + + require.Equal(t, map[string]any{"foo": "bar"}, results[0].Metadata.AsMap()) +} + func TestFilteredRelationshipChanges(t *testing.T) { ctx := context.Background() ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships, 0) @@ -374,7 +394,8 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) { require.False(t, ch.IsEmpty()) - results := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev3) + results, err := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev3) + require.NoError(t, err) require.Equal(t, 2, len(results)) require.False(t, ch.IsEmpty()) @@ -393,8 +414,9 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) { }, }, results) - remaining := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) + remaining, err := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) require.Equal(t, 1, len(remaining)) + require.NoError(t, err) require.Equal(t, []datastore.RevisionChanges{ { @@ -405,11 +427,13 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) { }, }, remaining) - results = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillion) + results, err = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillion) + require.NoError(t, err) require.Equal(t, 1, len(results)) require.True(t, ch.IsEmpty()) - results = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillionOne) + results, err = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillionOne) + require.NoError(t, err) require.Equal(t, 0, len(results)) require.True(t, ch.IsEmpty()) } @@ -432,7 +456,8 @@ func TestHLCOrdering(t *testing.T) { err = ch.AddRelationshipChange(ctx, rev0, tuple.MustParse("document:foo#viewer@user:tom"), core.RelationTupleUpdate_TOUCH) require.NoError(t, err) - remaining := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + remaining, err := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + require.NoError(t, err) require.Equal(t, 2, len(remaining)) require.Equal(t, []datastore.RevisionChanges{ @@ -475,7 +500,8 @@ func TestHLCSameRevision(t *testing.T) { err = ch.AddRelationshipChange(ctx, rev0again, tuple.MustParse("document:foo#viewer@user:sarah"), core.RelationTupleUpdate_TOUCH) require.NoError(t, err) - remaining := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + remaining, err := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + require.NoError(t, err) require.Equal(t, 1, len(remaining)) expected := []*core.RelationTupleUpdate{ diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 08d5bafb23..c5ac0a1492 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -52,6 +52,7 @@ const ( tableTransactions = "transactions" tableCaveat = "caveat" tableRelationshipCounter = "relationship_counter" + tableTransactionMetadata = "transaction_metadata" colNamespace = "namespace" colConfig = "serialized_config" @@ -77,6 +78,8 @@ const ( colCounterSerializedFilter = "serialized_filter" colCounterCurrentCount = "current_count" colCounterUpdatedAt = "updated_at_timestamp" + colExpiresAt = "expires_at" + colMetadata = "metadata" errUnableToInstantiate = "unable to instantiate datastore" errRevision = "unable to find revision: %w" @@ -199,6 +202,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas analyzeBeforeStatistics: config.analyzeBeforeStatistics, filterMaximumIDCount: config.filterMaximumIDCount, supportsIntegrity: config.withIntegrity, + gcWindow: config.gcWindow, } ds.RemoteClockRevisions.SetNowFunc(ds.headRevisionInternal) @@ -281,6 +285,7 @@ type crdbDatastore struct { writeOverlapKeyer overlapKeyer overlapKeyInit func(ctx context.Context) keySet analyzeBeforeStatistics bool + gcWindow time.Duration beginChangefeedQuery string transactionNowQuery string @@ -324,6 +329,23 @@ func (cds *crdbDatastore) ReadWriteTx( Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(querier, cds.supportsIntegrity), } + // If metadata is to be attached, write that row now. + if config.Metadata != nil { + expiresAt := time.Now().Add(cds.gcWindow).Add(1 * time.Minute) + insertTransactionMetadata := psql.Insert(tableTransactionMetadata). + Columns(colExpiresAt, colMetadata). + Values(expiresAt, config.Metadata.AsMap()) + + sql, args, err := insertTransactionMetadata.ToSql() + if err != nil { + return fmt.Errorf("error building metadata insert: %w", err) + } + + if _, err := tx.Exec(ctx, sql, args...); err != nil { + return fmt.Errorf("error writing metadata: %w", err) + } + } + rwt := &crdbReadWriteTXN{ &crdbReader{ querier, diff --git a/internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go b/internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go new file mode 100644 index 0000000000..a98a93c8f2 --- /dev/null +++ b/internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go @@ -0,0 +1,75 @@ +package migrations + +import ( + "context" + "strings" + + "github.com/jackc/pgx/v5" +) + +const ( + // ttl_expiration_expression support was added in CRDB v22.2, but the E2E tests + // use v21.2. + addTransactionMetadataTableQueryWithBasicTTL = ` + CREATE TABLE transaction_metadata ( + key UUID PRIMARY KEY DEFAULT gen_random_uuid(), + expires_at TIMESTAMPTZ, + metadata JSONB + ) WITH (ttl_expire_after = '1d'); + ` + + addTransactionMetadataTableQuery = ` + CREATE TABLE transaction_metadata ( + key UUID PRIMARY KEY DEFAULT gen_random_uuid(), + expires_at TIMESTAMPTZ, + metadata JSONB + ) WITH (ttl_expiration_expression = 'expires_at', ttl_job_cron = '@daily'); + ` + + // See: https://www.cockroachlabs.com/docs/stable/changefeed-messages#prevent-changefeeds-from-emitting-row-level-ttl-deletes + // for why we set ttl_disable_changefeed_replication = 'true'. This isn't stricly necessary as the Watch API will ignore the + // deletions of these metadata rows, but no reason to even have it in the changefeed. + // NOTE: This only applies on CRDB v24 and later. + addTransactionMetadataTableQueryWithTTLIgnore = ` + CREATE TABLE transaction_metadata ( + key UUID PRIMARY KEY DEFAULT gen_random_uuid(), + expires_at TIMESTAMPTZ, + metadata JSONB + ) WITH (ttl_expiration_expression = 'expires_at', ttl_job_cron = '@daily', ttl_disable_changefeed_replication = 'true'); + ` +) + +func init() { + err := CRDBMigrations.Register("add-transaction-metadata-table", "add-integrity-relationtuple-table", addTransactionMetadataTable, noAtomicMigration) + if err != nil { + panic("failed to register migration: " + err.Error()) + } +} + +func addTransactionMetadataTable(ctx context.Context, conn *pgx.Conn) error { + row := conn.QueryRow(ctx, "select version()") + var version string + if err := row.Scan(&version); err != nil { + return err + } + + if strings.Contains(version, "v22.1") { + if _, err := conn.Exec(ctx, addTransactionMetadataTableQueryWithBasicTTL); err != nil { + return err + } + return nil + } + + if strings.Contains(version, "v24.") { + if _, err := conn.Exec(ctx, addTransactionMetadataTableQueryWithTTLIgnore); err != nil { + return err + } + return nil + } + + // CRDB v23 and v22.2.* + if _, err := conn.Exec(ctx, addTransactionMetadataTableQuery); err != nil { + return err + } + return nil +} diff --git a/internal/datastore/crdb/watch.go b/internal/datastore/crdb/watch.go index 2241ae33f7..d03989c6fe 100644 --- a/internal/datastore/crdb/watch.go +++ b/internal/datastore/crdb/watch.go @@ -55,6 +55,8 @@ type changeDetails struct { IntegrityKeyID *string `json:"integrity_key_id"` IntegrityHashAsHex *string `json:"integrity_hash"` TimestampAsString *string `json:"timestamp"` + + Metadata map[string]any `json:"metadata"` } } @@ -110,7 +112,8 @@ func (cds *crdbDatastore) watch( } defer func() { _ = conn.Close(ctx) }() - tableNames := make([]string, 0, 3) + tableNames := make([]string, 0, 4) + tableNames = append(tableNames, tableTransactionMetadata) if opts.Content&datastore.WatchRelationships == datastore.WatchRelationships { tableNames = append(tableNames, cds.tableTupleName()) } @@ -217,7 +220,13 @@ func (cds *crdbDatastore) watch( return } - for _, revChange := range tracked.FilterAndRemoveRevisionChanges(revisions.HLCKeyLessThanFunc, rev) { + filtered, err := tracked.FilterAndRemoveRevisionChanges(revisions.HLCKeyLessThanFunc, rev) + if err != nil { + sendError(err) + return + } + + for _, revChange := range filtered { revChange := revChange if !sendChange(&revChange) { return @@ -393,6 +402,24 @@ func (cds *crdbDatastore) watch( return } } + + case tableTransactionMetadata: + if details.After != nil { + rev, err := revisions.HLCRevisionFromString(details.Updated) + if err != nil { + sendError(fmt.Errorf("malformed update timestamp: %w", err)) + return + } + + if err := tracked.SetRevisionMetadata(ctx, rev, details.After.Metadata); err != nil { + sendError(err) + return + } + } + + default: + sendError(spiceerrors.MustBugf("unexpected table name in changefeed: %s", tableName)) + return } } diff --git a/internal/datastore/memdb/memdb.go b/internal/datastore/memdb/memdb.go index ffec0e2523..f7e0b3b0c1 100644 --- a/internal/datastore/memdb/memdb.go +++ b/internal/datastore/memdb/memdb.go @@ -204,6 +204,12 @@ func (mdb *memdbDatastore) ReadWriteTx( tracked := common.NewChanges(revisions.TimestampIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) if tx != nil { + if config.Metadata != nil { + if err := tracked.SetRevisionMetadata(ctx, newRevision, config.Metadata.AsMap()); err != nil { + return datastore.NoRevision, err + } + } + for _, change := range tx.Changes() { switch change.Table { case tableRelationship: @@ -270,7 +276,11 @@ func (mdb *memdbDatastore) ReadWriteTx( } var rc datastore.RevisionChanges - changes := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + changes, err := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + if err != nil { + return datastore.NoRevision, err + } + if len(changes) > 1 { return datastore.NoRevision, spiceerrors.MustBugf("unexpected MemDB transaction with multiple revision changes") } else if len(changes) == 1 { diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index ed1fc7749a..33425dd9dd 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -37,6 +37,7 @@ const ( colID = "id" colTimestamp = "timestamp" + colMetadata = "metadata" colNamespace = "namespace" colConfig = "serialized_config" colCreatedTxn = "created_transaction" @@ -207,10 +208,7 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option driver := migrations.NewMySQLDriverFromDB(db, config.tablePrefix) queryBuilder := NewQueryBuilder(driver) - createTxn, _, err := sb.Insert(driver.RelationTupleTransaction()).Values().ToSql() - if err != nil { - return nil, fmt.Errorf("NewMySQLDatastore: %w", err) - } + createTxn := sb.Insert(driver.RelationTupleTransaction()).Columns(colMetadata) // used for seeding the initial relation_tuple_transaction. using INSERT IGNORE on a known // ID value makes this idempotent (i.e. safe to execute concurrently). @@ -339,7 +337,12 @@ func (mds *Datastore) ReadWriteTx( for i := uint8(0); i <= mds.maxRetries; i++ { var newTxnID uint64 if err = migrations.BeginTxFunc(ctx, mds.db, &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) error { - newTxnID, err = mds.createNewTransaction(ctx, tx) + var metadata map[string]any + if config.Metadata != nil { + metadata = config.Metadata.AsMap() + } + + newTxnID, err = mds.createNewTransaction(ctx, tx, metadata) if err != nil { return fmt.Errorf("unable to create new txn ID: %w", err) } @@ -442,7 +445,7 @@ func newMySQLExecutor(tx querier) common.ExecuteQueryFunc { } var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper err := rows.Scan( &nextTuple.ResourceAndRelation.Namespace, &nextTuple.ResourceAndRelation.ObjectId, @@ -497,7 +500,7 @@ type Datastore struct { cancelGc context.CancelFunc gcHasRun atomic.Bool - createTxn string + createTxn sq.InsertBuilder createBaseTxn string *QueryBuilder diff --git a/internal/datastore/mysql/datastore_test.go b/internal/datastore/mysql/datastore_test.go index 23390ba1eb..c9eeeec066 100644 --- a/internal/datastore/mysql/datastore_test.go +++ b/internal/datastore/mysql/datastore_test.go @@ -640,7 +640,7 @@ func TransactionTimestampsTest(t *testing.T, ds datastore.Datastore) { // Transaction timestamp should not be stored in system time zone tx, err := db.BeginTx(ctx, nil) req.NoError(err) - txID, err := ds.(*Datastore).createNewTransaction(ctx, tx) + txID, err := ds.(*Datastore).createNewTransaction(ctx, tx, nil) req.NoError(err) err = tx.Commit() req.NoError(err) diff --git a/internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go b/internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go new file mode 100644 index 0000000000..64e66a83e8 --- /dev/null +++ b/internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go @@ -0,0 +1,18 @@ +package migrations + +import "fmt" + +func addMetadataToTransactionTable(t *tables) string { + return fmt.Sprintf(`ALTER TABLE %s + ADD COLUMN metadata BLOB NULL DEFAULT NULL;`, + t.RelationTupleTransaction(), + ) +} + +func init() { + mustRegisterMigration("add_metadata_to_transaction_table", "add_relationship_counters_table", noNonatomicMigration, + newStatementBatch( + addMetadataToTransactionTable, + ).execute, + ) +} diff --git a/internal/datastore/mysql/query_builder.go b/internal/datastore/mysql/query_builder.go index 3e1b700811..2356dc605a 100644 --- a/internal/datastore/mysql/query_builder.go +++ b/internal/datastore/mysql/query_builder.go @@ -9,8 +9,8 @@ import ( // QueryBuilder captures all parameterizable queries used // by the MySQL datastore implementation type QueryBuilder struct { - GetLastRevision sq.SelectBuilder - GetRevisionRange sq.SelectBuilder + GetLastRevision sq.SelectBuilder + LoadRevisionRange sq.SelectBuilder WriteNamespaceQuery sq.InsertBuilder ReadNamespaceQuery sq.SelectBuilder @@ -43,7 +43,7 @@ func NewQueryBuilder(driver *migrations.MySQLDriver) *QueryBuilder { // transaction builders builder.GetLastRevision = getLastRevision(driver.RelationTupleTransaction()) - builder.GetRevisionRange = getRevisionRange(driver.RelationTupleTransaction()) + builder.LoadRevisionRange = loadRevisionRange(driver.RelationTupleTransaction()) // namespace builders builder.WriteNamespaceQuery = writeNamespace(driver.Namespace()) @@ -99,8 +99,8 @@ func getLastRevision(tableTransaction string) sq.SelectBuilder { return sb.Select("MAX(id)").From(tableTransaction).Limit(1) } -func getRevisionRange(tableTransaction string) sq.SelectBuilder { - return sb.Select("MIN(id)", "MAX(id)").From(tableTransaction) +func loadRevisionRange(tableTransaction string) sq.SelectBuilder { + return sb.Select(colID, colMetadata).From(tableTransaction) } func readCounter(tableRelationshipCounters string) sq.SelectBuilder { diff --git a/internal/datastore/mysql/readwrite.go b/internal/datastore/mysql/readwrite.go index 26c7f4a017..92ca509319 100644 --- a/internal/datastore/mysql/readwrite.go +++ b/internal/datastore/mysql/readwrite.go @@ -49,10 +49,10 @@ type mysqlReadWriteTXN struct { newTxnID uint64 } -// caveatContextWrapper is used to marshall maps into MySQLs JSON data type -type caveatContextWrapper map[string]any +// structpbWrapper is used to marshall maps into MySQLs JSON data type +type structpbWrapper map[string]any -func (cc *caveatContextWrapper) Scan(val any) error { +func (cc *structpbWrapper) Scan(val any) error { v, ok := val.([]byte) if !ok { return fmt.Errorf("unsupported type: %T", v) @@ -60,7 +60,7 @@ func (cc *caveatContextWrapper) Scan(val any) error { return json.Unmarshal(v, &cc) } -func (cc *caveatContextWrapper) Value() (driver.Value, error) { +func (cc *structpbWrapper) Value() (driver.Value, error) { return json.Marshal(&cc) } @@ -220,7 +220,7 @@ func (rwt *mysqlReadWriteTXN) WriteRelationships(ctx context.Context, mutations } var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper tupleIdsToDelete := make([]int64, 0, len(clauses)) for rows.Next() { @@ -281,7 +281,7 @@ func (rwt *mysqlReadWriteTXN) WriteRelationships(ctx context.Context, mutations tpl := mut.Tuple var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper if tpl.Caveat != nil { caveatName = tpl.Caveat.CaveatName caveatContext = tpl.Caveat.Context.AsMap() @@ -497,7 +497,7 @@ func (rwt *mysqlReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.BulkW } var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper if tpl.Caveat != nil { caveatName = tpl.Caveat.CaveatName caveatContext = tpl.Caveat.Context.AsMap() diff --git a/internal/datastore/mysql/revisions.go b/internal/datastore/mysql/revisions.go index 5570e670a9..80b259f77d 100644 --- a/internal/datastore/mysql/revisions.go +++ b/internal/datastore/mysql/revisions.go @@ -74,8 +74,6 @@ func (mds *Datastore) optimizedRevisionFunc(ctx context.Context) (datastore.Revi } func (mds *Datastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { - // implementation deviates slightly from PSQL implementation in order to support - // database seeding in runtime, instead of through migrate command revision, err := mds.loadRevision(ctx) if err != nil { return datastore.NoRevision, err @@ -156,16 +154,26 @@ func (mds *Datastore) checkValidTransaction(ctx context.Context, revisionTx uint return freshEnough.Bool, unknown.Bool, nil } -func (mds *Datastore) createNewTransaction(ctx context.Context, tx *sql.Tx) (newTxnID uint64, err error) { +func (mds *Datastore) createNewTransaction(ctx context.Context, tx *sql.Tx, metadata map[string]any) (newTxnID uint64, err error) { ctx, span := tracer.Start(ctx, "createNewTransaction") defer span.End() - createQuery := mds.createTxn + var wrappedMetadata structpbWrapper + if len(metadata) > 0 { + wrappedMetadata = metadata + } + + createQuery := mds.createTxn.Values(&wrappedMetadata) + if err != nil { + return 0, fmt.Errorf("createNewTransaction: %w", err) + } + + sql, args, err := createQuery.ToSql() if err != nil { return 0, fmt.Errorf("createNewTransaction: %w", err) } - result, err := tx.ExecContext(ctx, createQuery) + result, err := tx.ExecContext(ctx, sql, args...) if err != nil { return 0, fmt.Errorf("createNewTransaction: %w", err) } diff --git a/internal/datastore/mysql/watch.go b/internal/datastore/mysql/watch.go index f3c2dffc84..14e356796f 100644 --- a/internal/datastore/mysql/watch.go +++ b/internal/datastore/mysql/watch.go @@ -125,7 +125,52 @@ func (mds *Datastore) loadChanges( return } - sql, args, err := mds.QueryChangedQuery.Where(sq.Or{ + stagedChanges := common.NewChanges(revisions.TransactionIDKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) + + // Load any metadata for the revision range. + sql, args, err := mds.LoadRevisionRange.Where(sq.Or{ + sq.And{ + sq.Gt{colID: afterRevision}, + sq.LtOrEq{colID: newRevision}, + }, + }).ToSql() + if err != nil { + return + } + + rows, err := mds.db.QueryContext(ctx, sql, args...) + if err != nil { + if errors.Is(err, context.Canceled) { + err = datastore.NewWatchCanceledErr() + } + return + } + defer common.LogOnError(ctx, rows.Close) + + for rows.Next() { + var txnID uint64 + var metadata structpbWrapper + err = rows.Scan( + &txnID, + &metadata, + ) + if err != nil { + return nil, 0, err + } + + if len(metadata) > 0 { + if err := stagedChanges.SetRevisionMetadata(ctx, revisions.NewForTransactionID(txnID), metadata); err != nil { + return nil, 0, err + } + } + } + rows.Close() + if err = rows.Err(); err != nil { + return + } + + // Load the changes relationships for the revision range. + sql, args, err = mds.QueryChangedQuery.Where(sq.Or{ sq.And{ sq.Gt{colCreatedTxn: afterRevision}, sq.LtOrEq{colCreatedTxn: newRevision}, @@ -139,7 +184,7 @@ func (mds *Datastore) loadChanges( return } - rows, err := mds.db.QueryContext(ctx, sql, args...) + rows, err = mds.db.QueryContext(ctx, sql, args...) if err != nil { if errors.Is(err, context.Canceled) { err = datastore.NewWatchCanceledErr() @@ -148,8 +193,6 @@ func (mds *Datastore) loadChanges( } defer common.LogOnError(ctx, rows.Close) - stagedChanges := common.NewChanges(revisions.TransactionIDKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) - for rows.Next() { nextTuple := &core.RelationTuple{ ResourceAndRelation: &core.ObjectAndRelation{}, @@ -159,7 +202,7 @@ func (mds *Datastore) loadChanges( var createdTxn uint64 var deletedTxn uint64 var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper err = rows.Scan( &nextTuple.ResourceAndRelation.Namespace, &nextTuple.ResourceAndRelation.ObjectId, @@ -196,6 +239,6 @@ func (mds *Datastore) loadChanges( return } - changes = stagedChanges.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) + changes, err = stagedChanges.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) return } diff --git a/internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go b/internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go new file mode 100644 index 0000000000..162ba71f18 --- /dev/null +++ b/internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go @@ -0,0 +1,23 @@ +package migrations + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" +) + +const addMetadataToTransactionTable = `ALTER TABLE relation_tuple_transaction ADD COLUMN IF NOT EXISTS metadata JSONB NOT NULL DEFAULT '{}'` + +func init() { + if err := DatabaseMigrations.Register("add-metadata-to-transaction-table", "create-relationships-counters-table", + func(ctx context.Context, conn *pgx.Conn) error { + if _, err := conn.Exec(ctx, addMetadataToTransactionTable); err != nil { + return fmt.Errorf("failed to add metadata to transaction table: %w", err) + } + return nil + }, + noTxMigration); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 717b67f864..c091ae9fb0 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -49,6 +49,7 @@ const ( colXID = "xid" colTimestamp = "timestamp" + colMetadata = "metadata" colNamespace = "namespace" colConfig = "serialized_config" colCreatedXid = "created_xid" @@ -102,12 +103,7 @@ var ( OrderByClause(fmt.Sprintf("%s DESC", colXID)). Limit(1) - createTxn = fmt.Sprintf( - "INSERT INTO %s DEFAULT VALUES RETURNING %s, %s", - tableTransaction, - colXID, - colSnapshot, - ) + createTxn = psql.Insert(tableTransaction).Columns(colMetadata) getNow = psql.Select("NOW()") @@ -435,7 +431,12 @@ func (pgd *pgDatastore) ReadWriteTx( var newSnapshot pgSnapshot err = wrapError(pgx.BeginTxFunc(ctx, pgd.writePool, pgx.TxOptions{IsoLevel: pgx.Serializable}, func(tx pgx.Tx) error { var err error - newXID, newSnapshot, err = createNewTransaction(ctx, tx) + var metadata map[string]any + if config.Metadata != nil { + metadata = config.Metadata.AsMap() + } + + newXID, newSnapshot, err = createNewTransaction(ctx, tx, metadata) if err != nil { return err } diff --git a/internal/datastore/postgres/postgres_shared_test.go b/internal/datastore/postgres/postgres_shared_test.go index fd7a7fb3c3..33dd5c8f06 100644 --- a/internal/datastore/postgres/postgres_shared_test.go +++ b/internal/datastore/postgres/postgres_shared_test.go @@ -438,7 +438,7 @@ func TransactionTimestampsTest(t *testing.T, ds datastore.Datastore) { tx, err := pgd.writePool.Begin(ctx) require.NoError(err) - txXID, _, err := createNewTransaction(ctx, tx) + txXID, _, err := createNewTransaction(ctx, tx, nil) require.NoError(err) err = tx.Commit(ctx) diff --git a/internal/datastore/postgres/revisions.go b/internal/datastore/postgres/revisions.go index 5cc0f1cf28..1b6eb782d7 100644 --- a/internal/datastore/postgres/revisions.go +++ b/internal/datastore/postgres/revisions.go @@ -235,11 +235,22 @@ func parseRevisionDecimal(revisionStr string) (datastore.Revision, error) { }}, nil } -func createNewTransaction(ctx context.Context, tx pgx.Tx) (newXID xid8, newSnapshot pgSnapshot, err error) { +var emptyMetadata = map[string]any{} + +func createNewTransaction(ctx context.Context, tx pgx.Tx, metadata map[string]any) (newXID xid8, newSnapshot pgSnapshot, err error) { ctx, span := tracer.Start(ctx, "createNewTransaction") defer span.End() - cterr := tx.QueryRow(ctx, createTxn).Scan(&newXID, &newSnapshot) + if metadata == nil { + metadata = emptyMetadata + } + + sql, args, err := createTxn.Values(metadata).Suffix("RETURNING xid, snapshot").ToSql() + if err != nil { + return + } + + cterr := tx.QueryRow(ctx, sql, args...).Scan(&newXID, &newSnapshot) if cterr != nil { err = fmt.Errorf("error when trying to create a new transaction: %w", cterr) } @@ -250,6 +261,7 @@ type postgresRevision struct { snapshot pgSnapshot optionalTxID xid8 optionalNanosTimestamp uint64 + optionalMetadata map[string]any } func (pr postgresRevision) Equal(rhsRaw datastore.Revision) bool { diff --git a/internal/datastore/postgres/watch.go b/internal/datastore/postgres/watch.go index 2c7a2d3c97..029ddb3253 100644 --- a/internal/datastore/postgres/watch.go +++ b/internal/datastore/postgres/watch.go @@ -25,10 +25,10 @@ var ( // xid8 is one of the last ~2 billion transaction IDs generated. We should be garbage // collecting these transactions long before we get to that point. newRevisionsQuery = fmt.Sprintf(` - SELECT %[1]s, %[2]s, %[3]s FROM %[4]s + SELECT %[1]s, %[2]s, %[3]s, %[4]s FROM %[5]s WHERE %[1]s >= pg_snapshot_xmax($1) OR ( %[1]s >= pg_snapshot_xmin($1) AND NOT pg_visible_in_snapshot(%[1]s, $1) - ) ORDER BY pg_xact_commit_timestamp(%[1]s::xid), %[1]s;`, colXID, colSnapshot, colTimestamp, tableTransaction) + ) ORDER BY pg_xact_commit_timestamp(%[1]s::xid), %[1]s;`, colXID, colSnapshot, colMetadata, colTimestamp, tableTransaction) queryChangedTuples = psql.Select( colNamespace, @@ -199,8 +199,9 @@ func (pgd *pgDatastore) getNewRevisions(ctx context.Context, afterTX postgresRev for rows.Next() { var nextXID xid8 var nextSnapshot pgSnapshot + var metadata map[string]any var timestamp time.Time - if err := rows.Scan(&nextXID, &nextSnapshot, ×tamp); err != nil { + if err := rows.Scan(&nextXID, &nextSnapshot, &metadata, ×tamp); err != nil { return fmt.Errorf("unable to decode new revision: %w", err) } @@ -208,6 +209,7 @@ func (pgd *pgDatastore) getNewRevisions(ctx context.Context, afterTX postgresRev snapshot: nextSnapshot.markComplete(nextXID.Uint64), optionalTxID: nextXID, optionalNanosTimestamp: uint64(timestamp.UnixNano()), + optionalMetadata: metadata, }) } if rows.Err() != nil { @@ -227,6 +229,8 @@ func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []postgresRev filter := make(map[uint64]int, len(revisions)) txidToRevision := make(map[uint64]postgresRevision, len(revisions)) + tracked := common.NewChanges(revisionKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) + for i, rev := range revisions { if rev.optionalTxID.Uint64 < xmin { xmin = rev.optionalTxID.Uint64 @@ -236,9 +240,13 @@ func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []postgresRev } filter[rev.optionalTxID.Uint64] = i txidToRevision[rev.optionalTxID.Uint64] = rev - } - tracked := common.NewChanges(revisionKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) + if len(rev.optionalMetadata) > 0 { + if err := tracked.SetRevisionMetadata(ctx, rev, rev.optionalMetadata); err != nil { + return nil, err + } + } + } // Load relationship changes. if options.Content&datastore.WatchRelationships == datastore.WatchRelationships { @@ -265,10 +273,9 @@ func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []postgresRev } // Reconcile the changes. - reconciledChanges := tracked.AsRevisionChanges(func(lhs, rhs uint64) bool { + return tracked.AsRevisionChanges(func(lhs, rhs uint64) bool { return filter[lhs] < filter[rhs] }) - return reconciledChanges, nil } func (pgd *pgDatastore) loadRelationshipChanges(ctx context.Context, xmin uint64, xmax uint64, txidToRevision map[uint64]postgresRevision, filter map[uint64]int, tracked *common.Changes[postgresRevision, uint64]) error { diff --git a/internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go b/internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go new file mode 100644 index 0000000000..5c3165ed5b --- /dev/null +++ b/internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go @@ -0,0 +1,34 @@ +package migrations + +import ( + "context" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" +) + +const ( + addTransactionMetadataTable = `CREATE TABLE transaction_metadata ( + transaction_tag STRING(36) NOT NULL, + created_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP()), + metadata JSON + ) PRIMARY KEY (transaction_tag), + ROW DELETION POLICY (OLDER_THAN(created_at, INTERVAL 2 DAY)) + ` +) + +func init() { + if err := SpannerMigrations.Register("add-transaction-metadata-table", "add-relationship-counter-table", func(ctx context.Context, w Wrapper) error { + updateOp, err := w.adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ + Database: w.client.DatabaseName(), + Statements: []string{ + addTransactionMetadataTable, + }, + }) + if err != nil { + return err + } + return updateOp.Wait(ctx) + }, nil); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/spanner/schema.go b/internal/datastore/spanner/schema.go index 364fc549fc..3cde345a2a 100644 --- a/internal/datastore/spanner/schema.go +++ b/internal/datastore/spanner/schema.go @@ -30,6 +30,10 @@ const ( colCounterSerializedFilter = "serialized_filter" colCounterCurrentCount = "current_count" colCounterUpdatedAtTimestamp = "updated_at_timestamp" + + tableTransactionMetadata = "transaction_metadata" + colTransactionTag = "transaction_tag" + colMetadata = "metadata" ) var allRelationshipCols = []string{ diff --git a/internal/datastore/spanner/spanner.go b/internal/datastore/spanner/spanner.go index d9bfcbf360..6392551189 100644 --- a/internal/datastore/spanner/spanner.go +++ b/internal/datastore/spanner/spanner.go @@ -12,6 +12,7 @@ import ( "cloud.google.com/go/spanner" ocprom "contrib.go.opencensus.io/exporter/prometheus" sq "github.com/Masterminds/squirrel" + "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "go.opencensus.io/plugin/ocgrpc" "go.opencensus.io/stats/view" @@ -239,18 +240,50 @@ func (sd *spannerDatastore) SnapshotReader(revisionRaw datastore.Revision) datas return spannerReader{executor, txSource, sd.filterMaximumIDCount} } +func (sd *spannerDatastore) readTransactionMetadata(ctx context.Context, transactionTag string) (map[string]any, error) { + row, err := sd.client.Single().ReadRow(ctx, tableTransactionMetadata, spanner.Key{transactionTag}, []string{colMetadata}) + if err != nil { + if spanner.ErrCode(err) == codes.NotFound { + return map[string]any{}, nil + } + + return nil, err + } + + var metadata map[string]any + if err := row.Columns(&metadata); err != nil { + return nil, err + } + + return metadata, nil +} + func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUserFunc, opts ...options.RWTOptionsOption) (datastore.Revision, error) { config := options.NewRWTOptionsWithOptions(opts...) ctx, span := tracer.Start(ctx, "ReadWriteTx") defer span.End() + transactionTag := "sdb-rwt-" + uuid.NewString() + ctx, cancel := context.WithCancel(ctx) - ts, err := sd.client.ReadWriteTransaction(ctx, func(ctx context.Context, spannerRWT *spanner.ReadWriteTransaction) error { + rs, err := sd.client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, spannerRWT *spanner.ReadWriteTransaction) error { txSource := func() readTX { return &traceableRTX{delegate: spannerRWT} } + if config.Metadata != nil { + // Insert the metadata into the transaction metadata table. + mutation := spanner.Insert(tableTransactionMetadata, + []string{colTransactionTag, colMetadata}, + []any{transactionTag, config.Metadata.AsMap()}, + ) + + if err := spannerRWT.BufferWrite([]*spanner.Mutation{mutation}); err != nil { + return fmt.Errorf("unable to write metadata: %w", err) + } + } + executor := common.QueryExecutor{Executor: queryExecutor(txSource)} rwt := spannerReadWriteTXN{ spannerReader{executor, txSource, sd.filterMaximumIDCount}, @@ -270,7 +303,7 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser } return nil - }) + }, spanner.TransactionOptions{TransactionTag: transactionTag}) if err != nil { if cerr := convertToWriteConstraintError(err); cerr != nil { return datastore.NoRevision, cerr @@ -278,7 +311,7 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser return datastore.NoRevision, err } - return revisions.NewForTime(ts), nil + return revisions.NewForTime(rs.CommitTs), nil } func (sd *spannerDatastore) ReadyState(ctx context.Context) (datastore.ReadyState, error) { diff --git a/internal/datastore/spanner/watch.go b/internal/datastore/spanner/watch.go index edd4eab37e..5a533e4b8a 100644 --- a/internal/datastore/spanner/watch.go +++ b/internal/datastore/spanner/watch.go @@ -156,6 +156,23 @@ func (sd *spannerDatastore) watch( } defer reader.Close() + metadataForTransactionTag := map[string]map[string]any{} + + addMetadataForTransactionTag := func(ctx context.Context, tracked *common.Changes[revisions.TimestampRevision, int64], revision revisions.TimestampRevision, transactionTag string) error { + if metadata, ok := metadataForTransactionTag[transactionTag]; ok { + return tracked.SetRevisionMetadata(ctx, revision, metadata) + } + + // Otherwise, load the metadata from the transactions metadata table. + transactionMetadata, err := sd.readTransactionMetadata(ctx, transactionTag) + if err != nil { + return err + } + + metadataForTransactionTag[transactionTag] = transactionMetadata + return tracked.SetRevisionMetadata(ctx, revision, transactionMetadata) + } + err = reader.Read(ctx, func(result *changestreams.ReadResult) error { // See: https://cloud.google.com/spanner/docs/change-streams/details for _, record := range result.ChangeRecords { @@ -165,6 +182,12 @@ func (sd *spannerDatastore) watch( changeRevision := revisions.NewForTime(dcr.CommitTimestamp) modType := dcr.ModType // options are INSERT, UPDATE, DELETE + if len(dcr.TransactionTag) > 0 { + if err := addMetadataForTransactionTag(ctx, tracked, changeRevision, dcr.TransactionTag); err != nil { + return err + } + } + for _, mod := range dcr.Mods { primaryKeyColumnValues, ok := mod.Keys.Value.(map[string]any) if !ok { @@ -312,7 +335,12 @@ func (sd *spannerDatastore) watch( } if !tracked.IsEmpty() { - for _, revChange := range tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) { + changes, err := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + if err != nil { + return err + } + + for _, revChange := range changes { revChange := revChange if !sendChange(&revChange) { return datastore.NewWatchDisconnectedErr() diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index c31d2592a6..99804ce71f 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -9,6 +9,7 @@ import ( "time" "github.com/rs/zerolog" + "google.golang.org/protobuf/types/known/structpb" "github.com/authzed/spicedb/pkg/tuple" @@ -61,6 +62,9 @@ type RevisionChanges struct { // up until and including the Revision and that no additional schema updates can // have occurred before this point. IsCheckpoint bool + + // Metadata is the metadata associated with the revision, if any. + Metadata *structpb.Struct } func (rc *RevisionChanges) MarshalZerologObject(e *zerolog.Event) { diff --git a/pkg/datastore/options/options.go b/pkg/datastore/options/options.go index 34daf2e58a..1d13f2ca11 100644 --- a/pkg/datastore/options/options.go +++ b/pkg/datastore/options/options.go @@ -1,6 +1,8 @@ package options import ( + "google.golang.org/protobuf/types/known/structpb" + core "github.com/authzed/spicedb/pkg/proto/core/v1" ) @@ -51,7 +53,8 @@ type ResourceRelation struct { // RWTOptions are options that can affect the way a read-write transaction is // executed. type RWTOptions struct { - DisableRetries bool `debugmap:"visible"` + DisableRetries bool `debugmap:"visible"` + Metadata *structpb.Struct `debugmap:"visible"` } // DeleteOptions are the options that can affect the results of a delete relationships diff --git a/pkg/datastore/options/zz_generated.query_options.go b/pkg/datastore/options/zz_generated.query_options.go index f87f310b7c..f761b06b66 100644 --- a/pkg/datastore/options/zz_generated.query_options.go +++ b/pkg/datastore/options/zz_generated.query_options.go @@ -4,6 +4,7 @@ package options import ( defaults "github.com/creasty/defaults" helpers "github.com/ecordell/optgen/helpers" + structpb "google.golang.org/protobuf/types/known/structpb" ) type QueryOptionsOption func(q *QueryOptions) @@ -192,6 +193,7 @@ func NewRWTOptionsWithOptionsAndDefaults(opts ...RWTOptionsOption) *RWTOptions { func (r *RWTOptions) ToOption() RWTOptionsOption { return func(to *RWTOptions) { to.DisableRetries = r.DisableRetries + to.Metadata = r.Metadata } } @@ -199,6 +201,7 @@ func (r *RWTOptions) ToOption() RWTOptionsOption { func (r RWTOptions) DebugMap() map[string]any { debugMap := map[string]any{} debugMap["DisableRetries"] = helpers.DebugValue(r.DisableRetries, false) + debugMap["Metadata"] = helpers.DebugValue(r.Metadata, false) return debugMap } @@ -224,3 +227,10 @@ func WithDisableRetries(disableRetries bool) RWTOptionsOption { r.DisableRetries = disableRetries } } + +// WithMetadata returns an option that can set Metadata on a RWTOptions +func WithMetadata(metadata *structpb.Struct) RWTOptionsOption { + return func(r *RWTOptions) { + r.Metadata = metadata + } +} diff --git a/pkg/datastore/test/datastore.go b/pkg/datastore/test/datastore.go index 05f034c6ba..cc8a0999f5 100644 --- a/pkg/datastore/test/datastore.go +++ b/pkg/datastore/test/datastore.go @@ -162,6 +162,7 @@ func AllWithExceptions(t *testing.T, tester DatastoreTester, except Categories) t.Run("TestCaveatedRelationshipWatch", func(t *testing.T) { CaveatedRelationshipWatchTest(t, tester) }) t.Run("TestWatchWithTouch", func(t *testing.T) { WatchWithTouchTest(t, tester) }) t.Run("TestWatchWithDelete", func(t *testing.T) { WatchWithDeleteTest(t, tester) }) + t.Run("TestWatchWithMetadata", func(t *testing.T) { WatchWithMetadataTest(t, tester) }) } if !except.Watch() && !except.WatchSchema() { diff --git a/pkg/datastore/test/watch.go b/pkg/datastore/test/watch.go index 0191796bbc..a91ec8eefc 100644 --- a/pkg/datastore/test/watch.go +++ b/pkg/datastore/test/watch.go @@ -13,9 +13,11 @@ import ( "github.com/scylladb/go-set/strset" "github.com/stretchr/testify/require" "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/structpb" "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" @@ -176,6 +178,50 @@ func VerifyUpdates( require.False(expectDisconnect, "all changes verified without expected disconnect") } +func VerifyUpdatesWithMetadata( + require *require.Assertions, + testUpdates []updateWithMetadata, + changes <-chan *datastore.RevisionChanges, + errchan <-chan error, + expectDisconnect bool, +) { + for _, expected := range testUpdates { + changeWait := time.NewTimer(waitForChangesTimeout) + select { + case change, ok := <-changes: + if !ok { + require.True(expectDisconnect, "unexpected disconnect") + errWait := time.NewTimer(waitForChangesTimeout) + select { + case err := <-errchan: + require.True(errors.As(err, &datastore.ErrWatchDisconnected{})) + return + case <-errWait.C: + require.Fail("Timed out waiting for ErrWatchDisconnected") + } + return + } + + expectedChangeSet := setOfChanges(expected.updates) + actualChangeSet := setOfChanges(change.RelationshipChanges) + + missingExpected := strset.Difference(expectedChangeSet, actualChangeSet) + unexpected := strset.Difference(actualChangeSet, expectedChangeSet) + + require.True(missingExpected.IsEmpty(), "expected changes missing: %s", missingExpected) + require.True(unexpected.IsEmpty(), "unexpected changes: %s", unexpected) + + require.Equal(expected.metadata, change.Metadata.AsMap(), "metadata mismatch") + + time.Sleep(1 * time.Millisecond) + case <-changeWait.C: + require.Fail("Timed out", "waited for changes: %s", expected) + } + } + + require.False(expectDisconnect, "all changes verified without expected disconnect") +} + func setOfChanges(changes []*core.RelationTupleUpdate) *strset.Set { changeSet := strset.NewWithSize(len(changes)) for _, change := range changes { @@ -337,6 +383,50 @@ func WatchWithTouchTest(t *testing.T, tester DatastoreTester) { ) } +type updateWithMetadata struct { + updates []*core.RelationTupleUpdate + metadata map[string]any +} + +func WatchWithMetadataTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 16) + require.NoError(err) + + setupDatastore(ds, require) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lowestRevision, err := ds.HeadRevision(ctx) + require.NoError(err) + + changes, errchan := ds.Watch(ctx, lowestRevision, datastore.WatchJustRelationships()) + require.Zero(len(errchan)) + + metadata, err := structpb.NewStruct(map[string]any{"somekey": "somevalue"}) + require.NoError(err) + + _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.WriteRelationships(ctx, []*core.RelationTupleUpdate{ + tuple.Create(tuple.MustParse("document:firstdoc#viewer@user:tom")), + }) + }, options.WithMetadata(metadata)) + require.NoError(err) + + VerifyUpdatesWithMetadata(require, []updateWithMetadata{ + { + updates: []*core.RelationTupleUpdate{tuple.Touch(tuple.Parse("document:firstdoc#viewer@user:tom"))}, + metadata: map[string]any{"somekey": "somevalue"}, + }, + }, + changes, + errchan, + false, + ) +} + func WatchWithDeleteTest(t *testing.T, tester DatastoreTester) { require := require.New(t)