Skip to content

Commit

Permalink
Fix unsafe conversion errors
Browse files Browse the repository at this point in the history
  • Loading branch information
tstirrat15 committed Sep 11, 2024
1 parent 9bf5f3d commit 5731e26
Show file tree
Hide file tree
Showing 39 changed files with 378 additions and 145 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 // indirect
github.com/aws/smithy-go v1.20.4 // indirect
github.com/bombsimon/wsl/v4 v4.4.1 // indirect
github.com/ccoveille/go-safecast v1.1.0 // indirect
github.com/cilium/ebpf v0.9.1 // indirect
github.com/containerd/cgroups/v3 v3.0.1 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,8 @@ github.com/catenacyber/perfsprint v0.7.1 h1:PGW5G/Kxn+YrN04cRAZKC+ZuvlVwolYMrIyy
github.com/catenacyber/perfsprint v0.7.1/go.mod h1:/wclWYompEyjUD2FuIIDVKNkqz7IgBIWXIH3V0Zol50=
github.com/ccojocar/zxcvbn-go v1.0.2 h1:na/czXU8RrhXO4EZme6eQJLR4PzcGsahsBOAwU6I3Vg=
github.com/ccojocar/zxcvbn-go v1.0.2/go.mod h1:g1qkXtUSvHP8lhHp5GrSmTz6uWALGRMQdw6Qnz/hi60=
github.com/ccoveille/go-safecast v1.1.0 h1:iHKNWaZm+OznO7Eh6EljXPjGfGQsSfa6/sxPlIEKO+g=
github.com/ccoveille/go-safecast v1.1.0/go.mod h1:QqwNjxQ7DAqY0C721OIO9InMk9zCwcsO7tnRuHytad8=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
Expand Down
8 changes: 7 additions & 1 deletion internal/datastore/common/changes.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (

"golang.org/x/exp/maps"

"github.com/ccoveille/go-safecast"

log "github.com/authzed/spicedb/internal/logging"
"github.com/authzed/spicedb/pkg/datastore"
core "github.com/authzed/spicedb/pkg/proto/core/v1"
Expand Down Expand Up @@ -118,7 +120,11 @@ func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error {
return spiceerrors.MustBugf("byte size underflow")
}

if ch.currentByteSize > int64(ch.maxByteSize) {
// We checked for underflow above, so the current byte size
// should fit in a uint64
currentByteSize, _ := safecast.ToUint64(ch.currentByteSize)

if currentByteSize > ch.maxByteSize {
return datastore.NewMaximumChangesSizeExceededError(ch.maxByteSize)
}

Expand Down
12 changes: 6 additions & 6 deletions internal/datastore/common/gc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (gc *fakeGC) DeleteBeforeTx(_ context.Context, rev datastore.Revision) (Del

revInt := rev.(revisions.TransactionIDRevision).TransactionID()

return gc.deleter.DeleteBeforeTx(int64(revInt))
return gc.deleter.DeleteBeforeTx(revInt)
}

func (gc *fakeGC) HasGCRun() bool {
Expand Down Expand Up @@ -101,22 +101,22 @@ func (gc *fakeGC) GetMetrics() gcMetrics {

// Allows specifying different deletion behaviors for tests
type gcDeleter interface {
DeleteBeforeTx(revision int64) (DeletionCounts, error)
DeleteBeforeTx(revision uint64) (DeletionCounts, error)
}

// Always error trying to perform a delete
type alwaysErrorDeleter struct{}

func (alwaysErrorDeleter) DeleteBeforeTx(_ int64) (DeletionCounts, error) {
func (alwaysErrorDeleter) DeleteBeforeTx(_ uint64) (DeletionCounts, error) {
return DeletionCounts{}, fmt.Errorf("delete error")
}

// Only error on specific revisions
type revisionErrorDeleter struct {
errorOnRevisions []int64
errorOnRevisions []uint64
}

func (d revisionErrorDeleter) DeleteBeforeTx(revision int64) (DeletionCounts, error) {
func (d revisionErrorDeleter) DeleteBeforeTx(revision uint64) (DeletionCounts, error) {
if slices.Contains(d.errorOnRevisions, revision) {
return DeletionCounts{}, fmt.Errorf("delete error")
}
Expand Down Expand Up @@ -178,7 +178,7 @@ func TestGCFailureBackoffReset(t *testing.T) {
// Error on revisions 1 - 5, giving the exponential
// backoff enough time to fail the test if the interval
// is not reset properly.
errorOnRevisions: []int64{1, 2, 3, 4, 5},
errorOnRevisions: []uint64{1, 2, 3, 4, 5},
})

ctx, cancel := context.WithCancel(context.Background())
Expand Down
13 changes: 9 additions & 4 deletions internal/datastore/common/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

sq "github.com/Masterminds/squirrel"
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/ccoveille/go-safecast"
"github.com/jzelinskie/stringz"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -551,12 +552,13 @@ func (tqs QueryExecutor) ExecuteQuery(
query = query.After(queryOpts.After, queryOpts.Sort)
}

limit := math.MaxInt
var limit uint64
limit = math.MaxUint64
if queryOpts.Limit != nil {
limit = int(*queryOpts.Limit)
limit = *queryOpts.Limit
}

toExecute := query.limit(uint64(limit))
toExecute := query.limit(limit)
sql, args, err := toExecute.queryBuilder.ToSql()
if err != nil {
return nil, err
Expand All @@ -567,7 +569,10 @@ func (tqs QueryExecutor) ExecuteQuery(
return nil, err
}

if len(queryTuples) > limit {
// A length shouldn't be non-negative, so we can cast without
// checking here.
lenQueryTuples, _ := safecast.ToUint64(len(queryTuples))
if lenQueryTuples > limit {
queryTuples = queryTuples[:limit]
}

Expand Down
10 changes: 8 additions & 2 deletions internal/datastore/crdb/crdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,19 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas
if err != nil {
return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
}
config.readPoolOpts.ConfigurePgx(readPoolConfig)
err = config.readPoolOpts.ConfigurePgx(readPoolConfig)
if err != nil {
return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
}

writePoolConfig, err := pgxpool.ParseConfig(url)
if err != nil {
return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
}
config.writePoolOpts.ConfigurePgx(writePoolConfig)
err = config.writePoolOpts.ConfigurePgx(writePoolConfig)
if err != nil {
return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
}

initCtx, initCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer initCancel()
Expand Down
18 changes: 15 additions & 3 deletions internal/datastore/crdb/pool/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package pool
import (
"context"
"hash/maphash"
"math"
"math/rand"
"slices"
"strconv"
"time"

"github.com/ccoveille/go-safecast"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -84,7 +86,15 @@ type nodeConnectionBalancer[P balancePoolConn[C], C balanceConn] struct {
// newNodeConnectionBalancer is generic over underlying connection types for
// testing purposes. Callers should use the exported NewNodeConnectionBalancer.
func newNodeConnectionBalancer[P balancePoolConn[C], C balanceConn](pool balanceablePool[P, C], healthTracker *NodeHealthTracker, interval time.Duration) *nodeConnectionBalancer[P, C] {
seed := int64(new(maphash.Hash).Sum64())
seed := int64(0)
for seed == 0 {
// Sum64 returns a uint64, and safecast will return 0 if it's not castable,
// which will happen about half the time (?). We just keep running it until
// we get a seed that fits in the box.
// Subtracting math.MaxInt64 should mean that we retain the entire range of
// possible values.
seed, _ = safecast.ToInt64(new(maphash.Hash).Sum64() - math.MaxInt64)
}
return &nodeConnectionBalancer[P, C]{
ticker: time.NewTicker(interval),
sem: semaphore.NewWeighted(1),
Expand Down Expand Up @@ -147,7 +157,9 @@ func (p *nodeConnectionBalancer[P, C]) mustPruneConnections(ctx context.Context)
}
}

nodeCount := uint32(p.healthTracker.HealthyNodeCount())
// It's highly unlikely that we'll ever have an overflow in
// this context, so we cast directly.
nodeCount, _ := safecast.ToUint32(p.healthTracker.HealthyNodeCount())
if nodeCount == 0 {
nodeCount = 1
}
Expand Down Expand Up @@ -203,7 +215,7 @@ func (p *nodeConnectionBalancer[P, C]) mustPruneConnections(ctx context.Context)
// it's possible for the difference in connections between nodes to differ by up to
// the number of nodes.
if p.healthTracker.HealthyNodeCount() == 0 ||
uint32(i) < p.pool.MaxConns()%uint32(p.healthTracker.HealthyNodeCount()) {
i < int(p.pool.MaxConns())%p.healthTracker.HealthyNodeCount() {
perNodeMax++
}

Expand Down
16 changes: 13 additions & 3 deletions internal/datastore/postgres/common/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"time"

"github.com/ccoveille/go-safecast"
"github.com/exaring/otelpgx"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry"
zerologadapter "github.com/jackc/pgx-zerolog"
Expand Down Expand Up @@ -300,15 +301,23 @@ type PoolOptions struct {
}

// ConfigurePgx applies PoolOptions to a pgx connection pool confiugration.
func (opts PoolOptions) ConfigurePgx(pgxConfig *pgxpool.Config) {
func (opts PoolOptions) ConfigurePgx(pgxConfig *pgxpool.Config) error {
if opts.MaxOpenConns != nil {
pgxConfig.MaxConns = int32(*opts.MaxOpenConns)
maxConns, err := safecast.ToInt32(*opts.MaxOpenConns)
if err != nil {
return err
}
pgxConfig.MaxConns = maxConns
}

// Default to keeping the pool maxed out at all times.
pgxConfig.MinConns = pgxConfig.MaxConns
if opts.MinOpenConns != nil {
pgxConfig.MinConns = int32(*opts.MinOpenConns)
minConns, err := safecast.ToInt32(*opts.MinOpenConns)
if err != nil {
return err
}
pgxConfig.MinConns = minConns
}

if pgxConfig.MaxConns > 0 && pgxConfig.MinConns > 0 && pgxConfig.MaxConns < pgxConfig.MinConns {
Expand All @@ -335,6 +344,7 @@ func (opts PoolOptions) ConfigurePgx(pgxConfig *pgxpool.Config) {

ConfigurePGXLogger(pgxConfig.ConnConfig)
ConfigureOTELTracer(pgxConfig.ConnConfig)
return nil
}

type QuerierFuncs struct {
Expand Down
10 changes: 8 additions & 2 deletions internal/datastore/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ func newPostgresDatastore(

// Setup the config for each of the read and write pools.
readPoolConfig := pgConfig.Copy()
config.readPoolOpts.ConfigurePgx(readPoolConfig)
err = config.readPoolOpts.ConfigurePgx(readPoolConfig)
if err != nil {
return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, pgURL)
}

readPoolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
RegisterTypes(conn.TypeMap())
Expand All @@ -195,7 +198,10 @@ func newPostgresDatastore(
var writePoolConfig *pgxpool.Config
if isPrimary {
writePoolConfig = pgConfig.Copy()
config.writePoolOpts.ConfigurePgx(writePoolConfig)
err = config.writePoolOpts.ConfigurePgx(writePoolConfig)
if err != nil {
return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, pgURL)
}

writePoolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
RegisterTypes(conn.TypeMap())
Expand Down
17 changes: 9 additions & 8 deletions internal/datastore/postgres/postgres_shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ import (
"time"

sq "github.com/Masterminds/squirrel"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/samber/lo"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
"golang.org/x/sync/errgroup"

"github.com/authzed/spicedb/internal/datastore/common"
pgcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
pgversion "github.com/authzed/spicedb/internal/datastore/postgres/version"
Expand All @@ -25,14 +34,6 @@ import (
"github.com/authzed/spicedb/pkg/namespace"
core "github.com/authzed/spicedb/pkg/proto/core/v1"
"github.com/authzed/spicedb/pkg/tuple"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/samber/lo"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
"golang.org/x/sync/errgroup"
)

const pgSerializationFailure = "40001"
Expand Down
10 changes: 9 additions & 1 deletion internal/datastore/postgres/readwrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"strings"

"github.com/ccoveille/go-safecast"

"github.com/authzed/spicedb/pkg/datastore/options"
"github.com/authzed/spicedb/pkg/genutil/mapz"
"github.com/authzed/spicedb/pkg/spiceerrors"
Expand Down Expand Up @@ -383,6 +385,12 @@ func (rwt *pgReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.R
}

func (rwt *pgReadWriteTXN) deleteRelationshipsWithLimit(ctx context.Context, filter *v1.RelationshipFilter, limit uint64) (bool, error) {
// validate the limit
intLimit, err := safecast.ToInt64(limit)
if err != nil {
return false, fmt.Errorf("limit argument could not safely be cast to int64")
}

// Construct a select query for the relationships to be removed.
query := selectForDelete

Expand Down Expand Up @@ -444,7 +452,7 @@ func (rwt *pgReadWriteTXN) deleteRelationshipsWithLimit(ctx context.Context, fil
return false, fmt.Errorf(errUnableToDeleteRelationships, err)
}

return result.RowsAffected() == int64(limit), nil
return result.RowsAffected() == intLimit, nil
}

func (rwt *pgReadWriteTXN) deleteRelationships(ctx context.Context, filter *v1.RelationshipFilter) error {
Expand Down
24 changes: 20 additions & 4 deletions internal/datastore/postgres/revisions.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import (
"strings"
"time"

"github.com/ccoveille/go-safecast"
"github.com/jackc/pgx/v5"

"github.com/authzed/spicedb/pkg/datastore"
implv1 "github.com/authzed/spicedb/pkg/proto/impl/v1"
"github.com/authzed/spicedb/pkg/spiceerrors"
)

const (
Expand Down Expand Up @@ -153,7 +155,10 @@ func parseRevisionProto(revisionStr string) (datastore.Revision, error) {
return datastore.NoRevision, fmt.Errorf(errRevisionFormat, err)
}

xminInt := int64(decoded.Xmin)
xminInt, err := safecast.ToInt64(decoded.Xmin)
if err != nil {
return datastore.NoRevision, spiceerrors.MustBugf("Could not cast xmin to int64")
}

var xips []uint64
if len(decoded.RelativeXips) > 0 {
Expand Down Expand Up @@ -311,15 +316,26 @@ func (pr postgresRevision) OptionalNanosTimestamp() (uint64, bool) {
// for xmax and xip list values to save bytes when encoded as varint protos.
// For example, snapshot 1001:1004:1001,1003 becomes 1000:3:0,2.
func (pr postgresRevision) MarshalBinary() ([]byte, error) {
xminInt := int64(pr.snapshot.xmin)
xminInt, err := safecast.ToInt64(pr.snapshot.xmin)
if err != nil {
return nil, spiceerrors.MustBugf("could not safely cast snapshot xip to int64")
}
relativeXips := make([]int64, len(pr.snapshot.xipList))
for i, xip := range pr.snapshot.xipList {
relativeXips[i] = int64(xip) - xminInt
intXip, err := safecast.ToInt64(xip)
if err != nil {
return nil, spiceerrors.MustBugf("could not safely cast snapshot xip to int64")
}
relativeXips[i] = intXip - xminInt
}

relativeXmax, err := safecast.ToInt64(pr.snapshot.xmax)
if err != nil {
return nil, spiceerrors.MustBugf("could not safely cast snapshot xmax to int64")
}
protoRevision := implv1.PostgresRevision{
Xmin: pr.snapshot.xmin,
RelativeXmax: int64(pr.snapshot.xmax) - xminInt,
RelativeXmax: relativeXmax - xminInt,
RelativeXips: relativeXips,
}

Expand Down
4 changes: 2 additions & 2 deletions internal/datastore/proxy/schemacaching/watchingcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (p *watchingCachingProxy) startSync(ctx context.Context) error {

// Start watching for schema changes.
go (func() {
retryCount := 0
retryCount := uint8(0)

restartWatch:
for {
Expand Down Expand Up @@ -322,7 +322,7 @@ func (p *watchingCachingProxy) startSync(ctx context.Context) error {
log.Warn().Err(err).Msg("received retryable error in schema watch; sleeping for a bit and restarting watch")
retryCount++
wg.Add(1)
pgxcommon.SleepOnErr(ctx, err, uint8(retryCount))
pgxcommon.SleepOnErr(ctx, err, retryCount)
continue restartWatch
}

Expand Down
Loading

0 comments on commit 5731e26

Please sign in to comment.