diff --git a/marshal/tests/mod/all.go b/internal/tests/serialization/mod/all.go similarity index 100% rename from marshal/tests/mod/all.go rename to internal/tests/serialization/mod/all.go diff --git a/marshal/tests/mod/custom.go b/internal/tests/serialization/mod/custom.go similarity index 100% rename from marshal/tests/mod/custom.go rename to internal/tests/serialization/mod/custom.go diff --git a/marshal/tests/mod/custom_refs.go b/internal/tests/serialization/mod/custom_refs.go similarity index 100% rename from marshal/tests/mod/custom_refs.go rename to internal/tests/serialization/mod/custom_refs.go diff --git a/marshal/tests/mod/refs.go b/internal/tests/serialization/mod/refs.go similarity index 100% rename from marshal/tests/mod/refs.go rename to internal/tests/serialization/mod/refs.go diff --git a/marshal/tests/serialization/pointers.go b/internal/tests/serialization/pointers.go similarity index 81% rename from marshal/tests/serialization/pointers.go rename to internal/tests/serialization/pointers.go index dc4f69ea4..d6a72ac41 100644 --- a/marshal/tests/serialization/pointers.go +++ b/internal/tests/serialization/pointers.go @@ -5,15 +5,15 @@ import ( "reflect" ) -// ErrFirstPtrChanged this error indicates that a double or single reference was passed to the Unmarshal function +// errFirstPtrChanged this error indicates that a double or single reference was passed to the Unmarshal function // (example (**int)(**0) or (*int)(*0)) and Unmarshal overwritten first reference. -var ErrFirstPtrChanged = errors.New("unmarshal function rewrote first pointer") +var errFirstPtrChanged = errors.New("unmarshal function rewrote first pointer") -// ErrSecondPtrNotChanged this error indicates that a double reference was passed to the Unmarshal function +// errSecondPtrNotChanged this error indicates that a double reference was passed to the Unmarshal function // (example (**int)(**0)) and the function did not overwrite the second reference. // Of course, it's not friendly to the garbage collector, overwriting references to values all the time, // but this is the current implementation `gocql` and changing it can lead to unexpected results in some cases. -var ErrSecondPtrNotChanged = errors.New("unmarshal function did not rewrite second pointer") +var errSecondPtrNotChanged = errors.New("unmarshal function did not rewrite second pointer") func getPointers(i interface{}) *pointer { rv := reflect.ValueOf(i) @@ -45,10 +45,10 @@ func (p *pointer) NotNil() bool { func (p *pointer) Valid(v interface{}) error { p2 := getPointers(v) if p.Fist != p2.Fist { - return ErrFirstPtrChanged + return errFirstPtrChanged } if p.Second != 0 && p2.Second != 0 && p2.Second == p.Second { - return ErrSecondPtrNotChanged + return errSecondPtrNotChanged } return nil } diff --git a/marshal/tests/serialization/pointers_test.go b/internal/tests/serialization/pointers_test.go similarity index 100% rename from marshal/tests/serialization/pointers_test.go rename to internal/tests/serialization/pointers_test.go diff --git a/marshal/tests/serialization/set.go b/internal/tests/serialization/set.go similarity index 78% rename from marshal/tests/serialization/set.go rename to internal/tests/serialization/set.go index 55f28a67a..3b3856e3c 100644 --- a/marshal/tests/serialization/set.go +++ b/internal/tests/serialization/set.go @@ -7,12 +7,8 @@ import ( "reflect" "runtime/debug" "testing" - - "github.com/gocql/gocql/internal/tests/utils" ) -type Sets []*Set - // Set is a tool for generating test cases of marshal and unmarshall funcs. // For cases when the function should no error, // marshaled data from Set.Values should be equal with Set.Data, @@ -41,15 +37,15 @@ func (s Set) RunTest(name string, t *testing.T, marshal func(interface{}) ([]byt if unmarshal != nil { if rt := reflect.TypeOf(val); rt.Kind() != reflect.Ptr { - unmarshalIn := utils.NewRef(val) + unmarshalIn := newRef(val) s.runUnmarshalTest("unmarshal", t, unmarshal, val, unmarshalIn) } else { // Test unmarshal to (*type)(nil) - unmarshalIn := utils.NewRef(val) + unmarshalIn := newRef(val) s.runUnmarshalTest("unmarshal**nil", t, unmarshal, val, unmarshalIn) // Test unmarshal to &type{} - unmarshalInZero := utils.NewRefToZero(val) + unmarshalInZero := newRefToZero(val) s.runUnmarshalTest("unmarshal**zero", t, unmarshal, val, unmarshalInZero) } } @@ -70,22 +66,22 @@ func (s Set) RunCorruptTest(name string, t *testing.T, marshal func(interface{}) val := s.Values[m] if marshal != nil { - t.Run(utils.StringValue(val), func(t *testing.T) { + t.Run(stringValue(val), func(t *testing.T) { s.runMarshalCorruptTest(t, marshal, val) }) continue } if rt := reflect.TypeOf(val); rt.Kind() != reflect.Ptr { - unmarshalIn := utils.NewRef(val) + unmarshalIn := newRef(val) s.runUnmarshalCorruptTest(fmt.Sprintf("%T", val), t, unmarshal, unmarshalIn) } else { // Test unmarshal to (*type)(nil) - unmarshalIn := utils.NewRef(val) + unmarshalIn := newRef(val) s.runUnmarshalCorruptTest(fmt.Sprintf("%T**nil", val), t, unmarshal, unmarshalIn) // Test unmarshal to &type{} - unmarshalInZero := utils.NewRefToZero(val) + unmarshalInZero := newRefToZero(val) s.runUnmarshalCorruptTest(fmt.Sprintf("%T**zero", val), t, unmarshal, unmarshalInZero) } } @@ -98,7 +94,7 @@ func (s Set) runMarshalTest(t *testing.T, f func(interface{}) ([]byte, error), v result, err := func() (d []byte, err error) { defer func() { if r := recover(); r != nil { - err = utils.PanicErr{Err: r.(error), Stack: debug.Stack()} + err = panicErr{err: r.(error), stack: debug.Stack()} } }() return f(val) @@ -106,11 +102,11 @@ func (s Set) runMarshalTest(t *testing.T, f func(interface{}) ([]byte, error), v expected := bytes.Clone(s.Data) if err != nil { - if !errors.As(err, &utils.PanicErr{}) { - err = errors.Join(MarshalErr, err) + if !errors.As(err, &panicErr{}) { + err = errors.Join(marshalErr, err) } - } else if !utils.EqualData(expected, result) { - err = UnequalError{Expected: utils.StringData(s.Data), Got: utils.StringData(result)} + } else if !equalData(expected, result) { + err = unequalError{Expected: stringData(s.Data), Got: stringData(result)} } if isTypeOf(val, s.BrokenMarshalTypes) { @@ -133,18 +129,18 @@ func (s Set) runUnmarshalTest(name string, t *testing.T, f func([]byte, interfac err := func() (err error) { defer func() { if r := recover(); r != nil { - err = utils.PanicErr{Err: fmt.Errorf("%s", r), Stack: debug.Stack()} + err = panicErr{err: fmt.Errorf("%s", r), stack: debug.Stack()} } }() return f(bytes.Clone(s.Data), result) }() if err != nil { - if !errors.As(err, &utils.PanicErr{}) { - err = errors.Join(UnmarshalErr, err) + if !errors.As(err, &panicErr{}) { + err = errors.Join(unmarshalErr, err) } - } else if !utils.EqualVals(expected, utils.DeReference(result)) { - err = UnequalError{Expected: utils.StringValue(expected), Got: utils.StringValue(result)} + } else if !equalVals(expected, deReference(result)) { + err = unequalError{Expected: stringValue(expected), Got: stringValue(result)} } else { err = expectedPtr.Valid(result) } @@ -165,7 +161,7 @@ func (s Set) runMarshalCorruptTest(t *testing.T, f func(interface{}) ([]byte, er _, err := func() (d []byte, err error) { defer func() { if r := recover(); r != nil { - err = utils.PanicErr{Err: r.(error), Stack: debug.Stack()} + err = panicErr{err: r.(error), stack: debug.Stack()} } }() return f(val) @@ -180,7 +176,7 @@ func (s Set) runMarshalCorruptTest(t *testing.T, f func(interface{}) ([]byte, er if err == nil { t.Errorf("expected to fail for (%T), but did not fail", val) - } else if errors.As(err, &utils.PanicErr{}) { + } else if errors.As(err, &panicErr{}) { t.Errorf("was panic %s", err) } } @@ -190,13 +186,13 @@ func (s Set) runUnmarshalCorruptTest(name string, t *testing.T, f func([]byte, i err := func() (err error) { defer func() { if r := recover(); r != nil { - err = utils.PanicErr{Err: r.(error), Stack: debug.Stack()} + err = panicErr{err: r.(error), stack: debug.Stack()} } }() return f(bytes.Clone(s.Data), val) }() - if err != nil && !errors.As(err, &utils.PanicErr{}) { + if err != nil && !errors.As(err, &panicErr{}) { err = nil } @@ -209,7 +205,7 @@ func (s Set) runUnmarshalCorruptTest(name string, t *testing.T, f func([]byte, i if err == nil { t.Errorf("expected to fail for (%T), but did not fail", val) - } else if errors.As(err, &utils.PanicErr{}) { + } else if errors.As(err, &panicErr{}) { t.Errorf("was panic %s", err) } }) diff --git a/internal/tests/utils/utils.go b/internal/tests/serialization/utils.go similarity index 51% rename from internal/tests/utils/utils.go rename to internal/tests/serialization/utils.go index 6b2239b48..667efe993 100644 --- a/internal/tests/utils/utils.go +++ b/internal/tests/serialization/utils.go @@ -1,19 +1,9 @@ -package utils +package serialization import ( "reflect" ) -func DeReference(in interface{}) interface{} { - return reflect.Indirect(reflect.ValueOf(in)).Interface() -} - -func Reference(val interface{}) interface{} { - out := reflect.New(reflect.TypeOf(val)) - out.Elem().Set(reflect.ValueOf(val)) - return out.Interface() -} - func GetTypes(values ...interface{}) []reflect.Type { types := make([]reflect.Type, len(values)) for i, value := range values { @@ -21,3 +11,17 @@ func GetTypes(values ...interface{}) []reflect.Type { } return types } + +func isTypeOf(value interface{}, types []reflect.Type) bool { + valueType := reflect.TypeOf(value) + for i := range types { + if types[i] == valueType { + return true + } + } + return false +} + +func deReference(in interface{}) interface{} { + return reflect.Indirect(reflect.ValueOf(in)).Interface() +} diff --git a/internal/tests/utils/equal.go b/internal/tests/serialization/utils_equal.go similarity index 91% rename from internal/tests/utils/equal.go rename to internal/tests/serialization/utils_equal.go index 71e48af80..9acf5035a 100644 --- a/internal/tests/utils/equal.go +++ b/internal/tests/serialization/utils_equal.go @@ -1,24 +1,23 @@ -package utils +package serialization import ( "bytes" "fmt" + "github.com/gocql/gocql/internal/tests/serialization/mod" "gopkg.in/inf.v0" "math/big" "reflect" "unsafe" - - "github.com/gocql/gocql/marshal/tests/mod" ) -func EqualData(in1, in2 []byte) bool { +func equalData(in1, in2 []byte) bool { if in1 == nil || in2 == nil { return in1 == nil && in2 == nil } return bytes.Equal(in1, in2) } -func EqualVals(in1, in2 interface{}) bool { +func equalVals(in1, in2 interface{}) bool { rin1 := reflect.ValueOf(in1) rin2 := reflect.ValueOf(in2) if rin1.Kind() != rin2.Kind() { diff --git a/internal/tests/serialization/utils_error.go b/internal/tests/serialization/utils_error.go new file mode 100644 index 000000000..12b4ff3f4 --- /dev/null +++ b/internal/tests/serialization/utils_error.go @@ -0,0 +1,27 @@ +package serialization + +import ( + "errors" + "fmt" +) + +var unmarshalErr = errors.New("unmarshal unexpectedly failed with error") +var marshalErr = errors.New("marshal unexpectedly failed with error") + +type unequalError struct { + Expected string + Got string +} + +func (e unequalError) Error() string { + return fmt.Sprintf("expect %s but got %s", e.Expected, e.Got) +} + +type panicErr struct { + err error + stack []byte +} + +func (e panicErr) Error() string { + return fmt.Sprintf("%v\n%s", e.err, e.stack) +} diff --git a/internal/tests/utils/new.go b/internal/tests/serialization/utils_new.go similarity index 67% rename from internal/tests/utils/new.go rename to internal/tests/serialization/utils_new.go index d2901db01..a821272ed 100644 --- a/internal/tests/utils/new.go +++ b/internal/tests/serialization/utils_new.go @@ -1,15 +1,15 @@ -package utils +package serialization import ( "reflect" ) -func NewRef(in interface{}) interface{} { +func newRef(in interface{}) interface{} { out := reflect.New(reflect.TypeOf(in)).Interface() return out } -func NewRefToZero(in interface{}) interface{} { +func newRefToZero(in interface{}) interface{} { rv := reflect.ValueOf(in) nw := reflect.New(rv.Type().Elem()) out := reflect.New(rv.Type()) diff --git a/internal/tests/utils/string_vals.go b/internal/tests/serialization/utils_str.go similarity index 69% rename from internal/tests/utils/string_vals.go rename to internal/tests/serialization/utils_str.go index 73bbf19ee..3634224ec 100644 --- a/internal/tests/utils/string_vals.go +++ b/internal/tests/serialization/utils_str.go @@ -1,4 +1,4 @@ -package utils +package serialization import ( "fmt" @@ -9,16 +9,25 @@ import ( "time" ) -// StringValue returns (value_type)(value) in the human-readable format. -func StringValue(in interface{}) string { - valStr := stringValue(in) +const printLimit = 100 + +// stringValue returns (value_type)(value) in the human-readable format. +func stringValue(in interface{}) string { + valStr := stringVal(in) if len(valStr) > printLimit { valStr = valStr[:printLimit] } return fmt.Sprintf("(%T)(%s)", in, valStr) } -func stringValue(in interface{}) string { +func stringData(p []byte) string { + if len(p) > printLimit { + p = p[:printLimit] + } + return fmt.Sprintf("[%x]", p) +} + +func stringVal(in interface{}) string { switch i := in.(type) { case string: return i @@ -40,7 +49,7 @@ func stringValue(in interface{}) string { if rv.IsNil() { return "*nil" } - return fmt.Sprintf("*%s", stringValue(rv.Elem().Interface())) + return fmt.Sprintf("*%s", stringVal(rv.Elem().Interface())) case reflect.Slice: if rv.IsNil() { return "[nil]" diff --git a/internal/tests/utils/panic_err.go b/internal/tests/utils/panic_err.go deleted file mode 100644 index 45d27cb9c..000000000 --- a/internal/tests/utils/panic_err.go +++ /dev/null @@ -1,14 +0,0 @@ -package utils - -import ( - "fmt" -) - -type PanicErr struct { - Err error - Stack []byte -} - -func (e PanicErr) Error() string { - return fmt.Sprintf("%v\n%s", e.Err, e.Stack) -} diff --git a/internal/tests/utils/string.go b/internal/tests/utils/string.go deleted file mode 100644 index 8d7712518..000000000 --- a/internal/tests/utils/string.go +++ /dev/null @@ -1,14 +0,0 @@ -package utils - -import ( - "fmt" -) - -const printLimit = 100 - -func StringData(p []byte) string { - if len(p) > printLimit { - p = p[:printLimit] - } - return fmt.Sprintf("[%x]", p) -} diff --git a/marshal/tests/serialization/utils.go b/marshal/tests/serialization/utils.go deleted file mode 100644 index 63010dd56..000000000 --- a/marshal/tests/serialization/utils.go +++ /dev/null @@ -1,29 +0,0 @@ -package serialization - -import ( - "errors" - "fmt" - "reflect" -) - -var UnmarshalErr = errors.New("unmarshal unexpectedly failed with error") -var MarshalErr = errors.New("marshal unexpectedly failed with error") - -type UnequalError struct { - Expected string - Got string -} - -func (e UnequalError) Error() string { - return fmt.Sprintf("expect %s but got %s", e.Expected, e.Got) -} - -func isTypeOf(value interface{}, types []reflect.Type) bool { - valueType := reflect.TypeOf(value) - for i := range types { - if types[i] == valueType { - return true - } - } - return false -} diff --git a/marshal_3_smallint_corrupt_test.go b/marshal_3_smallint_corrupt_test.go index 13dec5d24..cb103f0e9 100644 --- a/marshal_3_smallint_corrupt_test.go +++ b/marshal_3_smallint_corrupt_test.go @@ -4,9 +4,8 @@ import ( "math/big" "testing" - "github.com/gocql/gocql/internal/tests/utils" - "github.com/gocql/gocql/marshal/tests/mod" - "github.com/gocql/gocql/marshal/tests/serialization" + "github.com/gocql/gocql/internal/tests/serialization" + "github.com/gocql/gocql/internal/tests/serialization/mod" ) func TestMarshalSmallintCorrupt(t *testing.T) { @@ -15,7 +14,7 @@ func TestMarshalSmallintCorrupt(t *testing.T) { return Unmarshal(NativeType{proto: 4, typ: TypeSmallInt}, bytes, i) } - brokenUnmarshalTypes := utils.GetTypes( + brokenUnmarshalTypes := serialization.GetTypes( mod.Values(mod.Reference( int8(0), int16(0), int32(0), int64(0), int(0), uint8(0), uint16(0), uint32(0), uint64(0), uint(0), diff --git a/marshal_3_smallint_test.go b/marshal_3_smallint_test.go index d19bf9633..0fc84604d 100644 --- a/marshal_3_smallint_test.go +++ b/marshal_3_smallint_test.go @@ -3,9 +3,8 @@ package gocql import ( "testing" - "github.com/gocql/gocql/internal/tests/utils" - "github.com/gocql/gocql/marshal/tests/mod" - "github.com/gocql/gocql/marshal/tests/serialization" + "github.com/gocql/gocql/internal/tests/serialization" + "github.com/gocql/gocql/internal/tests/serialization/mod" ) func TestMarshalSmallint(t *testing.T) { @@ -14,7 +13,7 @@ func TestMarshalSmallint(t *testing.T) { return Unmarshal(NativeType{proto: 4, typ: TypeSmallInt}, bytes, i) } - brokenTypes := utils.GetTypes(mod.String(""), (*mod.String)(nil)) + brokenTypes := serialization.GetTypes(mod.String(""), (*mod.String)(nil)) serialization.Set{ Data: nil,