Skip to content

Commit

Permalink
Implement support for metadata associated with read-write transactions
Browse files Browse the repository at this point in the history
This will allow callers of APIs such as WriteRelationships and DeleteRelationships to assign
metadata to the transaction that will be mirrored back out in the Watch API, to provide a means
for correlating updates
  • Loading branch information
josephschorr committed Sep 15, 2024
1 parent e629448 commit e7f5ad8
Show file tree
Hide file tree
Showing 27 changed files with 588 additions and 71 deletions.
51 changes: 43 additions & 8 deletions internal/datastore/common/changes.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"sort"

"golang.org/x/exp/maps"
"google.golang.org/protobuf/types/known/structpb"

log "github.com/authzed/spicedb/internal/logging"
"github.com/authzed/spicedb/pkg/datastore"
Expand Down Expand Up @@ -35,6 +36,7 @@ type changeRecord[R datastore.Revision] struct {
definitionsChanged map[string]datastore.SchemaDefinition
namespacesDeleted map[string]struct{}
caveatsDeleted map[string]struct{}
metadata map[string]any
}

// NewChanges creates a new Changes object for change tracking and de-duplication.
Expand Down Expand Up @@ -125,6 +127,25 @@ func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error {
return nil
}

// SetRevisionMetadata sets the metadata for the given revision.
func (ch *Changes[R, K]) SetRevisionMetadata(ctx context.Context, rev R, metadata map[string]any) error {
if len(metadata) == 0 {
return nil
}

record, err := ch.recordForRevision(rev)
if err != nil {
return err
}

if len(record.metadata) > 0 {
return spiceerrors.MustBugf("metadata already set for revision")
}

maps.Copy(record.metadata, metadata)
return nil
}

func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) {
k := ch.keyFunc(rev)
revisionChanges, ok := ch.records[k]
Expand All @@ -136,6 +157,7 @@ func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) {
make(map[string]datastore.SchemaDefinition),
make(map[string]struct{}),
make(map[string]struct{}),
make(map[string]any),
}
ch.records[k] = revisionChanges
}
Expand Down Expand Up @@ -241,21 +263,25 @@ func (ch *Changes[R, K]) AddChangedDefinition(

// AsRevisionChanges returns the list of changes processed so far as a datastore watch
// compatible, ordered, changelist.
func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) []datastore.RevisionChanges {
func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) ([]datastore.RevisionChanges, error) {
return ch.revisionChanges(lessThanFunc, *new(R), false)
}

// FilterAndRemoveRevisionChanges filters a list of changes processed up to the bound revision from the changes list, removing them
// and returning the filtered changes.
func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) []datastore.RevisionChanges {
changes := ch.revisionChanges(lessThanFunc, boundRev, true)
func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) ([]datastore.RevisionChanges, error) {
changes, err := ch.revisionChanges(lessThanFunc, boundRev, true)
if err != nil {
return nil, err
}

ch.removeAllChangesBefore(boundRev)
return changes
return changes, nil
}

func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) []datastore.RevisionChanges {
func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) ([]datastore.RevisionChanges, error) {
if ch.IsEmpty() {
return nil
return nil, nil
}

revisionsWithChanges := make([]K, 0, len(ch.records))
Expand All @@ -266,7 +292,7 @@ func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, bou
}

if len(revisionsWithChanges) == 0 {
return nil
return nil, nil
}

sort.Slice(revisionsWithChanges, func(i int, j int) bool {
Expand All @@ -292,9 +318,18 @@ func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, bou
changes[i].ChangedDefinitions = maps.Values(revisionChangeRecord.definitionsChanged)
changes[i].DeletedNamespaces = maps.Keys(revisionChangeRecord.namespacesDeleted)
changes[i].DeletedCaveats = maps.Keys(revisionChangeRecord.caveatsDeleted)

if len(revisionChangeRecord.metadata) > 0 {
metadata, err := structpb.NewStruct(revisionChangeRecord.metadata)
if err != nil {
return nil, spiceerrors.MustBugf("failed to convert metadata to structpb: %v", err)
}

changes[i].Metadata = metadata
}
}

return changes
return changes, nil
}

func (ch *Changes[R, K]) removeAllChangesBefore(boundRev R) {
Expand Down
40 changes: 33 additions & 7 deletions internal/datastore/common/changes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,12 @@ func TestChanges(t *testing.T) {
}
}

actual, err := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc)
require.NoError(err)

require.Equal(
canonicalize(tc.expected),
canonicalize(ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc)),
canonicalize(actual),
)
})
}
Expand All @@ -347,6 +350,23 @@ func TestFilteredSchemaChanges(t *testing.T) {
require.True(t, ch.IsEmpty())
}

func TestSetMetadata(t *testing.T) {
ctx := context.Background()
ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0)
require.True(t, ch.IsEmpty())

err := ch.SetRevisionMetadata(ctx, rev1, map[string]any{"foo": "bar"})
require.NoError(t, err)
require.False(t, ch.IsEmpty())

results, err := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev2)
require.NoError(t, err)
require.Equal(t, 1, len(results))
require.True(t, ch.IsEmpty())

require.Equal(t, map[string]any{"foo": "bar"}, results[0].Metadata.AsMap())
}

func TestFilteredRelationshipChanges(t *testing.T) {
ctx := context.Background()
ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships, 0)
Expand Down Expand Up @@ -374,7 +394,8 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) {

require.False(t, ch.IsEmpty())

results := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev3)
results, err := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev3)
require.NoError(t, err)
require.Equal(t, 2, len(results))
require.False(t, ch.IsEmpty())

Expand All @@ -393,8 +414,9 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) {
},
}, results)

remaining := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc)
remaining, err := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc)
require.Equal(t, 1, len(remaining))
require.NoError(t, err)

require.Equal(t, []datastore.RevisionChanges{
{
Expand All @@ -405,11 +427,13 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) {
},
}, remaining)

results = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillion)
results, err = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillion)
require.NoError(t, err)
require.Equal(t, 1, len(results))
require.True(t, ch.IsEmpty())

results = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillionOne)
results, err = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillionOne)
require.NoError(t, err)
require.Equal(t, 0, len(results))
require.True(t, ch.IsEmpty())
}
Expand All @@ -432,7 +456,8 @@ func TestHLCOrdering(t *testing.T) {
err = ch.AddRelationshipChange(ctx, rev0, tuple.MustParse("document:foo#viewer@user:tom"), core.RelationTupleUpdate_TOUCH)
require.NoError(t, err)

remaining := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc)
remaining, err := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc)
require.NoError(t, err)
require.Equal(t, 2, len(remaining))

require.Equal(t, []datastore.RevisionChanges{
Expand Down Expand Up @@ -475,7 +500,8 @@ func TestHLCSameRevision(t *testing.T) {
err = ch.AddRelationshipChange(ctx, rev0again, tuple.MustParse("document:foo#viewer@user:sarah"), core.RelationTupleUpdate_TOUCH)
require.NoError(t, err)

remaining := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc)
remaining, err := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc)
require.NoError(t, err)
require.Equal(t, 1, len(remaining))

expected := []*core.RelationTupleUpdate{
Expand Down
22 changes: 22 additions & 0 deletions internal/datastore/crdb/crdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ const (
tableTransactions = "transactions"
tableCaveat = "caveat"
tableRelationshipCounter = "relationship_counter"
tableTransactionMetadata = "transaction_metadata"

colNamespace = "namespace"
colConfig = "serialized_config"
Expand All @@ -77,6 +78,8 @@ const (
colCounterSerializedFilter = "serialized_filter"
colCounterCurrentCount = "current_count"
colCounterUpdatedAt = "updated_at_timestamp"
colExpiresAt = "expires_at"
colMetadata = "metadata"

errUnableToInstantiate = "unable to instantiate datastore"
errRevision = "unable to find revision: %w"
Expand Down Expand Up @@ -199,6 +202,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas
analyzeBeforeStatistics: config.analyzeBeforeStatistics,
filterMaximumIDCount: config.filterMaximumIDCount,
supportsIntegrity: config.withIntegrity,
gcWindow: config.gcWindow,
}
ds.RemoteClockRevisions.SetNowFunc(ds.headRevisionInternal)

Expand Down Expand Up @@ -281,6 +285,7 @@ type crdbDatastore struct {
writeOverlapKeyer overlapKeyer
overlapKeyInit func(ctx context.Context) keySet
analyzeBeforeStatistics bool
gcWindow time.Duration

beginChangefeedQuery string
transactionNowQuery string
Expand Down Expand Up @@ -324,6 +329,23 @@ func (cds *crdbDatastore) ReadWriteTx(
Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(querier, cds.supportsIntegrity),
}

// If metadata is to be attached, write that row now.
if config.Metadata != nil {
expiresAt := time.Now().Add(cds.gcWindow).Add(1 * time.Minute)
insertTransactionMetadata := psql.Insert(tableTransactionMetadata).
Columns(colExpiresAt, colMetadata).
Values(expiresAt, config.Metadata.AsMap())

sql, args, err := insertTransactionMetadata.ToSql()
if err != nil {
return fmt.Errorf("error building metadata insert: %w", err)
}

if _, err := tx.Exec(ctx, sql, args...); err != nil {
return fmt.Errorf("error writing metadata: %w", err)
}
}

rwt := &crdbReadWriteTXN{
&crdbReader{
querier,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package migrations

import (
"context"
"strings"

"github.com/jackc/pgx/v5"
)

const (
// ttl_expiration_expression support was added in CRDB v22.2, but the E2E tests
// use v21.2.
addTransactionMetadataTableQueryWithBasicTTL = `
CREATE TABLE transaction_metadata (
key UUID PRIMARY KEY DEFAULT gen_random_uuid(),
expires_at TIMESTAMPTZ,
metadata JSONB
) WITH (ttl_expire_after = '1d');
`

addTransactionMetadataTableQuery = `
CREATE TABLE transaction_metadata (
key UUID PRIMARY KEY DEFAULT gen_random_uuid(),
expires_at TIMESTAMPTZ,
metadata JSONB
) WITH (ttl_expiration_expression = 'expires_at', ttl_job_cron = '@daily');
`

// See: https://www.cockroachlabs.com/docs/stable/changefeed-messages#prevent-changefeeds-from-emitting-row-level-ttl-deletes
// for why we set ttl_disable_changefeed_replication = 'true'. This isn't stricly necessary as the Watch API will ignore the
// deletions of these metadata rows, but no reason to even have it in the changefeed.
// NOTE: This only applies on CRDB v24 and later.
addTransactionMetadataTableQueryWithTTLIgnore = `
CREATE TABLE transaction_metadata (
key UUID PRIMARY KEY DEFAULT gen_random_uuid(),
expires_at TIMESTAMPTZ,
metadata JSONB
) WITH (ttl_expiration_expression = 'expires_at', ttl_job_cron = '@daily', ttl_disable_changefeed_replication = 'true');
`
)

func init() {
err := CRDBMigrations.Register("add-transaction-metadata-table", "add-integrity-relationtuple-table", addTransactionMetadataTable, noAtomicMigration)
if err != nil {
panic("failed to register migration: " + err.Error())
}
}

func addTransactionMetadataTable(ctx context.Context, conn *pgx.Conn) error {
row := conn.QueryRow(ctx, "select version()")
var version string
if err := row.Scan(&version); err != nil {
return err
}

if strings.Contains(version, "v22.1") {
if _, err := conn.Exec(ctx, addTransactionMetadataTableQueryWithBasicTTL); err != nil {
return err
}
return nil
}

if strings.Contains(version, "v24.") {
if _, err := conn.Exec(ctx, addTransactionMetadataTableQueryWithTTLIgnore); err != nil {
return err
}
return nil
}

// CRDB v23 and v22.2.*
if _, err := conn.Exec(ctx, addTransactionMetadataTableQuery); err != nil {
return err
}
return nil
}
31 changes: 29 additions & 2 deletions internal/datastore/crdb/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ type changeDetails struct {
IntegrityKeyID *string `json:"integrity_key_id"`
IntegrityHashAsHex *string `json:"integrity_hash"`
TimestampAsString *string `json:"timestamp"`

Metadata map[string]any `json:"metadata"`
}
}

Expand Down Expand Up @@ -110,7 +112,8 @@ func (cds *crdbDatastore) watch(
}
defer func() { _ = conn.Close(ctx) }()

tableNames := make([]string, 0, 3)
tableNames := make([]string, 0, 4)
tableNames = append(tableNames, tableTransactionMetadata)
if opts.Content&datastore.WatchRelationships == datastore.WatchRelationships {
tableNames = append(tableNames, cds.tableTupleName())
}
Expand Down Expand Up @@ -217,7 +220,13 @@ func (cds *crdbDatastore) watch(
return
}

for _, revChange := range tracked.FilterAndRemoveRevisionChanges(revisions.HLCKeyLessThanFunc, rev) {
filtered, err := tracked.FilterAndRemoveRevisionChanges(revisions.HLCKeyLessThanFunc, rev)
if err != nil {
sendError(err)
return
}

for _, revChange := range filtered {
revChange := revChange
if !sendChange(&revChange) {
return
Expand Down Expand Up @@ -393,6 +402,24 @@ func (cds *crdbDatastore) watch(
return
}
}

case tableTransactionMetadata:
if details.After != nil {
rev, err := revisions.HLCRevisionFromString(details.Updated)
if err != nil {
sendError(fmt.Errorf("malformed update timestamp: %w", err))
return
}

if err := tracked.SetRevisionMetadata(ctx, rev, details.After.Metadata); err != nil {
sendError(err)
return
}
}

default:
sendError(spiceerrors.MustBugf("unexpected table name in changefeed: %s", tableName))
return
}
}

Expand Down
Loading

0 comments on commit e7f5ad8

Please sign in to comment.