From e7f5ad8de203abba2e0e6d8f962475137cf338ec Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Sun, 26 May 2024 17:36:16 -0400 Subject: [PATCH] Implement support for metadata associated with read-write transactions This will allow callers of APIs such as WriteRelationships and DeleteRelationships to assign metadata to the transaction that will be mirrored back out in the Watch API, to provide a means for correlating updates --- internal/datastore/common/changes.go | 51 +++++++++-- internal/datastore/common/changes_test.go | 40 +++++++-- internal/datastore/crdb/crdb.go | 22 +++++ ...ion.0008_add_transaction_metadata_table.go | 75 ++++++++++++++++ internal/datastore/crdb/watch.go | 31 ++++++- internal/datastore/memdb/memdb.go | 12 ++- internal/datastore/mysql/datastore.go | 17 ++-- internal/datastore/mysql/datastore_test.go | 2 +- ....0009_add_metadata_to_transaction_table.go | 18 ++++ internal/datastore/mysql/query_builder.go | 10 +-- internal/datastore/mysql/readwrite.go | 14 +-- internal/datastore/mysql/revisions.go | 18 ++-- internal/datastore/mysql/watch.go | 55 ++++++++++-- ....0019_add_metadata_to_transaction_table.go | 23 +++++ internal/datastore/postgres/postgres.go | 15 ++-- .../postgres/postgres_shared_test.go | 2 +- internal/datastore/postgres/revisions.go | 16 +++- internal/datastore/postgres/watch.go | 21 +++-- ...ion.0010_add_transaction_metadata_table.go | 34 +++++++ internal/datastore/spanner/schema.go | 4 + internal/datastore/spanner/spanner.go | 39 +++++++- internal/datastore/spanner/watch.go | 30 ++++++- pkg/datastore/datastore.go | 4 + pkg/datastore/options/options.go | 5 +- .../options/zz_generated.query_options.go | 10 +++ pkg/datastore/test/datastore.go | 1 + pkg/datastore/test/watch.go | 90 +++++++++++++++++++ 27 files changed, 588 insertions(+), 71 deletions(-) create mode 100644 internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go create mode 100644 internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go create mode 100644 internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go create mode 100644 internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go diff --git a/internal/datastore/common/changes.go b/internal/datastore/common/changes.go index 9449e47a5a..33c13f2905 100644 --- a/internal/datastore/common/changes.go +++ b/internal/datastore/common/changes.go @@ -5,6 +5,7 @@ import ( "sort" "golang.org/x/exp/maps" + "google.golang.org/protobuf/types/known/structpb" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" @@ -35,6 +36,7 @@ type changeRecord[R datastore.Revision] struct { definitionsChanged map[string]datastore.SchemaDefinition namespacesDeleted map[string]struct{} caveatsDeleted map[string]struct{} + metadata map[string]any } // NewChanges creates a new Changes object for change tracking and de-duplication. @@ -125,6 +127,25 @@ func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error { return nil } +// SetRevisionMetadata sets the metadata for the given revision. +func (ch *Changes[R, K]) SetRevisionMetadata(ctx context.Context, rev R, metadata map[string]any) error { + if len(metadata) == 0 { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + if len(record.metadata) > 0 { + return spiceerrors.MustBugf("metadata already set for revision") + } + + maps.Copy(record.metadata, metadata) + return nil +} + func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) { k := ch.keyFunc(rev) revisionChanges, ok := ch.records[k] @@ -136,6 +157,7 @@ func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) { make(map[string]datastore.SchemaDefinition), make(map[string]struct{}), make(map[string]struct{}), + make(map[string]any), } ch.records[k] = revisionChanges } @@ -241,21 +263,25 @@ func (ch *Changes[R, K]) AddChangedDefinition( // AsRevisionChanges returns the list of changes processed so far as a datastore watch // compatible, ordered, changelist. -func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) []datastore.RevisionChanges { +func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) ([]datastore.RevisionChanges, error) { return ch.revisionChanges(lessThanFunc, *new(R), false) } // FilterAndRemoveRevisionChanges filters a list of changes processed up to the bound revision from the changes list, removing them // and returning the filtered changes. -func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) []datastore.RevisionChanges { - changes := ch.revisionChanges(lessThanFunc, boundRev, true) +func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) ([]datastore.RevisionChanges, error) { + changes, err := ch.revisionChanges(lessThanFunc, boundRev, true) + if err != nil { + return nil, err + } + ch.removeAllChangesBefore(boundRev) - return changes + return changes, nil } -func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) []datastore.RevisionChanges { +func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) ([]datastore.RevisionChanges, error) { if ch.IsEmpty() { - return nil + return nil, nil } revisionsWithChanges := make([]K, 0, len(ch.records)) @@ -266,7 +292,7 @@ func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, bou } if len(revisionsWithChanges) == 0 { - return nil + return nil, nil } sort.Slice(revisionsWithChanges, func(i int, j int) bool { @@ -292,9 +318,18 @@ func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, bou changes[i].ChangedDefinitions = maps.Values(revisionChangeRecord.definitionsChanged) changes[i].DeletedNamespaces = maps.Keys(revisionChangeRecord.namespacesDeleted) changes[i].DeletedCaveats = maps.Keys(revisionChangeRecord.caveatsDeleted) + + if len(revisionChangeRecord.metadata) > 0 { + metadata, err := structpb.NewStruct(revisionChangeRecord.metadata) + if err != nil { + return nil, spiceerrors.MustBugf("failed to convert metadata to structpb: %v", err) + } + + changes[i].Metadata = metadata + } } - return changes + return changes, nil } func (ch *Changes[R, K]) removeAllChangesBefore(boundRev R) { diff --git a/internal/datastore/common/changes_test.go b/internal/datastore/common/changes_test.go index db496a4d77..7b16ab3e46 100644 --- a/internal/datastore/common/changes_test.go +++ b/internal/datastore/common/changes_test.go @@ -330,9 +330,12 @@ func TestChanges(t *testing.T) { } } + actual, err := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) + require.NoError(err) + require.Equal( canonicalize(tc.expected), - canonicalize(ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc)), + canonicalize(actual), ) }) } @@ -347,6 +350,23 @@ func TestFilteredSchemaChanges(t *testing.T) { require.True(t, ch.IsEmpty()) } +func TestSetMetadata(t *testing.T) { + ctx := context.Background() + ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) + require.True(t, ch.IsEmpty()) + + err := ch.SetRevisionMetadata(ctx, rev1, map[string]any{"foo": "bar"}) + require.NoError(t, err) + require.False(t, ch.IsEmpty()) + + results, err := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev2) + require.NoError(t, err) + require.Equal(t, 1, len(results)) + require.True(t, ch.IsEmpty()) + + require.Equal(t, map[string]any{"foo": "bar"}, results[0].Metadata.AsMap()) +} + func TestFilteredRelationshipChanges(t *testing.T) { ctx := context.Background() ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships, 0) @@ -374,7 +394,8 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) { require.False(t, ch.IsEmpty()) - results := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev3) + results, err := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev3) + require.NoError(t, err) require.Equal(t, 2, len(results)) require.False(t, ch.IsEmpty()) @@ -393,8 +414,9 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) { }, }, results) - remaining := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) + remaining, err := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) require.Equal(t, 1, len(remaining)) + require.NoError(t, err) require.Equal(t, []datastore.RevisionChanges{ { @@ -405,11 +427,13 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) { }, }, remaining) - results = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillion) + results, err = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillion) + require.NoError(t, err) require.Equal(t, 1, len(results)) require.True(t, ch.IsEmpty()) - results = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillionOne) + results, err = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillionOne) + require.NoError(t, err) require.Equal(t, 0, len(results)) require.True(t, ch.IsEmpty()) } @@ -432,7 +456,8 @@ func TestHLCOrdering(t *testing.T) { err = ch.AddRelationshipChange(ctx, rev0, tuple.MustParse("document:foo#viewer@user:tom"), core.RelationTupleUpdate_TOUCH) require.NoError(t, err) - remaining := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + remaining, err := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + require.NoError(t, err) require.Equal(t, 2, len(remaining)) require.Equal(t, []datastore.RevisionChanges{ @@ -475,7 +500,8 @@ func TestHLCSameRevision(t *testing.T) { err = ch.AddRelationshipChange(ctx, rev0again, tuple.MustParse("document:foo#viewer@user:sarah"), core.RelationTupleUpdate_TOUCH) require.NoError(t, err) - remaining := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + remaining, err := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + require.NoError(t, err) require.Equal(t, 1, len(remaining)) expected := []*core.RelationTupleUpdate{ diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 08d5bafb23..c5ac0a1492 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -52,6 +52,7 @@ const ( tableTransactions = "transactions" tableCaveat = "caveat" tableRelationshipCounter = "relationship_counter" + tableTransactionMetadata = "transaction_metadata" colNamespace = "namespace" colConfig = "serialized_config" @@ -77,6 +78,8 @@ const ( colCounterSerializedFilter = "serialized_filter" colCounterCurrentCount = "current_count" colCounterUpdatedAt = "updated_at_timestamp" + colExpiresAt = "expires_at" + colMetadata = "metadata" errUnableToInstantiate = "unable to instantiate datastore" errRevision = "unable to find revision: %w" @@ -199,6 +202,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas analyzeBeforeStatistics: config.analyzeBeforeStatistics, filterMaximumIDCount: config.filterMaximumIDCount, supportsIntegrity: config.withIntegrity, + gcWindow: config.gcWindow, } ds.RemoteClockRevisions.SetNowFunc(ds.headRevisionInternal) @@ -281,6 +285,7 @@ type crdbDatastore struct { writeOverlapKeyer overlapKeyer overlapKeyInit func(ctx context.Context) keySet analyzeBeforeStatistics bool + gcWindow time.Duration beginChangefeedQuery string transactionNowQuery string @@ -324,6 +329,23 @@ func (cds *crdbDatastore) ReadWriteTx( Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(querier, cds.supportsIntegrity), } + // If metadata is to be attached, write that row now. + if config.Metadata != nil { + expiresAt := time.Now().Add(cds.gcWindow).Add(1 * time.Minute) + insertTransactionMetadata := psql.Insert(tableTransactionMetadata). + Columns(colExpiresAt, colMetadata). + Values(expiresAt, config.Metadata.AsMap()) + + sql, args, err := insertTransactionMetadata.ToSql() + if err != nil { + return fmt.Errorf("error building metadata insert: %w", err) + } + + if _, err := tx.Exec(ctx, sql, args...); err != nil { + return fmt.Errorf("error writing metadata: %w", err) + } + } + rwt := &crdbReadWriteTXN{ &crdbReader{ querier, diff --git a/internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go b/internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go new file mode 100644 index 0000000000..a98a93c8f2 --- /dev/null +++ b/internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go @@ -0,0 +1,75 @@ +package migrations + +import ( + "context" + "strings" + + "github.com/jackc/pgx/v5" +) + +const ( + // ttl_expiration_expression support was added in CRDB v22.2, but the E2E tests + // use v21.2. + addTransactionMetadataTableQueryWithBasicTTL = ` + CREATE TABLE transaction_metadata ( + key UUID PRIMARY KEY DEFAULT gen_random_uuid(), + expires_at TIMESTAMPTZ, + metadata JSONB + ) WITH (ttl_expire_after = '1d'); + ` + + addTransactionMetadataTableQuery = ` + CREATE TABLE transaction_metadata ( + key UUID PRIMARY KEY DEFAULT gen_random_uuid(), + expires_at TIMESTAMPTZ, + metadata JSONB + ) WITH (ttl_expiration_expression = 'expires_at', ttl_job_cron = '@daily'); + ` + + // See: https://www.cockroachlabs.com/docs/stable/changefeed-messages#prevent-changefeeds-from-emitting-row-level-ttl-deletes + // for why we set ttl_disable_changefeed_replication = 'true'. This isn't stricly necessary as the Watch API will ignore the + // deletions of these metadata rows, but no reason to even have it in the changefeed. + // NOTE: This only applies on CRDB v24 and later. + addTransactionMetadataTableQueryWithTTLIgnore = ` + CREATE TABLE transaction_metadata ( + key UUID PRIMARY KEY DEFAULT gen_random_uuid(), + expires_at TIMESTAMPTZ, + metadata JSONB + ) WITH (ttl_expiration_expression = 'expires_at', ttl_job_cron = '@daily', ttl_disable_changefeed_replication = 'true'); + ` +) + +func init() { + err := CRDBMigrations.Register("add-transaction-metadata-table", "add-integrity-relationtuple-table", addTransactionMetadataTable, noAtomicMigration) + if err != nil { + panic("failed to register migration: " + err.Error()) + } +} + +func addTransactionMetadataTable(ctx context.Context, conn *pgx.Conn) error { + row := conn.QueryRow(ctx, "select version()") + var version string + if err := row.Scan(&version); err != nil { + return err + } + + if strings.Contains(version, "v22.1") { + if _, err := conn.Exec(ctx, addTransactionMetadataTableQueryWithBasicTTL); err != nil { + return err + } + return nil + } + + if strings.Contains(version, "v24.") { + if _, err := conn.Exec(ctx, addTransactionMetadataTableQueryWithTTLIgnore); err != nil { + return err + } + return nil + } + + // CRDB v23 and v22.2.* + if _, err := conn.Exec(ctx, addTransactionMetadataTableQuery); err != nil { + return err + } + return nil +} diff --git a/internal/datastore/crdb/watch.go b/internal/datastore/crdb/watch.go index 2241ae33f7..d03989c6fe 100644 --- a/internal/datastore/crdb/watch.go +++ b/internal/datastore/crdb/watch.go @@ -55,6 +55,8 @@ type changeDetails struct { IntegrityKeyID *string `json:"integrity_key_id"` IntegrityHashAsHex *string `json:"integrity_hash"` TimestampAsString *string `json:"timestamp"` + + Metadata map[string]any `json:"metadata"` } } @@ -110,7 +112,8 @@ func (cds *crdbDatastore) watch( } defer func() { _ = conn.Close(ctx) }() - tableNames := make([]string, 0, 3) + tableNames := make([]string, 0, 4) + tableNames = append(tableNames, tableTransactionMetadata) if opts.Content&datastore.WatchRelationships == datastore.WatchRelationships { tableNames = append(tableNames, cds.tableTupleName()) } @@ -217,7 +220,13 @@ func (cds *crdbDatastore) watch( return } - for _, revChange := range tracked.FilterAndRemoveRevisionChanges(revisions.HLCKeyLessThanFunc, rev) { + filtered, err := tracked.FilterAndRemoveRevisionChanges(revisions.HLCKeyLessThanFunc, rev) + if err != nil { + sendError(err) + return + } + + for _, revChange := range filtered { revChange := revChange if !sendChange(&revChange) { return @@ -393,6 +402,24 @@ func (cds *crdbDatastore) watch( return } } + + case tableTransactionMetadata: + if details.After != nil { + rev, err := revisions.HLCRevisionFromString(details.Updated) + if err != nil { + sendError(fmt.Errorf("malformed update timestamp: %w", err)) + return + } + + if err := tracked.SetRevisionMetadata(ctx, rev, details.After.Metadata); err != nil { + sendError(err) + return + } + } + + default: + sendError(spiceerrors.MustBugf("unexpected table name in changefeed: %s", tableName)) + return } } diff --git a/internal/datastore/memdb/memdb.go b/internal/datastore/memdb/memdb.go index ffec0e2523..f7e0b3b0c1 100644 --- a/internal/datastore/memdb/memdb.go +++ b/internal/datastore/memdb/memdb.go @@ -204,6 +204,12 @@ func (mdb *memdbDatastore) ReadWriteTx( tracked := common.NewChanges(revisions.TimestampIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) if tx != nil { + if config.Metadata != nil { + if err := tracked.SetRevisionMetadata(ctx, newRevision, config.Metadata.AsMap()); err != nil { + return datastore.NoRevision, err + } + } + for _, change := range tx.Changes() { switch change.Table { case tableRelationship: @@ -270,7 +276,11 @@ func (mdb *memdbDatastore) ReadWriteTx( } var rc datastore.RevisionChanges - changes := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + changes, err := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + if err != nil { + return datastore.NoRevision, err + } + if len(changes) > 1 { return datastore.NoRevision, spiceerrors.MustBugf("unexpected MemDB transaction with multiple revision changes") } else if len(changes) == 1 { diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index ed1fc7749a..33425dd9dd 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -37,6 +37,7 @@ const ( colID = "id" colTimestamp = "timestamp" + colMetadata = "metadata" colNamespace = "namespace" colConfig = "serialized_config" colCreatedTxn = "created_transaction" @@ -207,10 +208,7 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option driver := migrations.NewMySQLDriverFromDB(db, config.tablePrefix) queryBuilder := NewQueryBuilder(driver) - createTxn, _, err := sb.Insert(driver.RelationTupleTransaction()).Values().ToSql() - if err != nil { - return nil, fmt.Errorf("NewMySQLDatastore: %w", err) - } + createTxn := sb.Insert(driver.RelationTupleTransaction()).Columns(colMetadata) // used for seeding the initial relation_tuple_transaction. using INSERT IGNORE on a known // ID value makes this idempotent (i.e. safe to execute concurrently). @@ -339,7 +337,12 @@ func (mds *Datastore) ReadWriteTx( for i := uint8(0); i <= mds.maxRetries; i++ { var newTxnID uint64 if err = migrations.BeginTxFunc(ctx, mds.db, &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) error { - newTxnID, err = mds.createNewTransaction(ctx, tx) + var metadata map[string]any + if config.Metadata != nil { + metadata = config.Metadata.AsMap() + } + + newTxnID, err = mds.createNewTransaction(ctx, tx, metadata) if err != nil { return fmt.Errorf("unable to create new txn ID: %w", err) } @@ -442,7 +445,7 @@ func newMySQLExecutor(tx querier) common.ExecuteQueryFunc { } var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper err := rows.Scan( &nextTuple.ResourceAndRelation.Namespace, &nextTuple.ResourceAndRelation.ObjectId, @@ -497,7 +500,7 @@ type Datastore struct { cancelGc context.CancelFunc gcHasRun atomic.Bool - createTxn string + createTxn sq.InsertBuilder createBaseTxn string *QueryBuilder diff --git a/internal/datastore/mysql/datastore_test.go b/internal/datastore/mysql/datastore_test.go index 23390ba1eb..c9eeeec066 100644 --- a/internal/datastore/mysql/datastore_test.go +++ b/internal/datastore/mysql/datastore_test.go @@ -640,7 +640,7 @@ func TransactionTimestampsTest(t *testing.T, ds datastore.Datastore) { // Transaction timestamp should not be stored in system time zone tx, err := db.BeginTx(ctx, nil) req.NoError(err) - txID, err := ds.(*Datastore).createNewTransaction(ctx, tx) + txID, err := ds.(*Datastore).createNewTransaction(ctx, tx, nil) req.NoError(err) err = tx.Commit() req.NoError(err) diff --git a/internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go b/internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go new file mode 100644 index 0000000000..64e66a83e8 --- /dev/null +++ b/internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go @@ -0,0 +1,18 @@ +package migrations + +import "fmt" + +func addMetadataToTransactionTable(t *tables) string { + return fmt.Sprintf(`ALTER TABLE %s + ADD COLUMN metadata BLOB NULL DEFAULT NULL;`, + t.RelationTupleTransaction(), + ) +} + +func init() { + mustRegisterMigration("add_metadata_to_transaction_table", "add_relationship_counters_table", noNonatomicMigration, + newStatementBatch( + addMetadataToTransactionTable, + ).execute, + ) +} diff --git a/internal/datastore/mysql/query_builder.go b/internal/datastore/mysql/query_builder.go index 3e1b700811..2356dc605a 100644 --- a/internal/datastore/mysql/query_builder.go +++ b/internal/datastore/mysql/query_builder.go @@ -9,8 +9,8 @@ import ( // QueryBuilder captures all parameterizable queries used // by the MySQL datastore implementation type QueryBuilder struct { - GetLastRevision sq.SelectBuilder - GetRevisionRange sq.SelectBuilder + GetLastRevision sq.SelectBuilder + LoadRevisionRange sq.SelectBuilder WriteNamespaceQuery sq.InsertBuilder ReadNamespaceQuery sq.SelectBuilder @@ -43,7 +43,7 @@ func NewQueryBuilder(driver *migrations.MySQLDriver) *QueryBuilder { // transaction builders builder.GetLastRevision = getLastRevision(driver.RelationTupleTransaction()) - builder.GetRevisionRange = getRevisionRange(driver.RelationTupleTransaction()) + builder.LoadRevisionRange = loadRevisionRange(driver.RelationTupleTransaction()) // namespace builders builder.WriteNamespaceQuery = writeNamespace(driver.Namespace()) @@ -99,8 +99,8 @@ func getLastRevision(tableTransaction string) sq.SelectBuilder { return sb.Select("MAX(id)").From(tableTransaction).Limit(1) } -func getRevisionRange(tableTransaction string) sq.SelectBuilder { - return sb.Select("MIN(id)", "MAX(id)").From(tableTransaction) +func loadRevisionRange(tableTransaction string) sq.SelectBuilder { + return sb.Select(colID, colMetadata).From(tableTransaction) } func readCounter(tableRelationshipCounters string) sq.SelectBuilder { diff --git a/internal/datastore/mysql/readwrite.go b/internal/datastore/mysql/readwrite.go index 26c7f4a017..92ca509319 100644 --- a/internal/datastore/mysql/readwrite.go +++ b/internal/datastore/mysql/readwrite.go @@ -49,10 +49,10 @@ type mysqlReadWriteTXN struct { newTxnID uint64 } -// caveatContextWrapper is used to marshall maps into MySQLs JSON data type -type caveatContextWrapper map[string]any +// structpbWrapper is used to marshall maps into MySQLs JSON data type +type structpbWrapper map[string]any -func (cc *caveatContextWrapper) Scan(val any) error { +func (cc *structpbWrapper) Scan(val any) error { v, ok := val.([]byte) if !ok { return fmt.Errorf("unsupported type: %T", v) @@ -60,7 +60,7 @@ func (cc *caveatContextWrapper) Scan(val any) error { return json.Unmarshal(v, &cc) } -func (cc *caveatContextWrapper) Value() (driver.Value, error) { +func (cc *structpbWrapper) Value() (driver.Value, error) { return json.Marshal(&cc) } @@ -220,7 +220,7 @@ func (rwt *mysqlReadWriteTXN) WriteRelationships(ctx context.Context, mutations } var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper tupleIdsToDelete := make([]int64, 0, len(clauses)) for rows.Next() { @@ -281,7 +281,7 @@ func (rwt *mysqlReadWriteTXN) WriteRelationships(ctx context.Context, mutations tpl := mut.Tuple var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper if tpl.Caveat != nil { caveatName = tpl.Caveat.CaveatName caveatContext = tpl.Caveat.Context.AsMap() @@ -497,7 +497,7 @@ func (rwt *mysqlReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.BulkW } var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper if tpl.Caveat != nil { caveatName = tpl.Caveat.CaveatName caveatContext = tpl.Caveat.Context.AsMap() diff --git a/internal/datastore/mysql/revisions.go b/internal/datastore/mysql/revisions.go index 5570e670a9..80b259f77d 100644 --- a/internal/datastore/mysql/revisions.go +++ b/internal/datastore/mysql/revisions.go @@ -74,8 +74,6 @@ func (mds *Datastore) optimizedRevisionFunc(ctx context.Context) (datastore.Revi } func (mds *Datastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { - // implementation deviates slightly from PSQL implementation in order to support - // database seeding in runtime, instead of through migrate command revision, err := mds.loadRevision(ctx) if err != nil { return datastore.NoRevision, err @@ -156,16 +154,26 @@ func (mds *Datastore) checkValidTransaction(ctx context.Context, revisionTx uint return freshEnough.Bool, unknown.Bool, nil } -func (mds *Datastore) createNewTransaction(ctx context.Context, tx *sql.Tx) (newTxnID uint64, err error) { +func (mds *Datastore) createNewTransaction(ctx context.Context, tx *sql.Tx, metadata map[string]any) (newTxnID uint64, err error) { ctx, span := tracer.Start(ctx, "createNewTransaction") defer span.End() - createQuery := mds.createTxn + var wrappedMetadata structpbWrapper + if len(metadata) > 0 { + wrappedMetadata = metadata + } + + createQuery := mds.createTxn.Values(&wrappedMetadata) + if err != nil { + return 0, fmt.Errorf("createNewTransaction: %w", err) + } + + sql, args, err := createQuery.ToSql() if err != nil { return 0, fmt.Errorf("createNewTransaction: %w", err) } - result, err := tx.ExecContext(ctx, createQuery) + result, err := tx.ExecContext(ctx, sql, args...) if err != nil { return 0, fmt.Errorf("createNewTransaction: %w", err) } diff --git a/internal/datastore/mysql/watch.go b/internal/datastore/mysql/watch.go index f3c2dffc84..14e356796f 100644 --- a/internal/datastore/mysql/watch.go +++ b/internal/datastore/mysql/watch.go @@ -125,7 +125,52 @@ func (mds *Datastore) loadChanges( return } - sql, args, err := mds.QueryChangedQuery.Where(sq.Or{ + stagedChanges := common.NewChanges(revisions.TransactionIDKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) + + // Load any metadata for the revision range. + sql, args, err := mds.LoadRevisionRange.Where(sq.Or{ + sq.And{ + sq.Gt{colID: afterRevision}, + sq.LtOrEq{colID: newRevision}, + }, + }).ToSql() + if err != nil { + return + } + + rows, err := mds.db.QueryContext(ctx, sql, args...) + if err != nil { + if errors.Is(err, context.Canceled) { + err = datastore.NewWatchCanceledErr() + } + return + } + defer common.LogOnError(ctx, rows.Close) + + for rows.Next() { + var txnID uint64 + var metadata structpbWrapper + err = rows.Scan( + &txnID, + &metadata, + ) + if err != nil { + return nil, 0, err + } + + if len(metadata) > 0 { + if err := stagedChanges.SetRevisionMetadata(ctx, revisions.NewForTransactionID(txnID), metadata); err != nil { + return nil, 0, err + } + } + } + rows.Close() + if err = rows.Err(); err != nil { + return + } + + // Load the changes relationships for the revision range. + sql, args, err = mds.QueryChangedQuery.Where(sq.Or{ sq.And{ sq.Gt{colCreatedTxn: afterRevision}, sq.LtOrEq{colCreatedTxn: newRevision}, @@ -139,7 +184,7 @@ func (mds *Datastore) loadChanges( return } - rows, err := mds.db.QueryContext(ctx, sql, args...) + rows, err = mds.db.QueryContext(ctx, sql, args...) if err != nil { if errors.Is(err, context.Canceled) { err = datastore.NewWatchCanceledErr() @@ -148,8 +193,6 @@ func (mds *Datastore) loadChanges( } defer common.LogOnError(ctx, rows.Close) - stagedChanges := common.NewChanges(revisions.TransactionIDKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) - for rows.Next() { nextTuple := &core.RelationTuple{ ResourceAndRelation: &core.ObjectAndRelation{}, @@ -159,7 +202,7 @@ func (mds *Datastore) loadChanges( var createdTxn uint64 var deletedTxn uint64 var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper err = rows.Scan( &nextTuple.ResourceAndRelation.Namespace, &nextTuple.ResourceAndRelation.ObjectId, @@ -196,6 +239,6 @@ func (mds *Datastore) loadChanges( return } - changes = stagedChanges.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) + changes, err = stagedChanges.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) return } diff --git a/internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go b/internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go new file mode 100644 index 0000000000..162ba71f18 --- /dev/null +++ b/internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go @@ -0,0 +1,23 @@ +package migrations + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" +) + +const addMetadataToTransactionTable = `ALTER TABLE relation_tuple_transaction ADD COLUMN IF NOT EXISTS metadata JSONB NOT NULL DEFAULT '{}'` + +func init() { + if err := DatabaseMigrations.Register("add-metadata-to-transaction-table", "create-relationships-counters-table", + func(ctx context.Context, conn *pgx.Conn) error { + if _, err := conn.Exec(ctx, addMetadataToTransactionTable); err != nil { + return fmt.Errorf("failed to add metadata to transaction table: %w", err) + } + return nil + }, + noTxMigration); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 717b67f864..c091ae9fb0 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -49,6 +49,7 @@ const ( colXID = "xid" colTimestamp = "timestamp" + colMetadata = "metadata" colNamespace = "namespace" colConfig = "serialized_config" colCreatedXid = "created_xid" @@ -102,12 +103,7 @@ var ( OrderByClause(fmt.Sprintf("%s DESC", colXID)). Limit(1) - createTxn = fmt.Sprintf( - "INSERT INTO %s DEFAULT VALUES RETURNING %s, %s", - tableTransaction, - colXID, - colSnapshot, - ) + createTxn = psql.Insert(tableTransaction).Columns(colMetadata) getNow = psql.Select("NOW()") @@ -435,7 +431,12 @@ func (pgd *pgDatastore) ReadWriteTx( var newSnapshot pgSnapshot err = wrapError(pgx.BeginTxFunc(ctx, pgd.writePool, pgx.TxOptions{IsoLevel: pgx.Serializable}, func(tx pgx.Tx) error { var err error - newXID, newSnapshot, err = createNewTransaction(ctx, tx) + var metadata map[string]any + if config.Metadata != nil { + metadata = config.Metadata.AsMap() + } + + newXID, newSnapshot, err = createNewTransaction(ctx, tx, metadata) if err != nil { return err } diff --git a/internal/datastore/postgres/postgres_shared_test.go b/internal/datastore/postgres/postgres_shared_test.go index fd7a7fb3c3..33dd5c8f06 100644 --- a/internal/datastore/postgres/postgres_shared_test.go +++ b/internal/datastore/postgres/postgres_shared_test.go @@ -438,7 +438,7 @@ func TransactionTimestampsTest(t *testing.T, ds datastore.Datastore) { tx, err := pgd.writePool.Begin(ctx) require.NoError(err) - txXID, _, err := createNewTransaction(ctx, tx) + txXID, _, err := createNewTransaction(ctx, tx, nil) require.NoError(err) err = tx.Commit(ctx) diff --git a/internal/datastore/postgres/revisions.go b/internal/datastore/postgres/revisions.go index 5cc0f1cf28..1b6eb782d7 100644 --- a/internal/datastore/postgres/revisions.go +++ b/internal/datastore/postgres/revisions.go @@ -235,11 +235,22 @@ func parseRevisionDecimal(revisionStr string) (datastore.Revision, error) { }}, nil } -func createNewTransaction(ctx context.Context, tx pgx.Tx) (newXID xid8, newSnapshot pgSnapshot, err error) { +var emptyMetadata = map[string]any{} + +func createNewTransaction(ctx context.Context, tx pgx.Tx, metadata map[string]any) (newXID xid8, newSnapshot pgSnapshot, err error) { ctx, span := tracer.Start(ctx, "createNewTransaction") defer span.End() - cterr := tx.QueryRow(ctx, createTxn).Scan(&newXID, &newSnapshot) + if metadata == nil { + metadata = emptyMetadata + } + + sql, args, err := createTxn.Values(metadata).Suffix("RETURNING xid, snapshot").ToSql() + if err != nil { + return + } + + cterr := tx.QueryRow(ctx, sql, args...).Scan(&newXID, &newSnapshot) if cterr != nil { err = fmt.Errorf("error when trying to create a new transaction: %w", cterr) } @@ -250,6 +261,7 @@ type postgresRevision struct { snapshot pgSnapshot optionalTxID xid8 optionalNanosTimestamp uint64 + optionalMetadata map[string]any } func (pr postgresRevision) Equal(rhsRaw datastore.Revision) bool { diff --git a/internal/datastore/postgres/watch.go b/internal/datastore/postgres/watch.go index 2c7a2d3c97..029ddb3253 100644 --- a/internal/datastore/postgres/watch.go +++ b/internal/datastore/postgres/watch.go @@ -25,10 +25,10 @@ var ( // xid8 is one of the last ~2 billion transaction IDs generated. We should be garbage // collecting these transactions long before we get to that point. newRevisionsQuery = fmt.Sprintf(` - SELECT %[1]s, %[2]s, %[3]s FROM %[4]s + SELECT %[1]s, %[2]s, %[3]s, %[4]s FROM %[5]s WHERE %[1]s >= pg_snapshot_xmax($1) OR ( %[1]s >= pg_snapshot_xmin($1) AND NOT pg_visible_in_snapshot(%[1]s, $1) - ) ORDER BY pg_xact_commit_timestamp(%[1]s::xid), %[1]s;`, colXID, colSnapshot, colTimestamp, tableTransaction) + ) ORDER BY pg_xact_commit_timestamp(%[1]s::xid), %[1]s;`, colXID, colSnapshot, colMetadata, colTimestamp, tableTransaction) queryChangedTuples = psql.Select( colNamespace, @@ -199,8 +199,9 @@ func (pgd *pgDatastore) getNewRevisions(ctx context.Context, afterTX postgresRev for rows.Next() { var nextXID xid8 var nextSnapshot pgSnapshot + var metadata map[string]any var timestamp time.Time - if err := rows.Scan(&nextXID, &nextSnapshot, ×tamp); err != nil { + if err := rows.Scan(&nextXID, &nextSnapshot, &metadata, ×tamp); err != nil { return fmt.Errorf("unable to decode new revision: %w", err) } @@ -208,6 +209,7 @@ func (pgd *pgDatastore) getNewRevisions(ctx context.Context, afterTX postgresRev snapshot: nextSnapshot.markComplete(nextXID.Uint64), optionalTxID: nextXID, optionalNanosTimestamp: uint64(timestamp.UnixNano()), + optionalMetadata: metadata, }) } if rows.Err() != nil { @@ -227,6 +229,8 @@ func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []postgresRev filter := make(map[uint64]int, len(revisions)) txidToRevision := make(map[uint64]postgresRevision, len(revisions)) + tracked := common.NewChanges(revisionKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) + for i, rev := range revisions { if rev.optionalTxID.Uint64 < xmin { xmin = rev.optionalTxID.Uint64 @@ -236,9 +240,13 @@ func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []postgresRev } filter[rev.optionalTxID.Uint64] = i txidToRevision[rev.optionalTxID.Uint64] = rev - } - tracked := common.NewChanges(revisionKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) + if len(rev.optionalMetadata) > 0 { + if err := tracked.SetRevisionMetadata(ctx, rev, rev.optionalMetadata); err != nil { + return nil, err + } + } + } // Load relationship changes. if options.Content&datastore.WatchRelationships == datastore.WatchRelationships { @@ -265,10 +273,9 @@ func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []postgresRev } // Reconcile the changes. - reconciledChanges := tracked.AsRevisionChanges(func(lhs, rhs uint64) bool { + return tracked.AsRevisionChanges(func(lhs, rhs uint64) bool { return filter[lhs] < filter[rhs] }) - return reconciledChanges, nil } func (pgd *pgDatastore) loadRelationshipChanges(ctx context.Context, xmin uint64, xmax uint64, txidToRevision map[uint64]postgresRevision, filter map[uint64]int, tracked *common.Changes[postgresRevision, uint64]) error { diff --git a/internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go b/internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go new file mode 100644 index 0000000000..5c3165ed5b --- /dev/null +++ b/internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go @@ -0,0 +1,34 @@ +package migrations + +import ( + "context" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" +) + +const ( + addTransactionMetadataTable = `CREATE TABLE transaction_metadata ( + transaction_tag STRING(36) NOT NULL, + created_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP()), + metadata JSON + ) PRIMARY KEY (transaction_tag), + ROW DELETION POLICY (OLDER_THAN(created_at, INTERVAL 2 DAY)) + ` +) + +func init() { + if err := SpannerMigrations.Register("add-transaction-metadata-table", "add-relationship-counter-table", func(ctx context.Context, w Wrapper) error { + updateOp, err := w.adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ + Database: w.client.DatabaseName(), + Statements: []string{ + addTransactionMetadataTable, + }, + }) + if err != nil { + return err + } + return updateOp.Wait(ctx) + }, nil); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/spanner/schema.go b/internal/datastore/spanner/schema.go index 364fc549fc..3cde345a2a 100644 --- a/internal/datastore/spanner/schema.go +++ b/internal/datastore/spanner/schema.go @@ -30,6 +30,10 @@ const ( colCounterSerializedFilter = "serialized_filter" colCounterCurrentCount = "current_count" colCounterUpdatedAtTimestamp = "updated_at_timestamp" + + tableTransactionMetadata = "transaction_metadata" + colTransactionTag = "transaction_tag" + colMetadata = "metadata" ) var allRelationshipCols = []string{ diff --git a/internal/datastore/spanner/spanner.go b/internal/datastore/spanner/spanner.go index d9bfcbf360..6392551189 100644 --- a/internal/datastore/spanner/spanner.go +++ b/internal/datastore/spanner/spanner.go @@ -12,6 +12,7 @@ import ( "cloud.google.com/go/spanner" ocprom "contrib.go.opencensus.io/exporter/prometheus" sq "github.com/Masterminds/squirrel" + "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "go.opencensus.io/plugin/ocgrpc" "go.opencensus.io/stats/view" @@ -239,18 +240,50 @@ func (sd *spannerDatastore) SnapshotReader(revisionRaw datastore.Revision) datas return spannerReader{executor, txSource, sd.filterMaximumIDCount} } +func (sd *spannerDatastore) readTransactionMetadata(ctx context.Context, transactionTag string) (map[string]any, error) { + row, err := sd.client.Single().ReadRow(ctx, tableTransactionMetadata, spanner.Key{transactionTag}, []string{colMetadata}) + if err != nil { + if spanner.ErrCode(err) == codes.NotFound { + return map[string]any{}, nil + } + + return nil, err + } + + var metadata map[string]any + if err := row.Columns(&metadata); err != nil { + return nil, err + } + + return metadata, nil +} + func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUserFunc, opts ...options.RWTOptionsOption) (datastore.Revision, error) { config := options.NewRWTOptionsWithOptions(opts...) ctx, span := tracer.Start(ctx, "ReadWriteTx") defer span.End() + transactionTag := "sdb-rwt-" + uuid.NewString() + ctx, cancel := context.WithCancel(ctx) - ts, err := sd.client.ReadWriteTransaction(ctx, func(ctx context.Context, spannerRWT *spanner.ReadWriteTransaction) error { + rs, err := sd.client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, spannerRWT *spanner.ReadWriteTransaction) error { txSource := func() readTX { return &traceableRTX{delegate: spannerRWT} } + if config.Metadata != nil { + // Insert the metadata into the transaction metadata table. + mutation := spanner.Insert(tableTransactionMetadata, + []string{colTransactionTag, colMetadata}, + []any{transactionTag, config.Metadata.AsMap()}, + ) + + if err := spannerRWT.BufferWrite([]*spanner.Mutation{mutation}); err != nil { + return fmt.Errorf("unable to write metadata: %w", err) + } + } + executor := common.QueryExecutor{Executor: queryExecutor(txSource)} rwt := spannerReadWriteTXN{ spannerReader{executor, txSource, sd.filterMaximumIDCount}, @@ -270,7 +303,7 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser } return nil - }) + }, spanner.TransactionOptions{TransactionTag: transactionTag}) if err != nil { if cerr := convertToWriteConstraintError(err); cerr != nil { return datastore.NoRevision, cerr @@ -278,7 +311,7 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser return datastore.NoRevision, err } - return revisions.NewForTime(ts), nil + return revisions.NewForTime(rs.CommitTs), nil } func (sd *spannerDatastore) ReadyState(ctx context.Context) (datastore.ReadyState, error) { diff --git a/internal/datastore/spanner/watch.go b/internal/datastore/spanner/watch.go index edd4eab37e..5a533e4b8a 100644 --- a/internal/datastore/spanner/watch.go +++ b/internal/datastore/spanner/watch.go @@ -156,6 +156,23 @@ func (sd *spannerDatastore) watch( } defer reader.Close() + metadataForTransactionTag := map[string]map[string]any{} + + addMetadataForTransactionTag := func(ctx context.Context, tracked *common.Changes[revisions.TimestampRevision, int64], revision revisions.TimestampRevision, transactionTag string) error { + if metadata, ok := metadataForTransactionTag[transactionTag]; ok { + return tracked.SetRevisionMetadata(ctx, revision, metadata) + } + + // Otherwise, load the metadata from the transactions metadata table. + transactionMetadata, err := sd.readTransactionMetadata(ctx, transactionTag) + if err != nil { + return err + } + + metadataForTransactionTag[transactionTag] = transactionMetadata + return tracked.SetRevisionMetadata(ctx, revision, transactionMetadata) + } + err = reader.Read(ctx, func(result *changestreams.ReadResult) error { // See: https://cloud.google.com/spanner/docs/change-streams/details for _, record := range result.ChangeRecords { @@ -165,6 +182,12 @@ func (sd *spannerDatastore) watch( changeRevision := revisions.NewForTime(dcr.CommitTimestamp) modType := dcr.ModType // options are INSERT, UPDATE, DELETE + if len(dcr.TransactionTag) > 0 { + if err := addMetadataForTransactionTag(ctx, tracked, changeRevision, dcr.TransactionTag); err != nil { + return err + } + } + for _, mod := range dcr.Mods { primaryKeyColumnValues, ok := mod.Keys.Value.(map[string]any) if !ok { @@ -312,7 +335,12 @@ func (sd *spannerDatastore) watch( } if !tracked.IsEmpty() { - for _, revChange := range tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) { + changes, err := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + if err != nil { + return err + } + + for _, revChange := range changes { revChange := revChange if !sendChange(&revChange) { return datastore.NewWatchDisconnectedErr() diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index c31d2592a6..99804ce71f 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -9,6 +9,7 @@ import ( "time" "github.com/rs/zerolog" + "google.golang.org/protobuf/types/known/structpb" "github.com/authzed/spicedb/pkg/tuple" @@ -61,6 +62,9 @@ type RevisionChanges struct { // up until and including the Revision and that no additional schema updates can // have occurred before this point. IsCheckpoint bool + + // Metadata is the metadata associated with the revision, if any. + Metadata *structpb.Struct } func (rc *RevisionChanges) MarshalZerologObject(e *zerolog.Event) { diff --git a/pkg/datastore/options/options.go b/pkg/datastore/options/options.go index 34daf2e58a..1d13f2ca11 100644 --- a/pkg/datastore/options/options.go +++ b/pkg/datastore/options/options.go @@ -1,6 +1,8 @@ package options import ( + "google.golang.org/protobuf/types/known/structpb" + core "github.com/authzed/spicedb/pkg/proto/core/v1" ) @@ -51,7 +53,8 @@ type ResourceRelation struct { // RWTOptions are options that can affect the way a read-write transaction is // executed. type RWTOptions struct { - DisableRetries bool `debugmap:"visible"` + DisableRetries bool `debugmap:"visible"` + Metadata *structpb.Struct `debugmap:"visible"` } // DeleteOptions are the options that can affect the results of a delete relationships diff --git a/pkg/datastore/options/zz_generated.query_options.go b/pkg/datastore/options/zz_generated.query_options.go index f87f310b7c..f761b06b66 100644 --- a/pkg/datastore/options/zz_generated.query_options.go +++ b/pkg/datastore/options/zz_generated.query_options.go @@ -4,6 +4,7 @@ package options import ( defaults "github.com/creasty/defaults" helpers "github.com/ecordell/optgen/helpers" + structpb "google.golang.org/protobuf/types/known/structpb" ) type QueryOptionsOption func(q *QueryOptions) @@ -192,6 +193,7 @@ func NewRWTOptionsWithOptionsAndDefaults(opts ...RWTOptionsOption) *RWTOptions { func (r *RWTOptions) ToOption() RWTOptionsOption { return func(to *RWTOptions) { to.DisableRetries = r.DisableRetries + to.Metadata = r.Metadata } } @@ -199,6 +201,7 @@ func (r *RWTOptions) ToOption() RWTOptionsOption { func (r RWTOptions) DebugMap() map[string]any { debugMap := map[string]any{} debugMap["DisableRetries"] = helpers.DebugValue(r.DisableRetries, false) + debugMap["Metadata"] = helpers.DebugValue(r.Metadata, false) return debugMap } @@ -224,3 +227,10 @@ func WithDisableRetries(disableRetries bool) RWTOptionsOption { r.DisableRetries = disableRetries } } + +// WithMetadata returns an option that can set Metadata on a RWTOptions +func WithMetadata(metadata *structpb.Struct) RWTOptionsOption { + return func(r *RWTOptions) { + r.Metadata = metadata + } +} diff --git a/pkg/datastore/test/datastore.go b/pkg/datastore/test/datastore.go index 05f034c6ba..cc8a0999f5 100644 --- a/pkg/datastore/test/datastore.go +++ b/pkg/datastore/test/datastore.go @@ -162,6 +162,7 @@ func AllWithExceptions(t *testing.T, tester DatastoreTester, except Categories) t.Run("TestCaveatedRelationshipWatch", func(t *testing.T) { CaveatedRelationshipWatchTest(t, tester) }) t.Run("TestWatchWithTouch", func(t *testing.T) { WatchWithTouchTest(t, tester) }) t.Run("TestWatchWithDelete", func(t *testing.T) { WatchWithDeleteTest(t, tester) }) + t.Run("TestWatchWithMetadata", func(t *testing.T) { WatchWithMetadataTest(t, tester) }) } if !except.Watch() && !except.WatchSchema() { diff --git a/pkg/datastore/test/watch.go b/pkg/datastore/test/watch.go index 0191796bbc..a91ec8eefc 100644 --- a/pkg/datastore/test/watch.go +++ b/pkg/datastore/test/watch.go @@ -13,9 +13,11 @@ import ( "github.com/scylladb/go-set/strset" "github.com/stretchr/testify/require" "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/structpb" "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" @@ -176,6 +178,50 @@ func VerifyUpdates( require.False(expectDisconnect, "all changes verified without expected disconnect") } +func VerifyUpdatesWithMetadata( + require *require.Assertions, + testUpdates []updateWithMetadata, + changes <-chan *datastore.RevisionChanges, + errchan <-chan error, + expectDisconnect bool, +) { + for _, expected := range testUpdates { + changeWait := time.NewTimer(waitForChangesTimeout) + select { + case change, ok := <-changes: + if !ok { + require.True(expectDisconnect, "unexpected disconnect") + errWait := time.NewTimer(waitForChangesTimeout) + select { + case err := <-errchan: + require.True(errors.As(err, &datastore.ErrWatchDisconnected{})) + return + case <-errWait.C: + require.Fail("Timed out waiting for ErrWatchDisconnected") + } + return + } + + expectedChangeSet := setOfChanges(expected.updates) + actualChangeSet := setOfChanges(change.RelationshipChanges) + + missingExpected := strset.Difference(expectedChangeSet, actualChangeSet) + unexpected := strset.Difference(actualChangeSet, expectedChangeSet) + + require.True(missingExpected.IsEmpty(), "expected changes missing: %s", missingExpected) + require.True(unexpected.IsEmpty(), "unexpected changes: %s", unexpected) + + require.Equal(expected.metadata, change.Metadata.AsMap(), "metadata mismatch") + + time.Sleep(1 * time.Millisecond) + case <-changeWait.C: + require.Fail("Timed out", "waited for changes: %s", expected) + } + } + + require.False(expectDisconnect, "all changes verified without expected disconnect") +} + func setOfChanges(changes []*core.RelationTupleUpdate) *strset.Set { changeSet := strset.NewWithSize(len(changes)) for _, change := range changes { @@ -337,6 +383,50 @@ func WatchWithTouchTest(t *testing.T, tester DatastoreTester) { ) } +type updateWithMetadata struct { + updates []*core.RelationTupleUpdate + metadata map[string]any +} + +func WatchWithMetadataTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 16) + require.NoError(err) + + setupDatastore(ds, require) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lowestRevision, err := ds.HeadRevision(ctx) + require.NoError(err) + + changes, errchan := ds.Watch(ctx, lowestRevision, datastore.WatchJustRelationships()) + require.Zero(len(errchan)) + + metadata, err := structpb.NewStruct(map[string]any{"somekey": "somevalue"}) + require.NoError(err) + + _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.WriteRelationships(ctx, []*core.RelationTupleUpdate{ + tuple.Create(tuple.MustParse("document:firstdoc#viewer@user:tom")), + }) + }, options.WithMetadata(metadata)) + require.NoError(err) + + VerifyUpdatesWithMetadata(require, []updateWithMetadata{ + { + updates: []*core.RelationTupleUpdate{tuple.Touch(tuple.Parse("document:firstdoc#viewer@user:tom"))}, + metadata: map[string]any{"somekey": "somevalue"}, + }, + }, + changes, + errchan, + false, + ) +} + func WatchWithDeleteTest(t *testing.T, tester DatastoreTester) { require := require.New(t)