From 5731e26b3b497edb3c6bfa2bae6c7fc14dd8242f Mon Sep 17 00:00:00 2001 From: Tanner Stirrat Date: Wed, 11 Sep 2024 15:39:57 -0600 Subject: [PATCH] Fix unsafe conversion errors --- go.mod | 1 + go.sum | 2 + internal/datastore/common/changes.go | 8 ++- internal/datastore/common/gc_test.go | 12 ++-- internal/datastore/common/sql.go | 13 ++-- internal/datastore/crdb/crdb.go | 10 ++- internal/datastore/crdb/pool/balancer.go | 18 ++++- internal/datastore/postgres/common/pgx.go | 16 ++++- internal/datastore/postgres/postgres.go | 10 ++- .../postgres/postgres_shared_test.go | 17 ++--- internal/datastore/postgres/readwrite.go | 10 ++- internal/datastore/postgres/revisions.go | 24 +++++-- .../proxy/schemacaching/watchingcache.go | 4 +- .../datastore/revisions/commonrevision.go | 7 -- internal/datastore/revisions/hlcrevision.go | 11 ++-- .../datastore/revisions/timestamprevision.go | 4 -- internal/datastore/revisions/txidrevision.go | 6 +- .../dispatch/graph/lookupresources2_test.go | 7 +- .../dispatch/graph/lookupresources_test.go | 6 +- .../dispatch/graph/reachableresources_test.go | 6 +- internal/graph/cursors.go | 8 ++- .../usagemetrics/usagemetrics_test.go | 4 +- internal/services/v1/experimental_test.go | 17 +++-- internal/services/v1/permissions_test.go | 5 +- internal/services/v1/relationships.go | 4 +- internal/services/v1/relationships_test.go | 9 ++- pkg/cache/cache_otter.go | 10 ++- pkg/cmd/datastore/datastore.go | 66 +++++++++++++++---- pkg/cmd/server/cacheconfig.go | 14 ++-- pkg/datastore/pagination/iterator_test.go | 6 +- pkg/development/assertions.go | 23 +++++-- pkg/development/devcontext.go | 37 +++++++++-- pkg/development/parsing.go | 9 ++- pkg/development/resolver.go | 42 ++++++++++-- pkg/development/schema.go | 9 ++- pkg/development/validation.go | 34 ++++++---- pkg/development/warnings.go | 11 ++-- pkg/development/wasm/request.go | 3 +- pkg/genutil/ensure.go | 20 ++++-- 39 files changed, 378 insertions(+), 145 deletions(-) diff --git a/go.mod b/go.mod index 00d3d81a46..10e47bb439 100644 --- a/go.mod +++ b/go.mod @@ -136,6 +136,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 // indirect github.com/aws/smithy-go v1.20.4 // indirect github.com/bombsimon/wsl/v4 v4.4.1 // indirect + github.com/ccoveille/go-safecast v1.1.0 // indirect github.com/cilium/ebpf v0.9.1 // indirect github.com/containerd/cgroups/v3 v3.0.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect diff --git a/go.sum b/go.sum index cd6077c721..5ee747adf8 100644 --- a/go.sum +++ b/go.sum @@ -770,6 +770,8 @@ github.com/catenacyber/perfsprint v0.7.1 h1:PGW5G/Kxn+YrN04cRAZKC+ZuvlVwolYMrIyy github.com/catenacyber/perfsprint v0.7.1/go.mod h1:/wclWYompEyjUD2FuIIDVKNkqz7IgBIWXIH3V0Zol50= github.com/ccojocar/zxcvbn-go v1.0.2 h1:na/czXU8RrhXO4EZme6eQJLR4PzcGsahsBOAwU6I3Vg= github.com/ccojocar/zxcvbn-go v1.0.2/go.mod h1:g1qkXtUSvHP8lhHp5GrSmTz6uWALGRMQdw6Qnz/hi60= +github.com/ccoveille/go-safecast v1.1.0 h1:iHKNWaZm+OznO7Eh6EljXPjGfGQsSfa6/sxPlIEKO+g= +github.com/ccoveille/go-safecast v1.1.0/go.mod h1:QqwNjxQ7DAqY0C721OIO9InMk9zCwcsO7tnRuHytad8= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= diff --git a/internal/datastore/common/changes.go b/internal/datastore/common/changes.go index 9449e47a5a..342155ed9d 100644 --- a/internal/datastore/common/changes.go +++ b/internal/datastore/common/changes.go @@ -6,6 +6,8 @@ import ( "golang.org/x/exp/maps" + "github.com/ccoveille/go-safecast" + log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -118,7 +120,11 @@ func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error { return spiceerrors.MustBugf("byte size underflow") } - if ch.currentByteSize > int64(ch.maxByteSize) { + // We checked for underflow above, so the current byte size + // should fit in a uint64 + currentByteSize, _ := safecast.ToUint64(ch.currentByteSize) + + if currentByteSize > ch.maxByteSize { return datastore.NewMaximumChangesSizeExceededError(ch.maxByteSize) } diff --git a/internal/datastore/common/gc_test.go b/internal/datastore/common/gc_test.go index e7e794c69e..da8a09202f 100644 --- a/internal/datastore/common/gc_test.go +++ b/internal/datastore/common/gc_test.go @@ -68,7 +68,7 @@ func (gc *fakeGC) DeleteBeforeTx(_ context.Context, rev datastore.Revision) (Del revInt := rev.(revisions.TransactionIDRevision).TransactionID() - return gc.deleter.DeleteBeforeTx(int64(revInt)) + return gc.deleter.DeleteBeforeTx(revInt) } func (gc *fakeGC) HasGCRun() bool { @@ -101,22 +101,22 @@ func (gc *fakeGC) GetMetrics() gcMetrics { // Allows specifying different deletion behaviors for tests type gcDeleter interface { - DeleteBeforeTx(revision int64) (DeletionCounts, error) + DeleteBeforeTx(revision uint64) (DeletionCounts, error) } // Always error trying to perform a delete type alwaysErrorDeleter struct{} -func (alwaysErrorDeleter) DeleteBeforeTx(_ int64) (DeletionCounts, error) { +func (alwaysErrorDeleter) DeleteBeforeTx(_ uint64) (DeletionCounts, error) { return DeletionCounts{}, fmt.Errorf("delete error") } // Only error on specific revisions type revisionErrorDeleter struct { - errorOnRevisions []int64 + errorOnRevisions []uint64 } -func (d revisionErrorDeleter) DeleteBeforeTx(revision int64) (DeletionCounts, error) { +func (d revisionErrorDeleter) DeleteBeforeTx(revision uint64) (DeletionCounts, error) { if slices.Contains(d.errorOnRevisions, revision) { return DeletionCounts{}, fmt.Errorf("delete error") } @@ -178,7 +178,7 @@ func TestGCFailureBackoffReset(t *testing.T) { // Error on revisions 1 - 5, giving the exponential // backoff enough time to fail the test if the interval // is not reset properly. - errorOnRevisions: []int64{1, 2, 3, 4, 5}, + errorOnRevisions: []uint64{1, 2, 3, 4, 5}, }) ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index bfdfb17378..202379170e 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -7,6 +7,7 @@ import ( sq "github.com/Masterminds/squirrel" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/ccoveille/go-safecast" "github.com/jzelinskie/stringz" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -551,12 +552,13 @@ func (tqs QueryExecutor) ExecuteQuery( query = query.After(queryOpts.After, queryOpts.Sort) } - limit := math.MaxInt + var limit uint64 + limit = math.MaxUint64 if queryOpts.Limit != nil { - limit = int(*queryOpts.Limit) + limit = *queryOpts.Limit } - toExecute := query.limit(uint64(limit)) + toExecute := query.limit(limit) sql, args, err := toExecute.queryBuilder.ToSql() if err != nil { return nil, err @@ -567,7 +569,10 @@ func (tqs QueryExecutor) ExecuteQuery( return nil, err } - if len(queryTuples) > limit { + // A length shouldn't be non-negative, so we can cast without + // checking here. + lenQueryTuples, _ := safecast.ToUint64(len(queryTuples)) + if lenQueryTuples > limit { queryTuples = queryTuples[:limit] } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 08d5bafb23..dc255b61d9 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -99,13 +99,19 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas if err != nil { return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url) } - config.readPoolOpts.ConfigurePgx(readPoolConfig) + err = config.readPoolOpts.ConfigurePgx(readPoolConfig) + if err != nil { + return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url) + } writePoolConfig, err := pgxpool.ParseConfig(url) if err != nil { return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url) } - config.writePoolOpts.ConfigurePgx(writePoolConfig) + err = config.writePoolOpts.ConfigurePgx(writePoolConfig) + if err != nil { + return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url) + } initCtx, initCancel := context.WithTimeout(context.Background(), 5*time.Minute) defer initCancel() diff --git a/internal/datastore/crdb/pool/balancer.go b/internal/datastore/crdb/pool/balancer.go index cf825f1d44..dc5a6c59a1 100644 --- a/internal/datastore/crdb/pool/balancer.go +++ b/internal/datastore/crdb/pool/balancer.go @@ -3,11 +3,13 @@ package pool import ( "context" "hash/maphash" + "math" "math/rand" "slices" "strconv" "time" + "github.com/ccoveille/go-safecast" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/prometheus/client_golang/prometheus" @@ -84,7 +86,15 @@ type nodeConnectionBalancer[P balancePoolConn[C], C balanceConn] struct { // newNodeConnectionBalancer is generic over underlying connection types for // testing purposes. Callers should use the exported NewNodeConnectionBalancer. func newNodeConnectionBalancer[P balancePoolConn[C], C balanceConn](pool balanceablePool[P, C], healthTracker *NodeHealthTracker, interval time.Duration) *nodeConnectionBalancer[P, C] { - seed := int64(new(maphash.Hash).Sum64()) + seed := int64(0) + for seed == 0 { + // Sum64 returns a uint64, and safecast will return 0 if it's not castable, + // which will happen about half the time (?). We just keep running it until + // we get a seed that fits in the box. + // Subtracting math.MaxInt64 should mean that we retain the entire range of + // possible values. + seed, _ = safecast.ToInt64(new(maphash.Hash).Sum64() - math.MaxInt64) + } return &nodeConnectionBalancer[P, C]{ ticker: time.NewTicker(interval), sem: semaphore.NewWeighted(1), @@ -147,7 +157,9 @@ func (p *nodeConnectionBalancer[P, C]) mustPruneConnections(ctx context.Context) } } - nodeCount := uint32(p.healthTracker.HealthyNodeCount()) + // It's highly unlikely that we'll ever have an overflow in + // this context, so we cast directly. + nodeCount, _ := safecast.ToUint32(p.healthTracker.HealthyNodeCount()) if nodeCount == 0 { nodeCount = 1 } @@ -203,7 +215,7 @@ func (p *nodeConnectionBalancer[P, C]) mustPruneConnections(ctx context.Context) // it's possible for the difference in connections between nodes to differ by up to // the number of nodes. if p.healthTracker.HealthyNodeCount() == 0 || - uint32(i) < p.pool.MaxConns()%uint32(p.healthTracker.HealthyNodeCount()) { + i < int(p.pool.MaxConns())%p.healthTracker.HealthyNodeCount() { perNodeMax++ } diff --git a/internal/datastore/postgres/common/pgx.go b/internal/datastore/postgres/common/pgx.go index 94bdf08c66..5eca685998 100644 --- a/internal/datastore/postgres/common/pgx.go +++ b/internal/datastore/postgres/common/pgx.go @@ -7,6 +7,7 @@ import ( "fmt" "time" + "github.com/ccoveille/go-safecast" "github.com/exaring/otelpgx" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry" zerologadapter "github.com/jackc/pgx-zerolog" @@ -300,15 +301,23 @@ type PoolOptions struct { } // ConfigurePgx applies PoolOptions to a pgx connection pool confiugration. -func (opts PoolOptions) ConfigurePgx(pgxConfig *pgxpool.Config) { +func (opts PoolOptions) ConfigurePgx(pgxConfig *pgxpool.Config) error { if opts.MaxOpenConns != nil { - pgxConfig.MaxConns = int32(*opts.MaxOpenConns) + maxConns, err := safecast.ToInt32(*opts.MaxOpenConns) + if err != nil { + return err + } + pgxConfig.MaxConns = maxConns } // Default to keeping the pool maxed out at all times. pgxConfig.MinConns = pgxConfig.MaxConns if opts.MinOpenConns != nil { - pgxConfig.MinConns = int32(*opts.MinOpenConns) + minConns, err := safecast.ToInt32(*opts.MinOpenConns) + if err != nil { + return err + } + pgxConfig.MinConns = minConns } if pgxConfig.MaxConns > 0 && pgxConfig.MinConns > 0 && pgxConfig.MaxConns < pgxConfig.MinConns { @@ -335,6 +344,7 @@ func (opts PoolOptions) ConfigurePgx(pgxConfig *pgxpool.Config) { ConfigurePGXLogger(pgxConfig.ConnConfig) ConfigureOTELTracer(pgxConfig.ConnConfig) + return nil } type QuerierFuncs struct { diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 717b67f864..3c73f75cda 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -185,7 +185,10 @@ func newPostgresDatastore( // Setup the config for each of the read and write pools. readPoolConfig := pgConfig.Copy() - config.readPoolOpts.ConfigurePgx(readPoolConfig) + err = config.readPoolOpts.ConfigurePgx(readPoolConfig) + if err != nil { + return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, pgURL) + } readPoolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { RegisterTypes(conn.TypeMap()) @@ -195,7 +198,10 @@ func newPostgresDatastore( var writePoolConfig *pgxpool.Config if isPrimary { writePoolConfig = pgConfig.Copy() - config.writePoolOpts.ConfigurePgx(writePoolConfig) + err = config.writePoolOpts.ConfigurePgx(writePoolConfig) + if err != nil { + return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, pgURL) + } writePoolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { RegisterTypes(conn.TypeMap()) diff --git a/internal/datastore/postgres/postgres_shared_test.go b/internal/datastore/postgres/postgres_shared_test.go index fd7a7fb3c3..62c40797ef 100644 --- a/internal/datastore/postgres/postgres_shared_test.go +++ b/internal/datastore/postgres/postgres_shared_test.go @@ -14,6 +14,15 @@ import ( "time" sq "github.com/Masterminds/squirrel" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/samber/lo" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "golang.org/x/sync/errgroup" + "github.com/authzed/spicedb/internal/datastore/common" pgcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" pgversion "github.com/authzed/spicedb/internal/datastore/postgres/version" @@ -25,14 +34,6 @@ import ( "github.com/authzed/spicedb/pkg/namespace" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/samber/lo" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/sdk/trace" - "go.opentelemetry.io/otel/sdk/trace/tracetest" - "golang.org/x/sync/errgroup" ) const pgSerializationFailure = "40001" diff --git a/internal/datastore/postgres/readwrite.go b/internal/datastore/postgres/readwrite.go index c5c661d7d3..d2f89dbaa0 100644 --- a/internal/datastore/postgres/readwrite.go +++ b/internal/datastore/postgres/readwrite.go @@ -6,6 +6,8 @@ import ( "fmt" "strings" + "github.com/ccoveille/go-safecast" + "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/genutil/mapz" "github.com/authzed/spicedb/pkg/spiceerrors" @@ -383,6 +385,12 @@ func (rwt *pgReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.R } func (rwt *pgReadWriteTXN) deleteRelationshipsWithLimit(ctx context.Context, filter *v1.RelationshipFilter, limit uint64) (bool, error) { + // validate the limit + intLimit, err := safecast.ToInt64(limit) + if err != nil { + return false, fmt.Errorf("limit argument could not safely be cast to int64") + } + // Construct a select query for the relationships to be removed. query := selectForDelete @@ -444,7 +452,7 @@ func (rwt *pgReadWriteTXN) deleteRelationshipsWithLimit(ctx context.Context, fil return false, fmt.Errorf(errUnableToDeleteRelationships, err) } - return result.RowsAffected() == int64(limit), nil + return result.RowsAffected() == intLimit, nil } func (rwt *pgReadWriteTXN) deleteRelationships(ctx context.Context, filter *v1.RelationshipFilter) error { diff --git a/internal/datastore/postgres/revisions.go b/internal/datastore/postgres/revisions.go index 5cc0f1cf28..7be63c26c2 100644 --- a/internal/datastore/postgres/revisions.go +++ b/internal/datastore/postgres/revisions.go @@ -9,10 +9,12 @@ import ( "strings" "time" + "github.com/ccoveille/go-safecast" "github.com/jackc/pgx/v5" "github.com/authzed/spicedb/pkg/datastore" implv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" ) const ( @@ -153,7 +155,10 @@ func parseRevisionProto(revisionStr string) (datastore.Revision, error) { return datastore.NoRevision, fmt.Errorf(errRevisionFormat, err) } - xminInt := int64(decoded.Xmin) + xminInt, err := safecast.ToInt64(decoded.Xmin) + if err != nil { + return datastore.NoRevision, spiceerrors.MustBugf("Could not cast xmin to int64") + } var xips []uint64 if len(decoded.RelativeXips) > 0 { @@ -311,15 +316,26 @@ func (pr postgresRevision) OptionalNanosTimestamp() (uint64, bool) { // for xmax and xip list values to save bytes when encoded as varint protos. // For example, snapshot 1001:1004:1001,1003 becomes 1000:3:0,2. func (pr postgresRevision) MarshalBinary() ([]byte, error) { - xminInt := int64(pr.snapshot.xmin) + xminInt, err := safecast.ToInt64(pr.snapshot.xmin) + if err != nil { + return nil, spiceerrors.MustBugf("could not safely cast snapshot xip to int64") + } relativeXips := make([]int64, len(pr.snapshot.xipList)) for i, xip := range pr.snapshot.xipList { - relativeXips[i] = int64(xip) - xminInt + intXip, err := safecast.ToInt64(xip) + if err != nil { + return nil, spiceerrors.MustBugf("could not safely cast snapshot xip to int64") + } + relativeXips[i] = intXip - xminInt } + relativeXmax, err := safecast.ToInt64(pr.snapshot.xmax) + if err != nil { + return nil, spiceerrors.MustBugf("could not safely cast snapshot xmax to int64") + } protoRevision := implv1.PostgresRevision{ Xmin: pr.snapshot.xmin, - RelativeXmax: int64(pr.snapshot.xmax) - xminInt, + RelativeXmax: relativeXmax - xminInt, RelativeXips: relativeXips, } diff --git a/internal/datastore/proxy/schemacaching/watchingcache.go b/internal/datastore/proxy/schemacaching/watchingcache.go index d80320156d..2db8da5b0c 100644 --- a/internal/datastore/proxy/schemacaching/watchingcache.go +++ b/internal/datastore/proxy/schemacaching/watchingcache.go @@ -183,7 +183,7 @@ func (p *watchingCachingProxy) startSync(ctx context.Context) error { // Start watching for schema changes. go (func() { - retryCount := 0 + retryCount := uint8(0) restartWatch: for { @@ -322,7 +322,7 @@ func (p *watchingCachingProxy) startSync(ctx context.Context) error { log.Warn().Err(err).Msg("received retryable error in schema watch; sleeping for a bit and restarting watch") retryCount++ wg.Add(1) - pgxcommon.SleepOnErr(ctx, err, uint8(retryCount)) + pgxcommon.SleepOnErr(ctx, err, retryCount) continue restartWatch } diff --git a/internal/datastore/revisions/commonrevision.go b/internal/datastore/revisions/commonrevision.go index de85496d9b..709272818e 100644 --- a/internal/datastore/revisions/commonrevision.go +++ b/internal/datastore/revisions/commonrevision.go @@ -77,10 +77,3 @@ type WithTimestampRevision interface { TimestampNanoSec() int64 ConstructForTimestamp(timestampNanoSec int64) WithTimestampRevision } - -// WithIntegerRepresentation is an interface that can be implemented by a revision to -// provide an integer representation of the revision. -type WithIntegerRepresentation interface { - datastore.Revision - IntegerRepresentation() (int64, uint32) -} diff --git a/internal/datastore/revisions/hlcrevision.go b/internal/datastore/revisions/hlcrevision.go index 3d91c341de..5e0a70c8af 100644 --- a/internal/datastore/revisions/hlcrevision.go +++ b/internal/datastore/revisions/hlcrevision.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/ccoveille/go-safecast" "github.com/shopspring/decimal" "github.com/authzed/spicedb/pkg/datastore" @@ -55,12 +56,14 @@ func parseHLCRevisionString(revisionStr string) (datastore.Revision, error) { } paddedLogicalClockStr := pieces[1] + strings.Repeat("0", logicalClockLength-len(pieces[1])) - logicalclock, err := strconv.ParseInt(paddedLogicalClockStr, 10, 32) + logicalclock, err := strconv.ParseUint(paddedLogicalClockStr, 10, 32) if err != nil { return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr) } - return HLCRevision{timestamp, uint32(logicalclock) + logicalClockOffset}, nil + // Because we parsed with a bit size of 32 above, we know this range check should pass. + uintLogicalClock, _ := safecast.ToUint32(logicalclock) + return HLCRevision{timestamp, uintLogicalClock + logicalClockOffset}, nil } // HLCRevisionFromString parses a string into a hybrid logical clock revision. @@ -140,10 +143,6 @@ func (hlc HLCRevision) ConstructForTimestamp(timestamp int64) WithTimestampRevis return HLCRevision{timestamp, 0} } -func (hlc HLCRevision) IntegerRepresentation() (int64, uint32) { - return hlc.time, hlc.logicalclock -} - func (hlc HLCRevision) AsDecimal() (decimal.Decimal, error) { return decimal.NewFromString(hlc.String()) } diff --git a/internal/datastore/revisions/timestamprevision.go b/internal/datastore/revisions/timestamprevision.go index c916aa5c3c..d08ca22e68 100644 --- a/internal/datastore/revisions/timestamprevision.go +++ b/internal/datastore/revisions/timestamprevision.go @@ -77,10 +77,6 @@ func (ir TimestampRevision) ConstructForTimestamp(timestamp int64) WithTimestamp return TimestampRevision(timestamp) } -func (ir TimestampRevision) IntegerRepresentation() (int64, uint32) { - return int64(ir), 0 -} - var ( _ datastore.Revision = TimestampRevision(0) _ WithTimestampRevision = TimestampRevision(0) diff --git a/internal/datastore/revisions/txidrevision.go b/internal/datastore/revisions/txidrevision.go index 11de92fced..26deb64031 100644 --- a/internal/datastore/revisions/txidrevision.go +++ b/internal/datastore/revisions/txidrevision.go @@ -56,17 +56,13 @@ func (ir TransactionIDRevision) TransactionID() uint64 { } func (ir TransactionIDRevision) String() string { - return strconv.FormatInt(int64(ir), 10) + return strconv.FormatUint(uint64(ir), 10) } func (ir TransactionIDRevision) WithInexactFloat64() float64 { return float64(ir) } -func (ir TransactionIDRevision) IntegerRepresentation() (int64, uint32) { - return int64(ir), 0 -} - var _ datastore.Revision = TransactionIDRevision(0) // TransactionIDKeyFunc is used to create keys for transaction IDs. diff --git a/internal/dispatch/graph/lookupresources2_test.go b/internal/dispatch/graph/lookupresources2_test.go index 5d04642ae8..fdf123ff17 100644 --- a/internal/dispatch/graph/lookupresources2_test.go +++ b/internal/dispatch/graph/lookupresources2_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/ccoveille/go-safecast" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -701,6 +702,10 @@ func TestLookupResources2OverSchemaWithCursors(t *testing.T) { foundChunks := [][]*v1.DispatchLookupResources2Response{} for { stream := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupResources2Response](ctx) + + uintPageSize, err := safecast.ToUint32(pageSize) + require.NoError(err) + err = dispatcher.DispatchLookupResources2(&v1.DispatchLookupResources2Request{ ResourceRelation: tc.permission, SubjectRelation: RR(tc.subject.Namespace, "..."), @@ -710,7 +715,7 @@ func TestLookupResources2OverSchemaWithCursors(t *testing.T) { AtRevision: revision.String(), DepthRemaining: 50, }, - OptionalLimit: uint32(pageSize), + OptionalLimit: uintPageSize, OptionalCursor: currentCursor, }, stream) require.NoError(err) diff --git a/internal/dispatch/graph/lookupresources_test.go b/internal/dispatch/graph/lookupresources_test.go index e833c59cb3..009eccb6e9 100644 --- a/internal/dispatch/graph/lookupresources_test.go +++ b/internal/dispatch/graph/lookupresources_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/ccoveille/go-safecast" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -615,6 +616,9 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) { foundResourceIDs := mapz.NewSet[string]() for { stream := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupResourcesResponse](ctx) + uintPageSize, err := safecast.ToUint32(pageSize) + require.NoError(err) + err = dispatcher.DispatchLookupResources(&v1.DispatchLookupResourcesRequest{ ObjectRelation: tc.permission, Subject: tc.subject, @@ -622,7 +626,7 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) { AtRevision: revision.String(), DepthRemaining: 50, }, - OptionalLimit: uint32(pageSize), + OptionalLimit: uintPageSize, OptionalCursor: currentCursor, }, stream) require.NoError(err) diff --git a/internal/dispatch/graph/reachableresources_test.go b/internal/dispatch/graph/reachableresources_test.go index d8f6b18214..9ccb97b86c 100644 --- a/internal/dispatch/graph/reachableresources_test.go +++ b/internal/dispatch/graph/reachableresources_test.go @@ -10,6 +10,7 @@ import ( "sync" "testing" + "github.com/ccoveille/go-safecast" "github.com/stretchr/testify/require" "go.uber.org/goleak" "golang.org/x/sync/errgroup" @@ -1240,6 +1241,9 @@ func TestReachableResourcesOverSchema(t *testing.T) { var currentCursor *v1.Cursor for { stream := dispatch.NewCollectingDispatchStream[*v1.DispatchReachableResourcesResponse](ctx) + uintPageSize, err := safecast.ToUint32(pageSize) + require.NoError(err) + err = dispatcher.DispatchReachableResources(&v1.DispatchReachableResourcesRequest{ ResourceRelation: tc.permission, SubjectRelation: &core.RelationReference{ @@ -1252,7 +1256,7 @@ func TestReachableResourcesOverSchema(t *testing.T) { DepthRemaining: 50, }, OptionalCursor: currentCursor, - OptionalLimit: uint32(pageSize), + OptionalLimit: uintPageSize, }, stream) require.NoError(err) diff --git a/internal/graph/cursors.go b/internal/graph/cursors.go index 8b6560e28c..021f137674 100644 --- a/internal/graph/cursors.go +++ b/internal/graph/cursors.go @@ -6,6 +6,8 @@ import ( "strconv" "sync" + "github.com/ccoveille/go-safecast" + "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/taskrunner" "github.com/authzed/spicedb/pkg/datastore/options" @@ -501,7 +503,11 @@ func (ls *parallelLimitedIndexedStream[Q]) completedTaskIndex(index int) error { if ls.toPublishTaskIndex == 0 { // Remove the already emitted data from the overall limits. - if err := ls.ci.limits.markAlreadyPublished(uint32(ls.countingStream.PublishedCount())); err != nil { + publishedCount, err := safecast.ToUint32(ls.countingStream.PublishedCount()) + if err != nil { + return spiceerrors.MustBugf("cannot cast published count to uint32") + } + if err := ls.ci.limits.markAlreadyPublished(publishedCount); err != nil { return err } diff --git a/internal/middleware/usagemetrics/usagemetrics_test.go b/internal/middleware/usagemetrics/usagemetrics_test.go index e2dbec2a1e..1f8f03f07c 100644 --- a/internal/middleware/usagemetrics/usagemetrics_test.go +++ b/internal/middleware/usagemetrics/usagemetrics_test.go @@ -55,7 +55,7 @@ func (t testServer) PingList(_ *testpb.PingListRequest, server testpb.TestServic } func (t testServer) PingStream(stream testpb.TestService_PingStreamServer) error { - count := 0 + count := int32(0) for { _, err := stream.Recv() if errors.Is(err, io.EOF) { @@ -63,7 +63,7 @@ func (t testServer) PingStream(stream testpb.TestService_PingStreamServer) error } else if err != nil { return err } - _ = stream.Send(&testpb.PingStreamResponse{Value: "", Counter: int32(count)}) + _ = stream.Send(&testpb.PingStreamResponse{Value: "", Counter: count}) count++ } return nil diff --git a/internal/services/v1/experimental_test.go b/internal/services/v1/experimental_test.go index fa0abc5f8a..15292c0836 100644 --- a/internal/services/v1/experimental_test.go +++ b/internal/services/v1/experimental_test.go @@ -13,6 +13,7 @@ import ( "github.com/authzed/authzed-go/pkg/responsemeta" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/grpcutil" + "github.com/ccoveille/go-safecast" "github.com/scylladb/go-set" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -117,12 +118,12 @@ func constBatch(size int) func() int { } } -func randomBatch(minimum, max int) func() int { +func randomBatch(minimum, maximum int) func() int { return func() int { // nolint:gosec // G404 use of non cryptographically secure random number generator is not a security concern here, // as this is only used for generating fixtures in testing. - return rand.Intn(max-minimum) + minimum + return rand.Intn(maximum-minimum) + minimum } } @@ -159,8 +160,8 @@ func TestBulkExportRelationships(t *testing.T) { {tf.FolderNS.Name, "editor"}, } - totalToWrite := uint64(1_000) - expectedRels := set.NewStringSetWithSize(int(totalToWrite)) + totalToWrite := 1_000 + expectedRels := set.NewStringSetWithSize(totalToWrite) batch := make([]*v1.Relationship, totalToWrite) for i := range batch { nsAndRel := nsAndRels[i%len(nsAndRels)] @@ -291,7 +292,7 @@ func TestBulkExportRelationshipsWithFilter(t *testing.T) { }, } - batchSize := 14 + batchSize := uint32(14) for _, tc := range testCases { tc := tc @@ -362,7 +363,7 @@ func TestBulkExportRelationshipsWithFilter(t *testing.T) { stream, err := client.BulkExportRelationships(streamCtx, &v1.BulkExportRelationshipsRequest{ OptionalRelationshipFilter: tc.filter, - OptionalLimit: uint32(batchSize), + OptionalLimit: batchSize, OptionalCursor: cursor, }) require.NoError(err) @@ -374,7 +375,9 @@ func TestBulkExportRelationshipsWithFilter(t *testing.T) { } require.NoError(err) - require.LessOrEqual(uint32(len(batch.Relationships)), uint32(batchSize)) + relLength, err := safecast.ToUint32(len(batch.Relationships)) + require.NoError(err) + require.LessOrEqual(relLength, batchSize) require.NotNil(batch.AfterResultCursor) require.NotEmpty(batch.AfterResultCursor.Token) diff --git a/internal/services/v1/permissions_test.go b/internal/services/v1/permissions_test.go index 770b1a428b..4bdd6a1698 100644 --- a/internal/services/v1/permissions_test.go +++ b/internal/services/v1/permissions_test.go @@ -16,6 +16,7 @@ import ( "github.com/authzed/authzed-go/pkg/responsemeta" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/grpcutil" + "github.com/ccoveille/go-safecast" "github.com/stretchr/testify/require" "go.uber.org/goleak" "google.golang.org/genproto/googleapis/rpc/errdetails" @@ -1648,6 +1649,8 @@ func TestLookupResourcesWithCursors(t *testing.T) { for i := 0; i < 5; i++ { var trailer metadata.MD + uintLimit, err := safecast.ToUint32(limit) + require.NoError(err) lookupClient, err := client.LookupResources(context.Background(), &v1.LookupResourcesRequest{ ResourceObjectType: tc.objectType, Permission: tc.permission, @@ -1657,7 +1660,7 @@ func TestLookupResourcesWithCursors(t *testing.T) { AtLeastAsFresh: zedtoken.MustNewFromRevision(revision), }, }, - OptionalLimit: uint32(limit), + OptionalLimit: uintLimit, OptionalCursor: currentCursor, }, grpc.Trailer(&trailer)) diff --git a/internal/services/v1/relationships.go b/internal/services/v1/relationships.go index 6f6e308715..95e21a0294 100644 --- a/internal/services/v1/relationships.go +++ b/internal/services/v1/relationships.go @@ -423,13 +423,13 @@ func (ps *permissionServer) DeleteRelationships(ctx context.Context, req *v1.Del } defer iter.Close() - counter := 0 + counter := uint64(0) for tpl := iter.Next(); tpl != nil; tpl = iter.Next() { if iter.Err() != nil { return ps.rewriteError(ctx, err) } - if counter == int(limit) { + if counter == limit { return ps.rewriteError(ctx, NewCouldNotTransactionallyDeleteErr(req.RelationshipFilter, req.OptionalLimit)) } diff --git a/internal/services/v1/relationships_test.go b/internal/services/v1/relationships_test.go index 6d21d921f0..de0d7ba1c9 100644 --- a/internal/services/v1/relationships_test.go +++ b/internal/services/v1/relationships_test.go @@ -11,6 +11,7 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/grpcutil" + "github.com/ccoveille/go-safecast" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "google.golang.org/grpc/codes" @@ -292,6 +293,8 @@ func TestReadRelationships(t *testing.T) { testExpected[k] = struct{}{} } + uintPageSize, err := safecast.ToUint32(pageSize) + require.NoError(err) for i := 0; i < 20; i++ { stream, err := client.ReadRelationships(context.Background(), &v1.ReadRelationshipsRequest{ Consistency: &v1.Consistency{ @@ -300,7 +303,7 @@ func TestReadRelationships(t *testing.T) { }, }, RelationshipFilter: tc.filter, - OptionalLimit: uint32(pageSize), + OptionalLimit: uintPageSize, OptionalCursor: currentCursor, }) require.NoError(err) @@ -1240,6 +1243,8 @@ func TestDeleteRelationshipsBeyondLimitPartial(t *testing.T) { t.Cleanup(cleanup) iterations := 0 + uintBatchSize, err := safecast.ToUint32(batchSize) + require.NoError(err) for i := 0; i < 10; i++ { iterations++ @@ -1252,7 +1257,7 @@ func TestDeleteRelationshipsBeyondLimitPartial(t *testing.T) { RelationshipFilter: &v1.RelationshipFilter{ ResourceType: "document", }, - OptionalLimit: uint32(batchSize), + OptionalLimit: uintBatchSize, OptionalAllowPartialDeletions: true, }) require.NoError(err) diff --git a/pkg/cache/cache_otter.go b/pkg/cache/cache_otter.go index 14f83c4063..dda52b4388 100644 --- a/pkg/cache/cache_otter.go +++ b/pkg/cache/cache_otter.go @@ -4,8 +4,10 @@ package cache import ( + "math" "sync" + "github.com/ccoveille/go-safecast" "github.com/maypok86/otter" "github.com/rs/zerolog" ) @@ -66,7 +68,13 @@ func (wtc *otterCache[K, V]) Get(key K) (V, bool) { func (wtc *otterCache[K, V]) Set(key K, value V, cost int64) bool { keyString := key.KeyString() - return wtc.cache.Set(keyString, valueAndCost[V]{value, uint32(cost)}) + uintCost, err := safecast.ToUint32(cost) + if err != nil { + // We make an assumption that if the cast fails, it's because the value + // was too big, so we set to maxint in that case. + uintCost = math.MaxUint32 + } + return wtc.cache.Set(keyString, valueAndCost[V]{value, uintCost}) } func (wtc *otterCache[K, V]) Wait() { diff --git a/pkg/cmd/datastore/datastore.go b/pkg/cmd/datastore/datastore.go index df991f938d..6e9482db1b 100644 --- a/pkg/cmd/datastore/datastore.go +++ b/pkg/cmd/datastore/datastore.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/ccoveille/go-safecast" "github.com/spf13/pflag" "github.com/authzed/spicedb/internal/datastore/crdb" @@ -469,6 +470,11 @@ func newCRDBDatastore(ctx context.Context, opts Config) (datastore.Datastore, er return nil, errors.New("read replicas are not supported for the CockroachDB datastore engine") } + maxRetries, err := safecast.ToUint8(opts.MaxRetries) + if err != nil { + return nil, errors.New("max-retries could not be cast to uint8") + } + return crdb.NewCRDBDatastore( ctx, opts.URI, @@ -488,7 +494,7 @@ func newCRDBDatastore(ctx context.Context, opts Config) (datastore.Datastore, er crdb.WriteConnMaxLifetimeJitter(opts.WriteConnPool.MaxLifetimeJitter), crdb.WriteConnHealthCheckInterval(opts.WriteConnPool.HealthCheckInterval), crdb.FollowerReadDelay(opts.FollowerReadDelay), - crdb.MaxRetries(uint8(opts.MaxRetries)), + crdb.MaxRetries(maxRetries), crdb.OverlapKey(opts.OverlapKey), crdb.OverlapStrategy(opts.OverlapStrategy), crdb.WatchBufferLength(opts.WatchBufferLength), @@ -514,7 +520,11 @@ func newPostgresDatastore(ctx context.Context, opts Config) (datastore.Datastore replicas := make([]datastore.StrictReadDatastore, 0, len(opts.ReadReplicaURIs)) for index, replicaURI := range opts.ReadReplicaURIs { - replica, err := newPostgresReplicaDatastore(ctx, uint32(index), replicaURI, opts) + uintIndex, err := safecast.ToUint32(index) + if err != nil { + return nil, errors.New("too many replicas") + } + replica, err := newPostgresReplicaDatastore(ctx, uintIndex, replicaURI, opts) if err != nil { return nil, err } @@ -524,13 +534,18 @@ func newPostgresDatastore(ctx context.Context, opts Config) (datastore.Datastore return proxy.NewStrictReplicatedDatastore(primary, replicas...) } -func commonPostgresDatastoreOptions(opts Config) []postgres.Option { +func commonPostgresDatastoreOptions(opts Config) ([]postgres.Option, error) { + maxRetries, err := safecast.ToUint8(opts.MaxRetries) + if err != nil { + return nil, errors.New("max-retries could not be cast to uint8") + } + return []postgres.Option{ postgres.EnableTracing(), postgres.WithEnablePrometheusStats(opts.EnableDatastoreMetrics), - postgres.MaxRetries(uint8(opts.MaxRetries)), + postgres.MaxRetries(maxRetries), postgres.FilterMaximumIDCount(opts.FilterMaximumIDCount), - } + }, nil } func newPostgresReplicaDatastore(ctx context.Context, replicaIndex uint32, replicaURI string, opts Config) (datastore.StrictReadDatastore, error) { @@ -545,7 +560,11 @@ func newPostgresReplicaDatastore(ctx context.Context, replicaIndex uint32, repli postgres.ReadStrictMode( /* strict read mode is required for Postgres read replicas */ true), } - pgOpts = append(pgOpts, commonPostgresDatastoreOptions(opts)...) + commonOptions, err := commonPostgresDatastoreOptions(opts) + if err != nil { + return nil, err + } + pgOpts = append(pgOpts, commonOptions...) return postgres.NewReadOnlyPostgresDatastore(ctx, replicaURI, replicaIndex, pgOpts...) } @@ -575,7 +594,11 @@ func newPostgresPrimaryDatastore(ctx context.Context, opts Config) (datastore.Da postgres.MigrationPhase(opts.MigrationPhase), } - pgOpts = append(pgOpts, commonPostgresDatastoreOptions(opts)...) + commonOptions, err := commonPostgresDatastoreOptions(opts) + if err != nil { + return nil, err + } + pgOpts = append(pgOpts, commonOptions...) return postgres.NewPostgresDatastore(ctx, opts.URI, pgOpts...) } @@ -617,7 +640,11 @@ func newMySQLDatastore(ctx context.Context, opts Config) (datastore.Datastore, e replicas := make([]datastore.ReadOnlyDatastore, 0, len(opts.ReadReplicaURIs)) for index, replicaURI := range opts.ReadReplicaURIs { - replica, err := newMySQLReplicaDatastore(ctx, uint32(index), replicaURI, opts) + uintIndex, err := safecast.ToUint32(index) + if err != nil { + return nil, errors.New("too many replicas") + } + replica, err := newMySQLReplicaDatastore(ctx, uintIndex, replicaURI, opts) if err != nil { return nil, err } @@ -627,16 +654,21 @@ func newMySQLDatastore(ctx context.Context, opts Config) (datastore.Datastore, e return proxy.NewCheckingReplicatedDatastore(primary, replicas...) } -func commonMySQLDatastoreOptions(opts Config) []mysql.Option { +func commonMySQLDatastoreOptions(opts Config) ([]mysql.Option, error) { + maxRetries, err := safecast.ToUint8(opts.MaxRetries) + if err != nil { + return nil, errors.New("max-retries could not be cast to uint8") + } + return []mysql.Option{ mysql.TablePrefix(opts.TablePrefix), - mysql.MaxRetries(uint8(opts.MaxRetries)), + mysql.MaxRetries(maxRetries), mysql.OverrideLockWaitTimeout(1), mysql.WithEnablePrometheusStats(opts.EnableDatastoreMetrics), mysql.MaxRevisionStalenessPercent(opts.MaxRevisionStalenessPercent), mysql.RevisionQuantization(opts.RevisionQuantization), mysql.FilterMaximumIDCount(opts.FilterMaximumIDCount), - } + }, nil } func newMySQLReplicaDatastore(ctx context.Context, replicaIndex uint32, replicaURI string, opts Config) (datastore.ReadOnlyDatastore, error) { @@ -649,7 +681,11 @@ func newMySQLReplicaDatastore(ctx context.Context, replicaIndex uint32, replicaU mysql.CredentialsProviderName(opts.ReadReplicaCredentialsProviderName), } - mysqlOpts = append(mysqlOpts, commonMySQLDatastoreOptions(opts)...) + commonOptions, err := commonMySQLDatastoreOptions(opts) + if err != nil { + return nil, err + } + mysqlOpts = append(mysqlOpts, commonOptions...) return mysql.NewReadOnlyMySQLDatastore(ctx, replicaURI, replicaIndex, mysqlOpts...) } @@ -668,7 +704,11 @@ func newMySQLPrimaryDatastore(ctx context.Context, opts Config) (datastore.Datas mysql.CredentialsProviderName(opts.CredentialsProviderName), } - mysqlOpts = append(mysqlOpts, commonMySQLDatastoreOptions(opts)...) + commonOptions, err := commonMySQLDatastoreOptions(opts) + if err != nil { + return nil, err + } + mysqlOpts = append(mysqlOpts, commonOptions...) return mysql.NewMySQLDatastore(ctx, opts.URI, mysqlOpts...) } diff --git a/pkg/cmd/server/cacheconfig.go b/pkg/cmd/server/cacheconfig.go index 036f34015a..9795c354ed 100644 --- a/pkg/cmd/server/cacheconfig.go +++ b/pkg/cmd/server/cacheconfig.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/ccoveille/go-safecast" "github.com/dustin/go-humanize" "github.com/jzelinskie/stringz" "github.com/pbnjay/memory" @@ -74,18 +75,23 @@ func CompleteCache[K cache.KeyString, V any](cc *CacheConfig) (cache.Cache[K, V] return nil, fmt.Errorf("error parsing cache max memory: `%s`: %w", cc.MaxCost, err) } + intMaxCost, err := safecast.ToInt64(maxCost) + if err != nil { + return nil, fmt.Errorf("could not cast max cost to int64") + } + if cc.CacheKindForTesting != "" { switch cc.CacheKindForTesting { case "theine": return cache.NewTheineCache[K, V](&cache.Config{ - MaxCost: int64(maxCost), + MaxCost: intMaxCost, NumCounters: cc.NumCounters, DefaultTTL: cc.defaultTTL, }) case "otter": return cache.NewOtterCache[K, V](&cache.Config{ - MaxCost: int64(maxCost), + MaxCost: intMaxCost, NumCounters: cc.NumCounters, DefaultTTL: cc.defaultTTL, }) @@ -97,14 +103,14 @@ func CompleteCache[K cache.KeyString, V any](cc *CacheConfig) (cache.Cache[K, V] if cc.Metrics { return cache.NewStandardCacheWithMetrics[K, V](cc.Name, &cache.Config{ - MaxCost: int64(maxCost), + MaxCost: intMaxCost, NumCounters: cc.NumCounters, DefaultTTL: cc.defaultTTL, }) } return cache.NewStandardCache[K, V](&cache.Config{ - MaxCost: int64(maxCost), + MaxCost: intMaxCost, NumCounters: cc.NumCounters, DefaultTTL: cc.defaultTTL, }) diff --git a/pkg/datastore/pagination/iterator_test.go b/pkg/datastore/pagination/iterator_test.go index 2698589f41..4cde4d61b8 100644 --- a/pkg/datastore/pagination/iterator_test.go +++ b/pkg/datastore/pagination/iterator_test.go @@ -146,16 +146,16 @@ func TestPaginatedIterator(t *testing.T) { require := require.New(t) tpls := make([]*core.RelationTuple, 0, tc.totalRelationships) - for i := 0; i < int(tc.totalRelationships); i++ { + for i := uint64(0); i < tc.totalRelationships; i++ { tpls = append(tpls, &core.RelationTuple{ ResourceAndRelation: &core.ObjectAndRelation{ Namespace: "document", - ObjectId: strconv.Itoa(i), + ObjectId: strconv.FormatUint(i, 10), Relation: "owner", }, Subject: &core.ObjectAndRelation{ Namespace: "user", - ObjectId: strconv.Itoa(i), + ObjectId: strconv.FormatUint(i, 10), Relation: datastore.Ellipsis, }, }) diff --git a/pkg/development/assertions.go b/pkg/development/assertions.go index 5c600913f0..9a4ba35589 100644 --- a/pkg/development/assertions.go +++ b/pkg/development/assertions.go @@ -4,9 +4,11 @@ import ( "fmt" v1t "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/ccoveille/go-safecast" devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/spicedb/pkg/validationfile/blocks" ) @@ -42,14 +44,23 @@ func runAssertions(devContext *DevContext, assertions []blocks.Assertion, expect for _, assertion := range assertions { tpl := tuple.MustFromRelationship[*v1t.ObjectReference, *v1t.SubjectReference, *v1t.ContextualizedCaveat](assertion.Relationship) + lineNumber, err := safecast.ToUint32(assertion.SourcePosition.LineNumber) + if err != nil { + return nil, spiceerrors.MustBugf("Line number could not be cast to uint32") + } + columnPosition, err := safecast.ToUint32(assertion.SourcePosition.ColumnPosition) + if err != nil { + return nil, spiceerrors.MustBugf("Column position could not be cast to uint32") + } + if tpl.Caveat != nil { failures = append(failures, &devinterface.DeveloperError{ Message: fmt.Sprintf("cannot specify a caveat on an assertion: `%s`", assertion.RelationshipWithContextString), Source: devinterface.DeveloperError_ASSERTION, Kind: devinterface.DeveloperError_UNKNOWN_RELATION, Context: assertion.RelationshipWithContextString, - Line: uint32(assertion.SourcePosition.LineNumber), - Column: uint32(assertion.SourcePosition.ColumnPosition), + Line: lineNumber, + Column: columnPosition, }) continue } @@ -60,8 +71,8 @@ func runAssertions(devContext *DevContext, assertions []blocks.Assertion, expect devContext, err, devinterface.DeveloperError_ASSERTION, - uint32(assertion.SourcePosition.LineNumber), - uint32(assertion.SourcePosition.ColumnPosition), + lineNumber, + columnPosition, assertion.RelationshipWithContextString, ) if wireErr != nil { @@ -76,8 +87,8 @@ func runAssertions(devContext *DevContext, assertions []blocks.Assertion, expect Source: devinterface.DeveloperError_ASSERTION, Kind: devinterface.DeveloperError_ASSERTION_FAILED, Context: assertion.RelationshipWithContextString, - Line: uint32(assertion.SourcePosition.LineNumber), - Column: uint32(assertion.SourcePosition.ColumnPosition), + Line: lineNumber, + Column: columnPosition, CheckDebugInformation: cr.DispatchDebugInfo, CheckResolvedDebugInformation: cr.V1DebugInfo, }) diff --git a/pkg/development/devcontext.go b/pkg/development/devcontext.go index 1d4eff68c5..c8402c0984 100644 --- a/pkg/development/devcontext.go +++ b/pkg/development/devcontext.go @@ -7,6 +7,7 @@ import ( "time" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/ccoveille/go-safecast" humanize "github.com/dustin/go-humanize" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc" @@ -244,13 +245,21 @@ func loadCompiled( errWithSource, ok := spiceerrors.AsErrorWithSource(cverr) if ok { + lineNumber, err := safecast.ToUint32(errWithSource.LineNumber) + if err != nil { + return nil, spiceerrors.MustBugf("Could not cast line number to uint32") + } + columnPosition, err := safecast.ToUint32(errWithSource.ColumnPosition) + if err != nil { + return nil, spiceerrors.MustBugf("Could not cast column position to uint32") + } errors = append(errors, &devinterface.DeveloperError{ Message: cverr.Error(), Kind: devinterface.DeveloperError_SCHEMA_ISSUE, Source: devinterface.DeveloperError_SCHEMA, Context: errWithSource.SourceCodeString, - Line: uint32(errWithSource.LineNumber), - Column: uint32(errWithSource.ColumnPosition), + Line: lineNumber, + Column: columnPosition, }) } else { errors = append(errors, &devinterface.DeveloperError{ @@ -266,14 +275,22 @@ func loadCompiled( ts, terr := typesystem.NewNamespaceTypeSystem(nsDef, resolver) if terr != nil { errWithSource, ok := spiceerrors.AsErrorWithSource(terr) + lineNumber, err := safecast.ToUint32(errWithSource.LineNumber) + if err != nil { + return nil, spiceerrors.MustBugf("could not cast line number to uint32") + } + columnPosition, err := safecast.ToUint32(errWithSource.ColumnPosition) + if err != nil { + return nil, spiceerrors.MustBugf("could not cast column position to uint32") + } if ok { errors = append(errors, &devinterface.DeveloperError{ Message: terr.Error(), Kind: devinterface.DeveloperError_SCHEMA_ISSUE, Source: devinterface.DeveloperError_SCHEMA, Context: errWithSource.SourceCodeString, - Line: uint32(errWithSource.LineNumber), - Column: uint32(errWithSource.ColumnPosition), + Line: lineNumber, + Column: columnPosition, }) continue } @@ -297,13 +314,21 @@ func loadCompiled( errWithSource, ok := spiceerrors.AsErrorWithSource(tverr) if ok { + lineNumber, err := safecast.ToUint32(errWithSource.LineNumber) + if err != nil { + return nil, spiceerrors.MustBugf("could not cast line number to uint32") + } + columnPosition, err := safecast.ToUint32(errWithSource.ColumnPosition) + if err != nil { + return nil, spiceerrors.MustBugf("could not cast column position to uint32") + } errors = append(errors, &devinterface.DeveloperError{ Message: tverr.Error(), Kind: devinterface.DeveloperError_SCHEMA_ISSUE, Source: devinterface.DeveloperError_SCHEMA, Context: errWithSource.SourceCodeString, - Line: uint32(errWithSource.LineNumber), - Column: uint32(errWithSource.ColumnPosition), + Line: lineNumber, + Column: columnPosition, }) } else { errors = append(errors, &devinterface.DeveloperError{ diff --git a/pkg/development/parsing.go b/pkg/development/parsing.go index 239c0a57d5..9774c31f7c 100644 --- a/pkg/development/parsing.go +++ b/pkg/development/parsing.go @@ -1,6 +1,8 @@ package development import ( + "github.com/ccoveille/go-safecast" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/validationfile" @@ -46,12 +48,15 @@ func convertError(source devinterface.DeveloperError_Source, err error) *devinte } func convertSourceError(source devinterface.DeveloperError_Source, err *spiceerrors.ErrorWithSource) *devinterface.DeveloperError { + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, _ := safecast.ToUint32(err.LineNumber) + columnPosition, _ := safecast.ToUint32(err.ColumnPosition) return &devinterface.DeveloperError{ Message: err.Error(), Kind: devinterface.DeveloperError_PARSE_ERROR, Source: source, - Line: uint32(err.LineNumber), - Column: uint32(err.ColumnPosition), + Line: lineNumber, + Column: columnPosition, Context: err.SourceCodeString, } } diff --git a/pkg/development/resolver.go b/pkg/development/resolver.go index e61eb74b63..7682d05507 100644 --- a/pkg/development/resolver.go +++ b/pkg/development/resolver.go @@ -4,6 +4,8 @@ import ( "fmt" "strings" + "github.com/ccoveille/go-safecast" + "github.com/authzed/spicedb/pkg/caveats" "github.com/authzed/spicedb/pkg/namespace" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -11,6 +13,7 @@ import ( "github.com/authzed/spicedb/pkg/schemadsl/dslshape" "github.com/authzed/spicedb/pkg/schemadsl/generator" "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/typesystem" ) @@ -91,9 +94,17 @@ func (r *Resolver) ReferenceAtPosition(source input.Source, position input.Posit } relationReference := func(relation *core.Relation, ts *typesystem.TypeSystem) (*SchemaReference, error) { + lineNumber, err := safecast.ToInt(relation.SourcePosition.ZeroIndexedLineNumber) + if err != nil { + return nil, spiceerrors.MustBugf("could not cast line number to int") + } + columnPosition, err := safecast.ToInt(relation.SourcePosition.ZeroIndexedColumnPosition) + if err != nil { + return nil, spiceerrors.MustBugf("could not cast column positiion to int") + } relationPosition := input.Position{ - LineNumber: int(relation.SourcePosition.ZeroIndexedLineNumber), - ColumnPosition: int(relation.SourcePosition.ZeroIndexedColumnPosition), + LineNumber: lineNumber, + ColumnPosition: columnPosition, } targetSourceCode, err := generator.GenerateRelationSource(relation) @@ -139,9 +150,19 @@ func (r *Resolver) ReferenceAtPosition(source input.Source, position input.Posit } def := ts.Namespace() + + lineNumber, err := safecast.ToInt(def.SourcePosition.ZeroIndexedLineNumber) + if err != nil { + return nil, spiceerrors.MustBugf("Could not cast line number to int") + } + columnPosition, err := safecast.ToInt(def.SourcePosition.ZeroIndexedColumnPosition) + if err != nil { + return nil, spiceerrors.MustBugf("Could not cast column position to int") + } + defPosition := input.Position{ - LineNumber: int(def.SourcePosition.ZeroIndexedLineNumber), - ColumnPosition: int(def.SourcePosition.ZeroIndexedColumnPosition), + LineNumber: lineNumber, + ColumnPosition: columnPosition, } docComment := "" @@ -172,9 +193,18 @@ func (r *Resolver) ReferenceAtPosition(source input.Source, position input.Posit // Caveat Type reference. if caveatDef, ok := r.caveatTypeReferenceChain(nodeChain); ok { + lineNumber, err := safecast.ToInt(caveatDef.SourcePosition.ZeroIndexedLineNumber) + if err != nil { + return nil, spiceerrors.MustBugf("Could not cast line number to int") + } + columnPosition, err := safecast.ToInt(caveatDef.SourcePosition.ZeroIndexedColumnPosition) + if err != nil { + return nil, spiceerrors.MustBugf("Could not cast column position to int") + } + defPosition := input.Position{ - LineNumber: int(caveatDef.SourcePosition.ZeroIndexedLineNumber), - ColumnPosition: int(caveatDef.SourcePosition.ZeroIndexedColumnPosition), + LineNumber: lineNumber, + ColumnPosition: columnPosition, } var caveatSourceCode strings.Builder diff --git a/pkg/development/schema.go b/pkg/development/schema.go index 4a99947333..a3f4b1c8e1 100644 --- a/pkg/development/schema.go +++ b/pkg/development/schema.go @@ -3,6 +3,8 @@ package development import ( "errors" + "github.com/ccoveille/go-safecast" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" "github.com/authzed/spicedb/pkg/schemadsl/compiler" "github.com/authzed/spicedb/pkg/schemadsl/input" @@ -24,12 +26,15 @@ func CompileSchema(schema string) (*compiler.CompiledSchema, *devinterface.Devel return nil, nil, lerr } + // NOTE: zeroes are fine here on failure. + uintLine, _ := safecast.ToUint32(line) + uintColumn, _ := safecast.ToUint32(col) return nil, &devinterface.DeveloperError{ Message: contextError.BaseCompilerError.BaseMessage, Kind: devinterface.DeveloperError_SCHEMA_ISSUE, Source: devinterface.DeveloperError_SCHEMA, - Line: uint32(line) + 1, // 0-indexed in parser. - Column: uint32(col) + 1, // 0-indexed in parser. + Line: uintLine + 1, // 0-indexed in parser. + Column: uintColumn + 1, // 0-indexed in parser. Context: contextError.ErrorSourceCode, }, nil } diff --git a/pkg/development/validation.go b/pkg/development/validation.go index 6eaa1641a9..0538e1f947 100644 --- a/pkg/development/validation.go +++ b/pkg/development/validation.go @@ -5,6 +5,7 @@ import ( "sort" "strings" + "github.com/ccoveille/go-safecast" "github.com/google/go-cmp/cmp" yaml "gopkg.in/yaml.v2" @@ -90,14 +91,17 @@ func validateSubjects(onrKey blocks.ObjectRelation, fs developmentmembership.Fou encounteredSubjects := map[string]struct{}{} for _, expectedSubject := range expectedSubjects { subjectWithExceptions := expectedSubject.SubjectWithExceptions + // NOTE: zeroes are fine here on failure. + lineNumber, _ := safecast.ToUint32(expectedSubject.SourcePosition.LineNumber) + columnPosition, _ := safecast.ToUint32(expectedSubject.SourcePosition.ColumnPosition) if subjectWithExceptions == nil { failures = append(failures, &devinterface.DeveloperError{ Message: fmt.Sprintf("For object and permission/relation `%s`, no expected subject specified in `%s`", tuple.StringONR(onr), expectedSubject.ValidationString), Source: devinterface.DeveloperError_VALIDATION_YAML, Kind: devinterface.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, Context: string(expectedSubject.ValidationString), - Line: uint32(expectedSubject.SourcePosition.LineNumber), - Column: uint32(expectedSubject.SourcePosition.ColumnPosition), + Line: lineNumber, + Column: columnPosition, }) continue } @@ -111,8 +115,8 @@ func validateSubjects(onrKey blocks.ObjectRelation, fs developmentmembership.Fou Source: devinterface.DeveloperError_VALIDATION_YAML, Kind: devinterface.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, Context: string(expectedSubject.ValidationString), - Line: uint32(expectedSubject.SourcePosition.LineNumber), - Column: uint32(expectedSubject.SourcePosition.ColumnPosition), + Line: lineNumber, + Column: columnPosition, }) continue } @@ -133,8 +137,8 @@ func validateSubjects(onrKey blocks.ObjectRelation, fs developmentmembership.Fou Source: devinterface.DeveloperError_VALIDATION_YAML, Kind: devinterface.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, Context: string(expectedSubject.ValidationString), - Line: uint32(expectedSubject.SourcePosition.LineNumber), - Column: uint32(expectedSubject.SourcePosition.ColumnPosition), + Line: lineNumber, + Column: columnPosition, }) } @@ -159,8 +163,8 @@ func validateSubjects(onrKey blocks.ObjectRelation, fs developmentmembership.Fou Source: devinterface.DeveloperError_VALIDATION_YAML, Kind: devinterface.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, Context: string(expectedSubject.ValidationString), - Line: uint32(expectedSubject.SourcePosition.LineNumber), - Column: uint32(expectedSubject.SourcePosition.ColumnPosition), + Line: lineNumber, + Column: columnPosition, }) } } else { @@ -172,8 +176,8 @@ func validateSubjects(onrKey blocks.ObjectRelation, fs developmentmembership.Fou Source: devinterface.DeveloperError_VALIDATION_YAML, Kind: devinterface.DeveloperError_EXTRA_RELATIONSHIP_FOUND, Context: string(expectedSubject.ValidationString), - Line: uint32(expectedSubject.SourcePosition.LineNumber), - Column: uint32(expectedSubject.SourcePosition.ColumnPosition), + Line: lineNumber, + Column: columnPosition, }) } } @@ -187,8 +191,8 @@ func validateSubjects(onrKey blocks.ObjectRelation, fs developmentmembership.Fou Source: devinterface.DeveloperError_VALIDATION_YAML, Kind: devinterface.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, Context: string(expectedSubject.ValidationString), - Line: uint32(expectedSubject.SourcePosition.LineNumber), - Column: uint32(expectedSubject.SourcePosition.ColumnPosition), + Line: lineNumber, + Column: columnPosition, }) } } @@ -197,6 +201,8 @@ func validateSubjects(onrKey blocks.ObjectRelation, fs developmentmembership.Fou for _, foundSubject := range fs.ListFound() { _, ok := encounteredSubjects[tuple.StringONR(foundSubject.Subject())] if !ok { + onrLineNumber, _ := safecast.ToUint32(onrKey.SourcePosition.LineNumber) + onrColumnPosition, _ := safecast.ToUint32(onrKey.SourcePosition.ColumnPosition) failures = append(failures, &devinterface.DeveloperError{ Message: fmt.Sprintf("For object and permission/relation `%s`, subject `%s` found but missing from specified", tuple.StringONR(onr), @@ -205,8 +211,8 @@ func validateSubjects(onrKey blocks.ObjectRelation, fs developmentmembership.Fou Source: devinterface.DeveloperError_VALIDATION_YAML, Kind: devinterface.DeveloperError_EXTRA_RELATIONSHIP_FOUND, Context: tuple.StringONR(onr), - Line: uint32(onrKey.SourcePosition.LineNumber), - Column: uint32(onrKey.SourcePosition.ColumnPosition), + Line: onrLineNumber, + Column: onrColumnPosition, }) } } diff --git a/pkg/development/warnings.go b/pkg/development/warnings.go index 4fe5f606f7..91e1a404f3 100644 --- a/pkg/development/warnings.go +++ b/pkg/development/warnings.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "github.com/ccoveille/go-safecast" + "github.com/authzed/spicedb/pkg/namespace" corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" @@ -37,13 +39,14 @@ func warningForPosition(warningName string, message string, sourceCode string, s } } - lineNumber := sourcePosition.ZeroIndexedLineNumber + 1 - columnNumber := sourcePosition.ZeroIndexedColumnPosition + 1 + // NOTE: zeroes on failure are fine here. + lineNumber, _ := safecast.ToUint32(sourcePosition.ZeroIndexedLineNumber) + columnNumber, _ := safecast.ToUint32(sourcePosition.ZeroIndexedColumnPosition) return &devinterface.DeveloperWarning{ Message: message + " (" + warningName + ")", - Line: uint32(lineNumber), - Column: uint32(columnNumber), + Line: lineNumber + 1, + Column: columnNumber + 1, SourceCode: sourceCode, } } diff --git a/pkg/development/wasm/request.go b/pkg/development/wasm/request.go index b507fa838e..734c96d9d7 100644 --- a/pkg/development/wasm/request.go +++ b/pkg/development/wasm/request.go @@ -10,8 +10,9 @@ import ( "github.com/authzed/spicedb/pkg/development" - devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" "google.golang.org/protobuf/encoding/protojson" + + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" ) // runDeveloperRequest is the function exported into the WASM environment for invoking diff --git a/pkg/genutil/ensure.go b/pkg/genutil/ensure.go index 580d17fc02..56467bae08 100644 --- a/pkg/genutil/ensure.go +++ b/pkg/genutil/ensure.go @@ -1,6 +1,10 @@ package genutil -import "github.com/authzed/spicedb/pkg/spiceerrors" +import ( + "github.com/ccoveille/go-safecast" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) // MustEnsureUInt32 is a helper function that calls EnsureUInt32 and panics on error. func MustEnsureUInt32(value int) uint32 { @@ -13,16 +17,18 @@ func MustEnsureUInt32(value int) uint32 { // EnsureUInt32 ensures that the specified value can be represented as a uint32. func EnsureUInt32(value int) (uint32, error) { - if value > int(^uint32(0)) { - return 0, spiceerrors.MustBugf("specified value is too large to fit in a uint32") + uint32Value, err := safecast.ToUint32(value) + if err != nil { + return 0, spiceerrors.MustBugf("specified value could not be cast to a uint32") } - return uint32(value), nil + return uint32Value, nil } // EnsureUInt8 ensures that the specified value can be represented as a uint8. func EnsureUInt8(value int) (uint8, error) { - if value > int(^uint8(0)) { - return 0, spiceerrors.MustBugf("specified value is too large to fit in a uint8") + uint8Value, err := safecast.ToUint8(value) + if err != nil { + return 0, spiceerrors.MustBugf("specified value could not be cast to a uint8") } - return uint8(value), nil + return uint8Value, nil }