Skip to content

Commit

Permalink
Add an additional mode to replica support that uses a strict read mode
Browse files Browse the repository at this point in the history
This will allow replicas behind load balancers to be supported (just in Postgres for now)
  • Loading branch information
josephschorr committed Jun 26, 2024
1 parent 35cfc23 commit 3d4516e
Show file tree
Hide file tree
Showing 10 changed files with 448 additions and 58 deletions.
10 changes: 10 additions & 0 deletions internal/datastore/common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,13 @@ func RedactAndLogSensitiveConnString(ctx context.Context, baseErr string, err er
log.Ctx(ctx).Trace().Msg(baseErr + ": " + filtered)
return fmt.Errorf("%s. To view details of this error (that may contain sensitive information), please run with --log-level=trace", baseErr)
}

// RevisionUnavailableError is returned when a revision is not available on a replica.
type RevisionUnavailableError struct {
error
}

// NewRevisionUnavailableError creates a new RevisionUnavailableError.
func NewRevisionUnavailableError(err error) error {
return RevisionUnavailableError{err}
}
12 changes: 12 additions & 0 deletions internal/datastore/postgres/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type postgresOptions struct {
enablePrometheusStats bool
analyzeBeforeStatistics bool
gcEnabled bool
readStrictMode bool

migrationPhase string

Expand Down Expand Up @@ -61,6 +62,7 @@ const (
defaultMaxRetries = 10
defaultGCEnabled = true
defaultCredentialsProviderName = ""
defaultReadStrictMode = false
)

// Option provides the facility to configure how clients within the
Expand All @@ -80,6 +82,7 @@ func generateConfig(options []Option) (postgresOptions, error) {
maxRetries: defaultMaxRetries,
gcEnabled: defaultGCEnabled,
credentialsProviderName: defaultCredentialsProviderName,
readStrictMode: defaultReadStrictMode,
queryInterceptor: nil,
}

Expand All @@ -103,6 +106,15 @@ func generateConfig(options []Option) (postgresOptions, error) {
return computed, nil
}

// ReadStrictMode sets whether strict mode is used for reads in the Postgres reader. If enabled,
// an assertion is added into the WHERE clause of all read queries to ensure that the revision
// being read is available on the read connection.
//
// Strict mode is disabled by default, as the default behavior is to read from the primary.
func ReadStrictMode(readStrictMode bool) Option {
return func(po *postgresOptions) { po.readStrictMode = readStrictMode }
}

// ReadConnHealthCheckInterval is the frequency at which both idle and max
// lifetime connections are checked, and also the frequency at which the
// minimum number of connections is checked.
Expand Down
16 changes: 15 additions & 1 deletion internal/datastore/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ func newPostgresDatastore(
maxRetries: config.maxRetries,
credentialsProvider: credentialsProvider,
isPrimary: isPrimary,
inStrictReadMode: config.readStrictMode,
}

if isPrimary && config.readStrictMode {
return nil, spiceerrors.MustBugf("strict read mode is not supported on primary instances")
}

if isPrimary {
Expand Down Expand Up @@ -376,6 +381,7 @@ type pgDatastore struct {
maxRetries uint8
watchEnabled bool
isPrimary bool
inStrictReadMode bool

credentialsProvider datastore.CredentialsProvider

Expand All @@ -389,6 +395,10 @@ func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Read
rev := revRaw.(postgresRevision)

queryFuncs := pgxcommon.QuerierFuncsFor(pgd.readPool)
if pgd.inStrictReadMode {
queryFuncs = strictReaderQueryFuncs{wrapped: queryFuncs, revision: rev}
}

executor := common.QueryExecutor{
Executor: pgxcommon.NewPGXExecutor(queryFuncs),
}
Expand Down Expand Up @@ -590,7 +600,11 @@ func (pgd *pgDatastore) Close() error {
}

pgd.readPool.Close()
pgd.writePool.Close()

if pgd.writePool != nil {
pgd.writePool.Close()
}

return nil
}

Expand Down
70 changes: 63 additions & 7 deletions internal/datastore/postgres/postgres_shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func testPostgresDatastore(t *testing.T, pc []postgresConfig) {

test.All(t, test.DatastoreTesterFunc(func(revisionQuantization, gcInterval, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) {
ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore {
ds, err := newPostgresDatastore(ctx, uri, -1,
ds, err := newPostgresDatastore(ctx, uri, primaryInstanceID,
RevisionQuantization(revisionQuantization),
GCWindow(gcWindow),
GCInterval(gcInterval),
Expand Down Expand Up @@ -176,6 +176,16 @@ func testPostgresDatastore(t *testing.T, pc []postgresConfig) {
WatchBufferLength(50),
MigrationPhase(config.migrationPhase),
))

t.Run("TestStrictReadMode", createReplicaDatastoreTest(
b,
StrictReadModeTest,
RevisionQuantization(0),
GCWindow(1000*time.Second),
WatchBufferLength(50),
MigrationPhase(config.migrationPhase),
ReadStrictMode(true),
))
}

t.Run("OTelTracing", createDatastoreTest(
Expand Down Expand Up @@ -203,7 +213,7 @@ func testPostgresDatastoreWithoutCommitTimestamps(t *testing.T, pc []postgresCon
// NOTE: watch API requires the commit timestamps, so we skip those tests here.
test.AllWithExceptions(t, test.DatastoreTesterFunc(func(revisionQuantization, gcInterval, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) {
ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore {
ds, err := newPostgresDatastore(ctx, uri, -1,
ds, err := newPostgresDatastore(ctx, uri, primaryInstanceID,
RevisionQuantization(revisionQuantization),
GCWindow(gcWindow),
GCInterval(gcInterval),
Expand All @@ -225,7 +235,21 @@ func createDatastoreTest(b testdatastore.RunningEngineForTest, tf datastoreTestF
return func(t *testing.T) {
ctx := context.Background()
ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore {
ds, err := newPostgresDatastore(ctx, uri, -1, options...)
ds, err := newPostgresDatastore(ctx, uri, primaryInstanceID, options...)
require.NoError(t, err)
return ds
})
defer ds.Close()

tf(t, ds)
}
}

func createReplicaDatastoreTest(b testdatastore.RunningEngineForTest, tf datastoreTestFunc, options ...Option) func(*testing.T) {
return func(t *testing.T) {
ctx := context.Background()
ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore {
ds, err := newPostgresDatastore(ctx, uri, 42, options...)
require.NoError(t, err)
return ds
})
Expand Down Expand Up @@ -635,7 +659,7 @@ func QuantizedRevisionTest(t *testing.T, b testdatastore.RunningEngineForTest) {
ds, err := newPostgresDatastore(
ctx,
uri,
-1,
primaryInstanceID,
RevisionQuantization(5*time.Second),
GCWindow(24*time.Hour),
WatchBufferLength(1),
Expand Down Expand Up @@ -1137,7 +1161,7 @@ func WatchNotEnabledTest(t *testing.T, _ testdatastore.RunningEngineForTest, pgV
ds := testdatastore.RunPostgresForTestingWithCommitTimestamps(t, "", migrate.Head, false, pgVersion, false).NewDatastore(t, func(engine, uri string) datastore.Datastore {
ctx := context.Background()
ds, err := newPostgresDatastore(ctx, uri,
-1,
primaryInstanceID,
RevisionQuantization(0),
GCWindow(time.Millisecond*1),
WatchBufferLength(1),
Expand All @@ -1164,7 +1188,7 @@ func BenchmarkPostgresQuery(b *testing.B) {
ds := testdatastore.RunPostgresForTesting(b, "", migrate.Head, pgversion.MinimumSupportedPostgresVersion, false).NewDatastore(b, func(engine, uri string) datastore.Datastore {
ctx := context.Background()
ds, err := newPostgresDatastore(ctx, uri,
-1,
primaryInstanceID,
RevisionQuantization(0),
GCWindow(time.Millisecond*1),
WatchBufferLength(1),
Expand Down Expand Up @@ -1200,7 +1224,7 @@ func datastoreWithInterceptorAndTestData(t *testing.T, interceptor pgcommon.Quer
ds := testdatastore.RunPostgresForTestingWithCommitTimestamps(t, "", migrate.Head, false, pgVersion, false).NewDatastore(t, func(engine, uri string) datastore.Datastore {
ctx := context.Background()
ds, err := newPostgresDatastore(ctx, uri,
-1,
primaryInstanceID,
RevisionQuantization(0),
GCWindow(time.Millisecond*1),
WatchBufferLength(1),
Expand Down Expand Up @@ -1410,6 +1434,38 @@ func RepairTransactionsTest(t *testing.T, ds datastore.Datastore) {
require.Greater(t, currentMaximumID, 12345)
}

func StrictReadModeTest(t *testing.T, ds datastore.Datastore) {
require := require.New(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

lowestRevision, err := ds.HeadRevision(ctx)
require.NoError(err)

// Perform a read at the head revision, which should succeed.
reader := ds.SnapshotReader(lowestRevision)
it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{
OptionalResourceType: "resource",
})
require.NoError(err)
it.Close()

// Perform a read at a manually constructed revision beyond head, which should fail.
badRev := postgresRevision{
snapshot: pgSnapshot{
xmax: 9999999999999999999,
},
}

_, err = ds.SnapshotReader(badRev).QueryRelationships(ctx, datastore.RelationshipsFilter{
OptionalResourceType: "resource",
})
require.Error(err)
require.ErrorContains(err, "is not available on the replica")
require.ErrorAs(err, &common.RevisionUnavailableError{})
}

func NullCaveatWatchTest(t *testing.T, ds datastore.Datastore) {
require := require.New(t)

Expand Down
65 changes: 65 additions & 0 deletions internal/datastore/postgres/strictreader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package postgres

import (
"context"
"errors"
"fmt"
"strings"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"

"github.com/authzed/spicedb/internal/datastore/common"
pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
)

const pgInvalidArgument = "22023"

// strictReaderQueryFuncs wraps a DBFuncQuerier and adds a strict read assertion to all queries.
// This assertion ensures that the transaction is not reading from the future or from a
// transaction that has not been committed on the replica.
type strictReaderQueryFuncs struct {
wrapped pgxcommon.DBFuncQuerier
revision postgresRevision
}

func (srqf strictReaderQueryFuncs) ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, args ...any) error {
// NOTE: it is *required* for the pgx.QueryExecModeSimpleProtocol to be added as pgx will otherwise wrap
// the query as a prepared statement, which does *not* support running more than a single statement at a time.
return srqf.rewriteError(srqf.wrapped.ExecFunc(ctx, tagFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
}

func (srqf strictReaderQueryFuncs) QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, args ...any) error {
return srqf.rewriteError(srqf.wrapped.QueryFunc(ctx, rowsFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
}

func (srqf strictReaderQueryFuncs) QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, args ...any) error {
return srqf.rewriteError(srqf.wrapped.QueryRowFunc(ctx, rowFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
}

func (srqf strictReaderQueryFuncs) rewriteError(err error) error {
if err == nil {
return nil
}

var pgerr *pgconn.PgError
if errors.As(err, &pgerr) {
if (pgerr.Code == pgInvalidArgument && strings.Contains(pgerr.Message, "is in the future")) ||
strings.Contains(pgerr.Message, "replica missing revision") {
return common.NewRevisionUnavailableError(fmt.Errorf("revision %s is not available on the replica", srqf.revision.String()))
}
}

return err
}

func (srqf strictReaderQueryFuncs) addAssertToSQL(sql string) string {
// The assertion checks that the transaction is not reading from the future or from a
// transaction that is still in-progress on the replica. If the transaction is not yet
// available on the replica at all, the call to `pg_xact_status` will fail with an invalid
// argument error and a message indicating that the xid "is in the future". If the transaction
// does exist, but has not yet been committed (or aborted), the call to `pg_xact_status` will return
// "in progress". rewriteError will catch these cases and return a RevisionUnavailableError.
assertion := fmt.Sprintf(`do $$ begin assert (select pg_xact_status(%d::text::xid8) != 'in progress'), 'replica missing revision';end;$$;`, srqf.revision.snapshot.xmin-1)
return assertion + sql
}
10 changes: 5 additions & 5 deletions internal/datastore/proxy/cachedcheckrev.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ import (
func newCachedCheckRevision(ds datastore.ReadOnlyDatastore) datastore.ReadOnlyDatastore {
return &cachedCheckRevision{
ReadOnlyDatastore: ds,
lastCheckRevision: atomic.Value{},
lastCheckRevision: atomic.Pointer[datastore.Revision]{},
}
}

type cachedCheckRevision struct {
datastore.ReadOnlyDatastore
lastCheckRevision atomic.Value
lastCheckRevision atomic.Pointer[datastore.Revision]
}

func (c *cachedCheckRevision) CheckRevision(ctx context.Context, rev datastore.Revision) error {
// Check if we've already seen a revision at least as fresh as that specified. If so, we can skip the check.
lastChecked := c.lastCheckRevision.Load()
if lastChecked != nil {
lastCheckedRev := lastChecked.(datastore.Revision)
lastCheckedRev := *lastChecked
if lastCheckedRev.Equal(rev) || lastCheckedRev.GreaterThan(rev) {
return nil
}
Expand All @@ -36,8 +36,8 @@ func (c *cachedCheckRevision) CheckRevision(ctx context.Context, rev datastore.R
return err
}

if lastChecked == nil || rev.LessThan(lastChecked.(datastore.Revision)) {
c.lastCheckRevision.Store(rev)
if lastChecked == nil || rev.LessThan(*lastChecked) {
c.lastCheckRevision.CompareAndSwap(lastChecked, &rev)
}

return nil
Expand Down
Loading

0 comments on commit 3d4516e

Please sign in to comment.