diff --git a/internal/datastore/memdb/schema.go b/internal/datastore/memdb/schema.go index 5dc441dbd7..49cc62db0c 100644 --- a/internal/datastore/memdb/schema.go +++ b/internal/datastore/memdb/schema.go @@ -130,7 +130,10 @@ func (r relationship) RelationTuple() (*core.RelationTuple, error) { return nil, err } - ig := r.integrity.RelationshipIntegrity() + var ig *core.RelationshipIntegrity + if r.integrity != nil { + ig = r.integrity.RelationshipIntegrity() + } return &core.RelationTuple{ ResourceAndRelation: &core.ObjectAndRelation{ diff --git a/internal/datastore/proxy/relationshipintegrity.go b/internal/datastore/proxy/relationshipintegrity.go new file mode 100644 index 0000000000..32e34c18c7 --- /dev/null +++ b/internal/datastore/proxy/relationshipintegrity.go @@ -0,0 +1,397 @@ +package proxy + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "fmt" + "hash" + "strings" + "sync" + "time" + + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// KeyConfig is a configuration for a key used to sign relationships. +type KeyConfig struct { + ID string + ExpiredAt *time.Time + Bytes []byte +} + +type hmacConfig struct { + keyID string + expiredAt *time.Time + hmacPool *sync.Pool +} + +var ( + alg = sha256.New + versionByte = byte(0x01) +) + +// NewRelationshipIntegrityProxy creates a new datastore proxy that ensures the integrity of +// relationships by using HMACs to sign the data. The current key is used to sign new data, +// and the expired keys are used to verify old data, if any. +func NewRelationshipIntegrityProxy(ds datastore.Datastore, currentKey KeyConfig, expiredKeys []KeyConfig) (datastore.Datastore, error) { + currentKeyHMAC := hmacConfig{ + keyID: currentKey.ID, + expiredAt: currentKey.ExpiredAt, + hmacPool: &sync.Pool{ + New: func() any { return hmac.New(alg, currentKey.Bytes) }, + }, + } + + if currentKey.ExpiredAt != nil { + return nil, spiceerrors.MustBugf("current key cannot have an expiration") + } + + keysByID := make(map[string]hmacConfig, len(expiredKeys)+1) + keysByID[currentKey.ID] = currentKeyHMAC + + for _, key := range expiredKeys { + if key.ExpiredAt == nil { + return nil, spiceerrors.MustBugf("expired key missing expiration time") + } + + if _, ok := keysByID[key.ID]; ok { + return nil, spiceerrors.MustBugf("found duplicate key ID: %s", key.ID) + } + + keysByID[key.ID] = hmacConfig{ + keyID: key.ID, + expiredAt: key.ExpiredAt, + hmacPool: &sync.Pool{ + New: func() any { return hmac.New(alg, key.Bytes) }, + }, + } + } + + return &relationshipIntegrityProxy{ + ds: ds, + primaryKey: currentKeyHMAC, + keysByID: keysByID, + }, nil +} + +type relationshipIntegrityProxy struct { + ds datastore.Datastore + primaryKey hmacConfig + keysByID map[string]hmacConfig +} + +func (r *relationshipIntegrityProxy) lookupKey(keyID string) (hmacConfig, error) { + key, ok := r.keysByID[keyID] + if !ok { + return hmacConfig{}, fmt.Errorf("key not found: %s", keyID) + } + + return key, nil +} + +// computeRelationshipHash computes the HMAC hash of a relationship tuple. +func computeRelationshipHash(tpl *corev1.RelationTuple, key hmacConfig) ([]byte, error) { + hmac := key.hmacPool.Get().(hash.Hash) + defer key.hmacPool.Put(hmac) + defer hmac.Reset() + + var sb strings.Builder + sb.WriteString(tpl.ResourceAndRelation.Namespace) + sb.WriteString(":") + sb.WriteString(tpl.ResourceAndRelation.ObjectId) + sb.WriteString("#") + sb.WriteString(tpl.ResourceAndRelation.Relation) + sb.WriteString("@") + sb.WriteString(tpl.Subject.Namespace) + sb.WriteString(":") + sb.WriteString(tpl.Subject.ObjectId) + sb.WriteString("#") + sb.WriteString(tpl.Subject.Relation) + + if tpl.Caveat != nil && tpl.Caveat.CaveatName != "" { + sb.WriteString(" with ") + sb.WriteString(tpl.Caveat.CaveatName) + sb.WriteString(":") + sb.WriteString(tpl.Caveat.Context.String()) + } + + if _, err := hmac.Write([]byte(sb.String())); err != nil { + return nil, err + } + return hmac.Sum(nil)[:16], nil +} + +func (r *relationshipIntegrityProxy) SnapshotReader(rev datastore.Revision) datastore.Reader { + return relationshipIntegrityReader{ + parent: r, + wrapped: r.ds.SnapshotReader(rev), + } +} + +func (r *relationshipIntegrityProxy) ReadWriteTx(ctx context.Context, f datastore.TxUserFunc, opts ...options.RWTOptionsOption) (datastore.Revision, error) { + return r.ds.ReadWriteTx(ctx, func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + return f(ctx, &relationshipIntegrityTx{ + ReadWriteTransaction: tx, + parent: r, + }) + }, opts...) +} + +func (r *relationshipIntegrityProxy) CheckRevision(ctx context.Context, revision datastore.Revision) error { + return r.ds.CheckRevision(ctx, revision) +} + +func (r *relationshipIntegrityProxy) Close() error { + return r.ds.Close() +} + +func (r *relationshipIntegrityProxy) Features(ctx context.Context) (*datastore.Features, error) { + return r.ds.Features(ctx) +} + +func (r *relationshipIntegrityProxy) HeadRevision(ctx context.Context) (datastore.Revision, error) { + return r.ds.HeadRevision(ctx) +} + +func (r *relationshipIntegrityProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { + return r.ds.OptimizedRevision(ctx) +} + +func (r *relationshipIntegrityProxy) ReadyState(ctx context.Context) (datastore.ReadyState, error) { + return r.ds.ReadyState(ctx) +} + +func (r *relationshipIntegrityProxy) RevisionFromString(serialized string) (datastore.Revision, error) { + return r.ds.RevisionFromString(serialized) +} + +func (r *relationshipIntegrityProxy) Statistics(ctx context.Context) (datastore.Stats, error) { + return r.ds.Statistics(ctx) +} + +func (r *relationshipIntegrityProxy) Watch(ctx context.Context, afterRevision datastore.Revision, options datastore.WatchOptions) (<-chan *datastore.RevisionChanges, <-chan error) { + return r.ds.Watch(ctx, afterRevision, options) +} + +type relationshipIntegrityReader struct { + parent *relationshipIntegrityProxy + wrapped datastore.Reader +} + +func (r relationshipIntegrityReader) QueryRelationships(ctx context.Context, filter datastore.RelationshipsFilter, options ...options.QueryOptionsOption) (datastore.RelationshipIterator, error) { + it, err := r.wrapped.QueryRelationships(ctx, filter, options...) + if err != nil { + return nil, err + } + + return &relationshipIntegrityIterator{ + parent: r, + wrapped: it, + }, nil +} + +func (r relationshipIntegrityReader) ReverseQueryRelationships(ctx context.Context, subjectsFilter datastore.SubjectsFilter, options ...options.ReverseQueryOptionsOption) (datastore.RelationshipIterator, error) { + it, err := r.wrapped.ReverseQueryRelationships(ctx, subjectsFilter, options...) + if err != nil { + return nil, err + } + + return &relationshipIntegrityIterator{ + parent: r, + wrapped: it, + }, nil +} + +func (r relationshipIntegrityReader) CountRelationships(ctx context.Context, name string) (int, error) { + return r.wrapped.CountRelationships(ctx, name) +} + +func (r relationshipIntegrityReader) ListAllCaveats(ctx context.Context) ([]datastore.RevisionedDefinition[*corev1.CaveatDefinition], error) { + return r.wrapped.ListAllCaveats(ctx) +} + +func (r relationshipIntegrityReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedDefinition[*corev1.NamespaceDefinition], error) { + return r.wrapped.ListAllNamespaces(ctx) +} + +func (r relationshipIntegrityReader) LookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedDefinition[*corev1.CaveatDefinition], error) { + return r.wrapped.LookupCaveatsWithNames(ctx, names) +} + +func (r relationshipIntegrityReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) { + return r.wrapped.LookupCounters(ctx) +} + +func (r relationshipIntegrityReader) LookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*corev1.NamespaceDefinition], error) { + return r.wrapped.LookupNamespacesWithNames(ctx, nsNames) +} + +func (r relationshipIntegrityReader) ReadCaveatByName(ctx context.Context, name string) (caveat *corev1.CaveatDefinition, lastWritten datastore.Revision, err error) { + return r.wrapped.ReadCaveatByName(ctx, name) +} + +func (r relationshipIntegrityReader) ReadNamespaceByName(ctx context.Context, nsName string) (ns *corev1.NamespaceDefinition, lastWritten datastore.Revision, err error) { + return r.wrapped.ReadNamespaceByName(ctx, nsName) +} + +type relationshipIntegrityIterator struct { + parent relationshipIntegrityReader + wrapped datastore.RelationshipIterator + err error +} + +func (r *relationshipIntegrityIterator) Close() { + r.wrapped.Close() +} + +func (r *relationshipIntegrityIterator) Cursor() (options.Cursor, error) { + return r.wrapped.Cursor() +} + +func (r *relationshipIntegrityIterator) Err() error { + if r.err != nil { + return r.err + } + + return r.wrapped.Err() +} + +func (r *relationshipIntegrityIterator) Next() *corev1.RelationTuple { + tpl := r.wrapped.Next() + if tpl == nil { + return nil + } + + // Ensure the relationship has integrity data. + // TODO(jschorr): Do this in parallel with returning the relationship from + // the iterator and have the error be added to the iterator if it fails. + if tpl.Integrity == nil || len(tpl.Integrity.Hash) == 0 || tpl.Integrity.KeyId == "" { + str, err := tuple.String(tpl) + if err != nil { + r.err = err + return nil + } + + r.err = fmt.Errorf("relationship %s is missing required integrity data", str) + return nil + } + + hashWithoutByte := tpl.Integrity.Hash[1:] + if tpl.Integrity.Hash[0] != versionByte || len(hashWithoutByte) != 16 { + r.err = fmt.Errorf("relationship %s has invalid integrity data", tpl) + return nil + } + + // Validate the integrity of the relationship. + key, err := r.parent.parent.lookupKey(tpl.Integrity.KeyId) + if err != nil { + r.err = err + return nil + } + + if key.expiredAt != nil && key.expiredAt.Before(tpl.Integrity.HashedAt.AsTime()) { + r.err = fmt.Errorf("relationship %s is signed by an expired key", tpl) + return nil + } + + computedHash, err := computeRelationshipHash(tpl, key) + if err != nil { + r.err = err + return nil + } + + if !hmac.Equal(computedHash, hashWithoutByte) { + str, err := tuple.String(tpl) + if err != nil { + r.err = err + return nil + } + + r.err = fmt.Errorf("relationship %s has invalid integrity hash", str) + return nil + } + + return tpl +} + +type relationshipIntegrityTx struct { + datastore.ReadWriteTransaction + + parent *relationshipIntegrityProxy +} + +func (r *relationshipIntegrityTx) WriteRelationships( + ctx context.Context, + mutations []*corev1.RelationTupleUpdate, +) error { + // Add integrity data to the relationships. + key := r.parent.primaryKey + hashedAt := timestamppb.Now() + + // TODO(jschorr): Do this in parallel + for _, mutation := range mutations { + if mutation.Tuple.Integrity != nil { + return spiceerrors.MustBugf("relationship %s already has integrity data", mutation.Tuple) + } + + hash, err := computeRelationshipHash(mutation.Tuple, key) + if err != nil { + return err + } + + mutation.Tuple.Integrity = &corev1.RelationshipIntegrity{ + HashedAt: hashedAt, + Hash: append([]byte{versionByte}, hash...), + KeyId: key.keyID, + } + } + + return r.ReadWriteTransaction.WriteRelationships(ctx, mutations) +} + +func (r *relationshipIntegrityTx) BulkLoad( + ctx context.Context, + iter datastore.BulkWriteRelationshipSource, +) (uint64, error) { + wrapped := &wrappedBulkLoadIterator{iter, r.parent} + return r.ReadWriteTransaction.BulkLoad(ctx, wrapped) +} + +type wrappedBulkLoadIterator struct { + wrapped datastore.BulkWriteRelationshipSource + parent *relationshipIntegrityProxy +} + +func (w wrappedBulkLoadIterator) Next(ctx context.Context) (*corev1.RelationTuple, error) { + tpl, err := w.wrapped.Next(ctx) + if err != nil { + return nil, err + } + + key := w.parent.primaryKey + hashedAt := timestamppb.Now() + + hash, err := computeRelationshipHash(tpl, key) + if err != nil { + return nil, err + } + + if tpl.Integrity != nil { + return nil, spiceerrors.MustBugf("relationship %s already has integrity data", tpl) + } + + tpl.Integrity = &corev1.RelationshipIntegrity{ + HashedAt: hashedAt, + Hash: append([]byte{versionByte}, hash...), + KeyId: key.keyID, + } + + return tpl, nil +} diff --git a/internal/datastore/proxy/relationshipintegrity_test.go b/internal/datastore/proxy/relationshipintegrity_test.go new file mode 100644 index 0000000000..d508b9a466 --- /dev/null +++ b/internal/datastore/proxy/relationshipintegrity_test.go @@ -0,0 +1,534 @@ +package proxy + +import ( + "context" + "crypto/hmac" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +var defaultKeyForTesting = KeyConfig{ + ID: "defaultfortest", + Bytes: []byte("somedefaultkeyfortesting"), + ExpiredAt: nil, +} + +var toBeExpiredKeyForTesting = KeyConfig{ + ID: "expiredkeyfortest", + Bytes: []byte("somexpiredkeyfortesting"), +} + +var expiredKeyForTesting = KeyConfig{ + ID: "expiredkeyfortest", + Bytes: []byte("somexpiredkeyfortesting"), + ExpiredAt: (func() *time.Time { + t, err := time.Parse("2006-01-02", "2021-01-01") + if err != nil { + panic(err) + } + return &t + })(), +} + +var notYetExpiredKeyForTesting = KeyConfig{ + ID: "expiredkeyfortest", + Bytes: []byte("somexpiredkeyfortesting"), + ExpiredAt: (func() *time.Time { + t, err := time.Parse("2006-01-02", "2999-01-01") + if err != nil { + panic(err) + } + return &t + })(), +} + +func TestWriteWithPredefinedIntegrity(t *testing.T) { + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + pds, err := NewRelationshipIntegrityProxy(ds, defaultKeyForTesting, nil) + require.NoError(t, err) + + require.Panics(t, func() { + pds.ReadWriteTx(context.Background(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + tpl := tuple.MustParse("resource:foo#viewer@user:tom") + tpl.Integrity = &core.RelationshipIntegrity{} + return tx.WriteRelationships(context.Background(), []*core.RelationTupleUpdate{ + tuple.Create(tpl), + }) + }) + }) +} + +func TestReadWithMissingIntegrity(t *testing.T) { + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + // Write a relationship to the underlying datastore without integrity information. + _, err = ds.ReadWriteTx(context.Background(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + tpl := tuple.MustParse("resource:foo#viewer@user:tom") + return tx.WriteRelationships(context.Background(), []*core.RelationTupleUpdate{ + tuple.Create(tpl), + }) + }) + require.NoError(t, err) + + // Attempt to read, which should return an error. + pds, err := NewRelationshipIntegrityProxy(ds, defaultKeyForTesting, nil) + require.NoError(t, err) + + headRev, err := pds.HeadRevision(context.Background()) + require.NoError(t, err) + + reader := pds.SnapshotReader(headRev) + iter, err := reader.QueryRelationships( + context.Background(), + datastore.RelationshipsFilter{OptionalResourceType: "resource"}, + ) + require.NoError(t, err) + + found := iter.Next() + require.Nil(t, found) + require.Error(t, iter.Err()) + require.ErrorContains(t, iter.Err(), "is missing required integrity data") + + iter.Close() +} + +func TestBasicIntegrity(t *testing.T) { + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + pds, err := NewRelationshipIntegrityProxy(ds, defaultKeyForTesting, nil) + require.NoError(t, err) + + beforeWriteTime := time.Now() + + // Write some relationships. + _, err = pds.ReadWriteTx(context.Background(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + return tx.WriteRelationships(context.Background(), []*core.RelationTupleUpdate{ + tuple.Create(tuple.MustParse("resource:foo#viewer@user:tom")), + tuple.Create(tuple.MustParse("resource:foo#viewer@user:fred")), + tuple.Touch(tuple.MustParse("resource:bar#viewer@user:sarah")), + }) + }) + require.NoError(t, err) + + afterWriteTime := time.Now() + + // Read them back and ensure they are found. + headRev, err := pds.HeadRevision(context.Background()) + require.NoError(t, err) + + reader := pds.SnapshotReader(headRev) + iter, err := reader.QueryRelationships( + context.Background(), + datastore.RelationshipsFilter{OptionalResourceType: "resource"}, + ) + t.Cleanup(iter.Close) + require.NoError(t, err) + + foundRelationships := mapz.NewSet[string]() + for { + require.NoError(t, iter.Err()) + + rel := iter.Next() + if rel == nil { + break + } + + require.NotNil(t, rel.Integrity) + require.Equal(t, defaultKeyForTesting.ID, rel.Integrity.KeyId) + require.NotNil(t, rel.Integrity.Hash) + + // Ensure the integrity is within the expected bounds. + require.True(t, beforeWriteTime.Before(rel.Integrity.HashedAt.AsTime()) || beforeWriteTime.Equal(rel.Integrity.HashedAt.AsTime())) + require.True(t, afterWriteTime.After(rel.Integrity.HashedAt.AsTime()) || afterWriteTime.Equal(rel.Integrity.HashedAt.AsTime())) + + foundRelationships.Add(tuple.MustString(rel)) + } + require.NoError(t, iter.Err()) + iter.Close() + + require.ElementsMatch(t, foundRelationships.AsSlice(), []string{ + "resource:foo#viewer@user:tom", + "resource:foo#viewer@user:fred", + "resource:bar#viewer@user:sarah", + }) + + // Invoke reverse query relationship to ensure the integrity is also present. + iter, err = reader.ReverseQueryRelationships( + context.Background(), + datastore.SubjectsFilter{ + SubjectType: "user", + }, + ) + t.Cleanup(iter.Close) + require.NoError(t, err) + + foundRelationships = mapz.NewSet[string]() + for { + require.NoError(t, iter.Err()) + + rel := iter.Next() + if rel == nil { + break + } + + require.NotNil(t, rel.Integrity) + require.Equal(t, defaultKeyForTesting.ID, rel.Integrity.KeyId) + require.NotNil(t, rel.Integrity.Hash) + + // Ensure the integrity is within the expected bounds. + require.True(t, beforeWriteTime.Before(rel.Integrity.HashedAt.AsTime()) || beforeWriteTime.Equal(rel.Integrity.HashedAt.AsTime())) + require.True(t, afterWriteTime.After(rel.Integrity.HashedAt.AsTime()) || afterWriteTime.Equal(rel.Integrity.HashedAt.AsTime())) + + foundRelationships.Add(tuple.MustString(rel)) + } + require.NoError(t, iter.Err()) + iter.Close() + + require.ElementsMatch(t, foundRelationships.AsSlice(), []string{ + "resource:foo#viewer@user:tom", + "resource:foo#viewer@user:fred", + "resource:bar#viewer@user:sarah", + }) +} + +func TestBasicIntegrityFailureDueToInvalidHashVersion(t *testing.T) { + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + pds, err := NewRelationshipIntegrityProxy(ds, defaultKeyForTesting, nil) + require.NoError(t, err) + + // Write some relationships. + _, err = pds.ReadWriteTx(context.Background(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + return tx.WriteRelationships(context.Background(), []*core.RelationTupleUpdate{ + tuple.Create(tuple.MustParse("resource:foo#viewer@user:tom")), + tuple.Create(tuple.MustParse("resource:foo#viewer@user:fred")), + tuple.Touch(tuple.MustParse("resource:bar#viewer@user:sarah")), + }) + }) + require.NoError(t, err) + + // Insert an invalid integrity hash for one of the relationships to be invalid by bypassing + // the + _, err = ds.ReadWriteTx(context.Background(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + invalidTpl := tuple.MustParse("resource:foo#viewer@user:jimmy") + invalidTpl.Integrity = &core.RelationshipIntegrity{ + KeyId: "defaultfortest", + Hash: []byte("invalidhash"), + HashedAt: timestamppb.Now(), + } + + return tx.WriteRelationships(context.Background(), []*core.RelationTupleUpdate{ + tuple.Create(invalidTpl), + }) + }) + require.NoError(t, err) + + // Read them back and ensure the read fails. + headRev, err := pds.HeadRevision(context.Background()) + require.NoError(t, err) + + reader := pds.SnapshotReader(headRev) + iter, err := reader.QueryRelationships( + context.Background(), + datastore.RelationshipsFilter{OptionalResourceType: "resource"}, + ) + t.Cleanup(iter.Close) + require.NoError(t, err) + + var foundError error + for { + rel := iter.Next() + if rel == nil { + break + } + + err := iter.Err() + if err != nil { + foundError = err + break + } + } + + if foundError == nil { + foundError = iter.Err() + } + + require.Error(t, foundError) + require.ErrorContains(t, foundError, "has invalid integrity data") +} + +func TestBasicIntegrityFailureDueToInvalidHashSignature(t *testing.T) { + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + pds, err := NewRelationshipIntegrityProxy(ds, defaultKeyForTesting, nil) + require.NoError(t, err) + + // Write some relationships. + _, err = pds.ReadWriteTx(context.Background(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + return tx.WriteRelationships(context.Background(), []*core.RelationTupleUpdate{ + tuple.Create(tuple.MustParse("resource:foo#viewer@user:tom")), + tuple.Create(tuple.MustParse("resource:foo#viewer@user:fred")), + tuple.Touch(tuple.MustParse("resource:bar#viewer@user:sarah")), + }) + }) + require.NoError(t, err) + + // Insert an invalid integrity hash for one of the relationships to be invalid by bypassing + // the + _, err = ds.ReadWriteTx(context.Background(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + invalidTpl := tuple.MustParse("resource:foo#viewer@user:jimmy") + invalidTpl.Integrity = &core.RelationshipIntegrity{ + KeyId: "defaultfortest", + Hash: append([]byte{0x01}, []byte("someinvalidhashaasd")[0:16]...), + HashedAt: timestamppb.Now(), + } + + return tx.WriteRelationships(context.Background(), []*core.RelationTupleUpdate{ + tuple.Create(invalidTpl), + }) + }) + require.NoError(t, err) + + // Read them back and ensure the read fails. + headRev, err := pds.HeadRevision(context.Background()) + require.NoError(t, err) + + reader := pds.SnapshotReader(headRev) + iter, err := reader.QueryRelationships( + context.Background(), + datastore.RelationshipsFilter{OptionalResourceType: "resource"}, + ) + t.Cleanup(iter.Close) + require.NoError(t, err) + + var foundError error + for { + rel := iter.Next() + if rel == nil { + break + } + + err := iter.Err() + if err != nil { + foundError = err + break + } + } + + if foundError == nil { + foundError = iter.Err() + } + + require.Error(t, foundError) + require.ErrorContains(t, foundError, "has invalid integrity hash") +} + +func TestBasicIntegrityWithReplacedKey(t *testing.T) { + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + // Create a proxy with the to-be-expired key and write some relationships. + epds, err := NewRelationshipIntegrityProxy(ds, toBeExpiredKeyForTesting, nil) + require.NoError(t, err) + + // Write some relationships. + _, err = epds.ReadWriteTx(context.Background(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + return tx.WriteRelationships(context.Background(), []*core.RelationTupleUpdate{ + tuple.Create(tuple.MustParse("resource:foo#viewer@user:tom")), + tuple.Create(tuple.MustParse("resource:foo#viewer@user:fred")), + tuple.Touch(tuple.MustParse("resource:bar#viewer@user:sarah")), + }) + }) + require.NoError(t, err) + + // Create a proxy with the key now expired and ensure reads still work, as the key + // has not yet expired for the existing relationships. + pds, err := NewRelationshipIntegrityProxy(ds, defaultKeyForTesting, []KeyConfig{ + notYetExpiredKeyForTesting, + }) + require.NoError(t, err) + + // Read them back and ensure they are found. + headRev, err := pds.HeadRevision(context.Background()) + require.NoError(t, err) + + reader := pds.SnapshotReader(headRev) + iter, err := reader.QueryRelationships( + context.Background(), + datastore.RelationshipsFilter{OptionalResourceType: "resource"}, + ) + t.Cleanup(iter.Close) + require.NoError(t, err) + + foundRelationships := mapz.NewSet[string]() + for { + require.NoError(t, iter.Err()) + + rel := iter.Next() + if rel == nil { + break + } + + require.NotNil(t, rel.Integrity) + require.Equal(t, notYetExpiredKeyForTesting.ID, rel.Integrity.KeyId) + require.NotNil(t, rel.Integrity.Hash) + + foundRelationships.Add(tuple.MustString(rel)) + } + require.NoError(t, iter.Err()) + iter.Close() + + require.ElementsMatch(t, foundRelationships.AsSlice(), []string{ + "resource:foo#viewer@user:tom", + "resource:foo#viewer@user:fred", + "resource:bar#viewer@user:sarah", + }) +} + +func TestBasicIntegrityFailureDueToWriteWithExpiredKey(t *testing.T) { + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + // Create a proxy with the to-be-expired key and write some relationships. + epds, err := NewRelationshipIntegrityProxy(ds, toBeExpiredKeyForTesting, nil) + require.NoError(t, err) + + // Write some relationships. + _, err = epds.ReadWriteTx(context.Background(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + return tx.WriteRelationships(context.Background(), []*core.RelationTupleUpdate{ + tuple.Create(tuple.MustParse("resource:foo#viewer@user:tom")), + tuple.Create(tuple.MustParse("resource:foo#viewer@user:fred")), + tuple.Touch(tuple.MustParse("resource:bar#viewer@user:sarah")), + }) + }) + require.NoError(t, err) + + pds, err := NewRelationshipIntegrityProxy(ds, defaultKeyForTesting, []KeyConfig{ + expiredKeyForTesting, + }) + require.NoError(t, err) + + // Read them back and ensure the read fails. + headRev, err := pds.HeadRevision(context.Background()) + require.NoError(t, err) + + reader := pds.SnapshotReader(headRev) + iter, err := reader.QueryRelationships( + context.Background(), + datastore.RelationshipsFilter{OptionalResourceType: "resource"}, + ) + t.Cleanup(iter.Close) + require.NoError(t, err) + + var foundError error + for { + rel := iter.Next() + if rel == nil { + break + } + + err := iter.Err() + if err != nil { + foundError = err + break + } + } + + if foundError == nil { + foundError = iter.Err() + } + + require.Error(t, foundError) + require.ErrorContains(t, foundError, "is signed by an expired key") +} + +func BenchmarkQueryRelsWithIntegrity(b *testing.B) { + for _, withIntegrity := range []bool{true, false} { + b.Run(fmt.Sprintf("withIntegrity=%t", withIntegrity), func(b *testing.B) { + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(b, err) + + pds, err := NewRelationshipIntegrityProxy(ds, defaultKeyForTesting, nil) + require.NoError(b, err) + + _, err = pds.ReadWriteTx(context.Background(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + for i := 0; i < 1000; i++ { + tpl := tuple.MustParse(fmt.Sprintf("resource:foo#viewer@user:user-%d", i)) + if err := tx.WriteRelationships(context.Background(), []*core.RelationTupleUpdate{ + tuple.Create(tpl), + }); err != nil { + return err + } + } + + return nil + }) + require.NoError(b, err) + + headRev, err := pds.HeadRevision(context.Background()) + require.NoError(b, err) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var reader datastore.Reader + if withIntegrity { + reader = pds.SnapshotReader(headRev) + } else { + reader = ds.SnapshotReader(headRev) + } + iter, err := reader.QueryRelationships( + context.Background(), + datastore.RelationshipsFilter{OptionalResourceType: "resource"}, + ) + require.NoError(b, err) + + for { + err := iter.Err() + if err != nil { + require.NoError(b, err) + } + + rel := iter.Next() + if rel == nil { + break + } + } + + iter.Close() + } + b.StopTimer() + }) + } +} + +func BenchmarkComputeRelationshipHash(b *testing.B) { + config := hmacConfig{ + keyID: "defaultfortest", + hmacPool: &sync.Pool{ + New: func() any { return hmac.New(alg, []byte("sometestbytes")) }, + }, + } + + tpl := tuple.MustParse("resource:foo#viewer@user:tom") + for i := 0; i < b.N; i++ { + _, err := computeRelationshipHash(tpl, config) + require.NoError(b, err) + } +}