Skip to content

Commit

Permalink
join and move serialization and utils packages
Browse files Browse the repository at this point in the history
`serialization` and `utils` packages are joined and moved to `internal/tests/serialization`
  • Loading branch information
illia-li committed Sep 30, 2024
1 parent 275f9d2 commit 14bc33e
Show file tree
Hide file tree
Showing 26 changed files with 150 additions and 172 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ import (
"reflect"
)

// ErrFirstPtrChanged this error indicates that a double or single reference was passed to the Unmarshal function
// errFirstPtrChanged this error indicates that a double or single reference was passed to the Unmarshal function
// (example (**int)(**0) or (*int)(*0)) and Unmarshal overwritten first reference.
var ErrFirstPtrChanged = errors.New("unmarshal function rewrote first pointer")
var errFirstPtrChanged = errors.New("unmarshal function rewrote first pointer")

// ErrSecondPtrNotChanged this error indicates that a double reference was passed to the Unmarshal function
// errSecondPtrNotChanged this error indicates that a double reference was passed to the Unmarshal function
// (example (**int)(**0)) and the function did not overwrite the second reference.
// Of course, it's not friendly to the garbage collector, overwriting references to values all the time,
// but this is the current implementation `gocql` and changing it can lead to unexpected results in some cases.
var ErrSecondPtrNotChanged = errors.New("unmarshal function did not rewrite second pointer")
var errSecondPtrNotChanged = errors.New("unmarshal function did not rewrite second pointer")

func getPointers(i interface{}) *pointer {
rv := reflect.ValueOf(i)
Expand Down Expand Up @@ -45,10 +45,10 @@ func (p *pointer) NotNil() bool {
func (p *pointer) Valid(v interface{}) error {
p2 := getPointers(v)
if p.Fist != p2.Fist {
return ErrFirstPtrChanged
return errFirstPtrChanged
}
if p.Second != 0 && p2.Second != 0 && p2.Second == p.Second {
return ErrSecondPtrNotChanged
return errSecondPtrNotChanged
}
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@ import (
"reflect"
"runtime/debug"
"testing"

"github.com/gocql/gocql/internal/tests/utils"
)

type Sets []*Set

// Set is a tool for marshal and unmarshall funcs testing.
type Set struct {
Data []byte
Expand All @@ -22,6 +18,9 @@ type Set struct {
BrokenUnmarshalTypes []reflect.Type
}

// RunTest runs tests for cases when the function should no error,
// on marshal - marshaled data from Set.Values should be equal with Set.Data,
// on unmarshall - unmarshalled value from Set.Data should be equal with Set.Values.
func (s Set) RunTest(name string, t *testing.T, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) {
if name == "" {
t.Fatal("name should be provided")
Expand All @@ -38,15 +37,15 @@ func (s Set) RunTest(name string, t *testing.T, marshal func(interface{}) ([]byt

if unmarshal != nil {
if rt := reflect.TypeOf(val); rt.Kind() != reflect.Ptr {
unmarshalIn := utils.NewRef(val)
unmarshalIn := newRef(val)
s.runUnmarshalTest("unmarshal", t, unmarshal, val, unmarshalIn)
} else {
// Test unmarshal to (*type)(nil)
unmarshalIn := utils.NewRef(val)
unmarshalIn := newRef(val)
s.runUnmarshalTest("unmarshal**nil", t, unmarshal, val, unmarshalIn)

// Test unmarshal to &type{}
unmarshalInZero := utils.NewRefToZero(val)
unmarshalInZero := newRefToZero(val)
s.runUnmarshalTest("unmarshal**zero", t, unmarshal, val, unmarshalInZero)
}
}
Expand All @@ -68,22 +67,22 @@ func (s Set) RunCorruptTest(name string, t *testing.T, marshal func(interface{})
val := s.Values[m]

if marshal != nil {
t.Run(utils.StringValue(val), func(t *testing.T) {
t.Run(stringValue(val), func(t *testing.T) {
s.runMarshalCorruptTest(t, marshal, val)
})
continue
}

if rt := reflect.TypeOf(val); rt.Kind() != reflect.Ptr {
unmarshalIn := utils.NewRef(val)
unmarshalIn := newRef(val)
s.runUnmarshalCorruptTest(fmt.Sprintf("%T", val), t, unmarshal, val, unmarshalIn)
} else {
// Test unmarshal to (*type)(nil)
unmarshalIn := utils.NewRef(val)
unmarshalIn := newRef(val)
s.runUnmarshalCorruptTest(fmt.Sprintf("%T**nil", val), t, unmarshal, val, unmarshalIn)

// Test unmarshal to &type{}
unmarshalInZero := utils.NewRefToZero(val)
unmarshalInZero := newRefToZero(val)
s.runUnmarshalCorruptTest(fmt.Sprintf("%T**zero", val), t, unmarshal, val, unmarshalInZero)
}
}
Expand All @@ -96,19 +95,19 @@ func (s Set) runMarshalTest(t *testing.T, f func(interface{}) ([]byte, error), v
result, err := func() (d []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = utils.PanicErr{Err: r.(error), Stack: debug.Stack()}
err = panicErr{err: r.(error), stack: debug.Stack()}
}
}()
return f(val)
}()

expected := bytes.Clone(s.Data)
if err != nil {
if !errors.As(err, &utils.PanicErr{}) {
err = errors.Join(MarshalErr, err)
if !errors.As(err, &panicErr{}) {
err = errors.Join(marshalErr, err)
}
} else if !utils.EqualData(expected, result) {
err = UnequalError{Expected: utils.StringData(s.Data), Got: utils.StringData(result)}
} else if !equalData(expected, result) {
err = unequalError{Expected: stringData(s.Data), Got: stringData(result)}
}

if isTypeOf(val, s.BrokenMarshalTypes) {
Expand All @@ -131,18 +130,18 @@ func (s Set) runUnmarshalTest(name string, t *testing.T, f func([]byte, interfac
err := func() (err error) {
defer func() {
if r := recover(); r != nil {
err = utils.PanicErr{Err: fmt.Errorf("%s", r), Stack: debug.Stack()}
err = panicErr{err: fmt.Errorf("%s", r), stack: debug.Stack()}
}
}()
return f(bytes.Clone(s.Data), result)
}()

if err != nil {
if !errors.As(err, &utils.PanicErr{}) {
err = errors.Join(UnmarshalErr, err)
if !errors.As(err, &panicErr{}) {
err = errors.Join(unmarshalErr, err)
}
} else if !utils.EqualVals(expected, utils.DeReference(result)) {
err = UnequalError{Expected: utils.StringValue(expected), Got: utils.StringValue(result)}
} else if !equalVals(expected, deReference(result)) {
err = unequalError{Expected: stringValue(expected), Got: stringValue(result)}
} else {
err = expectedPtr.Valid(result)
}
Expand All @@ -163,14 +162,15 @@ func (s Set) runMarshalCorruptTest(t *testing.T, f func(interface{}) ([]byte, er
_, err := func() (d []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = utils.PanicErr{Err: r.(error), Stack: debug.Stack()}
err = panicErr{err: r.(error), stack: debug.Stack()}
}
}()
return f(val)
}()

testFailed := false
if err == nil || errors.As(err, &utils.PanicErr{}) {
wasPanic := errors.As(err, &panicErr{})
if err == nil || wasPanic {
testFailed = true
}

Expand All @@ -182,7 +182,7 @@ func (s Set) runMarshalCorruptTest(t *testing.T, f func(interface{}) ([]byte, er
}

if testFailed {
if errors.As(err, &utils.PanicErr{}) {
if wasPanic {
t.Fatalf("was panic %s", err)
}
t.Errorf("expected an error for (%T), but got no error", val)
Expand All @@ -194,14 +194,15 @@ func (s Set) runUnmarshalCorruptTest(name string, t *testing.T, f func([]byte, i
err := func() (err error) {
defer func() {
if r := recover(); r != nil {
err = utils.PanicErr{Err: r.(error), Stack: debug.Stack()}
err = panicErr{err: r.(error), stack: debug.Stack()}
}
}()
return f(bytes.Clone(s.Data), unmarshalIn)
}()

testFailed := false
if err == nil || errors.As(err, &utils.PanicErr{}) {
wasPanic := errors.As(err, &panicErr{})
if err == nil || wasPanic {
testFailed = true
}

Expand All @@ -213,7 +214,7 @@ func (s Set) runUnmarshalCorruptTest(name string, t *testing.T, f func([]byte, i
}

if testFailed {
if errors.As(err, &utils.PanicErr{}) {
if wasPanic {
t.Fatalf("was panic %s", err)
}
t.Errorf("expected an error for (%T), but got no error", unmarshalIn)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
package utils
package serialization

import (
"reflect"
)

func DeReference(in interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(in)).Interface()
}

func Reference(val interface{}) interface{} {
out := reflect.New(reflect.TypeOf(val))
out.Elem().Set(reflect.ValueOf(val))
return out.Interface()
}

func GetTypes(values ...interface{}) []reflect.Type {
types := make([]reflect.Type, len(values))
for i, value := range values {
types[i] = reflect.TypeOf(value)
}
return types
}

func isTypeOf(value interface{}, types []reflect.Type) bool {
valueType := reflect.TypeOf(value)
for i := range types {
if types[i] == valueType {
return true
}
}
return false
}

func deReference(in interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(in)).Interface()
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
package utils
package serialization

import (
"bytes"
"fmt"
"github.com/gocql/gocql/internal/tests/serialization/mod"
"gopkg.in/inf.v0"
"math/big"
"reflect"
"unsafe"

"github.com/gocql/gocql/marshal/tests/mod"
)

func EqualData(in1, in2 []byte) bool {
func equalData(in1, in2 []byte) bool {
if in1 == nil || in2 == nil {
return in1 == nil && in2 == nil
}
return bytes.Equal(in1, in2)
}

func EqualVals(in1, in2 interface{}) bool {
func equalVals(in1, in2 interface{}) bool {
rin1 := reflect.ValueOf(in1)
rin2 := reflect.ValueOf(in2)
if rin1.Kind() != rin2.Kind() {
Expand Down
27 changes: 27 additions & 0 deletions internal/tests/serialization/utils_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package serialization

import (
"errors"
"fmt"
)

var unmarshalErr = errors.New("unmarshal unexpectedly failed with error")
var marshalErr = errors.New("marshal unexpectedly failed with error")

type unequalError struct {
Expected string
Got string
}

func (e unequalError) Error() string {
return fmt.Sprintf("expect %s but got %s", e.Expected, e.Got)
}

type panicErr struct {
err error
stack []byte
}

func (e panicErr) Error() string {
return fmt.Sprintf("%v\n%s", e.err, e.stack)
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package utils
package serialization

import (
"reflect"
)

func NewRef(in interface{}) interface{} {
func newRef(in interface{}) interface{} {
out := reflect.New(reflect.TypeOf(in)).Interface()
return out
}

func NewRefToZero(in interface{}) interface{} {
func newRefToZero(in interface{}) interface{} {
rv := reflect.ValueOf(in)
nw := reflect.New(rv.Type().Elem())
out := reflect.New(rv.Type())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package utils
package serialization

import (
"fmt"
Expand All @@ -9,16 +9,25 @@ import (
"time"
)

// StringValue returns (value_type)(value) in the human-readable format.
func StringValue(in interface{}) string {
valStr := stringValue(in)
const printLimit = 100

// stringValue returns (value_type)(value) in the human-readable format.
func stringValue(in interface{}) string {
valStr := stringVal(in)
if len(valStr) > printLimit {
valStr = valStr[:printLimit]
}
return fmt.Sprintf("(%T)(%s)", in, valStr)
}

func stringValue(in interface{}) string {
func stringData(p []byte) string {
if len(p) > printLimit {
p = p[:printLimit]
}
return fmt.Sprintf("[%x]", p)
}

func stringVal(in interface{}) string {
switch i := in.(type) {
case string:
return i
Expand All @@ -40,7 +49,7 @@ func stringValue(in interface{}) string {
if rv.IsNil() {
return "*nil"
}
return fmt.Sprintf("*%s", stringValue(rv.Elem().Interface()))
return fmt.Sprintf("*%s", stringVal(rv.Elem().Interface()))
case reflect.Slice:
if rv.IsNil() {
return "[nil]"
Expand Down
14 changes: 0 additions & 14 deletions internal/tests/utils/panic_err.go

This file was deleted.

14 changes: 0 additions & 14 deletions internal/tests/utils/string.go

This file was deleted.

Loading

0 comments on commit 14bc33e

Please sign in to comment.