Skip to content

Commit

Permalink
move all code to internal/tests/serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
illia-li committed Sep 28, 2024
1 parent 1af653d commit b1a4ee3
Show file tree
Hide file tree
Showing 17 changed files with 98 additions and 122 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ import (
"reflect"
)

// ErrFirstPtrChanged this error indicates that a double or single reference was passed to the Unmarshal function
// errFirstPtrChanged this error indicates that a double or single reference was passed to the Unmarshal function
// (example (**int)(**0) or (*int)(*0)) and Unmarshal overwritten first reference.
var ErrFirstPtrChanged = errors.New("unmarshal function rewrote first pointer")
var errFirstPtrChanged = errors.New("unmarshal function rewrote first pointer")

// ErrSecondPtrNotChanged this error indicates that a double reference was passed to the Unmarshal function
// errSecondPtrNotChanged this error indicates that a double reference was passed to the Unmarshal function
// (example (**int)(**0)) and the function did not overwrite the second reference.
// Of course, it's not friendly to the garbage collector, overwriting references to values all the time,
// but this is the current implementation `gocql` and changing it can lead to unexpected results in some cases.
var ErrSecondPtrNotChanged = errors.New("unmarshal function did not rewrite second pointer")
var errSecondPtrNotChanged = errors.New("unmarshal function did not rewrite second pointer")

func getPointers(i interface{}) *pointer {
rv := reflect.ValueOf(i)
Expand Down Expand Up @@ -45,10 +45,10 @@ func (p *pointer) NotNil() bool {
func (p *pointer) Valid(v interface{}) error {
p2 := getPointers(v)
if p.Fist != p2.Fist {
return ErrFirstPtrChanged
return errFirstPtrChanged
}
if p.Second != 0 && p2.Second != 0 && p2.Second == p.Second {
return ErrSecondPtrNotChanged
return errSecondPtrNotChanged
}
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@ import (
"reflect"
"runtime/debug"
"testing"

"github.com/gocql/gocql/internal/tests/utils"
)

type Sets []*Set

// Set is a tool for generating test cases of marshal and unmarshall funcs.
// For cases when the function should no error,
// marshaled data from Set.Values should be equal with Set.Data,
Expand Down Expand Up @@ -41,15 +37,15 @@ func (s Set) RunTest(name string, t *testing.T, marshal func(interface{}) ([]byt

if unmarshal != nil {
if rt := reflect.TypeOf(val); rt.Kind() != reflect.Ptr {
unmarshalIn := utils.NewRef(val)
unmarshalIn := newRef(val)
s.runUnmarshalTest("unmarshal", t, unmarshal, val, unmarshalIn)
} else {
// Test unmarshal to (*type)(nil)
unmarshalIn := utils.NewRef(val)
unmarshalIn := newRef(val)
s.runUnmarshalTest("unmarshal**nil", t, unmarshal, val, unmarshalIn)

// Test unmarshal to &type{}
unmarshalInZero := utils.NewRefToZero(val)
unmarshalInZero := newRefToZero(val)
s.runUnmarshalTest("unmarshal**zero", t, unmarshal, val, unmarshalInZero)
}
}
Expand All @@ -70,22 +66,22 @@ func (s Set) RunCorruptTest(name string, t *testing.T, marshal func(interface{})
val := s.Values[m]

if marshal != nil {
t.Run(utils.StringValue(val), func(t *testing.T) {
t.Run(stringValue(val), func(t *testing.T) {
s.runMarshalCorruptTest(t, marshal, val)
})
continue
}

if rt := reflect.TypeOf(val); rt.Kind() != reflect.Ptr {
unmarshalIn := utils.NewRef(val)
unmarshalIn := newRef(val)
s.runUnmarshalCorruptTest(fmt.Sprintf("%T", val), t, unmarshal, unmarshalIn)
} else {
// Test unmarshal to (*type)(nil)
unmarshalIn := utils.NewRef(val)
unmarshalIn := newRef(val)
s.runUnmarshalCorruptTest(fmt.Sprintf("%T**nil", val), t, unmarshal, unmarshalIn)

// Test unmarshal to &type{}
unmarshalInZero := utils.NewRefToZero(val)
unmarshalInZero := newRefToZero(val)
s.runUnmarshalCorruptTest(fmt.Sprintf("%T**zero", val), t, unmarshal, unmarshalInZero)
}
}
Expand All @@ -98,19 +94,19 @@ func (s Set) runMarshalTest(t *testing.T, f func(interface{}) ([]byte, error), v
result, err := func() (d []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = utils.PanicErr{Err: r.(error), Stack: debug.Stack()}
err = panicErr{err: r.(error), stack: debug.Stack()}
}
}()
return f(val)
}()

expected := bytes.Clone(s.Data)
if err != nil {
if !errors.As(err, &utils.PanicErr{}) {
err = errors.Join(MarshalErr, err)
if !errors.As(err, &panicErr{}) {
err = errors.Join(marshalErr, err)
}
} else if !utils.EqualData(expected, result) {
err = UnequalError{Expected: utils.StringData(s.Data), Got: utils.StringData(result)}
} else if !equalData(expected, result) {
err = unequalError{Expected: stringData(s.Data), Got: stringData(result)}
}

if isTypeOf(val, s.BrokenMarshalTypes) {
Expand All @@ -133,18 +129,18 @@ func (s Set) runUnmarshalTest(name string, t *testing.T, f func([]byte, interfac
err := func() (err error) {
defer func() {
if r := recover(); r != nil {
err = utils.PanicErr{Err: fmt.Errorf("%s", r), Stack: debug.Stack()}
err = panicErr{err: fmt.Errorf("%s", r), stack: debug.Stack()}
}
}()
return f(bytes.Clone(s.Data), result)
}()

if err != nil {
if !errors.As(err, &utils.PanicErr{}) {
err = errors.Join(UnmarshalErr, err)
if !errors.As(err, &panicErr{}) {
err = errors.Join(unmarshalErr, err)
}
} else if !utils.EqualVals(expected, utils.DeReference(result)) {
err = UnequalError{Expected: utils.StringValue(expected), Got: utils.StringValue(result)}
} else if !equalVals(expected, deReference(result)) {
err = unequalError{Expected: stringValue(expected), Got: stringValue(result)}
} else {
err = expectedPtr.Valid(result)
}
Expand All @@ -165,7 +161,7 @@ func (s Set) runMarshalCorruptTest(t *testing.T, f func(interface{}) ([]byte, er
_, err := func() (d []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = utils.PanicErr{Err: r.(error), Stack: debug.Stack()}
err = panicErr{err: r.(error), stack: debug.Stack()}
}
}()
return f(val)
Expand All @@ -180,7 +176,7 @@ func (s Set) runMarshalCorruptTest(t *testing.T, f func(interface{}) ([]byte, er

if err == nil {
t.Errorf("expected to fail for (%T), but did not fail", val)
} else if errors.As(err, &utils.PanicErr{}) {
} else if errors.As(err, &panicErr{}) {
t.Errorf("was panic %s", err)
}
}
Expand All @@ -190,13 +186,13 @@ func (s Set) runUnmarshalCorruptTest(name string, t *testing.T, f func([]byte, i
err := func() (err error) {
defer func() {
if r := recover(); r != nil {
err = utils.PanicErr{Err: r.(error), Stack: debug.Stack()}
err = panicErr{err: r.(error), stack: debug.Stack()}
}
}()
return f(bytes.Clone(s.Data), val)
}()

if err != nil && !errors.As(err, &utils.PanicErr{}) {
if err != nil && !errors.As(err, &panicErr{}) {
err = nil
}

Expand All @@ -209,7 +205,7 @@ func (s Set) runUnmarshalCorruptTest(name string, t *testing.T, f func([]byte, i

if err == nil {
t.Errorf("expected to fail for (%T), but did not fail", val)
} else if errors.As(err, &utils.PanicErr{}) {
} else if errors.As(err, &panicErr{}) {
t.Errorf("was panic %s", err)
}
})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
package utils
package serialization

import (
"reflect"
)

func DeReference(in interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(in)).Interface()
}

func Reference(val interface{}) interface{} {
out := reflect.New(reflect.TypeOf(val))
out.Elem().Set(reflect.ValueOf(val))
return out.Interface()
}

func GetTypes(values ...interface{}) []reflect.Type {
types := make([]reflect.Type, len(values))
for i, value := range values {
types[i] = reflect.TypeOf(value)
}
return types
}

func isTypeOf(value interface{}, types []reflect.Type) bool {
valueType := reflect.TypeOf(value)
for i := range types {
if types[i] == valueType {
return true
}
}
return false
}

func deReference(in interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(in)).Interface()
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
package utils
package serialization

import (
"bytes"
"fmt"
"github.com/gocql/gocql/internal/tests/serialization/mod"
"gopkg.in/inf.v0"
"math/big"
"reflect"
"unsafe"

"github.com/gocql/gocql/marshal/tests/mod"
)

func EqualData(in1, in2 []byte) bool {
func equalData(in1, in2 []byte) bool {
if in1 == nil || in2 == nil {
return in1 == nil && in2 == nil
}
return bytes.Equal(in1, in2)
}

func EqualVals(in1, in2 interface{}) bool {
func equalVals(in1, in2 interface{}) bool {
rin1 := reflect.ValueOf(in1)
rin2 := reflect.ValueOf(in2)
if rin1.Kind() != rin2.Kind() {
Expand Down
27 changes: 27 additions & 0 deletions internal/tests/serialization/utils_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package serialization

import (
"errors"
"fmt"
)

var unmarshalErr = errors.New("unmarshal unexpectedly failed with error")
var marshalErr = errors.New("marshal unexpectedly failed with error")

type unequalError struct {
Expected string
Got string
}

func (e unequalError) Error() string {
return fmt.Sprintf("expect %s but got %s", e.Expected, e.Got)
}

type panicErr struct {
err error
stack []byte
}

func (e panicErr) Error() string {
return fmt.Sprintf("%v\n%s", e.err, e.stack)
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package utils
package serialization

import (
"reflect"
)

func NewRef(in interface{}) interface{} {
func newRef(in interface{}) interface{} {
out := reflect.New(reflect.TypeOf(in)).Interface()
return out
}

func NewRefToZero(in interface{}) interface{} {
func newRefToZero(in interface{}) interface{} {
rv := reflect.ValueOf(in)
nw := reflect.New(rv.Type().Elem())
out := reflect.New(rv.Type())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package utils
package serialization

import (
"fmt"
Expand All @@ -9,16 +9,25 @@ import (
"time"
)

// StringValue returns (value_type)(value) in the human-readable format.
func StringValue(in interface{}) string {
valStr := stringValue(in)
const printLimit = 100

// stringValue returns (value_type)(value) in the human-readable format.
func stringValue(in interface{}) string {
valStr := stringVal(in)
if len(valStr) > printLimit {
valStr = valStr[:printLimit]
}
return fmt.Sprintf("(%T)(%s)", in, valStr)
}

func stringValue(in interface{}) string {
func stringData(p []byte) string {
if len(p) > printLimit {
p = p[:printLimit]
}
return fmt.Sprintf("[%x]", p)
}

func stringVal(in interface{}) string {
switch i := in.(type) {
case string:
return i
Expand All @@ -40,7 +49,7 @@ func stringValue(in interface{}) string {
if rv.IsNil() {
return "*nil"
}
return fmt.Sprintf("*%s", stringValue(rv.Elem().Interface()))
return fmt.Sprintf("*%s", stringVal(rv.Elem().Interface()))
case reflect.Slice:
if rv.IsNil() {
return "[nil]"
Expand Down
14 changes: 0 additions & 14 deletions internal/tests/utils/panic_err.go

This file was deleted.

14 changes: 0 additions & 14 deletions internal/tests/utils/string.go

This file was deleted.

Loading

0 comments on commit b1a4ee3

Please sign in to comment.