diff --git a/e2e/go.mod b/e2e/go.mod index 836bd1dd52..ee36cdeb76 100644 --- a/e2e/go.mod +++ b/e2e/go.mod @@ -51,7 +51,7 @@ require ( github.com/jzelinskie/stringz v0.0.3 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/planetscale/vtprotobuf v0.6.0 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rs/zerolog v1.33.0 // indirect github.com/samber/lo v1.46.0 // indirect diff --git a/e2e/go.sum b/e2e/go.sum index 308ffca52f..3b4fd7f88e 100644 --- a/e2e/go.sum +++ b/e2e/go.sum @@ -210,8 +210,8 @@ github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79/go.mod h1:E26fwEtRNigB github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/planetscale/vtprotobuf v0.6.0 h1:nBeETjudeJ5ZgBHUz1fVHvbqUKnYOXNhsIEabROxmNA= -github.com/planetscale/vtprotobuf v0.6.0/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca h1:ujRGEVWJEoaxQ+8+HMl8YEpGaDAgohgZxJ5S+d2TTFQ= +github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/go.mod b/go.mod index f8a1240d89..5de9fc7925 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( cloud.google.com/go/spanner v1.65.0 contrib.go.opencensus.io/exporter/prometheus v0.4.2 github.com/IBM/pgxpoolprometheus v1.1.1 - github.com/KimMachineGun/automemlimit v0.6.1 // indirect + github.com/KimMachineGun/automemlimit v0.6.1 github.com/Masterminds/squirrel v1.5.4 github.com/authzed/authzed-go v0.14.0 @@ -67,7 +67,7 @@ require ( github.com/ory/dockertest/v3 v3.10.0 github.com/outcaste-io/ristretto v0.2.3 github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 - github.com/planetscale/vtprotobuf v0.6.0 + github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_model v0.6.1 github.com/prometheus/common v0.55.0 diff --git a/go.sum b/go.sum index 0a5174be6e..7345f582e4 100644 --- a/go.sum +++ b/go.sum @@ -1246,10 +1246,6 @@ github.com/julz/importas v0.1.0 h1:F78HnrsjY3cR7j0etXy5+TU1Zuy7Xt08X/1aJnH5xXY= github.com/julz/importas v0.1.0/go.mod h1:oSFU2R4XK/P7kNBrnL/FEQlDGN1/6WoxXEjSSXO0DV0= github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= -github.com/jzelinskie/cobrautil/v2 v2.0.0-20240813173937-98b79ae0b499 h1:dXbwn1pwooxn2DAnPF3SR7tNPqC6N4VmwftMrapCYng= -github.com/jzelinskie/cobrautil/v2 v2.0.0-20240813173937-98b79ae0b499/go.mod h1:jsl6cEF6BT3UeQoSLreA7G0sZXemoI5XNqyxzWCohbE= -github.com/jzelinskie/cobrautil/v2 v2.0.0-20240816002907-ef0e64d7f25b h1:dUjc3twJXVQ7FILS1+KhHilbM7LQwIvVgH4E7h0AwTA= -github.com/jzelinskie/cobrautil/v2 v2.0.0-20240816002907-ef0e64d7f25b/go.mod h1:jsl6cEF6BT3UeQoSLreA7G0sZXemoI5XNqyxzWCohbE= github.com/jzelinskie/cobrautil/v2 v2.0.0-20240819150235-f7fe73942d0f h1:+WgAZQQXj+X8lcJdwcrQpD89Zd9ekdauOK3hWl3FkPU= github.com/jzelinskie/cobrautil/v2 v2.0.0-20240819150235-f7fe73942d0f/go.mod h1:jsl6cEF6BT3UeQoSLreA7G0sZXemoI5XNqyxzWCohbE= github.com/jzelinskie/persistent v0.0.0-20230816160542-1205ef8f0e15 h1:lFr5Krrc4LESaXK9yW15IQMZ4p2bZGw/+71Z1dV6tuk= @@ -1433,8 +1429,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= -github.com/planetscale/vtprotobuf v0.6.0 h1:nBeETjudeJ5ZgBHUz1fVHvbqUKnYOXNhsIEabROxmNA= -github.com/planetscale/vtprotobuf v0.6.0/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca h1:ujRGEVWJEoaxQ+8+HMl8YEpGaDAgohgZxJ5S+d2TTFQ= +github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/internal/datastore/crdb/readwrite.go b/internal/datastore/crdb/readwrite.go index a3fd8917e1..425d5b1811 100644 --- a/internal/datastore/crdb/readwrite.go +++ b/internal/datastore/crdb/readwrite.go @@ -10,7 +10,6 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/jackc/pgx/v5" "github.com/jzelinskie/stringz" - "google.golang.org/protobuf/proto" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/internal/datastore/revisions" @@ -378,7 +377,7 @@ func (rwt *crdbReadWriteTXN) WriteNamespaces(ctx context.Context, newConfigs ... for _, newConfig := range newConfigs { rwt.addOverlapKey(newConfig.Name) - serialized, err := proto.Marshal(newConfig) + serialized, err := newConfig.MarshalVT() if err != nil { return fmt.Errorf(errUnableToWriteConfig, err) } diff --git a/internal/datastore/memdb/readwrite.go b/internal/datastore/memdb/readwrite.go index 6cabafe21a..4c1f7a148f 100644 --- a/internal/datastore/memdb/readwrite.go +++ b/internal/datastore/memdb/readwrite.go @@ -8,7 +8,6 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/hashicorp/go-memdb" "github.com/jzelinskie/stringz" - "google.golang.org/protobuf/proto" "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/pkg/datastore" @@ -265,7 +264,7 @@ func (rwt *memdbReadWriteTx) WriteNamespaces(_ context.Context, newConfigs ...*c } for _, newConfig := range newConfigs { - serialized, err := proto.Marshal(newConfig) + serialized, err := newConfig.MarshalVT() if err != nil { return err } diff --git a/internal/datastore/mysql/readwrite.go b/internal/datastore/mysql/readwrite.go index db83ac862a..26c7f4a017 100644 --- a/internal/datastore/mysql/readwrite.go +++ b/internal/datastore/mysql/readwrite.go @@ -15,7 +15,6 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/go-sql-driver/mysql" "github.com/jzelinskie/stringz" - "google.golang.org/protobuf/proto" "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/datastore/revisions" @@ -387,7 +386,7 @@ func (rwt *mysqlReadWriteTXN) WriteNamespaces(ctx context.Context, newNamespaces writeQuery := rwt.WriteNamespaceQuery for _, newNamespace := range newNamespaces { - serialized, err := proto.Marshal(newNamespace) + serialized, err := newNamespace.MarshalVT() if err != nil { return fmt.Errorf(errUnableToWriteConfig, err) } diff --git a/internal/datastore/postgres/readwrite.go b/internal/datastore/postgres/readwrite.go index dd34218c26..c5c661d7d3 100644 --- a/internal/datastore/postgres/readwrite.go +++ b/internal/datastore/postgres/readwrite.go @@ -15,7 +15,6 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/jackc/pgx/v5" "github.com/jzelinskie/stringz" - "google.golang.org/protobuf/proto" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/pkg/datastore" @@ -496,7 +495,7 @@ func (rwt *pgReadWriteTXN) WriteNamespaces(ctx context.Context, newConfigs ...*c writeQuery := writeNamespace for _, newNamespace := range newConfigs { - serialized, err := proto.Marshal(newNamespace) + serialized, err := newNamespace.MarshalVT() if err != nil { return fmt.Errorf(errUnableToWriteConfig, err) } diff --git a/internal/datastore/spanner/readwrite.go b/internal/datastore/spanner/readwrite.go index 8a63d7b0a1..dac9b46b91 100644 --- a/internal/datastore/spanner/readwrite.go +++ b/internal/datastore/spanner/readwrite.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/jzelinskie/stringz" - "google.golang.org/protobuf/proto" "github.com/authzed/spicedb/internal/datastore/revisions" log "github.com/authzed/spicedb/internal/logging" @@ -326,7 +325,7 @@ func caveatVals(r *core.RelationTuple) []any { func (rwt spannerReadWriteTXN) WriteNamespaces(_ context.Context, newConfigs ...*core.NamespaceDefinition) error { mutations := make([]*spanner.Mutation, 0, len(newConfigs)) for _, newConfig := range newConfigs { - serialized, err := proto.Marshal(newConfig) + serialized, err := newConfig.MarshalVT() if err != nil { return fmt.Errorf(errUnableToWriteConfig, err) } diff --git a/internal/namespace/test/serialization_test.go b/internal/namespace/test/serialization_test.go new file mode 100644 index 0000000000..e22d147369 --- /dev/null +++ b/internal/namespace/test/serialization_test.go @@ -0,0 +1,142 @@ +package namespace + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/testutil" +) + +/* +NOTE: this test exists because we found a place where MarshalVT +was producing a different serialization than proto.Marshal. The +idea is that each time we regenerate our _vtproto.pb.go files, +we run this generation, and then the serialization_test will +use these snapshots to assert that nothing has changed. +*/ + +type serializationTest struct { + name string + schema string +} + +var serializationTests = []serializationTest{ + {"Basic serialization test", "basic"}, +} + +type definitionInterface interface { + protoreflect.ProtoMessage + UnmarshalVT([]byte) error +} + +func assertParity(t *testing.T, filenames []string, emptyProto definitionInterface) { + vtProtoDefinitions := make(map[string]definitionInterface) + standardProtoDefinitions := make(map[string]definitionInterface) + definitions := mapz.NewSet[string]() + for _, filename := range filenames { + definition := strings.Split(filename, ".")[0] + definitions.Add(definition) + bytes, err := os.ReadFile(fmt.Sprintf("testdata/proto/%s", filename)) + require.NoError(t, err) + + standardRepresentation := proto.Clone(emptyProto).(definitionInterface) + err = proto.Unmarshal(bytes, standardRepresentation) + require.NoError(t, err) + standardProtoDefinitions[filename] = standardRepresentation + + vtRepresentation := proto.Clone(emptyProto).(definitionInterface) + err = vtRepresentation.UnmarshalVT(bytes) + require.NoError(t, err) + vtProtoDefinitions[filename] = vtRepresentation + } + + // For each namespace, we want to assert that all of the following are equivalent: + // standard serialization -> standard deserialization + // vt serialization -> standard deserialization + // standard serialization -> vt deserialization + // This is to validate that the vt serialization/deserialization isn't doing anything unexpected + // compared to the "official" serialization/deserialization + + for _, definition := range definitions.AsSlice() { + vtFilename := fmt.Sprintf("%s.vtproto", definition) + standardFilename := fmt.Sprintf("%s.proto", definition) + + testutil.RequireProtoEqual(t, standardProtoDefinitions[vtFilename], standardProtoDefinitions[standardFilename], "vt and standard serializations disagree") + testutil.RequireProtoEqual(t, standardProtoDefinitions[standardFilename], vtProtoDefinitions[standardFilename], "vt and standard deserializations of standard proto disagree") + testutil.RequireProtoEqual(t, vtProtoDefinitions[standardFilename], vtProtoDefinitions[vtFilename], "vt and standard deserializations of vt proto disagree") + } +} + +func TestSerialization(t *testing.T) { + for _, test := range serializationTests { + test := test + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + if os.Getenv("REGEN") == "true" { + schema, err := os.ReadFile(fmt.Sprintf("testdata/schema/%s.zed", test.schema)) + require.NoError(err) + compiled, _ := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: string(schema), + }, compiler.AllowUnprefixedObjectType()) + + for _, objectDef := range compiled.ObjectDefinitions { + protoSerialized, err := proto.Marshal(objectDef) + require.NoError(err) + err = os.WriteFile(fmt.Sprintf("testdata/proto/%s-definition-%s.proto", test.schema, objectDef.Name), protoSerialized, 0o600) + require.NoError(err) + + vtSerialized, err := objectDef.MarshalVT() + require.NoError(err) + err = os.WriteFile(fmt.Sprintf("testdata/proto/%s-definition-%s.vtproto", test.schema, objectDef.Name), vtSerialized, 0o600) + require.NoError(err) + } + for _, caveatDef := range compiled.CaveatDefinitions { + protoSerialized, err := proto.Marshal(caveatDef) + require.NoError(err) + err = os.WriteFile(fmt.Sprintf("testdata/proto/%s-caveat-%s.proto", test.schema, caveatDef.Name), protoSerialized, 0o600) + require.NoError(err) + + vtSerialized, err := caveatDef.MarshalVT() + require.NoError(err) + err = os.WriteFile(fmt.Sprintf("testdata/proto/%s-caveat-%s.vtproto", test.schema, caveatDef.Name), vtSerialized, 0o600) + require.NoError(err) + } + } else { + files, err := os.ReadDir("testdata/proto") + require.NoError(err) + + definitionFiles := mapz.NewSet[string]() + caveatFiles := mapz.NewSet[string]() + for _, file := range files { + filename := file.Name() + if strings.Contains(filename, test.schema) { + // NOTE: this makes some assumptions about the names of the files, + // namely that a schema file name will not have either of these + // keywords. + if strings.Contains(filename, "definition") { + definitionFiles.Add(filename) + } + if strings.Contains(filename, "caveat") { + caveatFiles.Add(filename) + } + } + } + + assertParity(t, definitionFiles.AsSlice(), &core.NamespaceDefinition{}) + assertParity(t, caveatFiles.AsSlice(), &core.CaveatDefinition{}) + } + }) + } +} diff --git a/internal/namespace/test/testdata/proto/basic-caveat-foo.proto b/internal/namespace/test/testdata/proto/basic-caveat-foo.proto new file mode 100644 index 0000000000..0bd5dc353f Binary files /dev/null and b/internal/namespace/test/testdata/proto/basic-caveat-foo.proto differ diff --git a/internal/namespace/test/testdata/proto/basic-caveat-foo.vtproto b/internal/namespace/test/testdata/proto/basic-caveat-foo.vtproto new file mode 100644 index 0000000000..0bd5dc353f Binary files /dev/null and b/internal/namespace/test/testdata/proto/basic-caveat-foo.vtproto differ diff --git a/internal/namespace/test/testdata/proto/basic-definition-document.proto b/internal/namespace/test/testdata/proto/basic-definition-document.proto new file mode 100644 index 0000000000..5d2bf44e63 Binary files /dev/null and b/internal/namespace/test/testdata/proto/basic-definition-document.proto differ diff --git a/internal/namespace/test/testdata/proto/basic-definition-document.vtproto b/internal/namespace/test/testdata/proto/basic-definition-document.vtproto new file mode 100644 index 0000000000..5d2bf44e63 Binary files /dev/null and b/internal/namespace/test/testdata/proto/basic-definition-document.vtproto differ diff --git a/internal/namespace/test/testdata/schema/basic.zed b/internal/namespace/test/testdata/schema/basic.zed new file mode 100644 index 0000000000..b757b8557f --- /dev/null +++ b/internal/namespace/test/testdata/schema/basic.zed @@ -0,0 +1,14 @@ +caveat foo(someParam int) { + someParam == 42 +} + +definition document { + relation viewer: user | user:* + relation editor: user | group#member with foo + relation parent: organization + permission edit = editor + permission view = viewer + edit + parent->view + permission other = viewer - edit + permission intersect = viewer & edit + permission with_nil = (viewer - edit) & parent->view & nil +} diff --git a/magefiles/go.mod b/magefiles/go.mod index d1739ee145..d97c6d71fe 100644 --- a/magefiles/go.mod +++ b/magefiles/go.mod @@ -8,7 +8,7 @@ require ( github.com/ecordell/optgen v0.0.9 github.com/envoyproxy/protoc-gen-validate v1.0.4 github.com/magefile/mage v1.15.0 - github.com/planetscale/vtprotobuf v0.5.1-0.20231212170721-e7d721933795 + github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca golang.org/x/tools v0.22.0 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0 google.golang.org/protobuf v1.34.2 diff --git a/magefiles/go.sum b/magefiles/go.sum index fc58e56421..616aef80aa 100644 --- a/magefiles/go.sum +++ b/magefiles/go.sum @@ -290,6 +290,8 @@ github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDj github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/planetscale/vtprotobuf v0.5.1-0.20231212170721-e7d721933795 h1:pH+U6pJP0BhxqQ4njBUjOg0++WMMvv3eByWzB+oATBY= github.com/planetscale/vtprotobuf v0.5.1-0.20231212170721-e7d721933795/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca h1:ujRGEVWJEoaxQ+8+HMl8YEpGaDAgohgZxJ5S+d2TTFQ= +github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= diff --git a/magefiles/lint.go b/magefiles/lint.go index 27e4bb0614..e07b6215e7 100644 --- a/magefiles/lint.go +++ b/magefiles/lint.go @@ -93,6 +93,14 @@ func (Lint) Analyzers() error { "-paniccheck.skip-files=_test,zz_", "-zerologmarshalcheck", "-zerologmarshalcheck.skip-files=_test,zz_", + "-protomarshalcheck", + // Skip generated protobuf files for this check + // Also skip test where we're explicitly using proto.Marshal to assert + // that the proto.Marshal behavior matches foo.MarshalVT() + "-protomarshalcheck.skip-files=.pb,serialization_test.go", + // Skip our dispatch codec logic that explicitly calls MarshalVT with proto.Marshal as a fallback + // Skip our internal telemetry reporter which uses a prometheus proto definition that we don't control + "-protomarshalcheck.skip-pkg=github.com/authzed/spicedb/pkg/proto/dispatch/v1,github.com/authzed/spicedb/internal/telemetry", "github.com/authzed/spicedb/...", ) } diff --git a/pkg/proto/core/v1/core_vtproto.pb.go b/pkg/proto/core/v1/core_vtproto.pb.go index cb4c57d1d6..e977f09598 100644 --- a/pkg/proto/core/v1/core_vtproto.pb.go +++ b/pkg/proto/core/v1/core_vtproto.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-vtproto. DO NOT EDIT. -// protoc-gen-go-vtproto version: v0.5.1-0.20231212170721-e7d721933795 +// protoc-gen-go-vtproto version: v0.6.1-0.20240409071808-615f978279ca // source: core/v1/core.proto package corev1 @@ -2942,6 +2942,10 @@ func (m *RelationTupleTreeNode_IntermediateNode) MarshalToSizedBufferVT(dAtA []b i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -2961,6 +2965,10 @@ func (m *RelationTupleTreeNode_LeafNode) MarshalToSizedBufferVT(dAtA []byte) (in i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -3706,6 +3714,10 @@ func (m *AllowedRelation_PublicWildcard_) MarshalToSizedBufferVT(dAtA []byte) (i i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x22 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x22 } return len(dAtA) - i, nil } @@ -3817,6 +3829,10 @@ func (m *UsersetRewrite_Union) MarshalToSizedBufferVT(dAtA []byte) (int, error) i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -3836,6 +3852,10 @@ func (m *UsersetRewrite_Intersection) MarshalToSizedBufferVT(dAtA []byte) (int, i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -3855,6 +3875,10 @@ func (m *UsersetRewrite_Exclusion) MarshalToSizedBufferVT(dAtA []byte) (int, err i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -4012,6 +4036,10 @@ func (m *SetOperation_Child_XThis) MarshalToSizedBufferVT(dAtA []byte) (int, err i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -4031,6 +4059,10 @@ func (m *SetOperation_Child_ComputedUserset) MarshalToSizedBufferVT(dAtA []byte) i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -4050,6 +4082,10 @@ func (m *SetOperation_Child_TupleToUserset) MarshalToSizedBufferVT(dAtA []byte) i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -4069,6 +4105,10 @@ func (m *SetOperation_Child_UsersetRewrite) MarshalToSizedBufferVT(dAtA []byte) i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x22 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x22 } return len(dAtA) - i, nil } @@ -4088,6 +4128,10 @@ func (m *SetOperation_Child_XNil) MarshalToSizedBufferVT(dAtA []byte) (int, erro i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x32 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x32 } return len(dAtA) - i, nil } @@ -4107,6 +4151,10 @@ func (m *SetOperation_Child_FunctionedTupleToUserset) MarshalToSizedBufferVT(dAt i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x42 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x42 } return len(dAtA) - i, nil } @@ -4522,6 +4570,10 @@ func (m *CaveatExpression_Operation) MarshalToSizedBufferVT(dAtA []byte) (int, e i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -4541,6 +4593,10 @@ func (m *CaveatExpression_Caveat) MarshalToSizedBufferVT(dAtA []byte) (int, erro i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -4962,6 +5018,8 @@ func (m *RelationTupleTreeNode_IntermediateNode) SizeVT() (n int) { if m.IntermediateNode != nil { l = m.IntermediateNode.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -4974,6 +5032,8 @@ func (m *RelationTupleTreeNode_LeafNode) SizeVT() (n int) { if m.LeafNode != nil { l = m.LeafNode.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5270,6 +5330,8 @@ func (m *AllowedRelation_PublicWildcard_) SizeVT() (n int) { if m.PublicWildcard != nil { l = m.PublicWildcard.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5313,6 +5375,8 @@ func (m *UsersetRewrite_Union) SizeVT() (n int) { if m.Union != nil { l = m.Union.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5325,6 +5389,8 @@ func (m *UsersetRewrite_Intersection) SizeVT() (n int) { if m.Intersection != nil { l = m.Intersection.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5337,6 +5403,8 @@ func (m *UsersetRewrite_Exclusion) SizeVT() (n int) { if m.Exclusion != nil { l = m.Exclusion.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5393,6 +5461,8 @@ func (m *SetOperation_Child_XThis) SizeVT() (n int) { if m.XThis != nil { l = m.XThis.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5405,6 +5475,8 @@ func (m *SetOperation_Child_ComputedUserset) SizeVT() (n int) { if m.ComputedUserset != nil { l = m.ComputedUserset.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5417,6 +5489,8 @@ func (m *SetOperation_Child_TupleToUserset) SizeVT() (n int) { if m.TupleToUserset != nil { l = m.TupleToUserset.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5429,6 +5503,8 @@ func (m *SetOperation_Child_UsersetRewrite) SizeVT() (n int) { if m.UsersetRewrite != nil { l = m.UsersetRewrite.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5441,6 +5517,8 @@ func (m *SetOperation_Child_XNil) SizeVT() (n int) { if m.XNil != nil { l = m.XNil.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5453,6 +5531,8 @@ func (m *SetOperation_Child_FunctionedTupleToUserset) SizeVT() (n int) { if m.FunctionedTupleToUserset != nil { l = m.FunctionedTupleToUserset.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5606,6 +5686,8 @@ func (m *CaveatExpression_Operation) SizeVT() (n int) { if m.Operation != nil { l = m.Operation.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -5618,6 +5700,8 @@ func (m *CaveatExpression_Caveat) SizeVT() (n int) { if m.Caveat != nil { l = m.Caveat.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } diff --git a/pkg/proto/developer/v1/developer_vtproto.pb.go b/pkg/proto/developer/v1/developer_vtproto.pb.go index c20d00e3e5..8f53dc6c4f 100644 --- a/pkg/proto/developer/v1/developer_vtproto.pb.go +++ b/pkg/proto/developer/v1/developer_vtproto.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-vtproto. DO NOT EDIT. -// protoc-gen-go-vtproto version: v0.5.1-0.20231212170721-e7d721933795 +// protoc-gen-go-vtproto version: v0.6.1-0.20240409071808-615f978279ca // source: developer/v1/developer.proto package developerv1 diff --git a/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go b/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go index 70fa90489f..8df373560a 100644 --- a/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go +++ b/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-vtproto. DO NOT EDIT. -// protoc-gen-go-vtproto version: v0.5.1-0.20231212170721-e7d721933795 +// protoc-gen-go-vtproto version: v0.6.1-0.20240409071808-615f978279ca // source: dispatch/v1/dispatch.proto package dispatchv1 diff --git a/pkg/proto/impl/v1/impl_vtproto.pb.go b/pkg/proto/impl/v1/impl_vtproto.pb.go index 951a7d19a0..95e368513f 100644 --- a/pkg/proto/impl/v1/impl_vtproto.pb.go +++ b/pkg/proto/impl/v1/impl_vtproto.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-vtproto. DO NOT EDIT. -// protoc-gen-go-vtproto version: v0.5.1-0.20231212170721-e7d721933795 +// protoc-gen-go-vtproto version: v0.6.1-0.20240409071808-615f978279ca // source: impl/v1/impl.proto package implv1 @@ -906,6 +906,10 @@ func (m *DecodedCaveat_Cel) MarshalToSizedBufferVT(dAtA []byte) (int, error) { } i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -1050,6 +1054,10 @@ func (m *DecodedZookie_V1) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -1069,6 +1077,10 @@ func (m *DecodedZookie_V2) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -1208,6 +1220,10 @@ func (m *DecodedZedToken_DeprecatedV1Zookie) MarshalToSizedBufferVT(dAtA []byte) i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -1227,6 +1243,10 @@ func (m *DecodedZedToken_V1) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -1288,6 +1308,10 @@ func (m *DecodedCursor_V1) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -1573,6 +1597,8 @@ func (m *DecodedCaveat_Cel) SizeVT() (n int) { l = proto.Size(m.Cel) } n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -1628,6 +1654,8 @@ func (m *DecodedZookie_V1) SizeVT() (n int) { if m.V1 != nil { l = m.V1.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -1640,6 +1668,8 @@ func (m *DecodedZookie_V2) SizeVT() (n int) { if m.V2 != nil { l = m.V2.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -1692,6 +1722,8 @@ func (m *DecodedZedToken_DeprecatedV1Zookie) SizeVT() (n int) { if m.DeprecatedV1Zookie != nil { l = m.DeprecatedV1Zookie.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -1704,6 +1736,8 @@ func (m *DecodedZedToken_V1) SizeVT() (n int) { if m.V1 != nil { l = m.V1.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } @@ -1729,6 +1763,8 @@ func (m *DecodedCursor_V1) SizeVT() (n int) { if m.V1 != nil { l = m.V1.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 } return n } diff --git a/pkg/proto/impl/v1/pgrevision_vtproto.pb.go b/pkg/proto/impl/v1/pgrevision_vtproto.pb.go index 25eec06927..c7a9dfff71 100644 --- a/pkg/proto/impl/v1/pgrevision_vtproto.pb.go +++ b/pkg/proto/impl/v1/pgrevision_vtproto.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-vtproto. DO NOT EDIT. -// protoc-gen-go-vtproto version: v0.5.1-0.20231212170721-e7d721933795 +// protoc-gen-go-vtproto version: v0.6.1-0.20240409071808-615f978279ca // source: impl/v1/pgrevision.proto package implv1 diff --git a/tools/analyzers/closeafterusagecheck/closeafterusagecheck.go b/tools/analyzers/closeafterusagecheck/closeafterusagecheck.go index 747596422a..752dd87d4c 100644 --- a/tools/analyzers/closeafterusagecheck/closeafterusagecheck.go +++ b/tools/analyzers/closeafterusagecheck/closeafterusagecheck.go @@ -6,19 +6,12 @@ import ( "slices" "strings" + "github.com/samber/lo" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" ) -func sliceMap(s []string, f func(value string) string) []string { - mapped := make([]string, 0, len(s)) - for _, value := range s { - mapped = append(mapped, f(value)) - } - return mapped -} - type nodeAndStack struct { node ast.Node stack []ast.Node @@ -39,7 +32,7 @@ func Analyzer() *analysis.Analyzer { Run: func(pass *analysis.Pass) (any, error) { // Check for a skipped package. if len(*skip) > 0 { - skipped := sliceMap(strings.Split(*skip, ","), strings.TrimSpace) + skipped := lo.Map(strings.Split(*skip, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) for _, s := range skipped { if strings.Contains(pass.Pkg.Path(), s) { return nil, nil diff --git a/tools/analyzers/cmd/analyzers/main.go b/tools/analyzers/cmd/analyzers/main.go index a3f28aef1c..153ffac21d 100644 --- a/tools/analyzers/cmd/analyzers/main.go +++ b/tools/analyzers/cmd/analyzers/main.go @@ -6,6 +6,7 @@ import ( "github.com/authzed/spicedb/tools/analyzers/lendowncastcheck" "github.com/authzed/spicedb/tools/analyzers/nilvaluecheck" "github.com/authzed/spicedb/tools/analyzers/paniccheck" + "github.com/authzed/spicedb/tools/analyzers/protomarshalcheck" "github.com/authzed/spicedb/tools/analyzers/zerologmarshalcheck" "golang.org/x/tools/go/analysis/multichecker" ) @@ -17,6 +18,7 @@ func main() { closeafterusagecheck.Analyzer(), paniccheck.Analyzer(), lendowncastcheck.Analyzer(), + protomarshalcheck.Analyzer(), zerologmarshalcheck.Analyzer(), ) } diff --git a/tools/analyzers/exprstatementcheck/exprstatementcheck.go b/tools/analyzers/exprstatementcheck/exprstatementcheck.go index a8b1f9c3fc..6e2672f4d9 100644 --- a/tools/analyzers/exprstatementcheck/exprstatementcheck.go +++ b/tools/analyzers/exprstatementcheck/exprstatementcheck.go @@ -7,19 +7,12 @@ import ( "slices" "strings" + "github.com/samber/lo" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" ) -func sliceMap(s []string, f func(value string) string) []string { - mapped := make([]string, 0, len(s)) - for _, value := range s { - mapped = append(mapped, f(value)) - } - return mapped -} - type disallowedExprStatementConfig struct { fullTypePath string errorMessage string @@ -47,7 +40,7 @@ func Analyzer() *analysis.Analyzer { Run: func(pass *analysis.Pass) (any, error) { // Check for a skipped package. if len(*skip) > 0 { - skipped := sliceMap(strings.Split(*skip, ","), strings.TrimSpace) + skipped := lo.Map(strings.Split(*skip, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) for _, s := range skipped { if strings.Contains(pass.Pkg.Path(), s) { return nil, nil diff --git a/tools/analyzers/lendowncastcheck/lendowncastcheck.go b/tools/analyzers/lendowncastcheck/lendowncastcheck.go index c4a8117090..b8e65b663f 100644 --- a/tools/analyzers/lendowncastcheck/lendowncastcheck.go +++ b/tools/analyzers/lendowncastcheck/lendowncastcheck.go @@ -7,19 +7,12 @@ import ( "regexp" "strings" + "github.com/samber/lo" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" ) -func sliceMap(s []string, f func(value string) string) []string { - mapped := make([]string, 0, len(s)) - for _, value := range s { - mapped = append(mapped, f(value)) - } - return mapped -} - var disallowedDowncastTypes = map[string]bool{ "int8": true, "int16": true, @@ -44,7 +37,7 @@ func Analyzer() *analysis.Analyzer { Run: func(pass *analysis.Pass) (any, error) { // Check for a skipped package. if len(*skipPkg) > 0 { - skipped := sliceMap(strings.Split(*skipPkg, ","), strings.TrimSpace) + skipped := lo.Map(strings.Split(*skipPkg, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) for _, s := range skipped { if strings.Contains(pass.Pkg.Path(), s) { return nil, nil @@ -55,7 +48,7 @@ func Analyzer() *analysis.Analyzer { // Check for a skipped file. skipFilePatterns := make([]string, 0) if len(*skipFiles) > 0 { - skipFilePatterns = sliceMap(strings.Split(*skipFiles, ","), strings.TrimSpace) + skipFilePatterns = lo.Map(strings.Split(*skipPkg, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) } for _, pattern := range skipFilePatterns { _, err := regexp.Compile(pattern) diff --git a/tools/analyzers/nilvaluecheck/nilvaluecheck.go b/tools/analyzers/nilvaluecheck/nilvaluecheck.go index a24b284cd3..be6fd782f9 100644 --- a/tools/analyzers/nilvaluecheck/nilvaluecheck.go +++ b/tools/analyzers/nilvaluecheck/nilvaluecheck.go @@ -7,19 +7,12 @@ import ( "slices" "strings" + "github.com/samber/lo" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" ) -func sliceMap(s []string, f func(value string) string) []string { - mapped := make([]string, 0, len(s)) - for _, value := range s { - mapped = append(mapped, f(value)) - } - return mapped -} - func Analyzer() *analysis.Analyzer { flagSet := flag.NewFlagSet("nilvaluecheck", flag.ExitOnError) disallowedPaths := flagSet.String("disallowed-nil-return-type-paths", "", "full paths of the types for whom nil returns are disallowed") @@ -31,7 +24,7 @@ func Analyzer() *analysis.Analyzer { Run: func(pass *analysis.Pass) (any, error) { // Check for a skipped package. if len(*skip) > 0 { - skipped := sliceMap(strings.Split(*skip, ","), strings.TrimSpace) + skipped := lo.Map(strings.Split(*skip, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) for _, s := range skipped { if strings.Contains(pass.Pkg.Path(), s) { return nil, nil @@ -46,7 +39,7 @@ func Analyzer() *analysis.Analyzer { (*ast.DeclStmt)(nil), } - typePaths := sliceMap(strings.Split(*disallowedPaths, ","), strings.TrimSpace) + typePaths := lo.Map(strings.Split(*disallowedPaths, ","), func(path string, _ int) string { return strings.TrimSpace(path) }) hasTypePath := func(path string) bool { return slices.Contains(typePaths, path) } diff --git a/tools/analyzers/paniccheck/paniccheck.go b/tools/analyzers/paniccheck/paniccheck.go index 03cdadee34..470773338c 100644 --- a/tools/analyzers/paniccheck/paniccheck.go +++ b/tools/analyzers/paniccheck/paniccheck.go @@ -7,19 +7,12 @@ import ( "regexp" "strings" + "github.com/samber/lo" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" ) -func sliceMap(s []string, f func(value string) string) []string { - mapped := make([]string, 0, len(s)) - for _, value := range s { - mapped = append(mapped, f(value)) - } - return mapped -} - func Analyzer() *analysis.Analyzer { flagSet := flag.NewFlagSet("paniccheck", flag.ExitOnError) skipPkg := flagSet.String("skip-pkg", "", "package(s) to skip for linting") @@ -31,7 +24,7 @@ func Analyzer() *analysis.Analyzer { Run: func(pass *analysis.Pass) (any, error) { // Check for a skipped package. if len(*skipPkg) > 0 { - skipped := sliceMap(strings.Split(*skipPkg, ","), strings.TrimSpace) + skipped := lo.Map(strings.Split(*skipPkg, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) for _, s := range skipped { if strings.Contains(pass.Pkg.Path(), s) { return nil, nil @@ -42,7 +35,7 @@ func Analyzer() *analysis.Analyzer { // Check for a skipped file. skipFilePatterns := make([]string, 0) if len(*skipFiles) > 0 { - skipFilePatterns = sliceMap(strings.Split(*skipFiles, ","), strings.TrimSpace) + skipFilePatterns = lo.Map(strings.Split(*skipPkg, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) } for _, pattern := range skipFilePatterns { _, err := regexp.Compile(pattern) diff --git a/tools/analyzers/protomarshalcheck/protomarshalcheck.go b/tools/analyzers/protomarshalcheck/protomarshalcheck.go new file mode 100644 index 0000000000..926e6c57f6 --- /dev/null +++ b/tools/analyzers/protomarshalcheck/protomarshalcheck.go @@ -0,0 +1,99 @@ +package protomarshalcheck + +import ( + "flag" + "fmt" + "go/ast" + "regexp" + "strings" + + "github.com/samber/lo" + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" +) + +func Analyzer() *analysis.Analyzer { + flagSet := flag.NewFlagSet("protomarshalcheck", flag.ExitOnError) + skipPkg := flagSet.String("skip-pkg", "", "package(s) to skip for linting") + skipFiles := flagSet.String("skip-files", "", "patterns of files to skip for linting") + + return &analysis.Analyzer{ + Name: "protomarshalcheck", + Doc: "reports calls to `proto.Marshal` and `proto.Unmarshal`, which should be replaced with their VT counterparts", + Run: func(pass *analysis.Pass) (any, error) { + // Check for a skipped package. + if len(*skipPkg) > 0 { + skipped := lo.Map(strings.Split(*skipPkg, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) + for _, s := range skipped { + if strings.Contains(pass.Pkg.Path(), s) { + return nil, nil + } + } + } + + // Check for a skipped file. + skipFilePatterns := make([]string, 0) + if len(*skipFiles) > 0 { + skipFilePatterns = lo.Map(strings.Split(*skipFiles, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) + } + for _, pattern := range skipFilePatterns { + _, err := regexp.Compile(pattern) + if err != nil { + return nil, fmt.Errorf("invalid skip-files pattern `%s`: %w", pattern, err) + } + } + + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + nodeFilter := []ast.Node{ + (*ast.File)(nil), + (*ast.CallExpr)(nil), + } + + inspect.WithStack(nodeFilter, func(n ast.Node, push bool, stack []ast.Node) bool { + switch s := n.(type) { + case *ast.File: + for _, pattern := range skipFilePatterns { + isMatch, _ := regexp.MatchString(pattern, pass.Fset.Position(s.Package).Filename) + if isMatch { + return false + } + } + return true + + case *ast.CallExpr: + selectorExpr, ok := s.Fun.(*ast.SelectorExpr) + if !ok { + return false + } + + expression, ok := selectorExpr.X.(*ast.Ident) + if !ok { + return false + } + if expression.Name != "proto" { + return false + } + + if selectorExpr.Sel.Name == "Unmarshal" { + pass.Reportf(s.Pos(), "`use someMessage.UnmarshalVT instead`") + } + + if selectorExpr.Sel.Name == "Marshal" { + pass.Reportf(s.Pos(), "`use someStruct.MarshalVT instead`") + } + + return false + + default: + return true + } + }) + + return nil, nil + }, + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Flags: *flagSet, + } +} diff --git a/tools/analyzers/protomarshalcheck/protomarshalcheck_test.go b/tools/analyzers/protomarshalcheck/protomarshalcheck_test.go new file mode 100644 index 0000000000..b678bdec8c --- /dev/null +++ b/tools/analyzers/protomarshalcheck/protomarshalcheck_test.go @@ -0,0 +1,15 @@ +package protomarshalcheck + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" +) + +func TestAnalyzer(t *testing.T) { + analyzer := Analyzer() + + testdata := analysistest.TestData() + analysistest.Run(t, testdata, analyzer, "disallowedmarshal") + analysistest.Run(t, testdata, analyzer, "validmarshal") +} diff --git a/tools/analyzers/protomarshalcheck/testdata/src/disallowedmarshal/disallowedmarshal.go b/tools/analyzers/protomarshalcheck/testdata/src/disallowedmarshal/disallowedmarshal.go new file mode 100644 index 0000000000..ca556e5a02 --- /dev/null +++ b/tools/analyzers/protomarshalcheck/testdata/src/disallowedmarshal/disallowedmarshal.go @@ -0,0 +1,36 @@ +package disallowedmarshal + +// NOTE: all of these defs are just to get us to `proto.Marshal` and +// `proto.Unmarshal` without needing to import things, because importing +// in these tests is difficult. +type SomeProto struct{} + +func (p SomeProto) Marshal(foo any) (any, error) { + return foo, nil +} + +func (p SomeProto) Unmarshal(foo any, bar any) (any, error) { + return foo, nil +} + +type ( + NamespaceDefinition struct{} + NamespaceMessage struct{} +) + +func WriteNamespaces(newConfigs ...*NamespaceDefinition) error { + proto := SomeProto{} + for _, newConfig := range newConfigs { + _, err := proto.Marshal(newConfig) // want "use someStruct.MarshalVT instead" + if err != nil { + return err + } + } + + return nil +} + +func DoAnUnmarshal() (any, error) { + proto := SomeProto{} + return proto.Unmarshal(make([]byte, 0), nil) // want "use someMessage.UnmarshalVT instead" +} diff --git a/tools/analyzers/protomarshalcheck/testdata/src/validmarshal/validmarshal.go b/tools/analyzers/protomarshalcheck/testdata/src/validmarshal/validmarshal.go new file mode 100644 index 0000000000..8d634d40dc --- /dev/null +++ b/tools/analyzers/protomarshalcheck/testdata/src/validmarshal/validmarshal.go @@ -0,0 +1,38 @@ +package validmarshal + +type NamespaceDefinition struct{} + +func (n *NamespaceDefinition) MarshalVT() (any, error) { + return nil, nil +} + +type SomeOtherObject struct{} + +func (s SomeOtherObject) Marshal() {} + +type NamespaceMessage struct{} + +func (n *NamespaceMessage) UnmarshalVT() (*NamespaceDefinition, error) { + return nil, nil +} + +func WriteNamespaces(newConfigs ...*NamespaceDefinition) error { + for _, newConfig := range newConfigs { + // This is the desired usage + _, err := newConfig.MarshalVT() + if err != nil { + return err + } + } + + return nil +} + +func DoAnUnmarshal(foo *NamespaceMessage) (*NamespaceDefinition, error) { + return foo.UnmarshalVT() +} + +func DoAMarshal(foo *SomeOtherObject) { + // Ensure that something else that isn't called proto isn't caught by the linter. + foo.Marshal() +} diff --git a/tools/analyzers/zerologmarshalcheck/zerologmarshalcheck.go b/tools/analyzers/zerologmarshalcheck/zerologmarshalcheck.go index ac64b83c08..6ec71d40e7 100644 --- a/tools/analyzers/zerologmarshalcheck/zerologmarshalcheck.go +++ b/tools/analyzers/zerologmarshalcheck/zerologmarshalcheck.go @@ -8,19 +8,12 @@ import ( "slices" "strings" + "github.com/samber/lo" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" ) -func sliceMap(s []string, f func(value string) string) []string { - mapped := make([]string, 0, len(s)) - for _, value := range s { - mapped = append(mapped, f(value)) - } - return mapped -} - func Analyzer() *analysis.Analyzer { flagSet := flag.NewFlagSet("zerologmarshalcheck", flag.ExitOnError) skipPkg := flagSet.String("skip-pkg", "", "package(s) to skip for linting") @@ -35,7 +28,7 @@ func Analyzer() *analysis.Analyzer { // Check for a skipped package. if len(*skipPkg) > 0 { - skipped := sliceMap(strings.Split(*skipPkg, ","), strings.TrimSpace) + skipped := lo.Map(strings.Split(*skipPkg, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) for _, s := range skipped { if strings.Contains(pass.Pkg.Path(), s) { return nil, nil @@ -46,7 +39,7 @@ func Analyzer() *analysis.Analyzer { // Check for a skipped file. skipFilePatterns := make([]string, 0) if len(*skipFiles) > 0 { - skipFilePatterns = sliceMap(strings.Split(*skipFiles, ","), strings.TrimSpace) + skipFilePatterns = lo.Map(strings.Split(*skipPkg, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) } for _, pattern := range skipFilePatterns { _, err := regexp.Compile(pattern) @@ -165,8 +158,6 @@ func Analyzer() *analysis.Analyzer { default: return true } - - return false }) return nil, nil